futures_rustls/
client.rs

1use super::*;
2use crate::common::IoSession;
3
4/// A wrapper around an underlying raw stream which implements the TLS or SSL
5/// protocol.
6#[derive(Debug)]
7pub struct TlsStream<IO> {
8    pub(crate) io: IO,
9    pub(crate) session: ClientConnection,
10    pub(crate) state: TlsState,
11
12    #[cfg(feature = "early-data")]
13    pub(crate) early_waker: Option<std::task::Waker>,
14}
15
16impl<IO> TlsStream<IO> {
17    #[inline]
18    pub fn get_ref(&self) -> (&IO, &ClientConnection) {
19        (&self.io, &self.session)
20    }
21
22    #[inline]
23    pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
24        (&mut self.io, &mut self.session)
25    }
26
27    #[inline]
28    pub fn into_inner(self) -> (IO, ClientConnection) {
29        (self.io, self.session)
30    }
31}
32
33#[cfg(feature = "early-data")]
34fn poll_handle_early_data<IO>(
35    state: &mut TlsState,
36    stream: &mut Stream<IO, ClientConnection>,
37    early_waker: &mut Option<std::task::Waker>,
38    cx: &mut Context<'_>,
39    bufs: &[io::IoSlice<'_>],
40) -> Poll<io::Result<usize>>
41where
42    IO: AsyncRead + AsyncWrite + Unpin,
43{
44    if let TlsState::EarlyData(pos, data) = state {
45        use std::io::Write;
46
47        // write early data
48        if let Some(mut early_data) = stream.session.early_data() {
49            let mut written = 0;
50
51            for buf in bufs {
52                if buf.is_empty() {
53                    continue;
54                }
55
56                let len = match early_data.write(buf) {
57                    Ok(0) => break,
58                    Ok(n) => n,
59                    Err(err) => return Poll::Ready(Err(err)),
60                };
61
62                written += len;
63                data.extend_from_slice(&buf[..len]);
64
65                if len < buf.len() {
66                    break;
67                }
68            }
69
70            if written != 0 {
71                return Poll::Ready(Ok(written));
72            }
73        }
74
75        // complete handshake
76        while stream.session.is_handshaking() {
77            ready!(stream.handshake(cx))?;
78        }
79
80        // write early data (fallback)
81        if !stream.session.is_early_data_accepted() {
82            while *pos < data.len() {
83                let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
84                *pos += len;
85            }
86        }
87
88        // end
89        *state = TlsState::Stream;
90
91        if let Some(waker) = early_waker.take() {
92            waker.wake();
93        }
94    }
95
96    Poll::Ready(Ok(0))
97}
98
99#[cfg(unix)]
100impl<S> AsRawFd for TlsStream<S>
101where
102    S: AsRawFd,
103{
104    fn as_raw_fd(&self) -> RawFd {
105        self.get_ref().0.as_raw_fd()
106    }
107}
108
109#[cfg(windows)]
110impl<S> AsRawSocket for TlsStream<S>
111where
112    S: AsRawSocket,
113{
114    fn as_raw_socket(&self) -> RawSocket {
115        self.get_ref().0.as_raw_socket()
116    }
117}
118
119impl<IO> IoSession for TlsStream<IO> {
120    type Io = IO;
121    type Session = ClientConnection;
122
123    #[inline]
124    fn skip_handshake(&self) -> bool {
125        self.state.is_early_data()
126    }
127
128    #[inline]
129    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
130        (&mut self.state, &mut self.io, &mut self.session)
131    }
132
133    #[inline]
134    fn into_io(self) -> Self::Io {
135        self.io
136    }
137}
138
139impl<IO> AsyncRead for TlsStream<IO>
140where
141    IO: AsyncRead + AsyncWrite + Unpin,
142{
143    fn poll_read(
144        self: Pin<&mut Self>,
145        cx: &mut Context<'_>,
146        buf: &mut [u8],
147    ) -> Poll<io::Result<usize>> {
148        match self.state {
149            #[cfg(feature = "early-data")]
150            TlsState::EarlyData(..) => {
151                let this = self.get_mut();
152
153                // In the EarlyData state, we have not really established a Tls connection.
154                // Before writing data through `AsyncWrite` and completing the tls handshake,
155                // we ignore read readiness and return to pending.
156                //
157                // In order to avoid event loss,
158                // we need to register a waker and wake it up after tls is connected.
159                if this
160                    .early_waker
161                    .as_ref()
162                    .filter(|waker| cx.waker().will_wake(waker))
163                    .is_none()
164                {
165                    this.early_waker = Some(cx.waker().clone());
166                }
167
168                Poll::Pending
169            }
170            TlsState::Stream | TlsState::WriteShutdown => {
171                let this = self.get_mut();
172                let mut stream =
173                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
174
175                match stream.as_mut_pin().poll_read(cx, buf) {
176                    Poll::Ready(Ok(n)) => {
177                        if n == 0 || stream.eof {
178                            this.state.shutdown_read();
179                        }
180
181                        Poll::Ready(Ok(n))
182                    }
183                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
184                        this.state.shutdown_read();
185                        Poll::Ready(Err(err))
186                    }
187                    output => output,
188                }
189            }
190            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
191        }
192    }
193}
194
195impl<IO> AsyncWrite for TlsStream<IO>
196where
197    IO: AsyncRead + AsyncWrite + Unpin,
198{
199    /// Note: that it does not guarantee the final data to be sent.
200    /// To be cautious, you must manually call `flush`.
201    fn poll_write(
202        self: Pin<&mut Self>,
203        cx: &mut Context<'_>,
204        buf: &[u8],
205    ) -> Poll<io::Result<usize>> {
206        let this = self.get_mut();
207        let mut stream =
208            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
209
210        #[cfg(feature = "early-data")]
211        {
212            let bufs = [io::IoSlice::new(buf)];
213            let written = ready!(poll_handle_early_data(
214                &mut this.state,
215                &mut stream,
216                &mut this.early_waker,
217                cx,
218                &bufs
219            ))?;
220            if written != 0 {
221                return Poll::Ready(Ok(written));
222            }
223        }
224        stream.as_mut_pin().poll_write(cx, buf)
225    }
226
227    fn poll_write_vectored(
228        self: Pin<&mut Self>,
229        cx: &mut Context<'_>,
230        bufs: &[io::IoSlice<'_>],
231    ) -> Poll<io::Result<usize>> {
232        let this = self.get_mut();
233        let mut stream =
234            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
235
236        #[cfg(feature = "early-data")]
237        {
238            let written = ready!(poll_handle_early_data(
239                &mut this.state,
240                &mut stream,
241                &mut this.early_waker,
242                cx,
243                bufs
244            ))?;
245            if written != 0 {
246                return Poll::Ready(Ok(written));
247            }
248        }
249
250        stream.as_mut_pin().poll_write_vectored(cx, bufs)
251    }
252
253    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
254        let this = self.get_mut();
255        let mut stream =
256            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
257
258        #[cfg(feature = "early-data")]
259        {
260            if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
261                // complete handshake
262                while stream.session.is_handshaking() {
263                    ready!(stream.handshake(cx))?;
264                }
265
266                // write early data (fallback)
267                if !stream.session.is_early_data_accepted() {
268                    while *pos < data.len() {
269                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
270                        *pos += len;
271                    }
272                }
273
274                this.state = TlsState::Stream;
275
276                if let Some(waker) = this.early_waker.take() {
277                    waker.wake();
278                }
279            }
280        }
281
282        stream.as_mut_pin().poll_flush(cx)
283    }
284
285    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
286        #[cfg(feature = "early-data")]
287        {
288            // complete handshake
289            if matches!(self.state, TlsState::EarlyData(..)) {
290                ready!(self.as_mut().poll_flush(cx))?;
291            }
292        }
293
294        if self.state.writeable() {
295            self.session.send_close_notify();
296            self.state.shutdown_write();
297        }
298
299        let this = self.get_mut();
300        let mut stream =
301            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
302        stream.as_mut_pin().poll_close(cx)
303    }
304}