tungstenite/protocol/frame/
frame.rs1use log::*;
2use std::{
3 default::Default,
4 fmt,
5 io::{Cursor, ErrorKind, Read, Write},
6 mem,
7 result::Result as StdResult,
8 str::Utf8Error,
9 string::String,
10};
11
12use super::{
13 coding::{CloseCode, Control, Data, OpCode},
14 mask::{apply_mask, generate_mask},
15};
16use crate::{
17 error::{Error, ProtocolError, Result},
18 protocol::frame::Utf8Bytes,
19};
20use bytes::{Bytes, BytesMut};
21
22#[derive(Debug, Clone, Eq, PartialEq)]
24pub struct CloseFrame {
25 pub code: CloseCode,
27 pub reason: Utf8Bytes,
29}
30
31impl fmt::Display for CloseFrame {
32 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33 write!(f, "{} ({})", self.reason, self.code)
34 }
35}
36
37#[allow(missing_copy_implementations)]
39#[derive(Debug, Clone, Eq, PartialEq)]
40pub struct FrameHeader {
41 pub is_final: bool,
43 pub rsv1: bool,
45 pub rsv2: bool,
47 pub rsv3: bool,
49 pub opcode: OpCode,
51 pub mask: Option<[u8; 4]>,
53}
54
55impl Default for FrameHeader {
56 fn default() -> Self {
57 FrameHeader {
58 is_final: true,
59 rsv1: false,
60 rsv2: false,
61 rsv3: false,
62 opcode: OpCode::Control(Control::Close),
63 mask: None,
64 }
65 }
66}
67
68impl FrameHeader {
69 pub(crate) const MAX_SIZE: usize = 14;
72
73 pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
77 let initial = cursor.position();
78 match Self::parse_internal(cursor) {
79 ret @ Ok(None) => {
80 cursor.set_position(initial);
81 ret
82 }
83 ret => ret,
84 }
85 }
86
87 #[allow(clippy::len_without_is_empty)]
89 pub fn len(&self, length: u64) -> usize {
90 2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
91 }
92
93 pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
95 let code: u8 = self.opcode.into();
96
97 let one = {
98 code | if self.is_final { 0x80 } else { 0 }
99 | if self.rsv1 { 0x40 } else { 0 }
100 | if self.rsv2 { 0x20 } else { 0 }
101 | if self.rsv3 { 0x10 } else { 0 }
102 };
103
104 let lenfmt = LengthFormat::for_length(length);
105
106 let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
107
108 output.write_all(&[one, two])?;
109 match lenfmt {
110 LengthFormat::U8(_) => (),
111 LengthFormat::U16 => {
112 output.write_all(&(length as u16).to_be_bytes())?;
113 }
114 LengthFormat::U64 => {
115 output.write_all(&length.to_be_bytes())?;
116 }
117 }
118
119 if let Some(ref mask) = self.mask {
120 output.write_all(mask)?;
121 }
122
123 Ok(())
124 }
125
126 pub(crate) fn set_random_mask(&mut self) {
130 self.mask = Some(generate_mask());
131 }
132}
133
134impl FrameHeader {
135 fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
139 let (first, second) = {
140 let mut head = [0u8; 2];
141 if cursor.read(&mut head)? != 2 {
142 return Ok(None);
143 }
144 trace!("Parsed headers {:?}", head);
145 (head[0], head[1])
146 };
147
148 trace!("First: {:b}", first);
149 trace!("Second: {:b}", second);
150
151 let is_final = first & 0x80 != 0;
152
153 let rsv1 = first & 0x40 != 0;
154 let rsv2 = first & 0x20 != 0;
155 let rsv3 = first & 0x10 != 0;
156
157 let opcode = OpCode::from(first & 0x0F);
158 trace!("Opcode: {:?}", opcode);
159
160 let masked = second & 0x80 != 0;
161 trace!("Masked: {:?}", masked);
162
163 let length = {
164 let length_byte = second & 0x7F;
165 let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
166 if length_length > 0 {
167 const SIZE: usize = mem::size_of::<u64>();
168 assert!(length_length <= SIZE, "length exceeded size of u64");
169 let start = SIZE - length_length;
170 let mut buffer = [0; SIZE];
171 match cursor.read_exact(&mut buffer[start..]) {
172 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => return Ok(None),
173 Err(err) => return Err(err.into()),
174 Ok(()) => u64::from_be_bytes(buffer),
175 }
176 } else {
177 u64::from(length_byte)
178 }
179 };
180
181 let mask = if masked {
182 let mut mask_bytes = [0u8; 4];
183 if cursor.read(&mut mask_bytes)? != 4 {
184 return Ok(None);
185 } else {
186 Some(mask_bytes)
187 }
188 } else {
189 None
190 };
191
192 match opcode {
194 OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
195 return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
196 }
197 _ => (),
198 }
199
200 let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
201
202 Ok(Some((hdr, length)))
203 }
204}
205
206#[derive(Debug, Clone, Eq, PartialEq)]
208pub struct Frame {
209 header: FrameHeader,
210 payload: Bytes,
211}
212
213impl Frame {
214 #[inline]
217 pub fn len(&self) -> usize {
218 let length = self.payload.len();
219 self.header.len(length as u64) + length
220 }
221
222 #[inline]
224 pub fn is_empty(&self) -> bool {
225 self.len() == 0
226 }
227
228 #[inline]
230 pub fn header(&self) -> &FrameHeader {
231 &self.header
232 }
233
234 #[inline]
236 pub fn header_mut(&mut self) -> &mut FrameHeader {
237 &mut self.header
238 }
239
240 #[inline]
242 pub fn payload(&self) -> &[u8] {
243 &self.payload
244 }
245
246 #[inline]
248 pub(crate) fn is_masked(&self) -> bool {
249 self.header.mask.is_some()
250 }
251
252 #[inline]
257 pub(crate) fn set_random_mask(&mut self) {
258 self.header.set_random_mask();
259 }
260
261 #[inline]
263 pub fn into_text(self) -> StdResult<Utf8Bytes, Utf8Error> {
264 self.payload.try_into()
265 }
266
267 #[inline]
269 pub fn into_payload(self) -> Bytes {
270 self.payload
271 }
272
273 #[inline]
275 pub fn to_text(&self) -> Result<&str, Utf8Error> {
276 std::str::from_utf8(&self.payload)
277 }
278
279 #[inline]
281 pub(crate) fn into_close(self) -> Result<Option<CloseFrame>> {
282 match self.payload.len() {
283 0 => Ok(None),
284 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
285 _ => {
286 let code = u16::from_be_bytes([self.payload[0], self.payload[1]]).into();
287 let reason = Utf8Bytes::try_from(self.payload.slice(2..))?;
288 Ok(Some(CloseFrame { code, reason }))
289 }
290 }
291 }
292
293 #[inline]
295 pub fn message(data: impl Into<Bytes>, opcode: OpCode, is_final: bool) -> Frame {
296 debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
297 Frame {
298 header: FrameHeader { is_final, opcode, ..FrameHeader::default() },
299 payload: data.into(),
300 }
301 }
302
303 #[inline]
305 pub fn pong(data: impl Into<Bytes>) -> Frame {
306 Frame {
307 header: FrameHeader {
308 opcode: OpCode::Control(Control::Pong),
309 ..FrameHeader::default()
310 },
311 payload: data.into(),
312 }
313 }
314
315 #[inline]
317 pub fn ping(data: impl Into<Bytes>) -> Frame {
318 Frame {
319 header: FrameHeader {
320 opcode: OpCode::Control(Control::Ping),
321 ..FrameHeader::default()
322 },
323 payload: data.into(),
324 }
325 }
326
327 #[inline]
329 pub fn close(msg: Option<CloseFrame>) -> Frame {
330 let payload = if let Some(CloseFrame { code, reason }) = msg {
331 let mut p = BytesMut::with_capacity(reason.len() + 2);
332 p.extend(u16::from(code).to_be_bytes());
333 p.extend_from_slice(reason.as_bytes());
334 p
335 } else {
336 <_>::default()
337 };
338
339 Frame { header: FrameHeader::default(), payload: payload.into() }
340 }
341
342 pub fn from_payload(header: FrameHeader, payload: Bytes) -> Self {
344 Frame { header, payload }
345 }
346
347 pub fn format(mut self, output: &mut impl Write) -> Result<()> {
349 self.header.format(self.payload.len() as u64, output)?;
350
351 if let Some(mask) = self.header.mask.take() {
352 let mut data = Vec::from(mem::take(&mut self.payload));
353 apply_mask(&mut data, mask);
354 output.write_all(&data)?;
355 } else {
356 output.write_all(&self.payload)?;
357 }
358
359 Ok(())
360 }
361
362 pub(crate) fn format_into_buf(mut self, buf: &mut Vec<u8>) -> Result<()> {
363 self.header.format(self.payload.len() as u64, buf)?;
364
365 let len = buf.len();
366 buf.extend_from_slice(&self.payload);
367
368 if let Some(mask) = self.header.mask.take() {
369 apply_mask(&mut buf[len..], mask);
370 }
371
372 Ok(())
373 }
374}
375
376impl fmt::Display for Frame {
377 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
378 use std::fmt::Write;
379
380 write!(
381 f,
382 "
383<FRAME>
384final: {}
385reserved: {} {} {}
386opcode: {}
387length: {}
388payload length: {}
389payload: 0x{}
390 ",
391 self.header.is_final,
392 self.header.rsv1,
393 self.header.rsv2,
394 self.header.rsv3,
395 self.header.opcode,
396 self.len(),
398 self.payload.len(),
399 self.payload.iter().fold(String::new(), |mut output, byte| {
400 _ = write!(output, "{byte:02x}");
401 output
402 })
403 )
404 }
405}
406
407enum LengthFormat {
409 U8(u8),
410 U16,
411 U64,
412}
413
414impl LengthFormat {
415 #[inline]
417 fn for_length(length: u64) -> Self {
418 if length < 126 {
419 LengthFormat::U8(length as u8)
420 } else if length < 65536 {
421 LengthFormat::U16
422 } else {
423 LengthFormat::U64
424 }
425 }
426
427 #[inline]
429 fn extra_bytes(&self) -> usize {
430 match *self {
431 LengthFormat::U8(_) => 0,
432 LengthFormat::U16 => 2,
433 LengthFormat::U64 => 8,
434 }
435 }
436
437 #[inline]
439 fn length_byte(&self) -> u8 {
440 match *self {
441 LengthFormat::U8(b) => b,
442 LengthFormat::U16 => 126,
443 LengthFormat::U64 => 127,
444 }
445 }
446
447 #[inline]
449 fn for_byte(byte: u8) -> Self {
450 match byte & 0x7F {
451 126 => LengthFormat::U16,
452 127 => LengthFormat::U64,
453 b => LengthFormat::U8(b),
454 }
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 use super::super::coding::{Data, OpCode};
463 use std::io::Cursor;
464
465 #[test]
466 fn parse() {
467 let mut raw: Cursor<Vec<u8>> =
468 Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
469 let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
470 assert_eq!(length, 7);
471 let mut payload = Vec::new();
472 raw.read_to_end(&mut payload).unwrap();
473 let frame = Frame::from_payload(header, payload.into());
474 assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]);
475 }
476
477 #[test]
478 fn format() {
479 let frame = Frame::ping(vec![0x01, 0x02]);
480 let mut buf = Vec::with_capacity(frame.len());
481 frame.format(&mut buf).unwrap();
482 assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
483 }
484
485 #[test]
486 fn format_into_buf() {
487 let frame = Frame::ping(vec![0x01, 0x02]);
488 let mut buf = Vec::with_capacity(frame.len());
489 frame.format_into_buf(&mut buf).unwrap();
490 assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
491 }
492
493 #[test]
494 fn display() {
495 let f = Frame::message(Bytes::from_static(b"hi there"), OpCode::Data(Data::Text), true);
496 let view = format!("{f}");
497 assert!(view.contains("payload:"));
498 }
499}