rustls/
stream.rs

1use core::ops::{Deref, DerefMut};
2use std::io::{BufRead, IoSlice, Read, Result, Write};
3
4use crate::conn::{ConnectionCommon, SideData};
5
6/// This type implements `io::Read` and `io::Write`, encapsulating
7/// a Connection `C` and an underlying transport `T`, such as a socket.
8///
9/// This allows you to use a rustls Connection like a normal stream.
10#[derive(Debug)]
11pub struct Stream<'a, C: 'a + ?Sized, T: 'a + Read + Write + ?Sized> {
12    /// Our TLS connection
13    pub conn: &'a mut C,
14
15    /// The underlying transport, like a socket
16    pub sock: &'a mut T,
17}
18
19impl<'a, C, T, S> Stream<'a, C, T>
20where
21    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
22    T: 'a + Read + Write,
23    S: SideData,
24{
25    /// Make a new Stream using the Connection `conn` and socket-like object
26    /// `sock`.  This does not fail and does no IO.
27    pub fn new(conn: &'a mut C, sock: &'a mut T) -> Self {
28        Self { conn, sock }
29    }
30
31    /// If we're handshaking, complete all the IO for that.
32    /// If we have data to write, write it all.
33    fn complete_prior_io(&mut self) -> Result<()> {
34        if self.conn.is_handshaking() {
35            self.conn.complete_io(self.sock)?;
36        }
37
38        if self.conn.wants_write() {
39            self.conn.complete_io(self.sock)?;
40        }
41
42        Ok(())
43    }
44
45    fn prepare_read(&mut self) -> Result<()> {
46        self.complete_prior_io()?;
47
48        // We call complete_io() in a loop since a single call may read only
49        // a partial packet from the underlying transport. A full packet is
50        // needed to get more plaintext, which we must do if EOF has not been
51        // hit.
52        while self.conn.wants_read() {
53            if self.conn.complete_io(self.sock)?.0 == 0 {
54                break;
55            }
56        }
57
58        Ok(())
59    }
60
61    // Implements `BufRead::fill_buf` but with more flexible lifetimes, so StreamOwned can reuse it
62    fn fill_buf(mut self) -> Result<&'a [u8]>
63    where
64        S: 'a,
65    {
66        self.prepare_read()?;
67        self.conn.reader().into_first_chunk()
68    }
69}
70
71impl<'a, C, T, S> Read for Stream<'a, C, T>
72where
73    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
74    T: 'a + Read + Write,
75    S: SideData,
76{
77    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
78        self.prepare_read()?;
79        self.conn.reader().read(buf)
80    }
81
82    #[cfg(read_buf)]
83    fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> {
84        self.prepare_read()?;
85        self.conn.reader().read_buf(cursor)
86    }
87}
88
89impl<'a, C, T, S> BufRead for Stream<'a, C, T>
90where
91    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
92    T: 'a + Read + Write,
93    S: 'a + SideData,
94{
95    fn fill_buf(&mut self) -> Result<&[u8]> {
96        // reborrow to get an owned `Stream`
97        Stream {
98            conn: self.conn,
99            sock: self.sock,
100        }
101        .fill_buf()
102    }
103
104    fn consume(&mut self, amt: usize) {
105        self.conn.reader().consume(amt)
106    }
107}
108
109impl<'a, C, T, S> Write for Stream<'a, C, T>
110where
111    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
112    T: 'a + Read + Write,
113    S: SideData,
114{
115    fn write(&mut self, buf: &[u8]) -> Result<usize> {
116        self.complete_prior_io()?;
117
118        let len = self.conn.writer().write(buf)?;
119
120        // Try to write the underlying transport here, but don't let
121        // any errors mask the fact we've consumed `len` bytes.
122        // Callers will learn of permanent errors on the next call.
123        let _ = self.conn.complete_io(self.sock);
124
125        Ok(len)
126    }
127
128    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> {
129        self.complete_prior_io()?;
130
131        let len = self
132            .conn
133            .writer()
134            .write_vectored(bufs)?;
135
136        // Try to write the underlying transport here, but don't let
137        // any errors mask the fact we've consumed `len` bytes.
138        // Callers will learn of permanent errors on the next call.
139        let _ = self.conn.complete_io(self.sock);
140
141        Ok(len)
142    }
143
144    fn flush(&mut self) -> Result<()> {
145        self.complete_prior_io()?;
146
147        self.conn.writer().flush()?;
148        if self.conn.wants_write() {
149            self.conn.complete_io(self.sock)?;
150        }
151        Ok(())
152    }
153}
154
155/// This type implements `io::Read` and `io::Write`, encapsulating
156/// and owning a Connection `C` and an underlying blocking transport
157/// `T`, such as a socket.
158///
159/// This allows you to use a rustls Connection like a normal stream.
160#[derive(Debug)]
161pub struct StreamOwned<C: Sized, T: Read + Write + Sized> {
162    /// Our connection
163    pub conn: C,
164
165    /// The underlying transport, like a socket
166    pub sock: T,
167}
168
169impl<C, T, S> StreamOwned<C, T>
170where
171    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
172    T: Read + Write,
173    S: SideData,
174{
175    /// Make a new StreamOwned taking the Connection `conn` and socket-like
176    /// object `sock`.  This does not fail and does no IO.
177    ///
178    /// This is the same as `Stream::new` except `conn` and `sock` are
179    /// moved into the StreamOwned.
180    pub fn new(conn: C, sock: T) -> Self {
181        Self { conn, sock }
182    }
183
184    /// Get a reference to the underlying socket
185    pub fn get_ref(&self) -> &T {
186        &self.sock
187    }
188
189    /// Get a mutable reference to the underlying socket
190    pub fn get_mut(&mut self) -> &mut T {
191        &mut self.sock
192    }
193
194    /// Extract the `conn` and `sock` parts from the `StreamOwned`
195    pub fn into_parts(self) -> (C, T) {
196        (self.conn, self.sock)
197    }
198}
199
200impl<'a, C, T, S> StreamOwned<C, T>
201where
202    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
203    T: Read + Write,
204    S: SideData,
205{
206    fn as_stream(&'a mut self) -> Stream<'a, C, T> {
207        Stream {
208            conn: &mut self.conn,
209            sock: &mut self.sock,
210        }
211    }
212}
213
214impl<C, T, S> Read for StreamOwned<C, T>
215where
216    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
217    T: Read + Write,
218    S: SideData,
219{
220    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
221        self.as_stream().read(buf)
222    }
223
224    #[cfg(read_buf)]
225    fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> {
226        self.as_stream().read_buf(cursor)
227    }
228}
229
230impl<C, T, S> BufRead for StreamOwned<C, T>
231where
232    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
233    T: Read + Write,
234    S: 'static + SideData,
235{
236    fn fill_buf(&mut self) -> Result<&[u8]> {
237        self.as_stream().fill_buf()
238    }
239
240    fn consume(&mut self, amt: usize) {
241        self.as_stream().consume(amt)
242    }
243}
244
245impl<C, T, S> Write for StreamOwned<C, T>
246where
247    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
248    T: Read + Write,
249    S: SideData,
250{
251    fn write(&mut self, buf: &[u8]) -> Result<usize> {
252        self.as_stream().write(buf)
253    }
254
255    fn flush(&mut self) -> Result<()> {
256        self.as_stream().flush()
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use std::net::TcpStream;
263
264    use super::{Stream, StreamOwned};
265    use crate::client::ClientConnection;
266    use crate::server::ServerConnection;
267
268    #[test]
269    fn stream_can_be_created_for_connection_and_tcpstream() {
270        type _Test<'a> = Stream<'a, ClientConnection, TcpStream>;
271    }
272
273    #[test]
274    fn streamowned_can_be_created_for_client_and_tcpstream() {
275        type _Test = StreamOwned<ClientConnection, TcpStream>;
276    }
277
278    #[test]
279    fn streamowned_can_be_created_for_server_and_tcpstream() {
280        type _Test = StreamOwned<ServerConnection, TcpStream>;
281    }
282}