1use super::*;
2use crate::common::IoSession;
3
4#[derive(Debug)]
7pub struct TlsStream<IO> {
8 pub(crate) io: IO,
9 pub(crate) session: ServerConnection,
10 pub(crate) state: TlsState,
11}
12
13impl<IO> TlsStream<IO> {
14 #[inline]
15 pub fn get_ref(&self) -> (&IO, &ServerConnection) {
16 (&self.io, &self.session)
17 }
18
19 #[inline]
20 pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
21 (&mut self.io, &mut self.session)
22 }
23
24 #[inline]
25 pub fn into_inner(self) -> (IO, ServerConnection) {
26 (self.io, self.session)
27 }
28}
29
30impl<IO> IoSession for TlsStream<IO> {
31 type Io = IO;
32 type Session = ServerConnection;
33
34 #[inline]
35 fn skip_handshake(&self) -> bool {
36 false
37 }
38
39 #[inline]
40 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
41 (&mut self.state, &mut self.io, &mut self.session)
42 }
43
44 #[inline]
45 fn into_io(self) -> Self::Io {
46 self.io
47 }
48}
49
50impl<IO> AsyncRead for TlsStream<IO>
51where
52 IO: AsyncRead + AsyncWrite + Unpin,
53{
54 fn poll_read(
55 self: Pin<&mut Self>,
56 cx: &mut Context<'_>,
57 buf: &mut [u8],
58 ) -> Poll<io::Result<usize>> {
59 let this = self.get_mut();
60 let mut stream =
61 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
62
63 match &this.state {
64 TlsState::Stream | TlsState::WriteShutdown => {
65 match stream.as_mut_pin().poll_read(cx, buf) {
66 Poll::Ready(Ok(n)) => {
67 if n == 0 || stream.eof {
68 this.state.shutdown_read();
69 }
70
71 Poll::Ready(Ok(n))
72 }
73 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
74 this.state.shutdown_read();
75 Poll::Ready(Err(err))
76 }
77 output => output,
78 }
79 }
80 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
81 #[cfg(feature = "early-data")]
82 s => unreachable!("server TLS can not hit this state: {:?}", s),
83 }
84 }
85}
86
87impl<IO> AsyncWrite for TlsStream<IO>
88where
89 IO: AsyncRead + AsyncWrite + Unpin,
90{
91 fn poll_write(
94 self: Pin<&mut Self>,
95 cx: &mut Context<'_>,
96 buf: &[u8],
97 ) -> Poll<io::Result<usize>> {
98 let this = self.get_mut();
99 let mut stream =
100 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
101 stream.as_mut_pin().poll_write(cx, buf)
102 }
103
104 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
105 let this = self.get_mut();
106 let mut stream =
107 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
108 stream.as_mut_pin().poll_flush(cx)
109 }
110
111 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
112 if self.state.writeable() {
113 self.session.send_close_notify();
114 self.state.shutdown_write();
115 }
116
117 let this = self.get_mut();
118 let mut stream =
119 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120 stream.as_mut_pin().poll_close(cx)
121 }
122
123 fn poll_write_vectored(
124 self: Pin<&mut Self>,
125 cx: &mut Context<'_>,
126 bufs: &[io::IoSlice<'_>],
127 ) -> Poll<io::Result<usize>> {
128 let this = self.get_mut();
129 let mut stream =
130 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
131 stream.as_mut_pin().poll_write_vectored(cx, bufs)
132 }
133}
134
135#[cfg(unix)]
136impl<IO> AsRawFd for TlsStream<IO>
137where
138 IO: AsRawFd,
139{
140 fn as_raw_fd(&self) -> RawFd {
141 self.get_ref().0.as_raw_fd()
142 }
143}
144
145#[cfg(windows)]
146impl<IO> AsRawSocket for TlsStream<IO>
147where
148 IO: AsRawSocket,
149{
150 fn as_raw_socket(&self) -> RawSocket {
151 self.get_ref().0.as_raw_socket()
152 }
153}