quinn_proto/
packet.rs

1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7    ConnectionId,
8    coding::{self, BufExt, BufMutExt},
9    crypto,
10};
11
12/// Decodes a QUIC packet's invariant header
13///
14/// Due to packet number encryption, it is impossible to fully decode a header
15/// (which includes a variable-length packet number) without crypto context.
16/// The crypto context (represented by the `Crypto` type in Quinn) is usually
17/// part of the `Connection`, or can be derived from the destination CID for
18/// Initial packets.
19///
20/// To cope with this, we decode the invariant header (which should be stable
21/// across QUIC versions), which gives us the destination CID and allows us
22/// to inspect the version and packet type (which depends on the version).
23/// This information allows us to fully decode and decrypt the packet.
24#[cfg_attr(test, derive(Clone))]
25#[derive(Debug)]
26pub struct PartialDecode {
27    plain_header: ProtectedHeader,
28    buf: io::Cursor<BytesMut>,
29}
30
31#[allow(clippy::len_without_is_empty)]
32impl PartialDecode {
33    /// Begin decoding a QUIC packet from `bytes`, returning any trailing data not part of that packet
34    pub fn new(
35        bytes: BytesMut,
36        cid_parser: &(impl ConnectionIdParser + ?Sized),
37        supported_versions: &[u32],
38        grease_quic_bit: bool,
39    ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
40        let mut buf = io::Cursor::new(bytes);
41        let plain_header =
42            ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
43        let dgram_len = buf.get_ref().len();
44        let packet_len = plain_header
45            .payload_len()
46            .map(|len| (buf.position() + len) as usize)
47            .unwrap_or(dgram_len);
48        match dgram_len.cmp(&packet_len) {
49            Ordering::Equal => Ok((Self { plain_header, buf }, None)),
50            Ordering::Less => Err(PacketDecodeError::InvalidHeader(
51                "packet too short to contain payload length",
52            )),
53            Ordering::Greater => {
54                let rest = Some(buf.get_mut().split_off(packet_len));
55                Ok((Self { plain_header, buf }, rest))
56            }
57        }
58    }
59
60    /// The underlying partially-decoded packet data
61    pub(crate) fn data(&self) -> &[u8] {
62        self.buf.get_ref()
63    }
64
65    pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
66        self.plain_header.as_initial()
67    }
68
69    pub(crate) fn has_long_header(&self) -> bool {
70        !matches!(self.plain_header, ProtectedHeader::Short { .. })
71    }
72
73    pub(crate) fn is_initial(&self) -> bool {
74        self.space() == Some(SpaceId::Initial)
75    }
76
77    pub(crate) fn space(&self) -> Option<SpaceId> {
78        use ProtectedHeader::*;
79        match self.plain_header {
80            Initial { .. } => Some(SpaceId::Initial),
81            Long {
82                ty: LongType::Handshake,
83                ..
84            } => Some(SpaceId::Handshake),
85            Long {
86                ty: LongType::ZeroRtt,
87                ..
88            } => Some(SpaceId::Data),
89            Short { .. } => Some(SpaceId::Data),
90            _ => None,
91        }
92    }
93
94    pub(crate) fn is_0rtt(&self) -> bool {
95        match self.plain_header {
96            ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
97            _ => false,
98        }
99    }
100
101    /// The destination connection ID of the packet
102    pub fn dst_cid(&self) -> &ConnectionId {
103        self.plain_header.dst_cid()
104    }
105
106    /// Length of QUIC packet being decoded
107    #[allow(unreachable_pub)] // fuzzing only
108    pub fn len(&self) -> usize {
109        self.buf.get_ref().len()
110    }
111
112    pub(crate) fn finish(
113        self,
114        header_crypto: Option<&dyn crypto::HeaderKey>,
115    ) -> Result<Packet, PacketDecodeError> {
116        use ProtectedHeader::*;
117        let Self {
118            plain_header,
119            mut buf,
120        } = self;
121
122        if let Initial(ProtectedInitialHeader {
123            dst_cid,
124            src_cid,
125            token_pos,
126            version,
127            ..
128        }) = plain_header
129        {
130            let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
131            let header_len = buf.position() as usize;
132            let mut bytes = buf.into_inner();
133
134            let header_data = bytes.split_to(header_len).freeze();
135            let token = header_data.slice(token_pos.start..token_pos.end);
136            return Ok(Packet {
137                header: Header::Initial(InitialHeader {
138                    dst_cid,
139                    src_cid,
140                    token,
141                    number,
142                    version,
143                }),
144                header_data,
145                payload: bytes,
146            });
147        }
148
149        let header = match plain_header {
150            Long {
151                ty,
152                dst_cid,
153                src_cid,
154                version,
155                ..
156            } => Header::Long {
157                ty,
158                dst_cid,
159                src_cid,
160                number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
161                version,
162            },
163            Retry {
164                dst_cid,
165                src_cid,
166                version,
167            } => Header::Retry {
168                dst_cid,
169                src_cid,
170                version,
171            },
172            Short { spin, dst_cid, .. } => {
173                let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
174                let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
175                Header::Short {
176                    spin,
177                    key_phase,
178                    dst_cid,
179                    number,
180                }
181            }
182            VersionNegotiate {
183                random,
184                dst_cid,
185                src_cid,
186            } => Header::VersionNegotiate {
187                random,
188                dst_cid,
189                src_cid,
190            },
191            Initial { .. } => unreachable!(),
192        };
193
194        let header_len = buf.position() as usize;
195        let mut bytes = buf.into_inner();
196        Ok(Packet {
197            header,
198            header_data: bytes.split_to(header_len).freeze(),
199            payload: bytes,
200        })
201    }
202
203    fn decrypt_header(
204        buf: &mut io::Cursor<BytesMut>,
205        header_crypto: &dyn crypto::HeaderKey,
206    ) -> Result<PacketNumber, PacketDecodeError> {
207        let packet_length = buf.get_ref().len();
208        let pn_offset = buf.position() as usize;
209        if packet_length < pn_offset + 4 + header_crypto.sample_size() {
210            return Err(PacketDecodeError::InvalidHeader(
211                "packet too short to extract header protection sample",
212            ));
213        }
214
215        header_crypto.decrypt(pn_offset, buf.get_mut());
216
217        let len = PacketNumber::decode_len(buf.get_ref()[0]);
218        PacketNumber::decode(len, buf)
219    }
220}
221
222pub(crate) struct Packet {
223    pub(crate) header: Header,
224    pub(crate) header_data: Bytes,
225    pub(crate) payload: BytesMut,
226}
227
228impl Packet {
229    pub(crate) fn reserved_bits_valid(&self) -> bool {
230        let mask = match self.header {
231            Header::Short { .. } => SHORT_RESERVED_BITS,
232            _ => LONG_RESERVED_BITS,
233        };
234        self.header_data[0] & mask == 0
235    }
236}
237
238pub(crate) struct InitialPacket {
239    pub(crate) header: InitialHeader,
240    pub(crate) header_data: Bytes,
241    pub(crate) payload: BytesMut,
242}
243
244impl From<InitialPacket> for Packet {
245    fn from(x: InitialPacket) -> Self {
246        Self {
247            header: Header::Initial(x.header),
248            header_data: x.header_data,
249            payload: x.payload,
250        }
251    }
252}
253
254#[cfg_attr(test, derive(Clone))]
255#[derive(Debug)]
256pub(crate) enum Header {
257    Initial(InitialHeader),
258    Long {
259        ty: LongType,
260        dst_cid: ConnectionId,
261        src_cid: ConnectionId,
262        number: PacketNumber,
263        version: u32,
264    },
265    Retry {
266        dst_cid: ConnectionId,
267        src_cid: ConnectionId,
268        version: u32,
269    },
270    Short {
271        spin: bool,
272        key_phase: bool,
273        dst_cid: ConnectionId,
274        number: PacketNumber,
275    },
276    VersionNegotiate {
277        random: u8,
278        src_cid: ConnectionId,
279        dst_cid: ConnectionId,
280    },
281}
282
283impl Header {
284    pub(crate) fn encode(&self, w: &mut Vec<u8>) -> PartialEncode {
285        use Header::*;
286        let start = w.len();
287        match *self {
288            Initial(InitialHeader {
289                ref dst_cid,
290                ref src_cid,
291                ref token,
292                number,
293                version,
294            }) => {
295                w.write(u8::from(LongHeaderType::Initial) | number.tag());
296                w.write(version);
297                dst_cid.encode_long(w);
298                src_cid.encode_long(w);
299                w.write_var(token.len() as u64);
300                w.put_slice(token);
301                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
302                number.encode(w);
303                PartialEncode {
304                    start,
305                    header_len: w.len() - start,
306                    pn: Some((number.len(), true)),
307                }
308            }
309            Long {
310                ty,
311                ref dst_cid,
312                ref src_cid,
313                number,
314                version,
315            } => {
316                w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
317                w.write(version);
318                dst_cid.encode_long(w);
319                src_cid.encode_long(w);
320                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
321                number.encode(w);
322                PartialEncode {
323                    start,
324                    header_len: w.len() - start,
325                    pn: Some((number.len(), true)),
326                }
327            }
328            Retry {
329                ref dst_cid,
330                ref src_cid,
331                version,
332            } => {
333                w.write(u8::from(LongHeaderType::Retry));
334                w.write(version);
335                dst_cid.encode_long(w);
336                src_cid.encode_long(w);
337                PartialEncode {
338                    start,
339                    header_len: w.len() - start,
340                    pn: None,
341                }
342            }
343            Short {
344                spin,
345                key_phase,
346                ref dst_cid,
347                number,
348            } => {
349                w.write(
350                    FIXED_BIT
351                        | if key_phase { KEY_PHASE_BIT } else { 0 }
352                        | if spin { SPIN_BIT } else { 0 }
353                        | number.tag(),
354                );
355                w.put_slice(dst_cid);
356                number.encode(w);
357                PartialEncode {
358                    start,
359                    header_len: w.len() - start,
360                    pn: Some((number.len(), false)),
361                }
362            }
363            VersionNegotiate {
364                ref random,
365                ref dst_cid,
366                ref src_cid,
367            } => {
368                w.write(0x80u8 | random);
369                w.write::<u32>(0);
370                dst_cid.encode_long(w);
371                src_cid.encode_long(w);
372                PartialEncode {
373                    start,
374                    header_len: w.len() - start,
375                    pn: None,
376                }
377            }
378        }
379    }
380
381    /// Whether the packet is encrypted on the wire
382    pub(crate) fn is_protected(&self) -> bool {
383        !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
384    }
385
386    pub(crate) fn number(&self) -> Option<PacketNumber> {
387        use Header::*;
388        Some(match *self {
389            Initial(InitialHeader { number, .. }) => number,
390            Long { number, .. } => number,
391            Short { number, .. } => number,
392            _ => {
393                return None;
394            }
395        })
396    }
397
398    pub(crate) fn space(&self) -> SpaceId {
399        use Header::*;
400        match *self {
401            Short { .. } => SpaceId::Data,
402            Long {
403                ty: LongType::ZeroRtt,
404                ..
405            } => SpaceId::Data,
406            Long {
407                ty: LongType::Handshake,
408                ..
409            } => SpaceId::Handshake,
410            _ => SpaceId::Initial,
411        }
412    }
413
414    pub(crate) fn key_phase(&self) -> bool {
415        match *self {
416            Self::Short { key_phase, .. } => key_phase,
417            _ => false,
418        }
419    }
420
421    pub(crate) fn is_short(&self) -> bool {
422        matches!(*self, Self::Short { .. })
423    }
424
425    pub(crate) fn is_1rtt(&self) -> bool {
426        self.is_short()
427    }
428
429    pub(crate) fn is_0rtt(&self) -> bool {
430        matches!(
431            *self,
432            Self::Long {
433                ty: LongType::ZeroRtt,
434                ..
435            }
436        )
437    }
438
439    pub(crate) fn dst_cid(&self) -> ConnectionId {
440        use Header::*;
441        match *self {
442            Initial(InitialHeader { dst_cid, .. }) => dst_cid,
443            Long { dst_cid, .. } => dst_cid,
444            Retry { dst_cid, .. } => dst_cid,
445            Short { dst_cid, .. } => dst_cid,
446            VersionNegotiate { dst_cid, .. } => dst_cid,
447        }
448    }
449
450    /// Whether the payload of this packet contains QUIC frames
451    pub(crate) fn has_frames(&self) -> bool {
452        use Header::*;
453        match *self {
454            Initial(_) => true,
455            Long { .. } => true,
456            Retry { .. } => false,
457            Short { .. } => true,
458            VersionNegotiate { .. } => false,
459        }
460    }
461}
462
463pub(crate) struct PartialEncode {
464    pub(crate) start: usize,
465    pub(crate) header_len: usize,
466    // Packet number length, payload length needed
467    pn: Option<(usize, bool)>,
468}
469
470impl PartialEncode {
471    pub(crate) fn finish(
472        self,
473        buf: &mut [u8],
474        header_crypto: &dyn crypto::HeaderKey,
475        crypto: Option<(u64, &dyn crypto::PacketKey)>,
476    ) {
477        let Self { header_len, pn, .. } = self;
478        let (pn_len, write_len) = match pn {
479            Some((pn_len, write_len)) => (pn_len, write_len),
480            None => return,
481        };
482
483        let pn_pos = header_len - pn_len;
484        if write_len {
485            let len = buf.len() - header_len + pn_len;
486            assert!(len < 2usize.pow(14)); // Fits in reserved space
487            let mut slice = &mut buf[pn_pos - 2..pn_pos];
488            slice.put_u16(len as u16 | (0b01 << 14));
489        }
490
491        if let Some((number, crypto)) = crypto {
492            crypto.encrypt(number, buf, header_len);
493        }
494
495        debug_assert!(
496            pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
497            "packet must be padded to at least {} bytes for header protection sampling",
498            pn_pos + 4 + header_crypto.sample_size()
499        );
500        header_crypto.encrypt(pn_pos, buf);
501    }
502}
503
504/// Plain packet header
505#[derive(Clone, Debug)]
506pub enum ProtectedHeader {
507    /// An Initial packet header
508    Initial(ProtectedInitialHeader),
509    /// A Long packet header, as used during the handshake
510    Long {
511        /// Type of the Long header packet
512        ty: LongType,
513        /// Destination Connection ID
514        dst_cid: ConnectionId,
515        /// Source Connection ID
516        src_cid: ConnectionId,
517        /// Length of the packet payload
518        len: u64,
519        /// QUIC version
520        version: u32,
521    },
522    /// A Retry packet header
523    Retry {
524        /// Destination Connection ID
525        dst_cid: ConnectionId,
526        /// Source Connection ID
527        src_cid: ConnectionId,
528        /// QUIC version
529        version: u32,
530    },
531    /// A short packet header, as used during the data phase
532    Short {
533        /// Spin bit
534        spin: bool,
535        /// Destination Connection ID
536        dst_cid: ConnectionId,
537    },
538    /// A Version Negotiation packet header
539    VersionNegotiate {
540        /// Random value
541        random: u8,
542        /// Destination Connection ID
543        dst_cid: ConnectionId,
544        /// Source Connection ID
545        src_cid: ConnectionId,
546    },
547}
548
549impl ProtectedHeader {
550    fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
551        match self {
552            Self::Initial(x) => Some(x),
553            _ => None,
554        }
555    }
556
557    /// The destination Connection ID of the packet
558    pub fn dst_cid(&self) -> &ConnectionId {
559        use ProtectedHeader::*;
560        match self {
561            Initial(header) => &header.dst_cid,
562            Long { dst_cid, .. } => dst_cid,
563            Retry { dst_cid, .. } => dst_cid,
564            Short { dst_cid, .. } => dst_cid,
565            VersionNegotiate { dst_cid, .. } => dst_cid,
566        }
567    }
568
569    fn payload_len(&self) -> Option<u64> {
570        use ProtectedHeader::*;
571        match self {
572            Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
573            _ => None,
574        }
575    }
576
577    /// Decode a plain header from given buffer, with given [`ConnectionIdParser`].
578    pub fn decode(
579        buf: &mut io::Cursor<BytesMut>,
580        cid_parser: &(impl ConnectionIdParser + ?Sized),
581        supported_versions: &[u32],
582        grease_quic_bit: bool,
583    ) -> Result<Self, PacketDecodeError> {
584        let first = buf.get::<u8>()?;
585        if !grease_quic_bit && first & FIXED_BIT == 0 {
586            return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
587        }
588        if first & LONG_HEADER_FORM == 0 {
589            let spin = first & SPIN_BIT != 0;
590
591            Ok(Self::Short {
592                spin,
593                dst_cid: cid_parser.parse(buf)?,
594            })
595        } else {
596            let version = buf.get::<u32>()?;
597
598            let dst_cid = ConnectionId::decode_long(buf)
599                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
600            let src_cid = ConnectionId::decode_long(buf)
601                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
602
603            // TODO: Support long CIDs for compatibility with future QUIC versions
604            if version == 0 {
605                let random = first & !LONG_HEADER_FORM;
606                return Ok(Self::VersionNegotiate {
607                    random,
608                    dst_cid,
609                    src_cid,
610                });
611            }
612
613            if !supported_versions.contains(&version) {
614                return Err(PacketDecodeError::UnsupportedVersion {
615                    src_cid,
616                    dst_cid,
617                    version,
618                });
619            }
620
621            match LongHeaderType::from_byte(first)? {
622                LongHeaderType::Initial => {
623                    let token_len = buf.get_var()? as usize;
624                    let token_start = buf.position() as usize;
625                    if token_len > buf.remaining() {
626                        return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
627                    }
628                    buf.advance(token_len);
629
630                    let len = buf.get_var()?;
631                    Ok(Self::Initial(ProtectedInitialHeader {
632                        dst_cid,
633                        src_cid,
634                        token_pos: token_start..token_start + token_len,
635                        len,
636                        version,
637                    }))
638                }
639                LongHeaderType::Retry => Ok(Self::Retry {
640                    dst_cid,
641                    src_cid,
642                    version,
643                }),
644                LongHeaderType::Standard(ty) => Ok(Self::Long {
645                    ty,
646                    dst_cid,
647                    src_cid,
648                    len: buf.get_var()?,
649                    version,
650                }),
651            }
652        }
653    }
654}
655
656/// Header of an Initial packet, before decryption
657#[derive(Clone, Debug)]
658pub struct ProtectedInitialHeader {
659    /// Destination Connection ID
660    pub dst_cid: ConnectionId,
661    /// Source Connection ID
662    pub src_cid: ConnectionId,
663    /// The position of a token in the packet buffer
664    pub token_pos: Range<usize>,
665    /// Length of the packet payload
666    pub len: u64,
667    /// QUIC version
668    pub version: u32,
669}
670
671#[derive(Clone, Debug)]
672pub(crate) struct InitialHeader {
673    pub(crate) dst_cid: ConnectionId,
674    pub(crate) src_cid: ConnectionId,
675    pub(crate) token: Bytes,
676    pub(crate) number: PacketNumber,
677    pub(crate) version: u32,
678}
679
680// An encoded packet number
681#[derive(Debug, Copy, Clone, Eq, PartialEq)]
682pub(crate) enum PacketNumber {
683    U8(u8),
684    U16(u16),
685    U24(u32),
686    U32(u32),
687}
688
689impl PacketNumber {
690    pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
691        let range = (n - largest_acked) * 2;
692        if range < 1 << 8 {
693            Self::U8(n as u8)
694        } else if range < 1 << 16 {
695            Self::U16(n as u16)
696        } else if range < 1 << 24 {
697            Self::U24(n as u32)
698        } else if range < 1 << 32 {
699            Self::U32(n as u32)
700        } else {
701            panic!("packet number too large to encode")
702        }
703    }
704
705    pub(crate) fn len(self) -> usize {
706        use PacketNumber::*;
707        match self {
708            U8(_) => 1,
709            U16(_) => 2,
710            U24(_) => 3,
711            U32(_) => 4,
712        }
713    }
714
715    pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
716        use PacketNumber::*;
717        match self {
718            U8(x) => w.write(x),
719            U16(x) => w.write(x),
720            U24(x) => w.put_uint(u64::from(x), 3),
721            U32(x) => w.write(x),
722        }
723    }
724
725    pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
726        use PacketNumber::*;
727        let pn = match len {
728            1 => U8(r.get()?),
729            2 => U16(r.get()?),
730            3 => U24(r.get_uint(3) as u32),
731            4 => U32(r.get()?),
732            _ => unreachable!(),
733        };
734        Ok(pn)
735    }
736
737    pub(crate) fn decode_len(tag: u8) -> usize {
738        1 + (tag & 0x03) as usize
739    }
740
741    fn tag(self) -> u8 {
742        use PacketNumber::*;
743        match self {
744            U8(_) => 0b00,
745            U16(_) => 0b01,
746            U24(_) => 0b10,
747            U32(_) => 0b11,
748        }
749    }
750
751    pub(crate) fn expand(self, expected: u64) -> u64 {
752        // From Appendix A
753        use PacketNumber::*;
754        let truncated = match self {
755            U8(x) => u64::from(x),
756            U16(x) => u64::from(x),
757            U24(x) => u64::from(x),
758            U32(x) => u64::from(x),
759        };
760        let nbits = self.len() * 8;
761        let win = 1 << nbits;
762        let hwin = win / 2;
763        let mask = win - 1;
764        // The incoming packet number should be greater than expected - hwin and less than or equal
765        // to expected + hwin
766        //
767        // This means we can't just strip the trailing bits from expected and add the truncated
768        // because that might yield a value outside the window.
769        //
770        // The following code calculates a candidate value and makes sure it's within the packet
771        // number window.
772        let candidate = (expected & !mask) | truncated;
773        if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
774            candidate + win
775        } else if candidate > expected + hwin && candidate > win {
776            candidate - win
777        } else {
778            candidate
779        }
780    }
781}
782
783/// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length
784pub struct FixedLengthConnectionIdParser {
785    expected_len: usize,
786}
787
788impl FixedLengthConnectionIdParser {
789    /// Create a new instance of `FixedLengthConnectionIdParser`
790    pub fn new(expected_len: usize) -> Self {
791        Self { expected_len }
792    }
793}
794
795impl ConnectionIdParser for FixedLengthConnectionIdParser {
796    fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
797        (buffer.remaining() >= self.expected_len)
798            .then(|| ConnectionId::from_buf(buffer, self.expected_len))
799            .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
800    }
801}
802
803/// Parse connection id in short header packet
804pub trait ConnectionIdParser {
805    /// Parse a connection id from given buffer
806    fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
807}
808
809/// Long packet type including non-uniform cases
810#[derive(Clone, Copy, Debug, Eq, PartialEq)]
811pub(crate) enum LongHeaderType {
812    Initial,
813    Retry,
814    Standard(LongType),
815}
816
817impl LongHeaderType {
818    fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
819        use {LongHeaderType::*, LongType::*};
820        debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
821        Ok(match (b & 0x30) >> 4 {
822            0x0 => Initial,
823            0x1 => Standard(ZeroRtt),
824            0x2 => Standard(Handshake),
825            0x3 => Retry,
826            _ => unreachable!(),
827        })
828    }
829}
830
831impl From<LongHeaderType> for u8 {
832    fn from(ty: LongHeaderType) -> Self {
833        use {LongHeaderType::*, LongType::*};
834        match ty {
835            Initial => LONG_HEADER_FORM | FIXED_BIT,
836            Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
837            Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
838            Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
839        }
840    }
841}
842
843/// Long packet types with uniform header structure
844#[derive(Clone, Copy, Debug, Eq, PartialEq)]
845pub enum LongType {
846    /// Handshake packet
847    Handshake,
848    /// 0-RTT packet
849    ZeroRtt,
850}
851
852/// Packet decode error
853#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
854pub enum PacketDecodeError {
855    /// Packet uses a QUIC version that is not supported
856    #[error("unsupported version {version:x}")]
857    UnsupportedVersion {
858        /// Source Connection ID
859        src_cid: ConnectionId,
860        /// Destination Connection ID
861        dst_cid: ConnectionId,
862        /// The version that was unsupported
863        version: u32,
864    },
865    /// The packet header is invalid
866    #[error("invalid header: {0}")]
867    InvalidHeader(&'static str),
868}
869
870impl From<coding::UnexpectedEnd> for PacketDecodeError {
871    fn from(_: coding::UnexpectedEnd) -> Self {
872        Self::InvalidHeader("unexpected end of packet")
873    }
874}
875
876pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
877pub(crate) const FIXED_BIT: u8 = 0x40;
878pub(crate) const SPIN_BIT: u8 = 0x20;
879const SHORT_RESERVED_BITS: u8 = 0x18;
880const LONG_RESERVED_BITS: u8 = 0x0c;
881const KEY_PHASE_BIT: u8 = 0x04;
882
883/// Packet number space identifiers
884#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
885pub enum SpaceId {
886    /// Unprotected packets, used to bootstrap the handshake
887    Initial = 0,
888    Handshake = 1,
889    /// Application data space, used for 0-RTT and post-handshake/1-RTT packets
890    Data = 2,
891}
892
893impl SpaceId {
894    pub fn iter() -> impl Iterator<Item = Self> {
895        [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
896    }
897}
898
899#[cfg(test)]
900mod tests {
901    use super::*;
902    use hex_literal::hex;
903    use std::io;
904
905    fn check_pn(typed: PacketNumber, encoded: &[u8]) {
906        let mut buf = Vec::new();
907        typed.encode(&mut buf);
908        assert_eq!(&buf[..], encoded);
909        let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
910        assert_eq!(typed, decoded);
911    }
912
913    #[test]
914    fn roundtrip_packet_numbers() {
915        check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
916        check_pn(PacketNumber::U16(0x80), &hex!("0080"));
917        check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
918        check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
919        check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
920    }
921
922    #[test]
923    fn pn_encode() {
924        check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
925        check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
926        check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
927    }
928
929    #[test]
930    fn pn_expand_roundtrip() {
931        for expected in 0..1024 {
932            for actual in expected..1024 {
933                assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
934            }
935        }
936    }
937
938    #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
939    #[test]
940    fn header_encoding() {
941        use crate::Side;
942        use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
943        #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
944        use rustls::crypto::aws_lc_rs::default_provider;
945        #[cfg(feature = "rustls-ring")]
946        use rustls::crypto::ring::default_provider;
947        use rustls::quic::Version;
948
949        let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
950        let provider = default_provider();
951
952        let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
953        let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
954        let mut buf = Vec::new();
955        let header = Header::Initial(InitialHeader {
956            number: PacketNumber::U8(0),
957            src_cid: ConnectionId::new(&[]),
958            dst_cid: dcid,
959            token: Bytes::new(),
960            version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
961        });
962        let encode = header.encode(&mut buf);
963        let header_len = buf.len();
964        buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
965        encode.finish(
966            &mut buf,
967            &*client.header.local,
968            Some((0, &*client.packet.local)),
969        );
970
971        for byte in &buf {
972            print!("{byte:02x}");
973        }
974        println!();
975        assert_eq!(
976            buf[..],
977            hex!(
978                "c8000000010806b858ec6f80452b00004021be
979                 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
980            )[..]
981        );
982
983        let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
984        let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
985        let decode = PartialDecode::new(
986            buf.as_slice().into(),
987            &FixedLengthConnectionIdParser::new(0),
988            &supported_versions,
989            false,
990        )
991        .unwrap()
992        .0;
993        let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
994        assert_eq!(
995            packet.header_data[..],
996            hex!("c0000000010806b858ec6f80452b0000402100")[..]
997        );
998        server
999            .packet
1000            .remote
1001            .decrypt(0, &packet.header_data, &mut packet.payload)
1002            .unwrap();
1003        assert_eq!(packet.payload[..], [0; 16]);
1004        match packet.header {
1005            Header::Initial(InitialHeader {
1006                number: PacketNumber::U8(0),
1007                ..
1008            }) => {}
1009            _ => {
1010                panic!("unexpected header {:?}", packet.header);
1011            }
1012        }
1013    }
1014}