1use super::*;
2use crate::common::IoSession;
3
4#[derive(Debug)]
7pub struct TlsStream<IO> {
8 pub(crate) io: IO,
9 pub(crate) session: ClientConnection,
10 pub(crate) state: TlsState,
11
12 #[cfg(feature = "early-data")]
13 pub(crate) early_waker: Option<std::task::Waker>,
14}
15
16impl<IO> TlsStream<IO> {
17 #[inline]
18 pub fn get_ref(&self) -> (&IO, &ClientConnection) {
19 (&self.io, &self.session)
20 }
21
22 #[inline]
23 pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
24 (&mut self.io, &mut self.session)
25 }
26
27 #[inline]
28 pub fn into_inner(self) -> (IO, ClientConnection) {
29 (self.io, self.session)
30 }
31}
32
33#[cfg(feature = "early-data")]
34fn poll_handle_early_data<IO>(
35 state: &mut TlsState,
36 stream: &mut Stream<IO, ClientConnection>,
37 early_waker: &mut Option<std::task::Waker>,
38 cx: &mut Context<'_>,
39 bufs: &[io::IoSlice<'_>],
40) -> Poll<io::Result<usize>>
41where
42 IO: AsyncRead + AsyncWrite + Unpin,
43{
44 if let TlsState::EarlyData(pos, data) = state {
45 use std::io::Write;
46
47 if let Some(mut early_data) = stream.session.early_data() {
49 let mut written = 0;
50
51 for buf in bufs {
52 if buf.is_empty() {
53 continue;
54 }
55
56 let len = match early_data.write(buf) {
57 Ok(0) => break,
58 Ok(n) => n,
59 Err(err) => return Poll::Ready(Err(err)),
60 };
61
62 written += len;
63 data.extend_from_slice(&buf[..len]);
64
65 if len < buf.len() {
66 break;
67 }
68 }
69
70 if written != 0 {
71 return Poll::Ready(Ok(written));
72 }
73 }
74
75 while stream.session.is_handshaking() {
77 ready!(stream.handshake(cx))?;
78 }
79
80 if !stream.session.is_early_data_accepted() {
82 while *pos < data.len() {
83 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
84 *pos += len;
85 }
86 }
87
88 *state = TlsState::Stream;
90
91 if let Some(waker) = early_waker.take() {
92 waker.wake();
93 }
94 }
95
96 Poll::Ready(Ok(0))
97}
98
99#[cfg(unix)]
100impl<S> AsRawFd for TlsStream<S>
101where
102 S: AsRawFd,
103{
104 fn as_raw_fd(&self) -> RawFd {
105 self.get_ref().0.as_raw_fd()
106 }
107}
108
109#[cfg(windows)]
110impl<S> AsRawSocket for TlsStream<S>
111where
112 S: AsRawSocket,
113{
114 fn as_raw_socket(&self) -> RawSocket {
115 self.get_ref().0.as_raw_socket()
116 }
117}
118
119impl<IO> IoSession for TlsStream<IO> {
120 type Io = IO;
121 type Session = ClientConnection;
122
123 #[inline]
124 fn skip_handshake(&self) -> bool {
125 self.state.is_early_data()
126 }
127
128 #[inline]
129 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
130 (&mut self.state, &mut self.io, &mut self.session)
131 }
132
133 #[inline]
134 fn into_io(self) -> Self::Io {
135 self.io
136 }
137}
138
139impl<IO> AsyncRead for TlsStream<IO>
140where
141 IO: AsyncRead + AsyncWrite + Unpin,
142{
143 fn poll_read(
144 self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 buf: &mut [u8],
147 ) -> Poll<io::Result<usize>> {
148 match self.state {
149 #[cfg(feature = "early-data")]
150 TlsState::EarlyData(..) => {
151 let this = self.get_mut();
152
153 if this
160 .early_waker
161 .as_ref()
162 .filter(|waker| cx.waker().will_wake(waker))
163 .is_none()
164 {
165 this.early_waker = Some(cx.waker().clone());
166 }
167
168 Poll::Pending
169 }
170 TlsState::Stream | TlsState::WriteShutdown => {
171 let this = self.get_mut();
172 let mut stream =
173 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
174
175 match stream.as_mut_pin().poll_read(cx, buf) {
176 Poll::Ready(Ok(n)) => {
177 if n == 0 || stream.eof {
178 this.state.shutdown_read();
179 }
180
181 Poll::Ready(Ok(n))
182 }
183 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
184 this.state.shutdown_read();
185 Poll::Ready(Err(err))
186 }
187 output => output,
188 }
189 }
190 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
191 }
192 }
193}
194
195impl<IO> AsyncWrite for TlsStream<IO>
196where
197 IO: AsyncRead + AsyncWrite + Unpin,
198{
199 fn poll_write(
202 self: Pin<&mut Self>,
203 cx: &mut Context<'_>,
204 buf: &[u8],
205 ) -> Poll<io::Result<usize>> {
206 let this = self.get_mut();
207 let mut stream =
208 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
209
210 #[cfg(feature = "early-data")]
211 {
212 let bufs = [io::IoSlice::new(buf)];
213 let written = ready!(poll_handle_early_data(
214 &mut this.state,
215 &mut stream,
216 &mut this.early_waker,
217 cx,
218 &bufs
219 ))?;
220 if written != 0 {
221 return Poll::Ready(Ok(written));
222 }
223 }
224 stream.as_mut_pin().poll_write(cx, buf)
225 }
226
227 fn poll_write_vectored(
228 self: Pin<&mut Self>,
229 cx: &mut Context<'_>,
230 bufs: &[io::IoSlice<'_>],
231 ) -> Poll<io::Result<usize>> {
232 let this = self.get_mut();
233 let mut stream =
234 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
235
236 #[cfg(feature = "early-data")]
237 {
238 let written = ready!(poll_handle_early_data(
239 &mut this.state,
240 &mut stream,
241 &mut this.early_waker,
242 cx,
243 bufs
244 ))?;
245 if written != 0 {
246 return Poll::Ready(Ok(written));
247 }
248 }
249
250 stream.as_mut_pin().poll_write_vectored(cx, bufs)
251 }
252
253 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
254 let this = self.get_mut();
255 let mut stream =
256 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
257
258 #[cfg(feature = "early-data")]
259 {
260 if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
261 while stream.session.is_handshaking() {
263 ready!(stream.handshake(cx))?;
264 }
265
266 if !stream.session.is_early_data_accepted() {
268 while *pos < data.len() {
269 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
270 *pos += len;
271 }
272 }
273
274 this.state = TlsState::Stream;
275
276 if let Some(waker) = this.early_waker.take() {
277 waker.wake();
278 }
279 }
280 }
281
282 stream.as_mut_pin().poll_flush(cx)
283 }
284
285 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
286 #[cfg(feature = "early-data")]
287 {
288 if matches!(self.state, TlsState::EarlyData(..)) {
290 ready!(self.as_mut().poll_flush(cx))?;
291 }
292 }
293
294 if self.state.writeable() {
295 self.session.send_close_notify();
296 self.state.shutdown_write();
297 }
298
299 let this = self.get_mut();
300 let mut stream =
301 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
302 stream.as_mut_pin().poll_close(cx)
303 }
304}