tungstenite/protocol/
message.rs

1use super::frame::{CloseFrame, Frame};
2use crate::{
3    error::{CapacityError, Error, Result},
4    protocol::frame::Utf8Bytes,
5};
6use std::{fmt, result::Result as StdResult, str};
7
8mod string_collect {
9    use utf8::DecodeError;
10
11    use crate::error::{Error, Result};
12
13    #[derive(Debug)]
14    pub struct StringCollector {
15        data: String,
16        incomplete: Option<utf8::Incomplete>,
17    }
18
19    impl StringCollector {
20        pub fn new() -> Self {
21            StringCollector { data: String::new(), incomplete: None }
22        }
23
24        pub fn len(&self) -> usize {
25            self.data
26                .len()
27                .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
28        }
29
30        pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> {
31            let mut input: &[u8] = tail.as_ref();
32
33            if let Some(mut incomplete) = self.incomplete.take() {
34                if let Some((result, rest)) = incomplete.try_complete(input) {
35                    input = rest;
36                    match result {
37                        Ok(text) => self.data.push_str(text),
38                        Err(result_bytes) => {
39                            return Err(Error::Utf8(String::from_utf8_lossy(result_bytes).into()))
40                        }
41                    }
42                } else {
43                    input = &[];
44                    self.incomplete = Some(incomplete);
45                }
46            }
47
48            if !input.is_empty() {
49                match utf8::decode(input) {
50                    Ok(text) => {
51                        self.data.push_str(text);
52                        Ok(())
53                    }
54                    Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
55                        self.data.push_str(valid_prefix);
56                        self.incomplete = Some(incomplete_suffix);
57                        Ok(())
58                    }
59                    Err(DecodeError::Invalid { valid_prefix, invalid_sequence, .. }) => {
60                        self.data.push_str(valid_prefix);
61                        Err(Error::Utf8(String::from_utf8_lossy(invalid_sequence).into()))
62                    }
63                }
64            } else {
65                Ok(())
66            }
67        }
68
69        pub fn into_string(self) -> Result<String> {
70            if let Some(incomplete) = self.incomplete {
71                Err(Error::Utf8(format!("incomplete string: {:?}", incomplete)))
72            } else {
73                Ok(self.data)
74            }
75        }
76    }
77}
78
79use self::string_collect::StringCollector;
80use bytes::Bytes;
81
82/// A struct representing the incomplete message.
83#[derive(Debug)]
84pub struct IncompleteMessage {
85    collector: IncompleteMessageCollector,
86}
87
88#[derive(Debug)]
89enum IncompleteMessageCollector {
90    Text(StringCollector),
91    Binary(Vec<u8>),
92}
93
94impl IncompleteMessage {
95    /// Create new.
96    pub fn new(message_type: IncompleteMessageType) -> Self {
97        IncompleteMessage {
98            collector: match message_type {
99                IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
100                IncompleteMessageType::Text => {
101                    IncompleteMessageCollector::Text(StringCollector::new())
102                }
103            },
104        }
105    }
106
107    /// Get the current filled size of the buffer.
108    pub fn len(&self) -> usize {
109        match self.collector {
110            IncompleteMessageCollector::Text(ref t) => t.len(),
111            IncompleteMessageCollector::Binary(ref b) => b.len(),
112        }
113    }
114
115    /// Add more data to an existing message.
116    pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T, size_limit: Option<usize>) -> Result<()> {
117        // Always have a max size. This ensures an error in case of concatenating two buffers
118        // of more than `usize::max_value()` bytes in total.
119        let max_size = size_limit.unwrap_or_else(usize::max_value);
120        let my_size = self.len();
121        let portion_size = tail.as_ref().len();
122        // Be careful about integer overflows here.
123        if my_size > max_size || portion_size > max_size - my_size {
124            return Err(Error::Capacity(CapacityError::MessageTooLong {
125                size: my_size + portion_size,
126                max_size,
127            }));
128        }
129
130        match self.collector {
131            IncompleteMessageCollector::Binary(ref mut v) => {
132                v.extend(tail.as_ref());
133                Ok(())
134            }
135            IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
136        }
137    }
138
139    /// Convert an incomplete message into a complete one.
140    pub fn complete(self) -> Result<Message> {
141        match self.collector {
142            IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v.into())),
143            IncompleteMessageCollector::Text(t) => {
144                let text = t.into_string()?;
145                Ok(Message::text(text))
146            }
147        }
148    }
149}
150
151/// The type of incomplete message.
152pub enum IncompleteMessageType {
153    Text,
154    Binary,
155}
156
157/// An enum representing the various forms of a WebSocket message.
158#[derive(Debug, Eq, PartialEq, Clone)]
159pub enum Message {
160    /// A text WebSocket message
161    Text(Utf8Bytes),
162    /// A binary WebSocket message
163    Binary(Bytes),
164    /// A ping message with the specified payload
165    ///
166    /// The payload here must have a length less than 125 bytes
167    Ping(Bytes),
168    /// A pong message with the specified payload
169    ///
170    /// The payload here must have a length less than 125 bytes
171    Pong(Bytes),
172    /// A close message with the optional close frame.
173    Close(Option<CloseFrame>),
174    /// Raw frame. Note, that you're not going to get this value while reading the message.
175    Frame(Frame),
176}
177
178impl Message {
179    /// Create a new text WebSocket message from a stringable.
180    pub fn text<S>(string: S) -> Message
181    where
182        S: Into<Utf8Bytes>,
183    {
184        Message::Text(string.into())
185    }
186
187    /// Create a new binary WebSocket message by converting to `Bytes`.
188    pub fn binary<B>(bin: B) -> Message
189    where
190        B: Into<Bytes>,
191    {
192        Message::Binary(bin.into())
193    }
194
195    /// Indicates whether a message is a text message.
196    pub fn is_text(&self) -> bool {
197        matches!(*self, Message::Text(_))
198    }
199
200    /// Indicates whether a message is a binary message.
201    pub fn is_binary(&self) -> bool {
202        matches!(*self, Message::Binary(_))
203    }
204
205    /// Indicates whether a message is a ping message.
206    pub fn is_ping(&self) -> bool {
207        matches!(*self, Message::Ping(_))
208    }
209
210    /// Indicates whether a message is a pong message.
211    pub fn is_pong(&self) -> bool {
212        matches!(*self, Message::Pong(_))
213    }
214
215    /// Indicates whether a message is a close message.
216    pub fn is_close(&self) -> bool {
217        matches!(*self, Message::Close(_))
218    }
219
220    /// Get the length of the WebSocket message.
221    pub fn len(&self) -> usize {
222        match *self {
223            Message::Text(ref string) => string.len(),
224            Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
225                data.len()
226            }
227            Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
228            Message::Frame(ref frame) => frame.len(),
229        }
230    }
231
232    /// Returns true if the WebSocket message has no content.
233    /// For example, if the other side of the connection sent an empty string.
234    pub fn is_empty(&self) -> bool {
235        self.len() == 0
236    }
237
238    /// Consume the WebSocket and return it as binary data.
239    pub fn into_data(self) -> Bytes {
240        match self {
241            Message::Text(utf8) => utf8.into(),
242            Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data,
243            Message::Close(None) => <_>::default(),
244            Message::Close(Some(frame)) => frame.reason.into(),
245            Message::Frame(frame) => frame.into_payload(),
246        }
247    }
248
249    /// Attempt to consume the WebSocket message and convert it to a String.
250    pub fn into_text(self) -> Result<Utf8Bytes> {
251        match self {
252            Message::Text(txt) => Ok(txt),
253            Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => {
254                Ok(data.try_into()?)
255            }
256            Message::Close(None) => Ok(<_>::default()),
257            Message::Close(Some(frame)) => Ok(frame.reason),
258            Message::Frame(frame) => Ok(frame.into_text()?),
259        }
260    }
261
262    /// Attempt to get a &str from the WebSocket message,
263    /// this will try to convert binary data to utf8.
264    pub fn to_text(&self) -> Result<&str> {
265        match *self {
266            Message::Text(ref string) => Ok(string.as_str()),
267            Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
268                Ok(str::from_utf8(data)?)
269            }
270            Message::Close(None) => Ok(""),
271            Message::Close(Some(ref frame)) => Ok(&frame.reason),
272            Message::Frame(ref frame) => Ok(frame.to_text()?),
273        }
274    }
275}
276
277impl From<String> for Message {
278    #[inline]
279    fn from(string: String) -> Self {
280        Message::text(string)
281    }
282}
283
284impl<'s> From<&'s str> for Message {
285    #[inline]
286    fn from(string: &'s str) -> Self {
287        Message::text(string)
288    }
289}
290
291impl<'b> From<&'b [u8]> for Message {
292    #[inline]
293    fn from(data: &'b [u8]) -> Self {
294        Message::binary(Bytes::copy_from_slice(data))
295    }
296}
297
298impl From<Bytes> for Message {
299    fn from(data: Bytes) -> Self {
300        Message::binary(data)
301    }
302}
303
304impl From<Vec<u8>> for Message {
305    #[inline]
306    fn from(data: Vec<u8>) -> Self {
307        Message::binary(data)
308    }
309}
310
311impl From<Message> for Bytes {
312    #[inline]
313    fn from(message: Message) -> Self {
314        message.into_data()
315    }
316}
317
318impl fmt::Display for Message {
319    fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> {
320        if let Ok(string) = self.to_text() {
321            write!(f, "{string}")
322        } else {
323            write!(f, "Binary Data<length={}>", self.len())
324        }
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn display() {
334        let t = Message::text("test".to_owned());
335        assert_eq!(t.to_string(), "test".to_owned());
336
337        let bin = Message::binary(vec![0, 1, 3, 4, 241]);
338        assert_eq!(bin.to_string(), "Binary Data<length=5>".to_owned());
339    }
340
341    #[test]
342    fn binary_convert() {
343        let bin = [6u8, 7, 8, 9, 10, 241];
344        let msg = Message::from(&bin[..]);
345        assert!(msg.is_binary());
346        assert!(msg.into_text().is_err());
347    }
348
349    #[test]
350    fn binary_convert_bytes() {
351        let bin = Bytes::from_iter([6u8, 7, 8, 9, 10, 241]);
352        let msg = Message::from(bin);
353        assert!(msg.is_binary());
354        assert!(msg.into_text().is_err());
355    }
356
357    #[test]
358    fn binary_convert_vec() {
359        let bin = vec![6u8, 7, 8, 9, 10, 241];
360        let msg = Message::from(bin);
361        assert!(msg.is_binary());
362        assert!(msg.into_text().is_err());
363    }
364
365    #[test]
366    fn binary_convert_into_bytes() {
367        let bin = vec![6u8, 7, 8, 9, 10, 241];
368        let bin_copy = bin.clone();
369        let msg = Message::from(bin);
370        let serialized: Bytes = msg.into();
371        assert_eq!(bin_copy, serialized);
372    }
373
374    #[test]
375    fn text_convert() {
376        let s = "kiwotsukete";
377        let msg = Message::from(s);
378        assert!(msg.is_text());
379    }
380}