1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7 ConnectionId,
8 coding::{self, BufExt, BufMutExt},
9 crypto,
10};
11
12#[cfg_attr(test, derive(Clone))]
25#[derive(Debug)]
26pub struct PartialDecode {
27 plain_header: ProtectedHeader,
28 buf: io::Cursor<BytesMut>,
29}
30
31#[allow(clippy::len_without_is_empty)]
32impl PartialDecode {
33 pub fn new(
35 bytes: BytesMut,
36 cid_parser: &(impl ConnectionIdParser + ?Sized),
37 supported_versions: &[u32],
38 grease_quic_bit: bool,
39 ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
40 let mut buf = io::Cursor::new(bytes);
41 let plain_header =
42 ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
43 let dgram_len = buf.get_ref().len();
44 let packet_len = plain_header
45 .payload_len()
46 .map(|len| (buf.position() + len) as usize)
47 .unwrap_or(dgram_len);
48 match dgram_len.cmp(&packet_len) {
49 Ordering::Equal => Ok((Self { plain_header, buf }, None)),
50 Ordering::Less => Err(PacketDecodeError::InvalidHeader(
51 "packet too short to contain payload length",
52 )),
53 Ordering::Greater => {
54 let rest = Some(buf.get_mut().split_off(packet_len));
55 Ok((Self { plain_header, buf }, rest))
56 }
57 }
58 }
59
60 pub(crate) fn data(&self) -> &[u8] {
62 self.buf.get_ref()
63 }
64
65 pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
66 self.plain_header.as_initial()
67 }
68
69 pub(crate) fn has_long_header(&self) -> bool {
70 !matches!(self.plain_header, ProtectedHeader::Short { .. })
71 }
72
73 pub(crate) fn is_initial(&self) -> bool {
74 self.space() == Some(SpaceId::Initial)
75 }
76
77 pub(crate) fn space(&self) -> Option<SpaceId> {
78 use ProtectedHeader::*;
79 match self.plain_header {
80 Initial { .. } => Some(SpaceId::Initial),
81 Long {
82 ty: LongType::Handshake,
83 ..
84 } => Some(SpaceId::Handshake),
85 Long {
86 ty: LongType::ZeroRtt,
87 ..
88 } => Some(SpaceId::Data),
89 Short { .. } => Some(SpaceId::Data),
90 _ => None,
91 }
92 }
93
94 pub(crate) fn is_0rtt(&self) -> bool {
95 match self.plain_header {
96 ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
97 _ => false,
98 }
99 }
100
101 pub fn dst_cid(&self) -> &ConnectionId {
103 self.plain_header.dst_cid()
104 }
105
106 #[allow(unreachable_pub)] pub fn len(&self) -> usize {
109 self.buf.get_ref().len()
110 }
111
112 pub(crate) fn finish(
113 self,
114 header_crypto: Option<&dyn crypto::HeaderKey>,
115 ) -> Result<Packet, PacketDecodeError> {
116 use ProtectedHeader::*;
117 let Self {
118 plain_header,
119 mut buf,
120 } = self;
121
122 if let Initial(ProtectedInitialHeader {
123 dst_cid,
124 src_cid,
125 token_pos,
126 version,
127 ..
128 }) = plain_header
129 {
130 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
131 let header_len = buf.position() as usize;
132 let mut bytes = buf.into_inner();
133
134 let header_data = bytes.split_to(header_len).freeze();
135 let token = header_data.slice(token_pos.start..token_pos.end);
136 return Ok(Packet {
137 header: Header::Initial(InitialHeader {
138 dst_cid,
139 src_cid,
140 token,
141 number,
142 version,
143 }),
144 header_data,
145 payload: bytes,
146 });
147 }
148
149 let header = match plain_header {
150 Long {
151 ty,
152 dst_cid,
153 src_cid,
154 version,
155 ..
156 } => Header::Long {
157 ty,
158 dst_cid,
159 src_cid,
160 number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
161 version,
162 },
163 Retry {
164 dst_cid,
165 src_cid,
166 version,
167 } => Header::Retry {
168 dst_cid,
169 src_cid,
170 version,
171 },
172 Short { spin, dst_cid, .. } => {
173 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
174 let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
175 Header::Short {
176 spin,
177 key_phase,
178 dst_cid,
179 number,
180 }
181 }
182 VersionNegotiate {
183 random,
184 dst_cid,
185 src_cid,
186 } => Header::VersionNegotiate {
187 random,
188 dst_cid,
189 src_cid,
190 },
191 Initial { .. } => unreachable!(),
192 };
193
194 let header_len = buf.position() as usize;
195 let mut bytes = buf.into_inner();
196 Ok(Packet {
197 header,
198 header_data: bytes.split_to(header_len).freeze(),
199 payload: bytes,
200 })
201 }
202
203 fn decrypt_header(
204 buf: &mut io::Cursor<BytesMut>,
205 header_crypto: &dyn crypto::HeaderKey,
206 ) -> Result<PacketNumber, PacketDecodeError> {
207 let packet_length = buf.get_ref().len();
208 let pn_offset = buf.position() as usize;
209 if packet_length < pn_offset + 4 + header_crypto.sample_size() {
210 return Err(PacketDecodeError::InvalidHeader(
211 "packet too short to extract header protection sample",
212 ));
213 }
214
215 header_crypto.decrypt(pn_offset, buf.get_mut());
216
217 let len = PacketNumber::decode_len(buf.get_ref()[0]);
218 PacketNumber::decode(len, buf)
219 }
220}
221
222pub(crate) struct Packet {
223 pub(crate) header: Header,
224 pub(crate) header_data: Bytes,
225 pub(crate) payload: BytesMut,
226}
227
228impl Packet {
229 pub(crate) fn reserved_bits_valid(&self) -> bool {
230 let mask = match self.header {
231 Header::Short { .. } => SHORT_RESERVED_BITS,
232 _ => LONG_RESERVED_BITS,
233 };
234 self.header_data[0] & mask == 0
235 }
236}
237
238pub(crate) struct InitialPacket {
239 pub(crate) header: InitialHeader,
240 pub(crate) header_data: Bytes,
241 pub(crate) payload: BytesMut,
242}
243
244impl From<InitialPacket> for Packet {
245 fn from(x: InitialPacket) -> Self {
246 Self {
247 header: Header::Initial(x.header),
248 header_data: x.header_data,
249 payload: x.payload,
250 }
251 }
252}
253
254#[cfg_attr(test, derive(Clone))]
255#[derive(Debug)]
256pub(crate) enum Header {
257 Initial(InitialHeader),
258 Long {
259 ty: LongType,
260 dst_cid: ConnectionId,
261 src_cid: ConnectionId,
262 number: PacketNumber,
263 version: u32,
264 },
265 Retry {
266 dst_cid: ConnectionId,
267 src_cid: ConnectionId,
268 version: u32,
269 },
270 Short {
271 spin: bool,
272 key_phase: bool,
273 dst_cid: ConnectionId,
274 number: PacketNumber,
275 },
276 VersionNegotiate {
277 random: u8,
278 src_cid: ConnectionId,
279 dst_cid: ConnectionId,
280 },
281}
282
283impl Header {
284 pub(crate) fn encode(&self, w: &mut Vec<u8>) -> PartialEncode {
285 use Header::*;
286 let start = w.len();
287 match *self {
288 Initial(InitialHeader {
289 ref dst_cid,
290 ref src_cid,
291 ref token,
292 number,
293 version,
294 }) => {
295 w.write(u8::from(LongHeaderType::Initial) | number.tag());
296 w.write(version);
297 dst_cid.encode_long(w);
298 src_cid.encode_long(w);
299 w.write_var(token.len() as u64);
300 w.put_slice(token);
301 w.write::<u16>(0); number.encode(w);
303 PartialEncode {
304 start,
305 header_len: w.len() - start,
306 pn: Some((number.len(), true)),
307 }
308 }
309 Long {
310 ty,
311 ref dst_cid,
312 ref src_cid,
313 number,
314 version,
315 } => {
316 w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
317 w.write(version);
318 dst_cid.encode_long(w);
319 src_cid.encode_long(w);
320 w.write::<u16>(0); number.encode(w);
322 PartialEncode {
323 start,
324 header_len: w.len() - start,
325 pn: Some((number.len(), true)),
326 }
327 }
328 Retry {
329 ref dst_cid,
330 ref src_cid,
331 version,
332 } => {
333 w.write(u8::from(LongHeaderType::Retry));
334 w.write(version);
335 dst_cid.encode_long(w);
336 src_cid.encode_long(w);
337 PartialEncode {
338 start,
339 header_len: w.len() - start,
340 pn: None,
341 }
342 }
343 Short {
344 spin,
345 key_phase,
346 ref dst_cid,
347 number,
348 } => {
349 w.write(
350 FIXED_BIT
351 | if key_phase { KEY_PHASE_BIT } else { 0 }
352 | if spin { SPIN_BIT } else { 0 }
353 | number.tag(),
354 );
355 w.put_slice(dst_cid);
356 number.encode(w);
357 PartialEncode {
358 start,
359 header_len: w.len() - start,
360 pn: Some((number.len(), false)),
361 }
362 }
363 VersionNegotiate {
364 ref random,
365 ref dst_cid,
366 ref src_cid,
367 } => {
368 w.write(0x80u8 | random);
369 w.write::<u32>(0);
370 dst_cid.encode_long(w);
371 src_cid.encode_long(w);
372 PartialEncode {
373 start,
374 header_len: w.len() - start,
375 pn: None,
376 }
377 }
378 }
379 }
380
381 pub(crate) fn is_protected(&self) -> bool {
383 !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
384 }
385
386 pub(crate) fn number(&self) -> Option<PacketNumber> {
387 use Header::*;
388 Some(match *self {
389 Initial(InitialHeader { number, .. }) => number,
390 Long { number, .. } => number,
391 Short { number, .. } => number,
392 _ => {
393 return None;
394 }
395 })
396 }
397
398 pub(crate) fn space(&self) -> SpaceId {
399 use Header::*;
400 match *self {
401 Short { .. } => SpaceId::Data,
402 Long {
403 ty: LongType::ZeroRtt,
404 ..
405 } => SpaceId::Data,
406 Long {
407 ty: LongType::Handshake,
408 ..
409 } => SpaceId::Handshake,
410 _ => SpaceId::Initial,
411 }
412 }
413
414 pub(crate) fn key_phase(&self) -> bool {
415 match *self {
416 Self::Short { key_phase, .. } => key_phase,
417 _ => false,
418 }
419 }
420
421 pub(crate) fn is_short(&self) -> bool {
422 matches!(*self, Self::Short { .. })
423 }
424
425 pub(crate) fn is_1rtt(&self) -> bool {
426 self.is_short()
427 }
428
429 pub(crate) fn is_0rtt(&self) -> bool {
430 matches!(
431 *self,
432 Self::Long {
433 ty: LongType::ZeroRtt,
434 ..
435 }
436 )
437 }
438
439 pub(crate) fn dst_cid(&self) -> ConnectionId {
440 use Header::*;
441 match *self {
442 Initial(InitialHeader { dst_cid, .. }) => dst_cid,
443 Long { dst_cid, .. } => dst_cid,
444 Retry { dst_cid, .. } => dst_cid,
445 Short { dst_cid, .. } => dst_cid,
446 VersionNegotiate { dst_cid, .. } => dst_cid,
447 }
448 }
449
450 pub(crate) fn has_frames(&self) -> bool {
452 use Header::*;
453 match *self {
454 Initial(_) => true,
455 Long { .. } => true,
456 Retry { .. } => false,
457 Short { .. } => true,
458 VersionNegotiate { .. } => false,
459 }
460 }
461}
462
463pub(crate) struct PartialEncode {
464 pub(crate) start: usize,
465 pub(crate) header_len: usize,
466 pn: Option<(usize, bool)>,
468}
469
470impl PartialEncode {
471 pub(crate) fn finish(
472 self,
473 buf: &mut [u8],
474 header_crypto: &dyn crypto::HeaderKey,
475 crypto: Option<(u64, &dyn crypto::PacketKey)>,
476 ) {
477 let Self { header_len, pn, .. } = self;
478 let (pn_len, write_len) = match pn {
479 Some((pn_len, write_len)) => (pn_len, write_len),
480 None => return,
481 };
482
483 let pn_pos = header_len - pn_len;
484 if write_len {
485 let len = buf.len() - header_len + pn_len;
486 assert!(len < 2usize.pow(14)); let mut slice = &mut buf[pn_pos - 2..pn_pos];
488 slice.put_u16(len as u16 | (0b01 << 14));
489 }
490
491 if let Some((number, crypto)) = crypto {
492 crypto.encrypt(number, buf, header_len);
493 }
494
495 debug_assert!(
496 pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
497 "packet must be padded to at least {} bytes for header protection sampling",
498 pn_pos + 4 + header_crypto.sample_size()
499 );
500 header_crypto.encrypt(pn_pos, buf);
501 }
502}
503
504#[derive(Clone, Debug)]
506pub enum ProtectedHeader {
507 Initial(ProtectedInitialHeader),
509 Long {
511 ty: LongType,
513 dst_cid: ConnectionId,
515 src_cid: ConnectionId,
517 len: u64,
519 version: u32,
521 },
522 Retry {
524 dst_cid: ConnectionId,
526 src_cid: ConnectionId,
528 version: u32,
530 },
531 Short {
533 spin: bool,
535 dst_cid: ConnectionId,
537 },
538 VersionNegotiate {
540 random: u8,
542 dst_cid: ConnectionId,
544 src_cid: ConnectionId,
546 },
547}
548
549impl ProtectedHeader {
550 fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
551 match self {
552 Self::Initial(x) => Some(x),
553 _ => None,
554 }
555 }
556
557 pub fn dst_cid(&self) -> &ConnectionId {
559 use ProtectedHeader::*;
560 match self {
561 Initial(header) => &header.dst_cid,
562 Long { dst_cid, .. } => dst_cid,
563 Retry { dst_cid, .. } => dst_cid,
564 Short { dst_cid, .. } => dst_cid,
565 VersionNegotiate { dst_cid, .. } => dst_cid,
566 }
567 }
568
569 fn payload_len(&self) -> Option<u64> {
570 use ProtectedHeader::*;
571 match self {
572 Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
573 _ => None,
574 }
575 }
576
577 pub fn decode(
579 buf: &mut io::Cursor<BytesMut>,
580 cid_parser: &(impl ConnectionIdParser + ?Sized),
581 supported_versions: &[u32],
582 grease_quic_bit: bool,
583 ) -> Result<Self, PacketDecodeError> {
584 let first = buf.get::<u8>()?;
585 if !grease_quic_bit && first & FIXED_BIT == 0 {
586 return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
587 }
588 if first & LONG_HEADER_FORM == 0 {
589 let spin = first & SPIN_BIT != 0;
590
591 Ok(Self::Short {
592 spin,
593 dst_cid: cid_parser.parse(buf)?,
594 })
595 } else {
596 let version = buf.get::<u32>()?;
597
598 let dst_cid = ConnectionId::decode_long(buf)
599 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
600 let src_cid = ConnectionId::decode_long(buf)
601 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
602
603 if version == 0 {
605 let random = first & !LONG_HEADER_FORM;
606 return Ok(Self::VersionNegotiate {
607 random,
608 dst_cid,
609 src_cid,
610 });
611 }
612
613 if !supported_versions.contains(&version) {
614 return Err(PacketDecodeError::UnsupportedVersion {
615 src_cid,
616 dst_cid,
617 version,
618 });
619 }
620
621 match LongHeaderType::from_byte(first)? {
622 LongHeaderType::Initial => {
623 let token_len = buf.get_var()? as usize;
624 let token_start = buf.position() as usize;
625 if token_len > buf.remaining() {
626 return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
627 }
628 buf.advance(token_len);
629
630 let len = buf.get_var()?;
631 Ok(Self::Initial(ProtectedInitialHeader {
632 dst_cid,
633 src_cid,
634 token_pos: token_start..token_start + token_len,
635 len,
636 version,
637 }))
638 }
639 LongHeaderType::Retry => Ok(Self::Retry {
640 dst_cid,
641 src_cid,
642 version,
643 }),
644 LongHeaderType::Standard(ty) => Ok(Self::Long {
645 ty,
646 dst_cid,
647 src_cid,
648 len: buf.get_var()?,
649 version,
650 }),
651 }
652 }
653 }
654}
655
656#[derive(Clone, Debug)]
658pub struct ProtectedInitialHeader {
659 pub dst_cid: ConnectionId,
661 pub src_cid: ConnectionId,
663 pub token_pos: Range<usize>,
665 pub len: u64,
667 pub version: u32,
669}
670
671#[derive(Clone, Debug)]
672pub(crate) struct InitialHeader {
673 pub(crate) dst_cid: ConnectionId,
674 pub(crate) src_cid: ConnectionId,
675 pub(crate) token: Bytes,
676 pub(crate) number: PacketNumber,
677 pub(crate) version: u32,
678}
679
680#[derive(Debug, Copy, Clone, Eq, PartialEq)]
682pub(crate) enum PacketNumber {
683 U8(u8),
684 U16(u16),
685 U24(u32),
686 U32(u32),
687}
688
689impl PacketNumber {
690 pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
691 let range = (n - largest_acked) * 2;
692 if range < 1 << 8 {
693 Self::U8(n as u8)
694 } else if range < 1 << 16 {
695 Self::U16(n as u16)
696 } else if range < 1 << 24 {
697 Self::U24(n as u32)
698 } else if range < 1 << 32 {
699 Self::U32(n as u32)
700 } else {
701 panic!("packet number too large to encode")
702 }
703 }
704
705 pub(crate) fn len(self) -> usize {
706 use PacketNumber::*;
707 match self {
708 U8(_) => 1,
709 U16(_) => 2,
710 U24(_) => 3,
711 U32(_) => 4,
712 }
713 }
714
715 pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
716 use PacketNumber::*;
717 match self {
718 U8(x) => w.write(x),
719 U16(x) => w.write(x),
720 U24(x) => w.put_uint(u64::from(x), 3),
721 U32(x) => w.write(x),
722 }
723 }
724
725 pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
726 use PacketNumber::*;
727 let pn = match len {
728 1 => U8(r.get()?),
729 2 => U16(r.get()?),
730 3 => U24(r.get_uint(3) as u32),
731 4 => U32(r.get()?),
732 _ => unreachable!(),
733 };
734 Ok(pn)
735 }
736
737 pub(crate) fn decode_len(tag: u8) -> usize {
738 1 + (tag & 0x03) as usize
739 }
740
741 fn tag(self) -> u8 {
742 use PacketNumber::*;
743 match self {
744 U8(_) => 0b00,
745 U16(_) => 0b01,
746 U24(_) => 0b10,
747 U32(_) => 0b11,
748 }
749 }
750
751 pub(crate) fn expand(self, expected: u64) -> u64 {
752 use PacketNumber::*;
754 let truncated = match self {
755 U8(x) => u64::from(x),
756 U16(x) => u64::from(x),
757 U24(x) => u64::from(x),
758 U32(x) => u64::from(x),
759 };
760 let nbits = self.len() * 8;
761 let win = 1 << nbits;
762 let hwin = win / 2;
763 let mask = win - 1;
764 let candidate = (expected & !mask) | truncated;
773 if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
774 candidate + win
775 } else if candidate > expected + hwin && candidate > win {
776 candidate - win
777 } else {
778 candidate
779 }
780 }
781}
782
783pub struct FixedLengthConnectionIdParser {
785 expected_len: usize,
786}
787
788impl FixedLengthConnectionIdParser {
789 pub fn new(expected_len: usize) -> Self {
791 Self { expected_len }
792 }
793}
794
795impl ConnectionIdParser for FixedLengthConnectionIdParser {
796 fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
797 (buffer.remaining() >= self.expected_len)
798 .then(|| ConnectionId::from_buf(buffer, self.expected_len))
799 .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
800 }
801}
802
803pub trait ConnectionIdParser {
805 fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
807}
808
809#[derive(Clone, Copy, Debug, Eq, PartialEq)]
811pub(crate) enum LongHeaderType {
812 Initial,
813 Retry,
814 Standard(LongType),
815}
816
817impl LongHeaderType {
818 fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
819 use {LongHeaderType::*, LongType::*};
820 debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
821 Ok(match (b & 0x30) >> 4 {
822 0x0 => Initial,
823 0x1 => Standard(ZeroRtt),
824 0x2 => Standard(Handshake),
825 0x3 => Retry,
826 _ => unreachable!(),
827 })
828 }
829}
830
831impl From<LongHeaderType> for u8 {
832 fn from(ty: LongHeaderType) -> Self {
833 use {LongHeaderType::*, LongType::*};
834 match ty {
835 Initial => LONG_HEADER_FORM | FIXED_BIT,
836 Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
837 Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
838 Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
839 }
840 }
841}
842
843#[derive(Clone, Copy, Debug, Eq, PartialEq)]
845pub enum LongType {
846 Handshake,
848 ZeroRtt,
850}
851
852#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
854pub enum PacketDecodeError {
855 #[error("unsupported version {version:x}")]
857 UnsupportedVersion {
858 src_cid: ConnectionId,
860 dst_cid: ConnectionId,
862 version: u32,
864 },
865 #[error("invalid header: {0}")]
867 InvalidHeader(&'static str),
868}
869
870impl From<coding::UnexpectedEnd> for PacketDecodeError {
871 fn from(_: coding::UnexpectedEnd) -> Self {
872 Self::InvalidHeader("unexpected end of packet")
873 }
874}
875
876pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
877pub(crate) const FIXED_BIT: u8 = 0x40;
878pub(crate) const SPIN_BIT: u8 = 0x20;
879const SHORT_RESERVED_BITS: u8 = 0x18;
880const LONG_RESERVED_BITS: u8 = 0x0c;
881const KEY_PHASE_BIT: u8 = 0x04;
882
883#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
885pub enum SpaceId {
886 Initial = 0,
888 Handshake = 1,
889 Data = 2,
891}
892
893impl SpaceId {
894 pub fn iter() -> impl Iterator<Item = Self> {
895 [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
896 }
897}
898
899#[cfg(test)]
900mod tests {
901 use super::*;
902 use hex_literal::hex;
903 use std::io;
904
905 fn check_pn(typed: PacketNumber, encoded: &[u8]) {
906 let mut buf = Vec::new();
907 typed.encode(&mut buf);
908 assert_eq!(&buf[..], encoded);
909 let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
910 assert_eq!(typed, decoded);
911 }
912
913 #[test]
914 fn roundtrip_packet_numbers() {
915 check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
916 check_pn(PacketNumber::U16(0x80), &hex!("0080"));
917 check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
918 check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
919 check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
920 }
921
922 #[test]
923 fn pn_encode() {
924 check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
925 check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
926 check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
927 }
928
929 #[test]
930 fn pn_expand_roundtrip() {
931 for expected in 0..1024 {
932 for actual in expected..1024 {
933 assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
934 }
935 }
936 }
937
938 #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
939 #[test]
940 fn header_encoding() {
941 use crate::Side;
942 use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
943 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
944 use rustls::crypto::aws_lc_rs::default_provider;
945 #[cfg(feature = "rustls-ring")]
946 use rustls::crypto::ring::default_provider;
947 use rustls::quic::Version;
948
949 let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
950 let provider = default_provider();
951
952 let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
953 let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
954 let mut buf = Vec::new();
955 let header = Header::Initial(InitialHeader {
956 number: PacketNumber::U8(0),
957 src_cid: ConnectionId::new(&[]),
958 dst_cid: dcid,
959 token: Bytes::new(),
960 version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
961 });
962 let encode = header.encode(&mut buf);
963 let header_len = buf.len();
964 buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
965 encode.finish(
966 &mut buf,
967 &*client.header.local,
968 Some((0, &*client.packet.local)),
969 );
970
971 for byte in &buf {
972 print!("{byte:02x}");
973 }
974 println!();
975 assert_eq!(
976 buf[..],
977 hex!(
978 "c8000000010806b858ec6f80452b00004021be
979 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
980 )[..]
981 );
982
983 let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
984 let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
985 let decode = PartialDecode::new(
986 buf.as_slice().into(),
987 &FixedLengthConnectionIdParser::new(0),
988 &supported_versions,
989 false,
990 )
991 .unwrap()
992 .0;
993 let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
994 assert_eq!(
995 packet.header_data[..],
996 hex!("c0000000010806b858ec6f80452b0000402100")[..]
997 );
998 server
999 .packet
1000 .remote
1001 .decrypt(0, &packet.header_data, &mut packet.payload)
1002 .unwrap();
1003 assert_eq!(packet.payload[..], [0; 16]);
1004 match packet.header {
1005 Header::Initial(InitialHeader {
1006 number: PacketNumber::U8(0),
1007 ..
1008 }) => {}
1009 _ => {
1010 panic!("unexpected header {:?}", packet.header);
1011 }
1012 }
1013 }
1014}