tungstenite/protocol/frame/
mod.rs1pub mod coding;
4
5#[allow(clippy::module_inception)]
6mod frame;
7mod mask;
8mod utf8;
9
10pub use self::{
11 frame::{CloseFrame, Frame, FrameHeader},
12 utf8::Utf8Bytes,
13};
14
15use crate::{
16 error::{CapacityError, Error, ProtocolError, Result},
17 protocol::frame::mask::apply_mask,
18 Message,
19};
20use bytes::BytesMut;
21use log::*;
22use std::io::{self, Cursor, Error as IoError, ErrorKind as IoErrorKind, Read, Write};
23
24const READ_BUF_LEN: usize = 128 * 1024;
26
27#[derive(Debug)]
29pub struct FrameSocket<Stream> {
30 stream: Stream,
32 codec: FrameCodec,
34}
35
36impl<Stream> FrameSocket<Stream> {
37 pub fn new(stream: Stream) -> Self {
39 FrameSocket { stream, codec: FrameCodec::new(READ_BUF_LEN) }
40 }
41
42 pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
44 FrameSocket { stream, codec: FrameCodec::from_partially_read(part, READ_BUF_LEN) }
45 }
46
47 pub fn into_inner(self) -> (Stream, BytesMut) {
49 (self.stream, self.codec.in_buffer)
50 }
51
52 pub fn get_ref(&self) -> &Stream {
54 &self.stream
55 }
56
57 pub fn get_mut(&mut self) -> &mut Stream {
59 &mut self.stream
60 }
61}
62
63impl<Stream> FrameSocket<Stream>
64where
65 Stream: Read,
66{
67 pub fn read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
69 self.codec.read_frame(&mut self.stream, max_size, false, true)
70 }
71}
72
73impl<Stream> FrameSocket<Stream>
74where
75 Stream: Write,
76{
77 pub fn send(&mut self, frame: Frame) -> Result<()> {
80 self.write(frame)?;
81 self.flush()
82 }
83
84 pub fn write(&mut self, frame: Frame) -> Result<()> {
92 self.codec.buffer_frame(&mut self.stream, frame)
93 }
94
95 pub fn flush(&mut self) -> Result<()> {
97 self.codec.write_out_buffer(&mut self.stream)?;
98 Ok(self.stream.flush()?)
99 }
100}
101
102#[derive(Debug)]
104pub(super) struct FrameCodec {
105 in_buffer: BytesMut,
107 in_buf_max_read: usize,
108 out_buffer: Vec<u8>,
110 max_out_buffer_len: usize,
112 out_buffer_write_len: usize,
118 header: Option<(FrameHeader, u64)>,
120}
121
122impl FrameCodec {
123 pub(super) fn new(in_buf_len: usize) -> Self {
125 Self {
126 in_buffer: BytesMut::with_capacity(in_buf_len),
127 in_buf_max_read: in_buf_len.max(FrameHeader::MAX_SIZE),
128 out_buffer: <_>::default(),
129 max_out_buffer_len: usize::MAX,
130 out_buffer_write_len: 0,
131 header: None,
132 }
133 }
134
135 pub(super) fn from_partially_read(part: Vec<u8>, min_in_buf_len: usize) -> Self {
137 let mut in_buffer = BytesMut::from_iter(part);
138 in_buffer.reserve(min_in_buf_len.saturating_sub(in_buffer.len()));
139 Self {
140 in_buffer,
141 in_buf_max_read: min_in_buf_len.max(FrameHeader::MAX_SIZE),
142 out_buffer: <_>::default(),
143 max_out_buffer_len: usize::MAX,
144 out_buffer_write_len: 0,
145 header: None,
146 }
147 }
148
149 pub(super) fn set_max_out_buffer_len(&mut self, max: usize) {
151 self.max_out_buffer_len = max;
152 }
153
154 pub(super) fn set_out_buffer_write_len(&mut self, len: usize) {
157 self.out_buffer_write_len = len;
158 }
159
160 pub(super) fn read_frame(
162 &mut self,
163 stream: &mut impl Read,
164 max_size: Option<usize>,
165 unmask: bool,
166 accept_unmasked: bool,
167 ) -> Result<Option<Frame>> {
168 let max_size = max_size.unwrap_or_else(usize::max_value);
169
170 let mut payload = loop {
171 if self.header.is_none() {
172 let mut cursor = Cursor::new(&mut self.in_buffer);
173 self.header = FrameHeader::parse(&mut cursor)?;
174 let advanced = cursor.position();
175 bytes::Buf::advance(&mut self.in_buffer, advanced as _);
176
177 if let Some((_, len)) = &self.header {
178 let len = *len as usize;
179
180 if len > max_size {
182 return Err(Error::Capacity(CapacityError::MessageTooLong {
183 size: len,
184 max_size,
185 }));
186 }
187
188 self.in_buffer.reserve(len);
191 } else {
192 self.in_buffer.reserve(FrameHeader::MAX_SIZE);
193 }
194 }
195
196 if let Some((_, len)) = &self.header {
197 let len = *len as usize;
198 if len <= self.in_buffer.len() {
199 break self.in_buffer.split_to(len);
200 }
201 }
202
203 if self.read_in(stream)? == 0 {
205 trace!("no frame received");
206 return Ok(None);
207 }
208 };
209
210 let (mut header, length) = self.header.take().expect("Bug: no frame header");
211 debug_assert_eq!(payload.len() as u64, length);
212
213 if unmask {
214 if let Some(mask) = header.mask.take() {
215 apply_mask(&mut payload, mask);
218 } else if !accept_unmasked {
219 return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient));
224 }
225 }
226
227 let frame = Frame::from_payload(header, payload.freeze());
228 trace!("received frame {frame}");
229 Ok(Some(frame))
230 }
231
232 fn read_in(&mut self, stream: &mut impl Read) -> io::Result<usize> {
234 let len = self.in_buffer.len();
235 debug_assert!(self.in_buffer.capacity() > len);
236 self.in_buffer.resize(self.in_buffer.capacity().min(len + self.in_buf_max_read), 0);
237 let size = stream.read(&mut self.in_buffer[len..]);
238 self.in_buffer.truncate(len + size.as_ref().copied().unwrap_or(0));
239 size
240 }
241
242 pub(super) fn buffer_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
250 where
251 Stream: Write,
252 {
253 if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
254 return Err(Error::WriteBufferFull(Message::Frame(frame)));
255 }
256
257 trace!("writing frame {frame}");
258
259 self.out_buffer.reserve(frame.len());
260 frame.format_into_buf(&mut self.out_buffer).expect("Bug: can't write to vector");
261
262 if self.out_buffer.len() > self.out_buffer_write_len {
263 self.write_out_buffer(stream)
264 } else {
265 Ok(())
266 }
267 }
268
269 pub(super) fn write_out_buffer<Stream>(&mut self, stream: &mut Stream) -> Result<()>
273 where
274 Stream: Write,
275 {
276 while !self.out_buffer.is_empty() {
277 let len = stream.write(&self.out_buffer)?;
278 if len == 0 {
279 return Err(IoError::new(
281 IoErrorKind::ConnectionReset,
282 "Connection reset while sending",
283 )
284 .into());
285 }
286 self.out_buffer.drain(0..len);
287 }
288
289 Ok(())
290 }
291}
292
293#[cfg(test)]
294mod tests {
295
296 use crate::error::{CapacityError, Error};
297
298 use super::{Frame, FrameSocket};
299
300 use std::io::Cursor;
301
302 #[test]
303 fn read_frames() {
304 env_logger::init();
305
306 let raw = Cursor::new(vec![
307 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
308 0x99,
309 ]);
310 let mut sock = FrameSocket::new(raw);
311
312 assert_eq!(
313 sock.read(None).unwrap().unwrap().into_payload(),
314 &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]
315 );
316 assert_eq!(sock.read(None).unwrap().unwrap().into_payload(), &[0x03, 0x02, 0x01][..]);
317 assert!(sock.read(None).unwrap().is_none());
318
319 let (_, rest) = sock.into_inner();
320 assert_eq!(rest, vec![0x99]);
321 }
322
323 #[test]
324 fn from_partially_read() {
325 let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
326 let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
327 assert_eq!(
328 sock.read(None).unwrap().unwrap().into_payload(),
329 &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]
330 );
331 }
332
333 #[test]
334 fn write_frames() {
335 let mut sock = FrameSocket::new(Vec::new());
336
337 let frame = Frame::ping(vec![0x04, 0x05]);
338 sock.send(frame).unwrap();
339
340 let frame = Frame::pong(vec![0x01]);
341 sock.send(frame).unwrap();
342
343 let (buf, _) = sock.into_inner();
344 assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
345 }
346
347 #[test]
348 fn parse_overflow() {
349 let raw = Cursor::new(vec![
350 0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
351 ]);
352 let mut sock = FrameSocket::new(raw);
353 let _ = sock.read(None); }
355
356 #[test]
357 fn size_limit_hit() {
358 let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
359 let mut sock = FrameSocket::new(raw);
360 assert!(matches!(
361 sock.read(Some(5)),
362 Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
363 ));
364 }
365}