tungstenite/protocol/frame/
frame.rs

1use log::*;
2use std::{
3    default::Default,
4    fmt,
5    io::{Cursor, ErrorKind, Read, Write},
6    mem,
7    result::Result as StdResult,
8    str::Utf8Error,
9    string::String,
10};
11
12use super::{
13    coding::{CloseCode, Control, Data, OpCode},
14    mask::{apply_mask, generate_mask},
15};
16use crate::{
17    error::{Error, ProtocolError, Result},
18    protocol::frame::Utf8Bytes,
19};
20use bytes::{Bytes, BytesMut};
21
22/// A struct representing the close command.
23#[derive(Debug, Clone, Eq, PartialEq)]
24pub struct CloseFrame {
25    /// The reason as a code.
26    pub code: CloseCode,
27    /// The reason as text string.
28    pub reason: Utf8Bytes,
29}
30
31impl fmt::Display for CloseFrame {
32    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33        write!(f, "{} ({})", self.reason, self.code)
34    }
35}
36
37/// A struct representing a WebSocket frame header.
38#[allow(missing_copy_implementations)]
39#[derive(Debug, Clone, Eq, PartialEq)]
40pub struct FrameHeader {
41    /// Indicates that the frame is the last one of a possibly fragmented message.
42    pub is_final: bool,
43    /// Reserved for protocol extensions.
44    pub rsv1: bool,
45    /// Reserved for protocol extensions.
46    pub rsv2: bool,
47    /// Reserved for protocol extensions.
48    pub rsv3: bool,
49    /// WebSocket protocol opcode.
50    pub opcode: OpCode,
51    /// A frame mask, if any.
52    pub mask: Option<[u8; 4]>,
53}
54
55impl Default for FrameHeader {
56    fn default() -> Self {
57        FrameHeader {
58            is_final: true,
59            rsv1: false,
60            rsv2: false,
61            rsv3: false,
62            opcode: OpCode::Control(Control::Close),
63            mask: None,
64        }
65    }
66}
67
68impl FrameHeader {
69    /// > The longest possible header is 14 bytes, which would represent a message sent from
70    /// > the client to the server with a payload greater than 64KB.
71    pub(crate) const MAX_SIZE: usize = 14;
72
73    /// Parse a header from an input stream.
74    /// Returns `None` if insufficient data and does not consume anything in this case.
75    /// Payload size is returned along with the header.
76    pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
77        let initial = cursor.position();
78        match Self::parse_internal(cursor) {
79            ret @ Ok(None) => {
80                cursor.set_position(initial);
81                ret
82            }
83            ret => ret,
84        }
85    }
86
87    /// Get the size of the header formatted with given payload length.
88    #[allow(clippy::len_without_is_empty)]
89    pub fn len(&self, length: u64) -> usize {
90        2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
91    }
92
93    /// Format a header for given payload size.
94    pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
95        let code: u8 = self.opcode.into();
96
97        let one = {
98            code | if self.is_final { 0x80 } else { 0 }
99                | if self.rsv1 { 0x40 } else { 0 }
100                | if self.rsv2 { 0x20 } else { 0 }
101                | if self.rsv3 { 0x10 } else { 0 }
102        };
103
104        let lenfmt = LengthFormat::for_length(length);
105
106        let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
107
108        output.write_all(&[one, two])?;
109        match lenfmt {
110            LengthFormat::U8(_) => (),
111            LengthFormat::U16 => {
112                output.write_all(&(length as u16).to_be_bytes())?;
113            }
114            LengthFormat::U64 => {
115                output.write_all(&length.to_be_bytes())?;
116            }
117        }
118
119        if let Some(ref mask) = self.mask {
120            output.write_all(mask)?;
121        }
122
123        Ok(())
124    }
125
126    /// Generate a random frame mask and store this in the header.
127    ///
128    /// Of course this does not change frame contents. It just generates a mask.
129    pub(crate) fn set_random_mask(&mut self) {
130        self.mask = Some(generate_mask());
131    }
132}
133
134impl FrameHeader {
135    /// Internal parse engine.
136    /// Returns `None` if insufficient data.
137    /// Payload size is returned along with the header.
138    fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
139        let (first, second) = {
140            let mut head = [0u8; 2];
141            if cursor.read(&mut head)? != 2 {
142                return Ok(None);
143            }
144            trace!("Parsed headers {:?}", head);
145            (head[0], head[1])
146        };
147
148        trace!("First: {:b}", first);
149        trace!("Second: {:b}", second);
150
151        let is_final = first & 0x80 != 0;
152
153        let rsv1 = first & 0x40 != 0;
154        let rsv2 = first & 0x20 != 0;
155        let rsv3 = first & 0x10 != 0;
156
157        let opcode = OpCode::from(first & 0x0F);
158        trace!("Opcode: {:?}", opcode);
159
160        let masked = second & 0x80 != 0;
161        trace!("Masked: {:?}", masked);
162
163        let length = {
164            let length_byte = second & 0x7F;
165            let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
166            if length_length > 0 {
167                const SIZE: usize = mem::size_of::<u64>();
168                assert!(length_length <= SIZE, "length exceeded size of u64");
169                let start = SIZE - length_length;
170                let mut buffer = [0; SIZE];
171                match cursor.read_exact(&mut buffer[start..]) {
172                    Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => return Ok(None),
173                    Err(err) => return Err(err.into()),
174                    Ok(()) => u64::from_be_bytes(buffer),
175                }
176            } else {
177                u64::from(length_byte)
178            }
179        };
180
181        let mask = if masked {
182            let mut mask_bytes = [0u8; 4];
183            if cursor.read(&mut mask_bytes)? != 4 {
184                return Ok(None);
185            } else {
186                Some(mask_bytes)
187            }
188        } else {
189            None
190        };
191
192        // Disallow bad opcode
193        match opcode {
194            OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
195                return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
196            }
197            _ => (),
198        }
199
200        let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
201
202        Ok(Some((hdr, length)))
203    }
204}
205
206/// A struct representing a WebSocket frame.
207#[derive(Debug, Clone, Eq, PartialEq)]
208pub struct Frame {
209    header: FrameHeader,
210    payload: Bytes,
211}
212
213impl Frame {
214    /// Get the length of the frame.
215    /// This is the length of the header + the length of the payload.
216    #[inline]
217    pub fn len(&self) -> usize {
218        let length = self.payload.len();
219        self.header.len(length as u64) + length
220    }
221
222    /// Check if the frame is empty.
223    #[inline]
224    pub fn is_empty(&self) -> bool {
225        self.len() == 0
226    }
227
228    /// Get a reference to the frame's header.
229    #[inline]
230    pub fn header(&self) -> &FrameHeader {
231        &self.header
232    }
233
234    /// Get a mutable reference to the frame's header.
235    #[inline]
236    pub fn header_mut(&mut self) -> &mut FrameHeader {
237        &mut self.header
238    }
239
240    /// Get a reference to the frame's payload.
241    #[inline]
242    pub fn payload(&self) -> &[u8] {
243        &self.payload
244    }
245
246    /// Test whether the frame is masked.
247    #[inline]
248    pub(crate) fn is_masked(&self) -> bool {
249        self.header.mask.is_some()
250    }
251
252    /// Generate a random mask for the frame.
253    ///
254    /// This just generates a mask, payload is not changed. The actual masking is performed
255    /// either on `format()` or on `apply_mask()` call.
256    #[inline]
257    pub(crate) fn set_random_mask(&mut self) {
258        self.header.set_random_mask();
259    }
260
261    /// Consume the frame into its payload as string.
262    #[inline]
263    pub fn into_text(self) -> StdResult<Utf8Bytes, Utf8Error> {
264        self.payload.try_into()
265    }
266
267    /// Consume the frame into its payload.
268    #[inline]
269    pub fn into_payload(self) -> Bytes {
270        self.payload
271    }
272
273    /// Get frame payload as `&str`.
274    #[inline]
275    pub fn to_text(&self) -> Result<&str, Utf8Error> {
276        std::str::from_utf8(&self.payload)
277    }
278
279    /// Consume the frame into a closing frame.
280    #[inline]
281    pub(crate) fn into_close(self) -> Result<Option<CloseFrame>> {
282        match self.payload.len() {
283            0 => Ok(None),
284            1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
285            _ => {
286                let code = u16::from_be_bytes([self.payload[0], self.payload[1]]).into();
287                let reason = Utf8Bytes::try_from(self.payload.slice(2..))?;
288                Ok(Some(CloseFrame { code, reason }))
289            }
290        }
291    }
292
293    /// Create a new data frame.
294    #[inline]
295    pub fn message(data: impl Into<Bytes>, opcode: OpCode, is_final: bool) -> Frame {
296        debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
297        Frame {
298            header: FrameHeader { is_final, opcode, ..FrameHeader::default() },
299            payload: data.into(),
300        }
301    }
302
303    /// Create a new Pong control frame.
304    #[inline]
305    pub fn pong(data: impl Into<Bytes>) -> Frame {
306        Frame {
307            header: FrameHeader {
308                opcode: OpCode::Control(Control::Pong),
309                ..FrameHeader::default()
310            },
311            payload: data.into(),
312        }
313    }
314
315    /// Create a new Ping control frame.
316    #[inline]
317    pub fn ping(data: impl Into<Bytes>) -> Frame {
318        Frame {
319            header: FrameHeader {
320                opcode: OpCode::Control(Control::Ping),
321                ..FrameHeader::default()
322            },
323            payload: data.into(),
324        }
325    }
326
327    /// Create a new Close control frame.
328    #[inline]
329    pub fn close(msg: Option<CloseFrame>) -> Frame {
330        let payload = if let Some(CloseFrame { code, reason }) = msg {
331            let mut p = BytesMut::with_capacity(reason.len() + 2);
332            p.extend(u16::from(code).to_be_bytes());
333            p.extend_from_slice(reason.as_bytes());
334            p
335        } else {
336            <_>::default()
337        };
338
339        Frame { header: FrameHeader::default(), payload: payload.into() }
340    }
341
342    /// Create a frame from given header and data.
343    pub fn from_payload(header: FrameHeader, payload: Bytes) -> Self {
344        Frame { header, payload }
345    }
346
347    /// Write a frame out to a buffer
348    pub fn format(mut self, output: &mut impl Write) -> Result<()> {
349        self.header.format(self.payload.len() as u64, output)?;
350
351        if let Some(mask) = self.header.mask.take() {
352            let mut data = Vec::from(mem::take(&mut self.payload));
353            apply_mask(&mut data, mask);
354            output.write_all(&data)?;
355        } else {
356            output.write_all(&self.payload)?;
357        }
358
359        Ok(())
360    }
361
362    pub(crate) fn format_into_buf(mut self, buf: &mut Vec<u8>) -> Result<()> {
363        self.header.format(self.payload.len() as u64, buf)?;
364
365        let len = buf.len();
366        buf.extend_from_slice(&self.payload);
367
368        if let Some(mask) = self.header.mask.take() {
369            apply_mask(&mut buf[len..], mask);
370        }
371
372        Ok(())
373    }
374}
375
376impl fmt::Display for Frame {
377    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
378        use std::fmt::Write;
379
380        write!(
381            f,
382            "
383<FRAME>
384final: {}
385reserved: {} {} {}
386opcode: {}
387length: {}
388payload length: {}
389payload: 0x{}
390            ",
391            self.header.is_final,
392            self.header.rsv1,
393            self.header.rsv2,
394            self.header.rsv3,
395            self.header.opcode,
396            // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
397            self.len(),
398            self.payload.len(),
399            self.payload.iter().fold(String::new(), |mut output, byte| {
400                _ = write!(output, "{byte:02x}");
401                output
402            })
403        )
404    }
405}
406
407/// Handling of the length format.
408enum LengthFormat {
409    U8(u8),
410    U16,
411    U64,
412}
413
414impl LengthFormat {
415    /// Get the length format for a given data size.
416    #[inline]
417    fn for_length(length: u64) -> Self {
418        if length < 126 {
419            LengthFormat::U8(length as u8)
420        } else if length < 65536 {
421            LengthFormat::U16
422        } else {
423            LengthFormat::U64
424        }
425    }
426
427    /// Get the size of the length encoding.
428    #[inline]
429    fn extra_bytes(&self) -> usize {
430        match *self {
431            LengthFormat::U8(_) => 0,
432            LengthFormat::U16 => 2,
433            LengthFormat::U64 => 8,
434        }
435    }
436
437    /// Encode the given length.
438    #[inline]
439    fn length_byte(&self) -> u8 {
440        match *self {
441            LengthFormat::U8(b) => b,
442            LengthFormat::U16 => 126,
443            LengthFormat::U64 => 127,
444        }
445    }
446
447    /// Get the length format for a given length byte.
448    #[inline]
449    fn for_byte(byte: u8) -> Self {
450        match byte & 0x7F {
451            126 => LengthFormat::U16,
452            127 => LengthFormat::U64,
453            b => LengthFormat::U8(b),
454        }
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    use super::super::coding::{Data, OpCode};
463    use std::io::Cursor;
464
465    #[test]
466    fn parse() {
467        let mut raw: Cursor<Vec<u8>> =
468            Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
469        let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
470        assert_eq!(length, 7);
471        let mut payload = Vec::new();
472        raw.read_to_end(&mut payload).unwrap();
473        let frame = Frame::from_payload(header, payload.into());
474        assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]);
475    }
476
477    #[test]
478    fn format() {
479        let frame = Frame::ping(vec![0x01, 0x02]);
480        let mut buf = Vec::with_capacity(frame.len());
481        frame.format(&mut buf).unwrap();
482        assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
483    }
484
485    #[test]
486    fn format_into_buf() {
487        let frame = Frame::ping(vec![0x01, 0x02]);
488        let mut buf = Vec::with_capacity(frame.len());
489        frame.format_into_buf(&mut buf).unwrap();
490        assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
491    }
492
493    #[test]
494    fn display() {
495        let f = Frame::message(Bytes::from_static(b"hi there"), OpCode::Data(Data::Text), true);
496        let view = format!("{f}");
497        assert!(view.contains("payload:"));
498    }
499}