tungstenite/protocol/frame/
mod.rs

1//! Utilities to work with raw WebSocket frames.
2
3pub mod coding;
4
5#[allow(clippy::module_inception)]
6mod frame;
7mod mask;
8mod utf8;
9
10pub use self::{
11    frame::{CloseFrame, Frame, FrameHeader},
12    utf8::Utf8Bytes,
13};
14
15use crate::{
16    error::{CapacityError, Error, ProtocolError, Result},
17    protocol::frame::mask::apply_mask,
18    Message,
19};
20use bytes::BytesMut;
21use log::*;
22use std::io::{self, Cursor, Error as IoError, ErrorKind as IoErrorKind, Read, Write};
23
24/// Read buffer size used for `FrameSocket`.
25const READ_BUF_LEN: usize = 128 * 1024;
26
27/// A reader and writer for WebSocket frames.
28#[derive(Debug)]
29pub struct FrameSocket<Stream> {
30    /// The underlying network stream.
31    stream: Stream,
32    /// Codec for reading/writing frames.
33    codec: FrameCodec,
34}
35
36impl<Stream> FrameSocket<Stream> {
37    /// Create a new frame socket.
38    pub fn new(stream: Stream) -> Self {
39        FrameSocket { stream, codec: FrameCodec::new(READ_BUF_LEN) }
40    }
41
42    /// Create a new frame socket from partially read data.
43    pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
44        FrameSocket { stream, codec: FrameCodec::from_partially_read(part, READ_BUF_LEN) }
45    }
46
47    /// Extract a stream from the socket.
48    pub fn into_inner(self) -> (Stream, BytesMut) {
49        (self.stream, self.codec.in_buffer)
50    }
51
52    /// Returns a shared reference to the inner stream.
53    pub fn get_ref(&self) -> &Stream {
54        &self.stream
55    }
56
57    /// Returns a mutable reference to the inner stream.
58    pub fn get_mut(&mut self) -> &mut Stream {
59        &mut self.stream
60    }
61}
62
63impl<Stream> FrameSocket<Stream>
64where
65    Stream: Read,
66{
67    /// Read a frame from stream.
68    pub fn read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
69        self.codec.read_frame(&mut self.stream, max_size, false, true)
70    }
71}
72
73impl<Stream> FrameSocket<Stream>
74where
75    Stream: Write,
76{
77    /// Writes and immediately flushes a frame.
78    /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
79    pub fn send(&mut self, frame: Frame) -> Result<()> {
80        self.write(frame)?;
81        self.flush()
82    }
83
84    /// Write a frame to stream.
85    ///
86    /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
87    ///
88    /// This function guarantees that the frame is queued unless [`Error::WriteBufferFull`]
89    /// is returned.
90    /// In order to handle WouldBlock or Incomplete, call [`flush`](Self::flush) afterwards.
91    pub fn write(&mut self, frame: Frame) -> Result<()> {
92        self.codec.buffer_frame(&mut self.stream, frame)
93    }
94
95    /// Flush writes.
96    pub fn flush(&mut self) -> Result<()> {
97        self.codec.write_out_buffer(&mut self.stream)?;
98        Ok(self.stream.flush()?)
99    }
100}
101
102/// A codec for WebSocket frames.
103#[derive(Debug)]
104pub(super) struct FrameCodec {
105    /// Buffer to read data from the stream.
106    in_buffer: BytesMut,
107    in_buf_max_read: usize,
108    /// Buffer to send packets to the network.
109    out_buffer: Vec<u8>,
110    /// Capacity limit for `out_buffer`.
111    max_out_buffer_len: usize,
112    /// Buffer target length to reach before writing to the stream
113    /// on calls to `buffer_frame`.
114    ///
115    /// Setting this to non-zero will buffer small writes from hitting
116    /// the stream.
117    out_buffer_write_len: usize,
118    /// Header and remaining size of the incoming packet being processed.
119    header: Option<(FrameHeader, u64)>,
120}
121
122impl FrameCodec {
123    /// Create a new frame codec.
124    pub(super) fn new(in_buf_len: usize) -> Self {
125        Self {
126            in_buffer: BytesMut::with_capacity(in_buf_len),
127            in_buf_max_read: in_buf_len.max(FrameHeader::MAX_SIZE),
128            out_buffer: <_>::default(),
129            max_out_buffer_len: usize::MAX,
130            out_buffer_write_len: 0,
131            header: None,
132        }
133    }
134
135    /// Create a new frame codec from partially read data.
136    pub(super) fn from_partially_read(part: Vec<u8>, min_in_buf_len: usize) -> Self {
137        let mut in_buffer = BytesMut::from_iter(part);
138        in_buffer.reserve(min_in_buf_len.saturating_sub(in_buffer.len()));
139        Self {
140            in_buffer,
141            in_buf_max_read: min_in_buf_len.max(FrameHeader::MAX_SIZE),
142            out_buffer: <_>::default(),
143            max_out_buffer_len: usize::MAX,
144            out_buffer_write_len: 0,
145            header: None,
146        }
147    }
148
149    /// Sets a maximum size for the out buffer.
150    pub(super) fn set_max_out_buffer_len(&mut self, max: usize) {
151        self.max_out_buffer_len = max;
152    }
153
154    /// Sets [`Self::buffer_frame`] buffer target length to reach before
155    /// writing to the stream.
156    pub(super) fn set_out_buffer_write_len(&mut self, len: usize) {
157        self.out_buffer_write_len = len;
158    }
159
160    /// Read a frame from the provided stream.
161    pub(super) fn read_frame(
162        &mut self,
163        stream: &mut impl Read,
164        max_size: Option<usize>,
165        unmask: bool,
166        accept_unmasked: bool,
167    ) -> Result<Option<Frame>> {
168        let max_size = max_size.unwrap_or_else(usize::max_value);
169
170        let mut payload = loop {
171            if self.header.is_none() {
172                let mut cursor = Cursor::new(&mut self.in_buffer);
173                self.header = FrameHeader::parse(&mut cursor)?;
174                let advanced = cursor.position();
175                bytes::Buf::advance(&mut self.in_buffer, advanced as _);
176
177                if let Some((_, len)) = &self.header {
178                    let len = *len as usize;
179
180                    // Enforce frame size limit early
181                    if len > max_size {
182                        return Err(Error::Capacity(CapacityError::MessageTooLong {
183                            size: len,
184                            max_size,
185                        }));
186                    }
187
188                    // Reserve full message length only once, even for multiple
189                    // loops or if WouldBlock errors cause multiple fn calls.
190                    self.in_buffer.reserve(len);
191                } else {
192                    self.in_buffer.reserve(FrameHeader::MAX_SIZE);
193                }
194            }
195
196            if let Some((_, len)) = &self.header {
197                let len = *len as usize;
198                if len <= self.in_buffer.len() {
199                    break self.in_buffer.split_to(len);
200                }
201            }
202
203            // Not enough data in buffer.
204            if self.read_in(stream)? == 0 {
205                trace!("no frame received");
206                return Ok(None);
207            }
208        };
209
210        let (mut header, length) = self.header.take().expect("Bug: no frame header");
211        debug_assert_eq!(payload.len() as u64, length);
212
213        if unmask {
214            if let Some(mask) = header.mask.take() {
215                // A server MUST remove masking for data frames received from a client
216                // as described in Section 5.3. (RFC 6455)
217                apply_mask(&mut payload, mask);
218            } else if !accept_unmasked {
219                // The server MUST close the connection upon receiving a
220                // frame that is not masked. (RFC 6455)
221                // The only exception here is if the user explicitly accepts given
222                // stream by setting WebSocketConfig.accept_unmasked_frames to true
223                return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient));
224            }
225        }
226
227        let frame = Frame::from_payload(header, payload.freeze());
228        trace!("received frame {frame}");
229        Ok(Some(frame))
230    }
231
232    /// Read into available `in_buffer` capacity.
233    fn read_in(&mut self, stream: &mut impl Read) -> io::Result<usize> {
234        let len = self.in_buffer.len();
235        debug_assert!(self.in_buffer.capacity() > len);
236        self.in_buffer.resize(self.in_buffer.capacity().min(len + self.in_buf_max_read), 0);
237        let size = stream.read(&mut self.in_buffer[len..]);
238        self.in_buffer.truncate(len + size.as_ref().copied().unwrap_or(0));
239        size
240    }
241
242    /// Writes a frame into the `out_buffer`.
243    /// If the out buffer size is over the `out_buffer_write_len` will also write
244    /// the out buffer into the provided `stream`.
245    ///
246    /// To ensure buffered frames are written call [`Self::write_out_buffer`].
247    ///
248    /// May write to the stream, will **not** flush.
249    pub(super) fn buffer_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
250    where
251        Stream: Write,
252    {
253        if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
254            return Err(Error::WriteBufferFull(Message::Frame(frame)));
255        }
256
257        trace!("writing frame {frame}");
258
259        self.out_buffer.reserve(frame.len());
260        frame.format_into_buf(&mut self.out_buffer).expect("Bug: can't write to vector");
261
262        if self.out_buffer.len() > self.out_buffer_write_len {
263            self.write_out_buffer(stream)
264        } else {
265            Ok(())
266        }
267    }
268
269    /// Writes the out_buffer to the provided stream.
270    ///
271    /// Does **not** flush.
272    pub(super) fn write_out_buffer<Stream>(&mut self, stream: &mut Stream) -> Result<()>
273    where
274        Stream: Write,
275    {
276        while !self.out_buffer.is_empty() {
277            let len = stream.write(&self.out_buffer)?;
278            if len == 0 {
279                // This is the same as "Connection reset by peer"
280                return Err(IoError::new(
281                    IoErrorKind::ConnectionReset,
282                    "Connection reset while sending",
283                )
284                .into());
285            }
286            self.out_buffer.drain(0..len);
287        }
288
289        Ok(())
290    }
291}
292
293#[cfg(test)]
294mod tests {
295
296    use crate::error::{CapacityError, Error};
297
298    use super::{Frame, FrameSocket};
299
300    use std::io::Cursor;
301
302    #[test]
303    fn read_frames() {
304        env_logger::init();
305
306        let raw = Cursor::new(vec![
307            0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
308            0x99,
309        ]);
310        let mut sock = FrameSocket::new(raw);
311
312        assert_eq!(
313            sock.read(None).unwrap().unwrap().into_payload(),
314            &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]
315        );
316        assert_eq!(sock.read(None).unwrap().unwrap().into_payload(), &[0x03, 0x02, 0x01][..]);
317        assert!(sock.read(None).unwrap().is_none());
318
319        let (_, rest) = sock.into_inner();
320        assert_eq!(rest, vec![0x99]);
321    }
322
323    #[test]
324    fn from_partially_read() {
325        let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
326        let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
327        assert_eq!(
328            sock.read(None).unwrap().unwrap().into_payload(),
329            &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]
330        );
331    }
332
333    #[test]
334    fn write_frames() {
335        let mut sock = FrameSocket::new(Vec::new());
336
337        let frame = Frame::ping(vec![0x04, 0x05]);
338        sock.send(frame).unwrap();
339
340        let frame = Frame::pong(vec![0x01]);
341        sock.send(frame).unwrap();
342
343        let (buf, _) = sock.into_inner();
344        assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
345    }
346
347    #[test]
348    fn parse_overflow() {
349        let raw = Cursor::new(vec![
350            0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
351        ]);
352        let mut sock = FrameSocket::new(raw);
353        let _ = sock.read(None); // should not crash
354    }
355
356    #[test]
357    fn size_limit_hit() {
358        let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
359        let mut sock = FrameSocket::new(raw);
360        assert!(matches!(
361            sock.read(Some(5)),
362            Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
363        ));
364    }
365}