futures_rustls/
lib.rs

1//! Asynchronous TLS/SSL streams for futures using [Rustls](https://github.com/rustls/rustls).
2
3macro_rules! ready {
4    ( $e:expr ) => {
5        match $e {
6            std::task::Poll::Ready(t) => t,
7            std::task::Poll::Pending => return std::task::Poll::Pending,
8        }
9    };
10}
11
12pub mod client;
13mod common;
14pub mod server;
15
16use common::{MidHandshake, Stream, TlsState};
17use futures_io::{AsyncRead, AsyncWrite};
18use rustls::server::AcceptedAlert;
19use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
20use std::future::Future;
21use std::io;
22#[cfg(unix)]
23use std::os::unix::io::{AsRawFd, RawFd};
24#[cfg(windows)]
25use std::os::windows::io::{AsRawSocket, RawSocket};
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30pub use pki_types;
31pub use rustls;
32
33/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
34#[derive(Clone)]
35pub struct TlsConnector {
36    inner: Arc<ClientConfig>,
37    #[cfg(feature = "early-data")]
38    early_data: bool,
39}
40
41/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
42#[derive(Clone)]
43pub struct TlsAcceptor {
44    inner: Arc<ServerConfig>,
45}
46
47impl From<Arc<ClientConfig>> for TlsConnector {
48    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
49        TlsConnector {
50            inner,
51            #[cfg(feature = "early-data")]
52            early_data: false,
53        }
54    }
55}
56
57impl From<Arc<ServerConfig>> for TlsAcceptor {
58    fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
59        TlsAcceptor { inner }
60    }
61}
62
63impl TlsConnector {
64    /// Enable 0-RTT.
65    ///
66    /// If you want to use 0-RTT,
67    /// You must also set `ClientConfig.enable_early_data` to `true`.
68    #[cfg(feature = "early-data")]
69    pub fn early_data(mut self, flag: bool) -> TlsConnector {
70        self.early_data = flag;
71        self
72    }
73
74    #[inline]
75    pub fn connect<IO>(&self, domain: pki_types::ServerName<'static>, stream: IO) -> Connect<IO>
76    where
77        IO: AsyncRead + AsyncWrite + Unpin,
78    {
79        self.connect_with(domain, stream, |_| ())
80    }
81
82    pub fn connect_with<IO, F>(
83        &self,
84        domain: pki_types::ServerName<'static>,
85        stream: IO,
86        f: F,
87    ) -> Connect<IO>
88    where
89        IO: AsyncRead + AsyncWrite + Unpin,
90        F: FnOnce(&mut ClientConnection),
91    {
92        let mut session = match ClientConnection::new(self.inner.clone(), domain) {
93            Ok(session) => session,
94            Err(error) => {
95                return Connect(MidHandshake::Error {
96                    io: stream,
97                    // TODO(eliza): should this really return an `io::Error`?
98                    // Probably not...
99                    error: io::Error::new(io::ErrorKind::Other, error),
100                });
101            }
102        };
103        f(&mut session);
104
105        Connect(MidHandshake::Handshaking(client::TlsStream {
106            io: stream,
107
108            #[cfg(not(feature = "early-data"))]
109            state: TlsState::Stream,
110
111            #[cfg(feature = "early-data")]
112            state: if self.early_data && session.early_data().is_some() {
113                TlsState::EarlyData(0, Vec::new())
114            } else {
115                TlsState::Stream
116            },
117
118            #[cfg(feature = "early-data")]
119            early_waker: None,
120
121            session,
122        }))
123    }
124}
125
126impl TlsAcceptor {
127    #[inline]
128    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
129    where
130        IO: AsyncRead + AsyncWrite + Unpin,
131    {
132        self.accept_with(stream, |_| ())
133    }
134
135    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
136    where
137        IO: AsyncRead + AsyncWrite + Unpin,
138        F: FnOnce(&mut ServerConnection),
139    {
140        let mut session = match ServerConnection::new(self.inner.clone()) {
141            Ok(session) => session,
142            Err(error) => {
143                return Accept(MidHandshake::Error {
144                    io: stream,
145                    // TODO(eliza): should this really return an `io::Error`?
146                    // Probably not...
147                    error: io::Error::new(io::ErrorKind::Other, error),
148                });
149            }
150        };
151        f(&mut session);
152
153        Accept(MidHandshake::Handshaking(server::TlsStream {
154            session,
155            io: stream,
156            state: TlsState::Stream,
157        }))
158    }
159}
160
161pub struct LazyConfigAcceptor<IO> {
162    acceptor: rustls::server::Acceptor,
163    io: Option<IO>,
164    alert: Option<(rustls::Error, AcceptedAlert)>,
165}
166
167impl<IO> LazyConfigAcceptor<IO>
168where
169    IO: AsyncRead + AsyncWrite + Unpin,
170{
171    #[inline]
172    pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
173        Self {
174            acceptor,
175            io: Some(io),
176            alert: None,
177        }
178    }
179}
180
181impl<IO> Future for LazyConfigAcceptor<IO>
182where
183    IO: AsyncRead + AsyncWrite + Unpin,
184{
185    type Output = Result<StartHandshake<IO>, io::Error>;
186
187    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188        let this = self.get_mut();
189        loop {
190            let io = match this.io.as_mut() {
191                Some(io) => io,
192                None => {
193                    return Poll::Ready(Err(io::Error::new(
194                        io::ErrorKind::Other,
195                        "acceptor cannot be polled after acceptance",
196                    )))
197                }
198            };
199
200            if let Some((err, mut alert)) = this.alert.take() {
201                match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
202                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
203                        this.alert = Some((err, alert));
204                        return Poll::Pending;
205                    }
206                    Ok(0) | Err(_) => {
207                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
208                    }
209                    Ok(_) => {
210                        this.alert = Some((err, alert));
211                        continue;
212                    }
213                };
214            }
215
216            let mut reader = common::SyncReadAdapter { io, cx };
217            match this.acceptor.read_tls(&mut reader) {
218                Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
219                Ok(_) => {}
220                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
221                Err(e) => return Err(e).into(),
222            }
223
224            match this.acceptor.accept() {
225                Ok(Some(accepted)) => {
226                    let io = this.io.take().unwrap();
227                    return Poll::Ready(Ok(StartHandshake { accepted, io }));
228                }
229                Ok(None) => {}
230                Err((err, alert)) => {
231                    this.alert = Some((err, alert));
232                }
233            }
234        }
235    }
236}
237
238pub struct StartHandshake<IO> {
239    accepted: rustls::server::Accepted,
240    io: IO,
241}
242
243impl<IO> StartHandshake<IO>
244where
245    IO: AsyncRead + AsyncWrite + Unpin,
246{
247    pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
248        self.accepted.client_hello()
249    }
250
251    pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
252        self.into_stream_with(config, |_| ())
253    }
254
255    pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
256    where
257        F: FnOnce(&mut ServerConnection),
258    {
259        let mut conn = match self.accepted.into_connection(config) {
260            Ok(conn) => conn,
261            Err((error, alert)) => {
262                return Accept(MidHandshake::SendAlert {
263                    io: self.io,
264                    // TODO(eliza): should this really return an `io::Error`?
265                    // Probably not...
266                    error: io::Error::new(io::ErrorKind::Other, error),
267                    alert,
268                });
269            }
270        };
271        f(&mut conn);
272
273        Accept(MidHandshake::Handshaking(server::TlsStream {
274            session: conn,
275            io: self.io,
276            state: TlsState::Stream,
277        }))
278    }
279}
280
281/// Future returned from `TlsConnector::connect` which will resolve
282/// once the connection handshake has finished.
283pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
284
285/// Future returned from `TlsAcceptor::accept` which will resolve
286/// once the accept handshake has finished.
287pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
288
289/// Like [Connect], but returns `IO` on failure.
290pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
291
292/// Like [Accept], but returns `IO` on failure.
293pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
294
295impl<IO> Connect<IO> {
296    #[inline]
297    pub fn into_fallible(self) -> FallibleConnect<IO> {
298        FallibleConnect(self.0)
299    }
300}
301
302impl<IO> Accept<IO> {
303    #[inline]
304    pub fn into_fallible(self) -> FallibleAccept<IO> {
305        FallibleAccept(self.0)
306    }
307}
308
309impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
310    type Output = io::Result<client::TlsStream<IO>>;
311
312    #[inline]
313    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
314        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
315    }
316}
317
318impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
319    type Output = io::Result<server::TlsStream<IO>>;
320
321    #[inline]
322    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
323        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
324    }
325}
326
327impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
328    type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
329
330    #[inline]
331    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
332        Pin::new(&mut self.0).poll(cx)
333    }
334}
335
336impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
337    type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
338
339    #[inline]
340    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
341        Pin::new(&mut self.0).poll(cx)
342    }
343}
344
345/// Unified TLS stream type
346///
347/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
348/// a single type to keep both client- and server-initiated TLS-encrypted connections.
349#[derive(Debug)]
350pub enum TlsStream<T> {
351    Client(client::TlsStream<T>),
352    Server(server::TlsStream<T>),
353}
354
355impl<T> TlsStream<T> {
356    pub fn get_ref(&self) -> (&T, &CommonState) {
357        use TlsStream::*;
358        match self {
359            Client(io) => {
360                let (io, session) = io.get_ref();
361                (io, session)
362            }
363            Server(io) => {
364                let (io, session) = io.get_ref();
365                (io, session)
366            }
367        }
368    }
369
370    pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
371        use TlsStream::*;
372        match self {
373            Client(io) => {
374                let (io, session) = io.get_mut();
375                (io, &mut *session)
376            }
377            Server(io) => {
378                let (io, session) = io.get_mut();
379                (io, &mut *session)
380            }
381        }
382    }
383}
384
385impl<T> From<client::TlsStream<T>> for TlsStream<T> {
386    fn from(s: client::TlsStream<T>) -> Self {
387        Self::Client(s)
388    }
389}
390
391impl<T> From<server::TlsStream<T>> for TlsStream<T> {
392    fn from(s: server::TlsStream<T>) -> Self {
393        Self::Server(s)
394    }
395}
396
397#[cfg(unix)]
398impl<S> AsRawFd for TlsStream<S>
399where
400    S: AsRawFd,
401{
402    fn as_raw_fd(&self) -> RawFd {
403        self.get_ref().0.as_raw_fd()
404    }
405}
406
407#[cfg(windows)]
408impl<S> AsRawSocket for TlsStream<S>
409where
410    S: AsRawSocket,
411{
412    fn as_raw_socket(&self) -> RawSocket {
413        self.get_ref().0.as_raw_socket()
414    }
415}
416
417impl<T> AsyncRead for TlsStream<T>
418where
419    T: AsyncRead + AsyncWrite + Unpin,
420{
421    #[inline]
422    fn poll_read(
423        self: Pin<&mut Self>,
424        cx: &mut Context<'_>,
425        buf: &mut [u8],
426    ) -> Poll<io::Result<usize>> {
427        match self.get_mut() {
428            TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
429            TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
430        }
431    }
432}
433
434impl<T> AsyncWrite for TlsStream<T>
435where
436    T: AsyncRead + AsyncWrite + Unpin,
437{
438    #[inline]
439    fn poll_write(
440        self: Pin<&mut Self>,
441        cx: &mut Context<'_>,
442        buf: &[u8],
443    ) -> Poll<io::Result<usize>> {
444        match self.get_mut() {
445            TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
446            TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
447        }
448    }
449
450    #[inline]
451    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
452        match self.get_mut() {
453            TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
454            TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
455        }
456    }
457
458    #[inline]
459    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
460        match self.get_mut() {
461            TlsStream::Client(x) => Pin::new(x).poll_close(cx),
462            TlsStream::Server(x) => Pin::new(x).poll_close(cx),
463        }
464    }
465
466    #[inline]
467    fn poll_write_vectored(
468        self: Pin<&mut Self>,
469        cx: &mut Context<'_>,
470        bufs: &[io::IoSlice<'_>],
471    ) -> Poll<io::Result<usize>> {
472        match self.get_mut() {
473            TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
474            TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
475        }
476    }
477}