futures_rustls/
server.rs

1use super::*;
2use crate::common::IoSession;
3
4/// A wrapper around an underlying raw stream which implements the TLS or SSL
5/// protocol.
6#[derive(Debug)]
7pub struct TlsStream<IO> {
8    pub(crate) io: IO,
9    pub(crate) session: ServerConnection,
10    pub(crate) state: TlsState,
11}
12
13impl<IO> TlsStream<IO> {
14    #[inline]
15    pub fn get_ref(&self) -> (&IO, &ServerConnection) {
16        (&self.io, &self.session)
17    }
18
19    #[inline]
20    pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
21        (&mut self.io, &mut self.session)
22    }
23
24    #[inline]
25    pub fn into_inner(self) -> (IO, ServerConnection) {
26        (self.io, self.session)
27    }
28}
29
30impl<IO> IoSession for TlsStream<IO> {
31    type Io = IO;
32    type Session = ServerConnection;
33
34    #[inline]
35    fn skip_handshake(&self) -> bool {
36        false
37    }
38
39    #[inline]
40    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
41        (&mut self.state, &mut self.io, &mut self.session)
42    }
43
44    #[inline]
45    fn into_io(self) -> Self::Io {
46        self.io
47    }
48}
49
50impl<IO> AsyncRead for TlsStream<IO>
51where
52    IO: AsyncRead + AsyncWrite + Unpin,
53{
54    fn poll_read(
55        self: Pin<&mut Self>,
56        cx: &mut Context<'_>,
57        buf: &mut [u8],
58    ) -> Poll<io::Result<usize>> {
59        let this = self.get_mut();
60        let mut stream =
61            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
62
63        match &this.state {
64            TlsState::Stream | TlsState::WriteShutdown => {
65                match stream.as_mut_pin().poll_read(cx, buf) {
66                    Poll::Ready(Ok(n)) => {
67                        if n == 0 || stream.eof {
68                            this.state.shutdown_read();
69                        }
70
71                        Poll::Ready(Ok(n))
72                    }
73                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
74                        this.state.shutdown_read();
75                        Poll::Ready(Err(err))
76                    }
77                    output => output,
78                }
79            }
80            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
81            #[cfg(feature = "early-data")]
82            s => unreachable!("server TLS can not hit this state: {:?}", s),
83        }
84    }
85}
86
87impl<IO> AsyncWrite for TlsStream<IO>
88where
89    IO: AsyncRead + AsyncWrite + Unpin,
90{
91    /// Note: that it does not guarantee the final data to be sent.
92    /// To be cautious, you must manually call `flush`.
93    fn poll_write(
94        self: Pin<&mut Self>,
95        cx: &mut Context<'_>,
96        buf: &[u8],
97    ) -> Poll<io::Result<usize>> {
98        let this = self.get_mut();
99        let mut stream =
100            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
101        stream.as_mut_pin().poll_write(cx, buf)
102    }
103
104    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
105        let this = self.get_mut();
106        let mut stream =
107            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
108        stream.as_mut_pin().poll_flush(cx)
109    }
110
111    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
112        if self.state.writeable() {
113            self.session.send_close_notify();
114            self.state.shutdown_write();
115        }
116
117        let this = self.get_mut();
118        let mut stream =
119            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120        stream.as_mut_pin().poll_close(cx)
121    }
122
123    fn poll_write_vectored(
124        self: Pin<&mut Self>,
125        cx: &mut Context<'_>,
126        bufs: &[io::IoSlice<'_>],
127    ) -> Poll<io::Result<usize>> {
128        let this = self.get_mut();
129        let mut stream =
130            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
131        stream.as_mut_pin().poll_write_vectored(cx, bufs)
132    }
133}
134
135#[cfg(unix)]
136impl<IO> AsRawFd for TlsStream<IO>
137where
138    IO: AsRawFd,
139{
140    fn as_raw_fd(&self) -> RawFd {
141        self.get_ref().0.as_raw_fd()
142    }
143}
144
145#[cfg(windows)]
146impl<IO> AsRawSocket for TlsStream<IO>
147where
148    IO: AsRawSocket,
149{
150    fn as_raw_socket(&self) -> RawSocket {
151        self.get_ref().0.as_raw_socket()
152    }
153}