rustls/msgs/
handshake.rs

1use alloc::boxed::Box;
2use alloc::collections::BTreeSet;
3#[cfg(feature = "logging")]
4use alloc::string::String;
5use alloc::vec;
6use alloc::vec::Vec;
7use core::ops::{Deref, DerefMut};
8use core::{fmt, iter};
9
10use pki_types::{CertificateDer, DnsName};
11
12#[cfg(feature = "tls12")]
13use crate::crypto::ActiveKeyExchange;
14use crate::crypto::SecureRandom;
15use crate::enums::{
16    CertificateCompressionAlgorithm, CertificateType, CipherSuite, EchClientHelloType,
17    HandshakeType, ProtocolVersion, SignatureScheme,
18};
19use crate::error::InvalidMessage;
20#[cfg(feature = "tls12")]
21use crate::ffdhe_groups::FfdheGroup;
22use crate::log::warn;
23use crate::msgs::base::{MaybeEmpty, NonEmpty, Payload, PayloadU8, PayloadU16, PayloadU24};
24use crate::msgs::codec::{
25    self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement, TlsListIter,
26};
27use crate::msgs::enums::{
28    CertificateStatusType, ClientCertificateType, Compression, ECCurveType, ECPointFormat,
29    EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest, NamedGroup,
30    PskKeyExchangeMode, ServerNameType,
31};
32use crate::rand;
33use crate::sync::Arc;
34use crate::verify::DigitallySignedStruct;
35use crate::x509::wrap_in_sequence;
36
37/// Create a newtype wrapper around a given type.
38///
39/// This is used to create newtypes for the various TLS message types which is used to wrap
40/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
41/// anything other than access to the underlying bytes.
42macro_rules! wrapped_payload(
43  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident$(<$inner_ty:ty>)?,) => {
44    $(#[$comment])*
45    #[derive(Clone, Debug)]
46    $vis struct $name($inner$(<$inner_ty>)?);
47
48    impl From<Vec<u8>> for $name {
49        fn from(v: Vec<u8>) -> Self {
50            Self($inner::new(v))
51        }
52    }
53
54    impl AsRef<[u8]> for $name {
55        fn as_ref(&self) -> &[u8] {
56            self.0.0.as_slice()
57        }
58    }
59
60    impl Codec<'_> for $name {
61        fn encode(&self, bytes: &mut Vec<u8>) {
62            self.0.encode(bytes);
63        }
64
65        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
66            Ok(Self($inner::read(r)?))
67        }
68    }
69  }
70);
71
72#[derive(Clone, Copy, Eq, PartialEq)]
73pub(crate) struct Random(pub(crate) [u8; 32]);
74
75impl fmt::Debug for Random {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        super::base::hex(f, &self.0)
78    }
79}
80
81static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
82    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
83    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
84]);
85
86static ZERO_RANDOM: Random = Random([0u8; 32]);
87
88impl Codec<'_> for Random {
89    fn encode(&self, bytes: &mut Vec<u8>) {
90        bytes.extend_from_slice(&self.0);
91    }
92
93    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
94        let Some(bytes) = r.take(32) else {
95            return Err(InvalidMessage::MissingData("Random"));
96        };
97
98        let mut opaque = [0; 32];
99        opaque.clone_from_slice(bytes);
100        Ok(Self(opaque))
101    }
102}
103
104impl Random {
105    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
106        let mut data = [0u8; 32];
107        secure_random.fill(&mut data)?;
108        Ok(Self(data))
109    }
110}
111
112impl From<[u8; 32]> for Random {
113    #[inline]
114    fn from(bytes: [u8; 32]) -> Self {
115        Self(bytes)
116    }
117}
118
119#[derive(Copy, Clone)]
120pub(crate) struct SessionId {
121    len: usize,
122    data: [u8; 32],
123}
124
125impl fmt::Debug for SessionId {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        super::base::hex(f, &self.data[..self.len])
128    }
129}
130
131impl PartialEq for SessionId {
132    fn eq(&self, other: &Self) -> bool {
133        if self.len != other.len {
134            return false;
135        }
136
137        let mut diff = 0u8;
138        for i in 0..self.len {
139            diff |= self.data[i] ^ other.data[i];
140        }
141
142        diff == 0u8
143    }
144}
145
146impl Codec<'_> for SessionId {
147    fn encode(&self, bytes: &mut Vec<u8>) {
148        debug_assert!(self.len <= 32);
149        bytes.push(self.len as u8);
150        bytes.extend_from_slice(self.as_ref());
151    }
152
153    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
154        let len = u8::read(r)? as usize;
155        if len > 32 {
156            return Err(InvalidMessage::TrailingData("SessionID"));
157        }
158
159        let Some(bytes) = r.take(len) else {
160            return Err(InvalidMessage::MissingData("SessionID"));
161        };
162
163        let mut out = [0u8; 32];
164        out[..len].clone_from_slice(&bytes[..len]);
165        Ok(Self { data: out, len })
166    }
167}
168
169impl SessionId {
170    pub(crate) fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
171        let mut data = [0u8; 32];
172        secure_random.fill(&mut data)?;
173        Ok(Self { data, len: 32 })
174    }
175
176    pub(crate) fn empty() -> Self {
177        Self {
178            data: [0u8; 32],
179            len: 0,
180        }
181    }
182
183    #[cfg(feature = "tls12")]
184    pub(crate) fn is_empty(&self) -> bool {
185        self.len == 0
186    }
187}
188
189impl AsRef<[u8]> for SessionId {
190    fn as_ref(&self) -> &[u8] {
191        &self.data[..self.len]
192    }
193}
194
195#[derive(Clone, Debug, PartialEq)]
196pub struct UnknownExtension {
197    pub(crate) typ: ExtensionType,
198    pub(crate) payload: Payload<'static>,
199}
200
201impl UnknownExtension {
202    fn encode(&self, bytes: &mut Vec<u8>) {
203        self.payload.encode(bytes);
204    }
205
206    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
207        let payload = Payload::read(r).into_owned();
208        Self { typ, payload }
209    }
210}
211
212#[derive(Clone, Copy, Debug)]
213pub(crate) struct SupportedEcPointFormats {
214    pub(crate) uncompressed: bool,
215}
216
217impl Codec<'_> for SupportedEcPointFormats {
218    fn encode(&self, bytes: &mut Vec<u8>) {
219        let inner = LengthPrefixedBuffer::new(ECPointFormat::SIZE_LEN, bytes);
220
221        if self.uncompressed {
222            ECPointFormat::Uncompressed.encode(inner.buf);
223        }
224    }
225
226    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
227        let mut uncompressed = false;
228
229        for pf in TlsListIter::<ECPointFormat>::new(r)? {
230            if let ECPointFormat::Uncompressed = pf? {
231                uncompressed = true;
232            }
233        }
234
235        Ok(Self { uncompressed })
236    }
237}
238
239impl Default for SupportedEcPointFormats {
240    fn default() -> Self {
241        Self { uncompressed: true }
242    }
243}
244
245/// RFC8422: `ECPointFormat ec_point_format_list<1..2^8-1>`
246impl TlsListElement for ECPointFormat {
247    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
248        empty_error: InvalidMessage::IllegalEmptyList("ECPointFormats"),
249    };
250}
251
252/// RFC8422: `NamedCurve named_curve_list<2..2^16-1>`
253impl TlsListElement for NamedGroup {
254    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
255        empty_error: InvalidMessage::IllegalEmptyList("NamedGroups"),
256    };
257}
258
259/// RFC8446: `SignatureScheme supported_signature_algorithms<2..2^16-2>;`
260impl TlsListElement for SignatureScheme {
261    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
262        empty_error: InvalidMessage::NoSignatureSchemes,
263    };
264}
265
266#[derive(Clone, Debug)]
267pub(crate) enum ServerNamePayload<'a> {
268    /// A successfully decoded value:
269    SingleDnsName(DnsName<'a>),
270
271    /// A DNS name which was actually an IP address
272    IpAddress,
273
274    /// A successfully decoded, but syntactically-invalid value.
275    Invalid,
276}
277
278impl ServerNamePayload<'_> {
279    fn into_owned(self) -> ServerNamePayload<'static> {
280        match self {
281            Self::SingleDnsName(d) => ServerNamePayload::SingleDnsName(d.to_owned()),
282            Self::IpAddress => ServerNamePayload::IpAddress,
283            Self::Invalid => ServerNamePayload::Invalid,
284        }
285    }
286
287    /// RFC6066: `ServerName server_name_list<1..2^16-1>`
288    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
289        empty_error: InvalidMessage::IllegalEmptyList("ServerNames"),
290    };
291}
292
293/// Simplified encoding/decoding for a `ServerName` extension payload to/from `DnsName`
294///
295/// This is possible because:
296///
297/// - the spec (RFC6066) disallows multiple names for a given name type
298/// - name types other than ServerNameType::HostName are not defined, and they and
299///   any data that follows them cannot be skipped over.
300impl<'a> Codec<'a> for ServerNamePayload<'a> {
301    fn encode(&self, bytes: &mut Vec<u8>) {
302        let server_name_list = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
303
304        let ServerNamePayload::SingleDnsName(dns_name) = self else {
305            return;
306        };
307
308        ServerNameType::HostName.encode(server_name_list.buf);
309        let name_slice = dns_name.as_ref().as_bytes();
310        (name_slice.len() as u16).encode(server_name_list.buf);
311        server_name_list
312            .buf
313            .extend_from_slice(name_slice);
314    }
315
316    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
317        let mut found = None;
318
319        let len = Self::SIZE_LEN.read(r)?;
320        let mut sub = r.sub(len)?;
321
322        while sub.any_left() {
323            let typ = ServerNameType::read(&mut sub)?;
324
325            let payload = match typ {
326                ServerNameType::HostName => HostNamePayload::read(&mut sub)?,
327                _ => {
328                    // Consume remainder of extension bytes.  Since the length of the item
329                    // is an unknown encoding, we cannot continue.
330                    sub.rest();
331                    break;
332                }
333            };
334
335            // "The ServerNameList MUST NOT contain more than one name of
336            // the same name_type." - RFC6066
337            if found.is_some() {
338                warn!("Illegal SNI extension: duplicate host_name received");
339                return Err(InvalidMessage::InvalidServerName);
340            }
341
342            found = match payload {
343                HostNamePayload::HostName(dns_name) => {
344                    Some(Self::SingleDnsName(dns_name.to_owned()))
345                }
346
347                HostNamePayload::IpAddress(_invalid) => {
348                    warn!(
349                        "Illegal SNI extension: ignoring IP address presented as hostname ({_invalid:?})"
350                    );
351                    Some(Self::IpAddress)
352                }
353
354                HostNamePayload::Invalid(_invalid) => {
355                    warn!(
356                        "Illegal SNI hostname received {:?}",
357                        String::from_utf8_lossy(&_invalid.0)
358                    );
359                    Some(Self::Invalid)
360                }
361            };
362        }
363
364        Ok(found.unwrap_or(Self::Invalid))
365    }
366}
367
368impl<'a> From<&DnsName<'a>> for ServerNamePayload<'static> {
369    fn from(value: &DnsName<'a>) -> Self {
370        Self::SingleDnsName(trim_hostname_trailing_dot_for_sni(value))
371    }
372}
373
374#[derive(Clone, Debug)]
375pub(crate) enum HostNamePayload {
376    HostName(DnsName<'static>),
377    IpAddress(PayloadU16<NonEmpty>),
378    Invalid(PayloadU16<NonEmpty>),
379}
380
381impl HostNamePayload {
382    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
383        use pki_types::ServerName;
384        let raw = PayloadU16::<NonEmpty>::read(r)?;
385
386        match ServerName::try_from(raw.0.as_slice()) {
387            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
388            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
389            Ok(_) | Err(_) => Ok(Self::Invalid(raw)),
390        }
391    }
392}
393
394wrapped_payload!(
395    /// RFC7301: `opaque ProtocolName<1..2^8-1>;`
396    pub(crate) struct ProtocolName, PayloadU8<NonEmpty>,
397);
398
399impl PartialEq for ProtocolName {
400    fn eq(&self, other: &Self) -> bool {
401        self.0 == other.0
402    }
403}
404
405impl Deref for ProtocolName {
406    type Target = [u8];
407
408    fn deref(&self) -> &Self::Target {
409        self.as_ref()
410    }
411}
412
413/// RFC7301: `ProtocolName protocol_name_list<2..2^16-1>`
414impl TlsListElement for ProtocolName {
415    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
416        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
417    };
418}
419
420/// RFC7301 encodes a single protocol name as `Vec<ProtocolName>`
421#[derive(Clone, Debug)]
422pub(crate) struct SingleProtocolName(ProtocolName);
423
424impl SingleProtocolName {
425    pub(crate) fn new(single: ProtocolName) -> Self {
426        Self(single)
427    }
428
429    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
430        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
431    };
432}
433
434impl Codec<'_> for SingleProtocolName {
435    fn encode(&self, bytes: &mut Vec<u8>) {
436        let body = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
437        self.0.encode(body.buf);
438    }
439
440    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
441        let len = Self::SIZE_LEN.read(reader)?;
442        let mut sub = reader.sub(len)?;
443
444        let item = ProtocolName::read(&mut sub)?;
445
446        if sub.any_left() {
447            Err(InvalidMessage::TrailingData("SingleProtocolName"))
448        } else {
449            Ok(Self(item))
450        }
451    }
452}
453
454impl AsRef<ProtocolName> for SingleProtocolName {
455    fn as_ref(&self) -> &ProtocolName {
456        &self.0
457    }
458}
459
460// --- TLS 1.3 Key shares ---
461#[derive(Clone, Debug)]
462pub(crate) struct KeyShareEntry {
463    pub(crate) group: NamedGroup,
464    /// RFC8446: `opaque key_exchange<1..2^16-1>;`
465    pub(crate) payload: PayloadU16<NonEmpty>,
466}
467
468impl KeyShareEntry {
469    pub(crate) fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
470        Self {
471            group,
472            payload: PayloadU16::new(payload.into()),
473        }
474    }
475}
476
477impl Codec<'_> for KeyShareEntry {
478    fn encode(&self, bytes: &mut Vec<u8>) {
479        self.group.encode(bytes);
480        self.payload.encode(bytes);
481    }
482
483    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
484        let group = NamedGroup::read(r)?;
485        let payload = PayloadU16::read(r)?;
486
487        Ok(Self { group, payload })
488    }
489}
490
491// --- TLS 1.3 PresharedKey offers ---
492#[derive(Clone, Debug)]
493pub(crate) struct PresharedKeyIdentity {
494    /// RFC8446: `opaque identity<1..2^16-1>;`
495    pub(crate) identity: PayloadU16<NonEmpty>,
496    pub(crate) obfuscated_ticket_age: u32,
497}
498
499impl PresharedKeyIdentity {
500    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
501        Self {
502            identity: PayloadU16::new(id),
503            obfuscated_ticket_age: age,
504        }
505    }
506}
507
508impl Codec<'_> for PresharedKeyIdentity {
509    fn encode(&self, bytes: &mut Vec<u8>) {
510        self.identity.encode(bytes);
511        self.obfuscated_ticket_age.encode(bytes);
512    }
513
514    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
515        Ok(Self {
516            identity: PayloadU16::read(r)?,
517            obfuscated_ticket_age: u32::read(r)?,
518        })
519    }
520}
521
522/// RFC8446: `PskIdentity identities<7..2^16-1>;`
523impl TlsListElement for PresharedKeyIdentity {
524    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
525        empty_error: InvalidMessage::IllegalEmptyList("PskIdentities"),
526    };
527}
528
529wrapped_payload!(
530    /// RFC8446: `opaque PskBinderEntry<32..255>;`
531    pub(crate) struct PresharedKeyBinder, PayloadU8<NonEmpty>,
532);
533
534/// RFC8446: `PskBinderEntry binders<33..2^16-1>;`
535impl TlsListElement for PresharedKeyBinder {
536    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
537        empty_error: InvalidMessage::IllegalEmptyList("PskBinders"),
538    };
539}
540
541#[derive(Clone, Debug)]
542pub(crate) struct PresharedKeyOffer {
543    pub(crate) identities: Vec<PresharedKeyIdentity>,
544    pub(crate) binders: Vec<PresharedKeyBinder>,
545}
546
547impl PresharedKeyOffer {
548    /// Make a new one with one entry.
549    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
550        Self {
551            identities: vec![id],
552            binders: vec![PresharedKeyBinder::from(binder)],
553        }
554    }
555}
556
557impl Codec<'_> for PresharedKeyOffer {
558    fn encode(&self, bytes: &mut Vec<u8>) {
559        self.identities.encode(bytes);
560        self.binders.encode(bytes);
561    }
562
563    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
564        Ok(Self {
565            identities: Vec::read(r)?,
566            binders: Vec::read(r)?,
567        })
568    }
569}
570
571// --- RFC6066 certificate status request ---
572wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
573
574/// RFC6066: `ResponderID responder_id_list<0..2^16-1>;`
575impl TlsListElement for ResponderId {
576    const SIZE_LEN: ListLength = ListLength::U16;
577}
578
579#[derive(Clone, Debug)]
580pub(crate) struct OcspCertificateStatusRequest {
581    pub(crate) responder_ids: Vec<ResponderId>,
582    pub(crate) extensions: PayloadU16,
583}
584
585impl Codec<'_> for OcspCertificateStatusRequest {
586    fn encode(&self, bytes: &mut Vec<u8>) {
587        CertificateStatusType::OCSP.encode(bytes);
588        self.responder_ids.encode(bytes);
589        self.extensions.encode(bytes);
590    }
591
592    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
593        Ok(Self {
594            responder_ids: Vec::read(r)?,
595            extensions: PayloadU16::read(r)?,
596        })
597    }
598}
599
600#[derive(Clone, Debug)]
601pub(crate) enum CertificateStatusRequest {
602    Ocsp(OcspCertificateStatusRequest),
603    Unknown((CertificateStatusType, Payload<'static>)),
604}
605
606impl Codec<'_> for CertificateStatusRequest {
607    fn encode(&self, bytes: &mut Vec<u8>) {
608        match self {
609            Self::Ocsp(r) => r.encode(bytes),
610            Self::Unknown((typ, payload)) => {
611                typ.encode(bytes);
612                payload.encode(bytes);
613            }
614        }
615    }
616
617    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
618        let typ = CertificateStatusType::read(r)?;
619
620        match typ {
621            CertificateStatusType::OCSP => {
622                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
623                Ok(Self::Ocsp(ocsp_req))
624            }
625            _ => {
626                let data = Payload::read(r).into_owned();
627                Ok(Self::Unknown((typ, data)))
628            }
629        }
630    }
631}
632
633impl CertificateStatusRequest {
634    pub(crate) fn build_ocsp() -> Self {
635        let ocsp = OcspCertificateStatusRequest {
636            responder_ids: Vec::new(),
637            extensions: PayloadU16::empty(),
638        };
639        Self::Ocsp(ocsp)
640    }
641}
642
643// ---
644
645/// RFC8446: `PskKeyExchangeMode ke_modes<1..255>;`
646#[derive(Clone, Copy, Debug, Default)]
647pub(crate) struct PskKeyExchangeModes {
648    pub(crate) psk_dhe: bool,
649    pub(crate) psk: bool,
650}
651
652impl Codec<'_> for PskKeyExchangeModes {
653    fn encode(&self, bytes: &mut Vec<u8>) {
654        let inner = LengthPrefixedBuffer::new(PskKeyExchangeMode::SIZE_LEN, bytes);
655        if self.psk_dhe {
656            PskKeyExchangeMode::PSK_DHE_KE.encode(inner.buf);
657        }
658        if self.psk {
659            PskKeyExchangeMode::PSK_KE.encode(inner.buf);
660        }
661    }
662
663    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
664        let mut psk_dhe = false;
665        let mut psk = false;
666
667        for ke in TlsListIter::<PskKeyExchangeMode>::new(reader)? {
668            match ke? {
669                PskKeyExchangeMode::PSK_DHE_KE => psk_dhe = true,
670                PskKeyExchangeMode::PSK_KE => psk = true,
671                _ => continue,
672            };
673        }
674
675        Ok(Self { psk_dhe, psk })
676    }
677}
678
679impl TlsListElement for PskKeyExchangeMode {
680    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
681        empty_error: InvalidMessage::IllegalEmptyList("PskKeyExchangeModes"),
682    };
683}
684
685/// RFC8446: `KeyShareEntry client_shares<0..2^16-1>;`
686impl TlsListElement for KeyShareEntry {
687    const SIZE_LEN: ListLength = ListLength::U16;
688}
689
690/// The body of the `SupportedVersions` extension when it appears in a
691/// `ClientHello`
692///
693/// This is documented as a preference-order vector, but we (as a server)
694/// ignore the preference of the client.
695///
696/// RFC8446: `ProtocolVersion versions<2..254>;`
697#[derive(Clone, Copy, Debug, Default)]
698pub(crate) struct SupportedProtocolVersions {
699    pub(crate) tls13: bool,
700    pub(crate) tls12: bool,
701}
702
703impl SupportedProtocolVersions {
704    /// Return true if `filter` returns true for any enabled version.
705    pub(crate) fn any(&self, filter: impl Fn(ProtocolVersion) -> bool) -> bool {
706        if self.tls13 && filter(ProtocolVersion::TLSv1_3) {
707            return true;
708        }
709        if self.tls12 && filter(ProtocolVersion::TLSv1_2) {
710            return true;
711        }
712        false
713    }
714
715    const LIST_LENGTH: ListLength = ListLength::NonZeroU8 {
716        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
717    };
718}
719
720impl Codec<'_> for SupportedProtocolVersions {
721    fn encode(&self, bytes: &mut Vec<u8>) {
722        let inner = LengthPrefixedBuffer::new(Self::LIST_LENGTH, bytes);
723        if self.tls13 {
724            ProtocolVersion::TLSv1_3.encode(inner.buf);
725        }
726        if self.tls12 {
727            ProtocolVersion::TLSv1_2.encode(inner.buf);
728        }
729    }
730
731    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
732        let mut tls12 = false;
733        let mut tls13 = false;
734
735        for pv in TlsListIter::<ProtocolVersion>::new(reader)? {
736            match pv? {
737                ProtocolVersion::TLSv1_3 => tls13 = true,
738                ProtocolVersion::TLSv1_2 => tls12 = true,
739                _ => continue,
740            };
741        }
742
743        Ok(Self { tls13, tls12 })
744    }
745}
746
747impl TlsListElement for ProtocolVersion {
748    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
749        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
750    };
751}
752
753/// RFC7250: `CertificateType client_certificate_types<1..2^8-1>;`
754///
755/// Ditto `CertificateType server_certificate_types<1..2^8-1>;`
756impl TlsListElement for CertificateType {
757    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
758        empty_error: InvalidMessage::IllegalEmptyList("CertificateTypes"),
759    };
760}
761
762/// RFC8879: `CertificateCompressionAlgorithm algorithms<2..2^8-2>;`
763impl TlsListElement for CertificateCompressionAlgorithm {
764    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
765        empty_error: InvalidMessage::IllegalEmptyList("CertificateCompressionAlgorithms"),
766    };
767}
768
769/// A precursor to `ClientExtensions`, allowing customisation.
770///
771/// This is smaller than `ClientExtensions`, as it only contains the extensions
772/// we need to vary between different protocols (eg, TCP-TLS versus QUIC).
773#[derive(Clone, Default)]
774pub(crate) struct ClientExtensionsInput<'a> {
775    /// QUIC transport parameters
776    pub(crate) transport_parameters: Option<TransportParameters<'a>>,
777
778    /// ALPN protocols
779    pub(crate) protocols: Option<Vec<ProtocolName>>,
780}
781
782impl ClientExtensionsInput<'_> {
783    pub(crate) fn from_alpn(alpn_protocols: Vec<Vec<u8>>) -> ClientExtensionsInput<'static> {
784        let protocols = match alpn_protocols.is_empty() {
785            true => None,
786            false => Some(
787                alpn_protocols
788                    .into_iter()
789                    .map(ProtocolName::from)
790                    .collect::<Vec<_>>(),
791            ),
792        };
793
794        ClientExtensionsInput {
795            transport_parameters: None,
796            protocols,
797        }
798    }
799
800    pub(crate) fn into_owned(self) -> ClientExtensionsInput<'static> {
801        let Self {
802            transport_parameters,
803            protocols,
804        } = self;
805        ClientExtensionsInput {
806            transport_parameters: transport_parameters.map(|x| x.into_owned()),
807            protocols,
808        }
809    }
810}
811
812#[derive(Clone)]
813pub(crate) enum TransportParameters<'a> {
814    /// QUIC transport parameters (RFC9001 prior to draft 33)
815    QuicDraft(Payload<'a>),
816
817    /// QUIC transport parameters (RFC9001)
818    Quic(Payload<'a>),
819}
820
821impl TransportParameters<'_> {
822    pub(crate) fn into_owned(self) -> TransportParameters<'static> {
823        match self {
824            Self::QuicDraft(v) => TransportParameters::QuicDraft(v.into_owned()),
825            Self::Quic(v) => TransportParameters::Quic(v.into_owned()),
826        }
827    }
828}
829
830extension_struct! {
831    /// A representation of extensions present in a `ClientHello` message
832    ///
833    /// All extensions are optional (by definition) so are represented with `Option<T>`.
834    ///
835    /// Some extensions have an empty value and are represented with Option<()>.
836    ///
837    /// Unknown extensions are dropped during parsing.
838    pub(crate) struct ClientExtensions<'a> {
839        /// Requested server name indication (RFC6066)
840        ExtensionType::ServerName =>
841            pub(crate) server_name: Option<ServerNamePayload<'a>>,
842
843        /// Certificate status is requested (RFC6066)
844        ExtensionType::StatusRequest =>
845            pub(crate) certificate_status_request: Option<CertificateStatusRequest>,
846
847        /// Supported groups (RFC4492/RFC8446)
848        ExtensionType::EllipticCurves =>
849            pub(crate) named_groups: Option<Vec<NamedGroup>>,
850
851        /// Supported EC point formats (RFC4492)
852        ExtensionType::ECPointFormats =>
853            pub(crate) ec_point_formats: Option<SupportedEcPointFormats>,
854
855        /// Supported signature schemes (RFC5246/RFC8446)
856        ExtensionType::SignatureAlgorithms =>
857            pub(crate) signature_schemes: Option<Vec<SignatureScheme>>,
858
859        /// Offered ALPN protocols (RFC6066)
860        ExtensionType::ALProtocolNegotiation =>
861            pub(crate) protocols: Option<Vec<ProtocolName>>,
862
863        /// Available client certificate types (RFC7250)
864        ExtensionType::ClientCertificateType =>
865            pub(crate) client_certificate_types: Option<Vec<CertificateType>>,
866
867        /// Acceptable server certificate types (RFC7250)
868        ExtensionType::ServerCertificateType =>
869            pub(crate) server_certificate_types: Option<Vec<CertificateType>>,
870
871        /// Extended master secret is requested (RFC7627)
872        ExtensionType::ExtendedMasterSecret =>
873            pub(crate) extended_master_secret_request: Option<()>,
874
875        /// Offered certificate compression methods (RFC8879)
876        ExtensionType::CompressCertificate =>
877            pub(crate) certificate_compression_algorithms: Option<Vec<CertificateCompressionAlgorithm>>,
878
879        /// Session ticket offer or request (RFC5077/RFC8446)
880        ExtensionType::SessionTicket =>
881            pub(crate) session_ticket: Option<ClientSessionTicket>,
882
883        /// Offered preshared keys (RFC8446)
884        ExtensionType::PreSharedKey =>
885            pub(crate) preshared_key_offer: Option<PresharedKeyOffer>,
886
887        /// Early data is requested (RFC8446)
888        ExtensionType::EarlyData =>
889            pub(crate) early_data_request: Option<()>,
890
891        /// Supported TLS versions (RFC8446)
892        ExtensionType::SupportedVersions =>
893            pub(crate) supported_versions: Option<SupportedProtocolVersions>,
894
895        /// Stateless HelloRetryRequest cookie (RFC8446)
896        ExtensionType::Cookie =>
897            pub(crate) cookie: Option<PayloadU16<NonEmpty>>,
898
899        /// Offered preshared key modes (RFC8446)
900        ExtensionType::PSKKeyExchangeModes =>
901            pub(crate) preshared_key_modes: Option<PskKeyExchangeModes>,
902
903        /// Certificate authority names (RFC8446)
904        ExtensionType::CertificateAuthorities =>
905            pub(crate) certificate_authority_names: Option<Vec<DistinguishedName>>,
906
907        /// Offered key exchange shares (RFC8446)
908        ExtensionType::KeyShare =>
909            pub(crate) key_shares: Option<Vec<KeyShareEntry>>,
910
911        /// QUIC transport parameters (RFC9001)
912        ExtensionType::TransportParameters =>
913            pub(crate) transport_parameters: Option<Payload<'a>>,
914
915        /// Secure renegotiation (RFC5746)
916        ExtensionType::RenegotiationInfo =>
917            pub(crate) renegotiation_info: Option<PayloadU8>,
918
919        /// QUIC transport parameters (RFC9001 prior to draft 33)
920        ExtensionType::TransportParametersDraft =>
921            pub(crate) transport_parameters_draft: Option<Payload<'a>>,
922
923        /// Encrypted inner client hello (draft-ietf-tls-esni)
924        ExtensionType::EncryptedClientHello =>
925            pub(crate) encrypted_client_hello: Option<EncryptedClientHello>,
926
927        /// Encrypted client hello outer extensions (draft-ietf-tls-esni)
928        ExtensionType::EncryptedClientHelloOuterExtensions =>
929            pub(crate) encrypted_client_hello_outer: Option<Vec<ExtensionType>>,
930    } + {
931        /// Order randomization seed.
932        pub(crate) order_seed: u16,
933
934        /// Extensions that must appear contiguously.
935        pub(crate) contiguous_extensions: Vec<ExtensionType>,
936    }
937}
938
939impl ClientExtensions<'_> {
940    pub(crate) fn into_owned(self) -> ClientExtensions<'static> {
941        let Self {
942            server_name,
943            certificate_status_request,
944            named_groups,
945            ec_point_formats,
946            signature_schemes,
947            protocols,
948            client_certificate_types,
949            server_certificate_types,
950            extended_master_secret_request,
951            certificate_compression_algorithms,
952            session_ticket,
953            preshared_key_offer,
954            early_data_request,
955            supported_versions,
956            cookie,
957            preshared_key_modes,
958            certificate_authority_names,
959            key_shares,
960            transport_parameters,
961            renegotiation_info,
962            transport_parameters_draft,
963            encrypted_client_hello,
964            encrypted_client_hello_outer,
965            order_seed,
966            contiguous_extensions,
967        } = self;
968        ClientExtensions {
969            server_name: server_name.map(|x| x.into_owned()),
970            certificate_status_request,
971            named_groups,
972            ec_point_formats,
973            signature_schemes,
974            protocols,
975            client_certificate_types,
976            server_certificate_types,
977            extended_master_secret_request,
978            certificate_compression_algorithms,
979            session_ticket,
980            preshared_key_offer,
981            early_data_request,
982            supported_versions,
983            cookie,
984            preshared_key_modes,
985            certificate_authority_names,
986            key_shares,
987            transport_parameters: transport_parameters.map(|x| x.into_owned()),
988            renegotiation_info,
989            transport_parameters_draft: transport_parameters_draft.map(|x| x.into_owned()),
990            encrypted_client_hello,
991            encrypted_client_hello_outer,
992            order_seed,
993            contiguous_extensions,
994        }
995    }
996
997    pub(crate) fn used_extensions_in_encoding_order(&self) -> Vec<ExtensionType> {
998        let mut exts = self.order_insensitive_extensions_in_random_order();
999        exts.extend(&self.contiguous_extensions);
1000
1001        if self
1002            .encrypted_client_hello_outer
1003            .is_some()
1004        {
1005            exts.push(ExtensionType::EncryptedClientHelloOuterExtensions);
1006        }
1007        if self.encrypted_client_hello.is_some() {
1008            exts.push(ExtensionType::EncryptedClientHello);
1009        }
1010        if self.preshared_key_offer.is_some() {
1011            exts.push(ExtensionType::PreSharedKey);
1012        }
1013        exts
1014    }
1015
1016    /// Returns extensions which don't need a specific order, in randomized order.
1017    ///
1018    /// Extensions are encoded in three portions:
1019    ///
1020    /// - First, extensions not otherwise dealt with by other cases.
1021    ///   These are encoded in random order, controlled by `self.order_seed`,
1022    ///   and this is the set of extensions returned by this function.
1023    ///
1024    /// - Second, extensions named in `self.contiguous_extensions`, in the order
1025    ///   given by that field.
1026    ///
1027    /// - Lastly, any ECH and PSK extensions (in that order).  These
1028    ///   are required to be last by the standard.
1029    fn order_insensitive_extensions_in_random_order(&self) -> Vec<ExtensionType> {
1030        let mut order = self.collect_used();
1031
1032        // Remove extensions which have specific order requirements.
1033        order.retain(|ext| {
1034            !(matches!(
1035                ext,
1036                ExtensionType::PreSharedKey
1037                    | ExtensionType::EncryptedClientHello
1038                    | ExtensionType::EncryptedClientHelloOuterExtensions
1039            ) || self.contiguous_extensions.contains(ext))
1040        });
1041
1042        order.sort_by_cached_key(|new_ext| {
1043            let seed = ((self.order_seed as u32) << 16) | (u16::from(*new_ext) as u32);
1044            low_quality_integer_hash(seed)
1045        });
1046
1047        order
1048    }
1049}
1050
1051impl<'a> Codec<'a> for ClientExtensions<'a> {
1052    fn encode(&self, bytes: &mut Vec<u8>) {
1053        let order = self.used_extensions_in_encoding_order();
1054
1055        if order.is_empty() {
1056            return;
1057        }
1058
1059        let body = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1060        for item in order {
1061            self.encode_one(item, body.buf);
1062        }
1063    }
1064
1065    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1066        let mut out = Self::default();
1067
1068        // extensions length can be absent if no extensions
1069        if !r.any_left() {
1070            return Ok(out);
1071        }
1072
1073        let mut checker = DuplicateExtensionChecker::new();
1074
1075        let len = usize::from(u16::read(r)?);
1076        let mut sub = r.sub(len)?;
1077
1078        while sub.any_left() {
1079            let typ = out.read_one(&mut sub, |unknown| checker.check(unknown))?;
1080
1081            // PreSharedKey offer must come last
1082            if typ == ExtensionType::PreSharedKey && sub.any_left() {
1083                return Err(InvalidMessage::PreSharedKeyIsNotFinalExtension);
1084            }
1085        }
1086
1087        Ok(out)
1088    }
1089}
1090
1091fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
1092    let dns_name_str = dns_name.as_ref();
1093
1094    // RFC6066: "The hostname is represented as a byte string using
1095    // ASCII encoding without a trailing dot"
1096    if dns_name_str.ends_with('.') {
1097        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
1098        DnsName::try_from(trimmed)
1099            .unwrap()
1100            .to_owned()
1101    } else {
1102        dns_name.to_owned()
1103    }
1104}
1105
1106#[derive(Clone, Debug)]
1107pub(crate) enum ClientSessionTicket {
1108    Request,
1109    Offer(Payload<'static>),
1110}
1111
1112impl<'a> Codec<'a> for ClientSessionTicket {
1113    fn encode(&self, bytes: &mut Vec<u8>) {
1114        match self {
1115            Self::Request => (),
1116            Self::Offer(p) => p.encode(bytes),
1117        }
1118    }
1119
1120    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1121        Ok(match r.left() {
1122            0 => Self::Request,
1123            _ => Self::Offer(Payload::read(r).into_owned()),
1124        })
1125    }
1126}
1127
1128#[derive(Default)]
1129pub(crate) struct ServerExtensionsInput<'a> {
1130    /// QUIC transport parameters
1131    pub(crate) transport_parameters: Option<TransportParameters<'a>>,
1132}
1133
1134extension_struct! {
1135    pub(crate) struct ServerExtensions<'a> {
1136        /// Supported EC point formats (RFC4492)
1137        ExtensionType::ECPointFormats =>
1138            pub(crate) ec_point_formats: Option<SupportedEcPointFormats>,
1139
1140        /// Server name indication acknowledgement (RFC6066)
1141        ExtensionType::ServerName =>
1142            pub(crate) server_name_ack: Option<()>,
1143
1144        /// Session ticket acknowledgement (RFC5077)
1145        ExtensionType::SessionTicket =>
1146            pub(crate) session_ticket_ack: Option<()>,
1147
1148        ExtensionType::RenegotiationInfo =>
1149            pub(crate) renegotiation_info: Option<PayloadU8>,
1150
1151        /// Selected ALPN protocol (RFC7301)
1152        ExtensionType::ALProtocolNegotiation =>
1153            pub(crate) selected_protocol: Option<SingleProtocolName>,
1154
1155        /// Key exchange server share (RFC8446)
1156        ExtensionType::KeyShare =>
1157            pub(crate) key_share: Option<KeyShareEntry>,
1158
1159        /// Selected preshared key index (RFC8446)
1160        ExtensionType::PreSharedKey =>
1161            pub(crate) preshared_key: Option<u16>,
1162
1163        /// Required client certificate type (RFC7250)
1164        ExtensionType::ClientCertificateType =>
1165            pub(crate) client_certificate_type: Option<CertificateType>,
1166
1167        /// Selected server certificate type (RFC7250)
1168        ExtensionType::ServerCertificateType =>
1169            pub(crate) server_certificate_type: Option<CertificateType>,
1170
1171        /// Extended master secret is in use (RFC7627)
1172        ExtensionType::ExtendedMasterSecret =>
1173            pub(crate) extended_master_secret_ack: Option<()>,
1174
1175        /// Certificate status acknowledgement (RFC6066)
1176        ExtensionType::StatusRequest =>
1177            pub(crate) certificate_status_request_ack: Option<()>,
1178
1179        /// Selected TLS version (RFC8446)
1180        ExtensionType::SupportedVersions =>
1181            pub(crate) selected_version: Option<ProtocolVersion>,
1182
1183        /// QUIC transport parameters (RFC9001)
1184        ExtensionType::TransportParameters =>
1185            pub(crate) transport_parameters: Option<Payload<'a>>,
1186
1187        /// QUIC transport parameters (RFC9001 prior to draft 33)
1188        ExtensionType::TransportParametersDraft =>
1189            pub(crate) transport_parameters_draft: Option<Payload<'a>>,
1190
1191        /// Early data is accepted (RFC8446)
1192        ExtensionType::EarlyData =>
1193            pub(crate) early_data_ack: Option<()>,
1194
1195        /// Encrypted inner client hello response (draft-ietf-tls-esni)
1196        ExtensionType::EncryptedClientHello =>
1197            pub(crate) encrypted_client_hello_ack: Option<ServerEncryptedClientHello>,
1198    } + {
1199        pub(crate) unknown_extensions: BTreeSet<u16>,
1200    }
1201}
1202
1203impl ServerExtensions<'_> {
1204    fn into_owned(self) -> ServerExtensions<'static> {
1205        let Self {
1206            ec_point_formats,
1207            server_name_ack,
1208            session_ticket_ack,
1209            renegotiation_info,
1210            selected_protocol,
1211            key_share,
1212            preshared_key,
1213            client_certificate_type,
1214            server_certificate_type,
1215            extended_master_secret_ack,
1216            certificate_status_request_ack,
1217            selected_version,
1218            transport_parameters,
1219            transport_parameters_draft,
1220            early_data_ack,
1221            encrypted_client_hello_ack,
1222            unknown_extensions,
1223        } = self;
1224        ServerExtensions {
1225            ec_point_formats,
1226            server_name_ack,
1227            session_ticket_ack,
1228            renegotiation_info,
1229            selected_protocol,
1230            key_share,
1231            preshared_key,
1232            client_certificate_type,
1233            server_certificate_type,
1234            extended_master_secret_ack,
1235            certificate_status_request_ack,
1236            selected_version,
1237            transport_parameters: transport_parameters.map(|x| x.into_owned()),
1238            transport_parameters_draft: transport_parameters_draft.map(|x| x.into_owned()),
1239            early_data_ack,
1240            encrypted_client_hello_ack,
1241            unknown_extensions,
1242        }
1243    }
1244}
1245
1246impl<'a> Codec<'a> for ServerExtensions<'a> {
1247    fn encode(&self, bytes: &mut Vec<u8>) {
1248        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1249
1250        for ext in Self::ALL_EXTENSIONS {
1251            self.encode_one(*ext, extensions.buf);
1252        }
1253    }
1254
1255    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1256        let mut out = Self::default();
1257        let mut checker = DuplicateExtensionChecker::new();
1258
1259        let len = usize::from(u16::read(r)?);
1260        let mut sub = r.sub(len)?;
1261
1262        while sub.any_left() {
1263            out.read_one(&mut sub, |unknown| checker.check(unknown))?;
1264        }
1265
1266        out.unknown_extensions = checker.0;
1267        Ok(out)
1268    }
1269}
1270
1271#[derive(Clone, Debug)]
1272pub(crate) struct ClientHelloPayload {
1273    pub(crate) client_version: ProtocolVersion,
1274    pub(crate) random: Random,
1275    pub(crate) session_id: SessionId,
1276    pub(crate) cipher_suites: Vec<CipherSuite>,
1277    pub(crate) compression_methods: Vec<Compression>,
1278    pub(crate) extensions: Box<ClientExtensions<'static>>,
1279}
1280
1281impl ClientHelloPayload {
1282    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
1283        let mut bytes = Vec::new();
1284        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
1285        bytes
1286    }
1287
1288    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1289        self.client_version.encode(bytes);
1290        self.random.encode(bytes);
1291
1292        match purpose {
1293            // SessionID is required to be empty in the encoded inner client hello.
1294            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
1295            _ => self.session_id.encode(bytes),
1296        }
1297
1298        self.cipher_suites.encode(bytes);
1299        self.compression_methods.encode(bytes);
1300
1301        let to_compress = match purpose {
1302            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
1303            _ => {
1304                self.extensions.encode(bytes);
1305                return;
1306            }
1307        };
1308
1309        let mut compressed = self.extensions.clone();
1310
1311        // First, eliminate the full-fat versions of the extensions
1312        for e in &to_compress {
1313            compressed.clear(*e);
1314        }
1315
1316        // Replace with the marker noting which extensions were elided.
1317        compressed.encrypted_client_hello_outer = Some(to_compress);
1318
1319        // And encode as normal.
1320        compressed.encode(bytes);
1321    }
1322
1323    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1324        self.key_shares
1325            .as_ref()
1326            .map(|entries| {
1327                has_duplicates::<_, _, u16>(
1328                    entries
1329                        .iter()
1330                        .map(|kse| u16::from(kse.group)),
1331                )
1332            })
1333            .unwrap_or_default()
1334    }
1335
1336    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1337        if let Some(algs) = &self.certificate_compression_algorithms {
1338            has_duplicates::<_, _, u16>(algs.iter().cloned())
1339        } else {
1340            false
1341        }
1342    }
1343}
1344
1345impl Codec<'_> for ClientHelloPayload {
1346    fn encode(&self, bytes: &mut Vec<u8>) {
1347        self.payload_encode(bytes, Encoding::Standard)
1348    }
1349
1350    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1351        let ret = Self {
1352            client_version: ProtocolVersion::read(r)?,
1353            random: Random::read(r)?,
1354            session_id: SessionId::read(r)?,
1355            cipher_suites: Vec::read(r)?,
1356            compression_methods: Vec::read(r)?,
1357            extensions: Box::new(ClientExtensions::read(r)?.into_owned()),
1358        };
1359
1360        match r.any_left() {
1361            true => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
1362            false => Ok(ret),
1363        }
1364    }
1365}
1366
1367impl Deref for ClientHelloPayload {
1368    type Target = ClientExtensions<'static>;
1369    fn deref(&self) -> &Self::Target {
1370        &self.extensions
1371    }
1372}
1373
1374impl DerefMut for ClientHelloPayload {
1375    fn deref_mut(&mut self) -> &mut Self::Target {
1376        &mut self.extensions
1377    }
1378}
1379
1380/// RFC8446: `CipherSuite cipher_suites<2..2^16-2>;`
1381impl TlsListElement for CipherSuite {
1382    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
1383        empty_error: InvalidMessage::IllegalEmptyList("CipherSuites"),
1384    };
1385}
1386
1387/// RFC5246: `CompressionMethod compression_methods<1..2^8-1>;`
1388impl TlsListElement for Compression {
1389    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1390        empty_error: InvalidMessage::IllegalEmptyList("Compressions"),
1391    };
1392}
1393
1394/// draft-ietf-tls-esni-17: `ExtensionType OuterExtensions<2..254>;`
1395impl TlsListElement for ExtensionType {
1396    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1397        empty_error: InvalidMessage::IllegalEmptyList("ExtensionTypes"),
1398    };
1399}
1400
1401extension_struct! {
1402    /// A representation of extensions present in a `HelloRetryRequest` message
1403    pub(crate) struct HelloRetryRequestExtensions<'a> {
1404        ExtensionType::KeyShare =>
1405            pub(crate) key_share: Option<NamedGroup>,
1406
1407        ExtensionType::Cookie =>
1408            pub(crate) cookie: Option<PayloadU16<NonEmpty>>,
1409
1410        ExtensionType::SupportedVersions =>
1411            pub(crate) supported_versions: Option<ProtocolVersion>,
1412
1413        ExtensionType::EncryptedClientHello =>
1414            pub(crate) encrypted_client_hello: Option<Payload<'a>>,
1415    } + {
1416        /// Records decoding order of records, and controls encoding order.
1417        pub(crate) order: Option<Vec<ExtensionType>>,
1418    }
1419}
1420
1421impl HelloRetryRequestExtensions<'_> {
1422    fn into_owned(self) -> HelloRetryRequestExtensions<'static> {
1423        let Self {
1424            key_share,
1425            cookie,
1426            supported_versions,
1427            encrypted_client_hello,
1428            order,
1429        } = self;
1430        HelloRetryRequestExtensions {
1431            key_share,
1432            cookie,
1433            supported_versions,
1434            encrypted_client_hello: encrypted_client_hello.map(|x| x.into_owned()),
1435            order,
1436        }
1437    }
1438}
1439
1440impl<'a> Codec<'a> for HelloRetryRequestExtensions<'a> {
1441    fn encode(&self, bytes: &mut Vec<u8>) {
1442        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1443
1444        for ext in self
1445            .order
1446            .as_deref()
1447            .unwrap_or(Self::ALL_EXTENSIONS)
1448        {
1449            self.encode_one(*ext, extensions.buf);
1450        }
1451    }
1452
1453    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1454        let mut out = Self::default();
1455
1456        // we must record order, so re-encoding round trips.  this is needed,
1457        // unfortunately, for ECH HRR confirmation
1458        let mut order = vec![];
1459
1460        let len = usize::from(u16::read(r)?);
1461        let mut sub = r.sub(len)?;
1462
1463        while sub.any_left() {
1464            let typ = out.read_one(&mut sub, |_unk| {
1465                Err(InvalidMessage::UnknownHelloRetryRequestExtension)
1466            })?;
1467
1468            order.push(typ);
1469        }
1470
1471        out.order = Some(order);
1472        Ok(out)
1473    }
1474}
1475
1476#[derive(Clone, Debug)]
1477pub(crate) struct HelloRetryRequest {
1478    pub(crate) legacy_version: ProtocolVersion,
1479    pub(crate) session_id: SessionId,
1480    pub(crate) cipher_suite: CipherSuite,
1481    pub(crate) extensions: HelloRetryRequestExtensions<'static>,
1482}
1483
1484impl Codec<'_> for HelloRetryRequest {
1485    fn encode(&self, bytes: &mut Vec<u8>) {
1486        self.payload_encode(bytes, Encoding::Standard)
1487    }
1488
1489    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1490        let session_id = SessionId::read(r)?;
1491        let cipher_suite = CipherSuite::read(r)?;
1492        let compression = Compression::read(r)?;
1493
1494        if compression != Compression::Null {
1495            return Err(InvalidMessage::UnsupportedCompression);
1496        }
1497
1498        Ok(Self {
1499            legacy_version: ProtocolVersion::Unknown(0),
1500            session_id,
1501            cipher_suite,
1502            extensions: HelloRetryRequestExtensions::read(r)?.into_owned(),
1503        })
1504    }
1505}
1506
1507impl HelloRetryRequest {
1508    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1509        self.legacy_version.encode(bytes);
1510        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1511        self.session_id.encode(bytes);
1512        self.cipher_suite.encode(bytes);
1513        Compression::Null.encode(bytes);
1514
1515        match purpose {
1516            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1517            // must have its payload replaced by 8 zero bytes.
1518            //
1519            // See draft-ietf-tls-esni-18 7.2.1:
1520            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1521            Encoding::EchConfirmation
1522                if self
1523                    .extensions
1524                    .encrypted_client_hello
1525                    .is_some() =>
1526            {
1527                let hrr_confirmation = [0u8; 8];
1528                HelloRetryRequestExtensions {
1529                    encrypted_client_hello: Some(Payload::Borrowed(&hrr_confirmation)),
1530                    ..self.extensions.clone()
1531                }
1532                .encode(bytes);
1533            }
1534            _ => self.extensions.encode(bytes),
1535        }
1536    }
1537}
1538
1539impl Deref for HelloRetryRequest {
1540    type Target = HelloRetryRequestExtensions<'static>;
1541    fn deref(&self) -> &Self::Target {
1542        &self.extensions
1543    }
1544}
1545
1546impl DerefMut for HelloRetryRequest {
1547    fn deref_mut(&mut self) -> &mut Self::Target {
1548        &mut self.extensions
1549    }
1550}
1551
1552#[derive(Clone, Debug)]
1553pub(crate) struct ServerHelloPayload {
1554    pub(crate) legacy_version: ProtocolVersion,
1555    pub(crate) random: Random,
1556    pub(crate) session_id: SessionId,
1557    pub(crate) cipher_suite: CipherSuite,
1558    pub(crate) compression_method: Compression,
1559    pub(crate) extensions: Box<ServerExtensions<'static>>,
1560}
1561
1562impl Codec<'_> for ServerHelloPayload {
1563    fn encode(&self, bytes: &mut Vec<u8>) {
1564        self.payload_encode(bytes, Encoding::Standard)
1565    }
1566
1567    // minus version and random, which have already been read.
1568    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1569        let session_id = SessionId::read(r)?;
1570        let suite = CipherSuite::read(r)?;
1571        let compression = Compression::read(r)?;
1572
1573        // RFC5246:
1574        // "The presence of extensions can be detected by determining whether
1575        //  there are bytes following the compression_method field at the end of
1576        //  the ServerHello."
1577        let extensions = Box::new(
1578            if r.any_left() {
1579                ServerExtensions::read(r)?
1580            } else {
1581                ServerExtensions::default()
1582            }
1583            .into_owned(),
1584        );
1585
1586        let ret = Self {
1587            legacy_version: ProtocolVersion::Unknown(0),
1588            random: ZERO_RANDOM,
1589            session_id,
1590            cipher_suite: suite,
1591            compression_method: compression,
1592            extensions,
1593        };
1594
1595        r.expect_empty("ServerHelloPayload")
1596            .map(|_| ret)
1597    }
1598}
1599
1600impl ServerHelloPayload {
1601    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1602        debug_assert!(
1603            !matches!(encoding, Encoding::EchConfirmation),
1604            "we cannot compute an ECH confirmation on a received ServerHello"
1605        );
1606
1607        self.legacy_version.encode(bytes);
1608        self.random.encode(bytes);
1609        self.session_id.encode(bytes);
1610        self.cipher_suite.encode(bytes);
1611        self.compression_method.encode(bytes);
1612        self.extensions.encode(bytes);
1613    }
1614}
1615
1616impl Deref for ServerHelloPayload {
1617    type Target = ServerExtensions<'static>;
1618    fn deref(&self) -> &Self::Target {
1619        &self.extensions
1620    }
1621}
1622
1623impl DerefMut for ServerHelloPayload {
1624    fn deref_mut(&mut self) -> &mut Self::Target {
1625        &mut self.extensions
1626    }
1627}
1628
1629#[derive(Clone, Default, Debug)]
1630pub(crate) struct CertificateChain<'a>(pub(crate) Vec<CertificateDer<'a>>);
1631
1632impl CertificateChain<'_> {
1633    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1634        CertificateChain(
1635            self.0
1636                .into_iter()
1637                .map(|c| c.into_owned())
1638                .collect(),
1639        )
1640    }
1641}
1642
1643impl<'a> Codec<'a> for CertificateChain<'a> {
1644    fn encode(&self, bytes: &mut Vec<u8>) {
1645        Vec::encode(&self.0, bytes)
1646    }
1647
1648    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1649        Vec::read(r).map(Self)
1650    }
1651}
1652
1653impl<'a> Deref for CertificateChain<'a> {
1654    type Target = [CertificateDer<'a>];
1655
1656    fn deref(&self) -> &[CertificateDer<'a>] {
1657        &self.0
1658    }
1659}
1660
1661impl TlsListElement for CertificateDer<'_> {
1662    const SIZE_LEN: ListLength = ListLength::U24 {
1663        max: CERTIFICATE_MAX_SIZE_LIMIT,
1664        error: InvalidMessage::CertificatePayloadTooLarge,
1665    };
1666}
1667
1668/// TLS has a 16MB size limit on any handshake message,
1669/// plus a 16MB limit on any given certificate.
1670///
1671/// We contract that to 64KB to limit the amount of memory allocation
1672/// that is directly controllable by the peer.
1673pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1674
1675extension_struct! {
1676    pub(crate) struct CertificateExtensions<'a> {
1677        ExtensionType::StatusRequest =>
1678            pub(crate) status: Option<CertificateStatus<'a>>,
1679    }
1680}
1681
1682impl CertificateExtensions<'_> {
1683    fn into_owned(self) -> CertificateExtensions<'static> {
1684        CertificateExtensions {
1685            status: self.status.map(|s| s.into_owned()),
1686        }
1687    }
1688}
1689
1690impl<'a> Codec<'a> for CertificateExtensions<'a> {
1691    fn encode(&self, bytes: &mut Vec<u8>) {
1692        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1693
1694        for ext in Self::ALL_EXTENSIONS {
1695            self.encode_one(*ext, extensions.buf);
1696        }
1697    }
1698
1699    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1700        let mut out = Self::default();
1701
1702        let len = usize::from(u16::read(r)?);
1703        let mut sub = r.sub(len)?;
1704
1705        while sub.any_left() {
1706            out.read_one(&mut sub, |_unk| {
1707                Err(InvalidMessage::UnknownCertificateExtension)
1708            })?;
1709        }
1710
1711        Ok(out)
1712    }
1713}
1714
1715#[derive(Debug)]
1716pub(crate) struct CertificateEntry<'a> {
1717    pub(crate) cert: CertificateDer<'a>,
1718    pub(crate) extensions: CertificateExtensions<'a>,
1719}
1720
1721impl<'a> Codec<'a> for CertificateEntry<'a> {
1722    fn encode(&self, bytes: &mut Vec<u8>) {
1723        self.cert.encode(bytes);
1724        self.extensions.encode(bytes);
1725    }
1726
1727    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1728        Ok(Self {
1729            cert: CertificateDer::read(r)?,
1730            extensions: CertificateExtensions::read(r)?.into_owned(),
1731        })
1732    }
1733}
1734
1735impl<'a> CertificateEntry<'a> {
1736    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1737        Self {
1738            cert,
1739            extensions: CertificateExtensions::default(),
1740        }
1741    }
1742
1743    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1744        CertificateEntry {
1745            cert: self.cert.into_owned(),
1746            extensions: self.extensions.into_owned(),
1747        }
1748    }
1749}
1750
1751impl TlsListElement for CertificateEntry<'_> {
1752    const SIZE_LEN: ListLength = ListLength::U24 {
1753        max: CERTIFICATE_MAX_SIZE_LIMIT,
1754        error: InvalidMessage::CertificatePayloadTooLarge,
1755    };
1756}
1757
1758#[derive(Debug)]
1759pub(crate) struct CertificatePayloadTls13<'a> {
1760    pub(crate) context: PayloadU8,
1761    pub(crate) entries: Vec<CertificateEntry<'a>>,
1762}
1763
1764impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1765    fn encode(&self, bytes: &mut Vec<u8>) {
1766        self.context.encode(bytes);
1767        self.entries.encode(bytes);
1768    }
1769
1770    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1771        Ok(Self {
1772            context: PayloadU8::read(r)?,
1773            entries: Vec::read(r)?,
1774        })
1775    }
1776}
1777
1778impl<'a> CertificatePayloadTls13<'a> {
1779    pub(crate) fn new(
1780        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1781        ocsp_response: Option<&'a [u8]>,
1782    ) -> Self {
1783        Self {
1784            context: PayloadU8::empty(),
1785            entries: certs
1786                // zip certificate iterator with `ocsp_response` followed by
1787                // an infinite-length iterator of `None`.
1788                .zip(
1789                    ocsp_response
1790                        .into_iter()
1791                        .map(Some)
1792                        .chain(iter::repeat(None)),
1793                )
1794                .map(|(cert, ocsp)| {
1795                    let mut e = CertificateEntry::new(cert.clone());
1796                    if let Some(ocsp) = ocsp {
1797                        e.extensions.status = Some(CertificateStatus::new(ocsp));
1798                    }
1799                    e
1800                })
1801                .collect(),
1802        }
1803    }
1804
1805    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1806        CertificatePayloadTls13 {
1807            context: self.context,
1808            entries: self
1809                .entries
1810                .into_iter()
1811                .map(CertificateEntry::into_owned)
1812                .collect(),
1813        }
1814    }
1815
1816    pub(crate) fn end_entity_ocsp(&self) -> Vec<u8> {
1817        let Some(entry) = self.entries.first() else {
1818            return vec![];
1819        };
1820        entry
1821            .extensions
1822            .status
1823            .as_ref()
1824            .map(|status| {
1825                status
1826                    .ocsp_response
1827                    .0
1828                    .clone()
1829                    .into_vec()
1830            })
1831            .unwrap_or_default()
1832    }
1833
1834    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1835        CertificateChain(
1836            self.entries
1837                .into_iter()
1838                .map(|e| e.cert)
1839                .collect(),
1840        )
1841    }
1842}
1843
1844/// Describes supported key exchange mechanisms.
1845#[derive(Clone, Copy, Debug, PartialEq)]
1846#[non_exhaustive]
1847pub enum KeyExchangeAlgorithm {
1848    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1849    ///
1850    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1851    DHE,
1852    /// Key exchange performed via elliptic curve Diffie-Hellman.
1853    ECDHE,
1854}
1855
1856pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1857    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1858
1859// We don't support arbitrary curves.  It's a terrible
1860// idea and unnecessary attack surface.  Please,
1861// get a grip.
1862#[derive(Debug)]
1863pub(crate) struct EcParameters {
1864    pub(crate) curve_type: ECCurveType,
1865    pub(crate) named_group: NamedGroup,
1866}
1867
1868impl Codec<'_> for EcParameters {
1869    fn encode(&self, bytes: &mut Vec<u8>) {
1870        self.curve_type.encode(bytes);
1871        self.named_group.encode(bytes);
1872    }
1873
1874    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1875        let ct = ECCurveType::read(r)?;
1876        if ct != ECCurveType::NamedCurve {
1877            return Err(InvalidMessage::UnsupportedCurveType);
1878        }
1879
1880        let grp = NamedGroup::read(r)?;
1881
1882        Ok(Self {
1883            curve_type: ct,
1884            named_group: grp,
1885        })
1886    }
1887}
1888
1889#[cfg(feature = "tls12")]
1890pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1891    /// Decode a key exchange message given the key_exchange `algo`
1892    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1893}
1894
1895#[cfg(feature = "tls12")]
1896#[derive(Debug)]
1897pub(crate) enum ClientKeyExchangeParams {
1898    Ecdh(ClientEcdhParams),
1899    Dh(ClientDhParams),
1900}
1901
1902#[cfg(feature = "tls12")]
1903impl ClientKeyExchangeParams {
1904    pub(crate) fn pub_key(&self) -> &[u8] {
1905        match self {
1906            Self::Ecdh(ecdh) => &ecdh.public.0,
1907            Self::Dh(dh) => &dh.public.0,
1908        }
1909    }
1910
1911    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1912        match self {
1913            Self::Ecdh(ecdh) => ecdh.encode(buf),
1914            Self::Dh(dh) => dh.encode(buf),
1915        }
1916    }
1917}
1918
1919#[cfg(feature = "tls12")]
1920impl KxDecode<'_> for ClientKeyExchangeParams {
1921    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1922        use KeyExchangeAlgorithm::*;
1923        Ok(match algo {
1924            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1925            DHE => Self::Dh(ClientDhParams::read(r)?),
1926        })
1927    }
1928}
1929
1930#[cfg(feature = "tls12")]
1931#[derive(Debug)]
1932pub(crate) struct ClientEcdhParams {
1933    /// RFC4492: `opaque point <1..2^8-1>;`
1934    pub(crate) public: PayloadU8<NonEmpty>,
1935}
1936
1937#[cfg(feature = "tls12")]
1938impl Codec<'_> for ClientEcdhParams {
1939    fn encode(&self, bytes: &mut Vec<u8>) {
1940        self.public.encode(bytes);
1941    }
1942
1943    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1944        let pb = PayloadU8::read(r)?;
1945        Ok(Self { public: pb })
1946    }
1947}
1948
1949#[cfg(feature = "tls12")]
1950#[derive(Debug)]
1951pub(crate) struct ClientDhParams {
1952    /// RFC5246: `opaque dh_Yc<1..2^16-1>;`
1953    pub(crate) public: PayloadU16<NonEmpty>,
1954}
1955
1956#[cfg(feature = "tls12")]
1957impl Codec<'_> for ClientDhParams {
1958    fn encode(&self, bytes: &mut Vec<u8>) {
1959        self.public.encode(bytes);
1960    }
1961
1962    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1963        Ok(Self {
1964            public: PayloadU16::read(r)?,
1965        })
1966    }
1967}
1968
1969#[derive(Debug)]
1970pub(crate) struct ServerEcdhParams {
1971    pub(crate) curve_params: EcParameters,
1972    /// RFC4492: `opaque point <1..2^8-1>;`
1973    pub(crate) public: PayloadU8<NonEmpty>,
1974}
1975
1976impl ServerEcdhParams {
1977    #[cfg(feature = "tls12")]
1978    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1979        Self {
1980            curve_params: EcParameters {
1981                curve_type: ECCurveType::NamedCurve,
1982                named_group: kx.group(),
1983            },
1984            public: PayloadU8::new(kx.pub_key().to_vec()),
1985        }
1986    }
1987}
1988
1989impl Codec<'_> for ServerEcdhParams {
1990    fn encode(&self, bytes: &mut Vec<u8>) {
1991        self.curve_params.encode(bytes);
1992        self.public.encode(bytes);
1993    }
1994
1995    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1996        let cp = EcParameters::read(r)?;
1997        let pb = PayloadU8::read(r)?;
1998
1999        Ok(Self {
2000            curve_params: cp,
2001            public: pb,
2002        })
2003    }
2004}
2005
2006#[derive(Debug)]
2007#[allow(non_snake_case)]
2008pub(crate) struct ServerDhParams {
2009    /// RFC5246: `opaque dh_p<1..2^16-1>;`
2010    pub(crate) dh_p: PayloadU16<NonEmpty>,
2011    /// RFC5246: `opaque dh_g<1..2^16-1>;`
2012    pub(crate) dh_g: PayloadU16<NonEmpty>,
2013    /// RFC5246: `opaque dh_Ys<1..2^16-1>;`
2014    pub(crate) dh_Ys: PayloadU16<NonEmpty>,
2015}
2016
2017impl ServerDhParams {
2018    #[cfg(feature = "tls12")]
2019    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2020        let Some(params) = kx.ffdhe_group() else {
2021            panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group());
2022        };
2023
2024        Self {
2025            dh_p: PayloadU16::new(params.p.to_vec()),
2026            dh_g: PayloadU16::new(params.g.to_vec()),
2027            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
2028        }
2029    }
2030
2031    #[cfg(feature = "tls12")]
2032    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
2033        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
2034    }
2035}
2036
2037impl Codec<'_> for ServerDhParams {
2038    fn encode(&self, bytes: &mut Vec<u8>) {
2039        self.dh_p.encode(bytes);
2040        self.dh_g.encode(bytes);
2041        self.dh_Ys.encode(bytes);
2042    }
2043
2044    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2045        Ok(Self {
2046            dh_p: PayloadU16::read(r)?,
2047            dh_g: PayloadU16::read(r)?,
2048            dh_Ys: PayloadU16::read(r)?,
2049        })
2050    }
2051}
2052
2053#[allow(dead_code)]
2054#[derive(Debug)]
2055pub(crate) enum ServerKeyExchangeParams {
2056    Ecdh(ServerEcdhParams),
2057    Dh(ServerDhParams),
2058}
2059
2060impl ServerKeyExchangeParams {
2061    #[cfg(feature = "tls12")]
2062    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2063        match kx.group().key_exchange_algorithm() {
2064            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
2065            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
2066        }
2067    }
2068
2069    #[cfg(feature = "tls12")]
2070    pub(crate) fn pub_key(&self) -> &[u8] {
2071        match self {
2072            Self::Ecdh(ecdh) => &ecdh.public.0,
2073            Self::Dh(dh) => &dh.dh_Ys.0,
2074        }
2075    }
2076
2077    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2078        match self {
2079            Self::Ecdh(ecdh) => ecdh.encode(buf),
2080            Self::Dh(dh) => dh.encode(buf),
2081        }
2082    }
2083}
2084
2085#[cfg(feature = "tls12")]
2086impl KxDecode<'_> for ServerKeyExchangeParams {
2087    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
2088        use KeyExchangeAlgorithm::*;
2089        Ok(match algo {
2090            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
2091            DHE => Self::Dh(ServerDhParams::read(r)?),
2092        })
2093    }
2094}
2095
2096#[derive(Debug)]
2097pub(crate) struct ServerKeyExchange {
2098    pub(crate) params: ServerKeyExchangeParams,
2099    pub(crate) dss: DigitallySignedStruct,
2100}
2101
2102impl ServerKeyExchange {
2103    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2104        self.params.encode(buf);
2105        self.dss.encode(buf);
2106    }
2107}
2108
2109#[derive(Debug)]
2110pub(crate) enum ServerKeyExchangePayload {
2111    Known(ServerKeyExchange),
2112    Unknown(Payload<'static>),
2113}
2114
2115impl From<ServerKeyExchange> for ServerKeyExchangePayload {
2116    fn from(value: ServerKeyExchange) -> Self {
2117        Self::Known(value)
2118    }
2119}
2120
2121impl Codec<'_> for ServerKeyExchangePayload {
2122    fn encode(&self, bytes: &mut Vec<u8>) {
2123        match self {
2124            Self::Known(x) => x.encode(bytes),
2125            Self::Unknown(x) => x.encode(bytes),
2126        }
2127    }
2128
2129    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2130        // read as Unknown, fully parse when we know the
2131        // KeyExchangeAlgorithm
2132        Ok(Self::Unknown(Payload::read(r).into_owned()))
2133    }
2134}
2135
2136impl ServerKeyExchangePayload {
2137    #[cfg(feature = "tls12")]
2138    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
2139        if let Self::Unknown(unk) = self {
2140            let mut rd = Reader::init(unk.bytes());
2141
2142            let result = ServerKeyExchange {
2143                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
2144                dss: DigitallySignedStruct::read(&mut rd).ok()?,
2145            };
2146
2147            if !rd.any_left() {
2148                return Some(result);
2149            };
2150        }
2151
2152        None
2153    }
2154}
2155
2156/// RFC5246: `ClientCertificateType certificate_types<1..2^8-1>;`
2157impl TlsListElement for ClientCertificateType {
2158    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
2159        empty_error: InvalidMessage::IllegalEmptyList("ClientCertificateTypes"),
2160    };
2161}
2162
2163wrapped_payload!(
2164    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2165    ///
2166    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2167    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2168    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2169    ///
2170    /// ```ignore
2171    /// for name in distinguished_names {
2172    ///     use x509_parser::prelude::FromDer;
2173    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2174    /// }
2175    /// ```
2176    ///
2177    /// The TLS encoding is defined in RFC5246: `opaque DistinguishedName<1..2^16-1>;`
2178    pub struct DistinguishedName,
2179    PayloadU16<NonEmpty>,
2180);
2181
2182impl DistinguishedName {
2183    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2184    ///
2185    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2186    ///
2187    /// ```ignore
2188    /// use x509_parser::prelude::FromDer;
2189    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2190    /// ```
2191    pub fn in_sequence(bytes: &[u8]) -> Self {
2192        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2193    }
2194}
2195
2196/// RFC8446: `DistinguishedName authorities<3..2^16-1>;` however,
2197/// RFC5246: `DistinguishedName certificate_authorities<0..2^16-1>;`
2198impl TlsListElement for DistinguishedName {
2199    const SIZE_LEN: ListLength = ListLength::U16;
2200}
2201
2202#[derive(Debug)]
2203pub(crate) struct CertificateRequestPayload {
2204    pub(crate) certtypes: Vec<ClientCertificateType>,
2205    pub(crate) sigschemes: Vec<SignatureScheme>,
2206    pub(crate) canames: Vec<DistinguishedName>,
2207}
2208
2209impl Codec<'_> for CertificateRequestPayload {
2210    fn encode(&self, bytes: &mut Vec<u8>) {
2211        self.certtypes.encode(bytes);
2212        self.sigschemes.encode(bytes);
2213        self.canames.encode(bytes);
2214    }
2215
2216    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2217        let certtypes = Vec::read(r)?;
2218        let sigschemes = Vec::read(r)?;
2219        let canames = Vec::read(r)?;
2220
2221        if sigschemes.is_empty() {
2222            warn!("meaningless CertificateRequest message");
2223            Err(InvalidMessage::NoSignatureSchemes)
2224        } else {
2225            Ok(Self {
2226                certtypes,
2227                sigschemes,
2228                canames,
2229            })
2230        }
2231    }
2232}
2233
2234extension_struct! {
2235    pub(crate) struct CertificateRequestExtensions {
2236        ExtensionType::SignatureAlgorithms =>
2237            pub(crate) signature_algorithms: Option<Vec<SignatureScheme>>,
2238
2239        ExtensionType::CertificateAuthorities =>
2240            pub(crate) authority_names: Option<Vec<DistinguishedName>>,
2241
2242        ExtensionType::CompressCertificate =>
2243            pub(crate) certificate_compression_algorithms: Option<Vec<CertificateCompressionAlgorithm>>,
2244    }
2245}
2246
2247impl Codec<'_> for CertificateRequestExtensions {
2248    fn encode(&self, bytes: &mut Vec<u8>) {
2249        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2250
2251        for ext in Self::ALL_EXTENSIONS {
2252            self.encode_one(*ext, extensions.buf);
2253        }
2254    }
2255
2256    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2257        let mut out = Self::default();
2258
2259        let mut checker = DuplicateExtensionChecker::new();
2260
2261        let len = usize::from(u16::read(r)?);
2262        let mut sub = r.sub(len)?;
2263
2264        while sub.any_left() {
2265            out.read_one(&mut sub, |unknown| checker.check(unknown))?;
2266        }
2267
2268        if out
2269            .signature_algorithms
2270            .as_ref()
2271            .map(|algs| algs.is_empty())
2272            .unwrap_or_default()
2273        {
2274            return Err(InvalidMessage::NoSignatureSchemes);
2275        }
2276
2277        Ok(out)
2278    }
2279}
2280
2281#[derive(Debug)]
2282pub(crate) struct CertificateRequestPayloadTls13 {
2283    pub(crate) context: PayloadU8,
2284    pub(crate) extensions: CertificateRequestExtensions,
2285}
2286
2287impl Codec<'_> for CertificateRequestPayloadTls13 {
2288    fn encode(&self, bytes: &mut Vec<u8>) {
2289        self.context.encode(bytes);
2290        self.extensions.encode(bytes);
2291    }
2292
2293    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2294        let context = PayloadU8::read(r)?;
2295        let extensions = CertificateRequestExtensions::read(r)?;
2296
2297        Ok(Self {
2298            context,
2299            extensions,
2300        })
2301    }
2302}
2303
2304// -- NewSessionTicket --
2305#[derive(Debug)]
2306pub(crate) struct NewSessionTicketPayload {
2307    pub(crate) lifetime_hint: u32,
2308    // Tickets can be large (KB), so we deserialise this straight
2309    // into an Arc, so it can be passed directly into the client's
2310    // session object without copying.
2311    pub(crate) ticket: Arc<PayloadU16>,
2312}
2313
2314impl NewSessionTicketPayload {
2315    #[cfg(feature = "tls12")]
2316    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2317        Self {
2318            lifetime_hint,
2319            ticket: Arc::new(PayloadU16::new(ticket)),
2320        }
2321    }
2322}
2323
2324impl Codec<'_> for NewSessionTicketPayload {
2325    fn encode(&self, bytes: &mut Vec<u8>) {
2326        self.lifetime_hint.encode(bytes);
2327        self.ticket.encode(bytes);
2328    }
2329
2330    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2331        let lifetime = u32::read(r)?;
2332        let ticket = Arc::new(PayloadU16::read(r)?);
2333
2334        Ok(Self {
2335            lifetime_hint: lifetime,
2336            ticket,
2337        })
2338    }
2339}
2340
2341// -- NewSessionTicket electric boogaloo --
2342extension_struct! {
2343    pub(crate) struct NewSessionTicketExtensions {
2344        ExtensionType::EarlyData =>
2345            pub(crate) max_early_data_size: Option<u32>,
2346    }
2347}
2348
2349impl Codec<'_> for NewSessionTicketExtensions {
2350    fn encode(&self, bytes: &mut Vec<u8>) {
2351        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2352
2353        for ext in Self::ALL_EXTENSIONS {
2354            self.encode_one(*ext, extensions.buf);
2355        }
2356    }
2357
2358    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2359        let mut out = Self::default();
2360
2361        let mut checker = DuplicateExtensionChecker::new();
2362
2363        let len = usize::from(u16::read(r)?);
2364        let mut sub = r.sub(len)?;
2365
2366        while sub.any_left() {
2367            out.read_one(&mut sub, |unknown| checker.check(unknown))?;
2368        }
2369
2370        Ok(out)
2371    }
2372}
2373
2374#[derive(Debug)]
2375pub(crate) struct NewSessionTicketPayloadTls13 {
2376    pub(crate) lifetime: u32,
2377    pub(crate) age_add: u32,
2378    pub(crate) nonce: PayloadU8,
2379    pub(crate) ticket: Arc<PayloadU16>,
2380    pub(crate) extensions: NewSessionTicketExtensions,
2381}
2382
2383impl NewSessionTicketPayloadTls13 {
2384    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2385        Self {
2386            lifetime,
2387            age_add,
2388            nonce: PayloadU8::new(nonce),
2389            ticket: Arc::new(PayloadU16::new(ticket)),
2390            extensions: NewSessionTicketExtensions::default(),
2391        }
2392    }
2393}
2394
2395impl Codec<'_> for NewSessionTicketPayloadTls13 {
2396    fn encode(&self, bytes: &mut Vec<u8>) {
2397        self.lifetime.encode(bytes);
2398        self.age_add.encode(bytes);
2399        self.nonce.encode(bytes);
2400        self.ticket.encode(bytes);
2401        self.extensions.encode(bytes);
2402    }
2403
2404    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2405        let lifetime = u32::read(r)?;
2406        let age_add = u32::read(r)?;
2407        let nonce = PayloadU8::read(r)?;
2408        // nb. RFC8446: `opaque ticket<1..2^16-1>;`
2409        let ticket = Arc::new(match PayloadU16::<NonEmpty>::read(r) {
2410            Err(InvalidMessage::IllegalEmptyValue) => Err(InvalidMessage::EmptyTicketValue),
2411            Err(err) => Err(err),
2412            Ok(pl) => Ok(PayloadU16::new(pl.0)),
2413        }?);
2414        let extensions = NewSessionTicketExtensions::read(r)?;
2415
2416        Ok(Self {
2417            lifetime,
2418            age_add,
2419            nonce,
2420            ticket,
2421            extensions,
2422        })
2423    }
2424}
2425
2426// -- RFC6066 certificate status types
2427
2428/// Only supports OCSP
2429#[derive(Clone, Debug)]
2430pub(crate) struct CertificateStatus<'a> {
2431    pub(crate) ocsp_response: PayloadU24<'a>,
2432}
2433
2434impl<'a> Codec<'a> for CertificateStatus<'a> {
2435    fn encode(&self, bytes: &mut Vec<u8>) {
2436        CertificateStatusType::OCSP.encode(bytes);
2437        self.ocsp_response.encode(bytes);
2438    }
2439
2440    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2441        let typ = CertificateStatusType::read(r)?;
2442
2443        match typ {
2444            CertificateStatusType::OCSP => Ok(Self {
2445                ocsp_response: PayloadU24::read(r)?,
2446            }),
2447            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2448        }
2449    }
2450}
2451
2452impl<'a> CertificateStatus<'a> {
2453    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2454        CertificateStatus {
2455            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2456        }
2457    }
2458
2459    #[cfg(feature = "tls12")]
2460    pub(crate) fn into_inner(self) -> Vec<u8> {
2461        self.ocsp_response.0.into_vec()
2462    }
2463
2464    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2465        CertificateStatus {
2466            ocsp_response: self.ocsp_response.into_owned(),
2467        }
2468    }
2469}
2470
2471// -- RFC8879 compressed certificates
2472
2473#[derive(Debug)]
2474pub(crate) struct CompressedCertificatePayload<'a> {
2475    pub(crate) alg: CertificateCompressionAlgorithm,
2476    pub(crate) uncompressed_len: u32,
2477    pub(crate) compressed: PayloadU24<'a>,
2478}
2479
2480impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2481    fn encode(&self, bytes: &mut Vec<u8>) {
2482        self.alg.encode(bytes);
2483        codec::u24(self.uncompressed_len).encode(bytes);
2484        self.compressed.encode(bytes);
2485    }
2486
2487    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2488        Ok(Self {
2489            alg: CertificateCompressionAlgorithm::read(r)?,
2490            uncompressed_len: codec::u24::read(r)?.0,
2491            compressed: PayloadU24::read(r)?,
2492        })
2493    }
2494}
2495
2496impl CompressedCertificatePayload<'_> {
2497    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2498        CompressedCertificatePayload {
2499            compressed: self.compressed.into_owned(),
2500            ..self
2501        }
2502    }
2503
2504    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2505        CompressedCertificatePayload {
2506            alg: self.alg,
2507            uncompressed_len: self.uncompressed_len,
2508            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2509        }
2510    }
2511}
2512
2513#[derive(Debug)]
2514pub(crate) enum HandshakePayload<'a> {
2515    HelloRequest,
2516    ClientHello(ClientHelloPayload),
2517    ServerHello(ServerHelloPayload),
2518    HelloRetryRequest(HelloRetryRequest),
2519    Certificate(CertificateChain<'a>),
2520    CertificateTls13(CertificatePayloadTls13<'a>),
2521    CompressedCertificate(CompressedCertificatePayload<'a>),
2522    ServerKeyExchange(ServerKeyExchangePayload),
2523    CertificateRequest(CertificateRequestPayload),
2524    CertificateRequestTls13(CertificateRequestPayloadTls13),
2525    CertificateVerify(DigitallySignedStruct),
2526    ServerHelloDone,
2527    EndOfEarlyData,
2528    ClientKeyExchange(Payload<'a>),
2529    NewSessionTicket(NewSessionTicketPayload),
2530    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2531    EncryptedExtensions(Box<ServerExtensions<'a>>),
2532    KeyUpdate(KeyUpdateRequest),
2533    Finished(Payload<'a>),
2534    CertificateStatus(CertificateStatus<'a>),
2535    MessageHash(Payload<'a>),
2536    Unknown((HandshakeType, Payload<'a>)),
2537}
2538
2539impl HandshakePayload<'_> {
2540    fn encode(&self, bytes: &mut Vec<u8>) {
2541        use self::HandshakePayload::*;
2542        match self {
2543            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2544            ClientHello(x) => x.encode(bytes),
2545            ServerHello(x) => x.encode(bytes),
2546            HelloRetryRequest(x) => x.encode(bytes),
2547            Certificate(x) => x.encode(bytes),
2548            CertificateTls13(x) => x.encode(bytes),
2549            CompressedCertificate(x) => x.encode(bytes),
2550            ServerKeyExchange(x) => x.encode(bytes),
2551            ClientKeyExchange(x) => x.encode(bytes),
2552            CertificateRequest(x) => x.encode(bytes),
2553            CertificateRequestTls13(x) => x.encode(bytes),
2554            CertificateVerify(x) => x.encode(bytes),
2555            NewSessionTicket(x) => x.encode(bytes),
2556            NewSessionTicketTls13(x) => x.encode(bytes),
2557            EncryptedExtensions(x) => x.encode(bytes),
2558            KeyUpdate(x) => x.encode(bytes),
2559            Finished(x) => x.encode(bytes),
2560            CertificateStatus(x) => x.encode(bytes),
2561            MessageHash(x) => x.encode(bytes),
2562            Unknown((_, x)) => x.encode(bytes),
2563        }
2564    }
2565
2566    pub(crate) fn handshake_type(&self) -> HandshakeType {
2567        use self::HandshakePayload::*;
2568        match self {
2569            HelloRequest => HandshakeType::HelloRequest,
2570            ClientHello(_) => HandshakeType::ClientHello,
2571            ServerHello(_) => HandshakeType::ServerHello,
2572            HelloRetryRequest(_) => HandshakeType::HelloRetryRequest,
2573            Certificate(_) | CertificateTls13(_) => HandshakeType::Certificate,
2574            CompressedCertificate(_) => HandshakeType::CompressedCertificate,
2575            ServerKeyExchange(_) => HandshakeType::ServerKeyExchange,
2576            CertificateRequest(_) | CertificateRequestTls13(_) => HandshakeType::CertificateRequest,
2577            CertificateVerify(_) => HandshakeType::CertificateVerify,
2578            ServerHelloDone => HandshakeType::ServerHelloDone,
2579            EndOfEarlyData => HandshakeType::EndOfEarlyData,
2580            ClientKeyExchange(_) => HandshakeType::ClientKeyExchange,
2581            NewSessionTicket(_) | NewSessionTicketTls13(_) => HandshakeType::NewSessionTicket,
2582            EncryptedExtensions(_) => HandshakeType::EncryptedExtensions,
2583            KeyUpdate(_) => HandshakeType::KeyUpdate,
2584            Finished(_) => HandshakeType::Finished,
2585            CertificateStatus(_) => HandshakeType::CertificateStatus,
2586            MessageHash(_) => HandshakeType::MessageHash,
2587            Unknown((t, _)) => *t,
2588        }
2589    }
2590
2591    fn wire_handshake_type(&self) -> HandshakeType {
2592        match self.handshake_type() {
2593            // A `HelloRetryRequest` appears on the wire as a `ServerHello` with a magic `random` value.
2594            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2595            other => other,
2596        }
2597    }
2598
2599    fn into_owned(self) -> HandshakePayload<'static> {
2600        use HandshakePayload::*;
2601
2602        match self {
2603            HelloRequest => HelloRequest,
2604            ClientHello(x) => ClientHello(x),
2605            ServerHello(x) => ServerHello(x),
2606            HelloRetryRequest(x) => HelloRetryRequest(x),
2607            Certificate(x) => Certificate(x.into_owned()),
2608            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2609            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2610            ServerKeyExchange(x) => ServerKeyExchange(x),
2611            CertificateRequest(x) => CertificateRequest(x),
2612            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2613            CertificateVerify(x) => CertificateVerify(x),
2614            ServerHelloDone => ServerHelloDone,
2615            EndOfEarlyData => EndOfEarlyData,
2616            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2617            NewSessionTicket(x) => NewSessionTicket(x),
2618            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2619            EncryptedExtensions(x) => EncryptedExtensions(Box::new(x.into_owned())),
2620            KeyUpdate(x) => KeyUpdate(x),
2621            Finished(x) => Finished(x.into_owned()),
2622            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2623            MessageHash(x) => MessageHash(x.into_owned()),
2624            Unknown((t, x)) => Unknown((t, x.into_owned())),
2625        }
2626    }
2627}
2628
2629#[derive(Debug)]
2630pub struct HandshakeMessagePayload<'a>(pub(crate) HandshakePayload<'a>);
2631
2632impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2633    fn encode(&self, bytes: &mut Vec<u8>) {
2634        self.payload_encode(bytes, Encoding::Standard);
2635    }
2636
2637    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2638        Self::read_version(r, ProtocolVersion::TLSv1_2)
2639    }
2640}
2641
2642impl<'a> HandshakeMessagePayload<'a> {
2643    pub(crate) fn read_version(
2644        r: &mut Reader<'a>,
2645        vers: ProtocolVersion,
2646    ) -> Result<Self, InvalidMessage> {
2647        let typ = HandshakeType::read(r)?;
2648        let len = codec::u24::read(r)?.0 as usize;
2649        let mut sub = r.sub(len)?;
2650
2651        let payload = match typ {
2652            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2653            HandshakeType::ClientHello => {
2654                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2655            }
2656            HandshakeType::ServerHello => {
2657                let version = ProtocolVersion::read(&mut sub)?;
2658                let random = Random::read(&mut sub)?;
2659
2660                if random == HELLO_RETRY_REQUEST_RANDOM {
2661                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2662                    hrr.legacy_version = version;
2663                    HandshakePayload::HelloRetryRequest(hrr)
2664                } else {
2665                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2666                    shp.legacy_version = version;
2667                    shp.random = random;
2668                    HandshakePayload::ServerHello(shp)
2669                }
2670            }
2671            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2672                let p = CertificatePayloadTls13::read(&mut sub)?;
2673                HandshakePayload::CertificateTls13(p)
2674            }
2675            HandshakeType::Certificate => {
2676                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2677            }
2678            HandshakeType::ServerKeyExchange => {
2679                let p = ServerKeyExchangePayload::read(&mut sub)?;
2680                HandshakePayload::ServerKeyExchange(p)
2681            }
2682            HandshakeType::ServerHelloDone => {
2683                sub.expect_empty("ServerHelloDone")?;
2684                HandshakePayload::ServerHelloDone
2685            }
2686            HandshakeType::ClientKeyExchange => {
2687                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2688            }
2689            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2690                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2691                HandshakePayload::CertificateRequestTls13(p)
2692            }
2693            HandshakeType::CertificateRequest => {
2694                let p = CertificateRequestPayload::read(&mut sub)?;
2695                HandshakePayload::CertificateRequest(p)
2696            }
2697            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2698                CompressedCertificatePayload::read(&mut sub)?,
2699            ),
2700            HandshakeType::CertificateVerify => {
2701                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2702            }
2703            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2704                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2705                HandshakePayload::NewSessionTicketTls13(p)
2706            }
2707            HandshakeType::NewSessionTicket => {
2708                let p = NewSessionTicketPayload::read(&mut sub)?;
2709                HandshakePayload::NewSessionTicket(p)
2710            }
2711            HandshakeType::EncryptedExtensions => {
2712                HandshakePayload::EncryptedExtensions(Box::new(ServerExtensions::read(&mut sub)?))
2713            }
2714            HandshakeType::KeyUpdate => {
2715                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2716            }
2717            HandshakeType::EndOfEarlyData => {
2718                sub.expect_empty("EndOfEarlyData")?;
2719                HandshakePayload::EndOfEarlyData
2720            }
2721            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2722            HandshakeType::CertificateStatus => {
2723                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2724            }
2725            HandshakeType::MessageHash => {
2726                // does not appear on the wire
2727                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2728            }
2729            HandshakeType::HelloRetryRequest => {
2730                // not legal on wire
2731                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2732            }
2733            _ => HandshakePayload::Unknown((typ, Payload::read(&mut sub))),
2734        };
2735
2736        sub.expect_empty("HandshakeMessagePayload")
2737            .map(|_| Self(payload))
2738    }
2739
2740    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2741        let mut ret = self.get_encoding();
2742        let ret_len = ret.len() - self.total_binder_length();
2743        ret.truncate(ret_len);
2744        ret
2745    }
2746
2747    pub(crate) fn total_binder_length(&self) -> usize {
2748        match &self.0 {
2749            HandshakePayload::ClientHello(ch) => match &ch.preshared_key_offer {
2750                Some(offer) => {
2751                    let mut binders_encoding = Vec::new();
2752                    offer
2753                        .binders
2754                        .encode(&mut binders_encoding);
2755                    binders_encoding.len()
2756                }
2757                _ => 0,
2758            },
2759            _ => 0,
2760        }
2761    }
2762
2763    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2764        // output type, length, and encoded payload
2765        self.0
2766            .wire_handshake_type()
2767            .encode(bytes);
2768
2769        let nested = LengthPrefixedBuffer::new(
2770            ListLength::U24 {
2771                max: usize::MAX,
2772                error: InvalidMessage::MessageTooLarge,
2773            },
2774            bytes,
2775        );
2776
2777        match &self.0 {
2778            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
2779            // differently based on the purpose of the encoding.
2780            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
2781            HandshakePayload::HelloRetryRequest(payload) => {
2782                payload.payload_encode(nested.buf, encoding)
2783            }
2784
2785            // All other payload types are encoded the same regardless of purpose.
2786            _ => self.0.encode(nested.buf),
2787        }
2788    }
2789
2790    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2791        Self(HandshakePayload::MessageHash(Payload::new(hash.to_vec())))
2792    }
2793
2794    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
2795        HandshakeMessagePayload(self.0.into_owned())
2796    }
2797}
2798
2799#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
2800pub struct HpkeSymmetricCipherSuite {
2801    pub kdf_id: HpkeKdf,
2802    pub aead_id: HpkeAead,
2803}
2804
2805impl Codec<'_> for HpkeSymmetricCipherSuite {
2806    fn encode(&self, bytes: &mut Vec<u8>) {
2807        self.kdf_id.encode(bytes);
2808        self.aead_id.encode(bytes);
2809    }
2810
2811    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2812        Ok(Self {
2813            kdf_id: HpkeKdf::read(r)?,
2814            aead_id: HpkeAead::read(r)?,
2815        })
2816    }
2817}
2818
2819/// draft-ietf-tls-esni-24: `HpkeSymmetricCipherSuite cipher_suites<4..2^16-4>;`
2820impl TlsListElement for HpkeSymmetricCipherSuite {
2821    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
2822        empty_error: InvalidMessage::IllegalEmptyList("HpkeSymmetricCipherSuites"),
2823    };
2824}
2825
2826#[derive(Clone, Debug, PartialEq)]
2827pub struct HpkeKeyConfig {
2828    pub config_id: u8,
2829    pub kem_id: HpkeKem,
2830    /// draft-ietf-tls-esni-24: `opaque HpkePublicKey<1..2^16-1>;`
2831    pub public_key: PayloadU16<NonEmpty>,
2832    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
2833}
2834
2835impl Codec<'_> for HpkeKeyConfig {
2836    fn encode(&self, bytes: &mut Vec<u8>) {
2837        self.config_id.encode(bytes);
2838        self.kem_id.encode(bytes);
2839        self.public_key.encode(bytes);
2840        self.symmetric_cipher_suites
2841            .encode(bytes);
2842    }
2843
2844    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2845        Ok(Self {
2846            config_id: u8::read(r)?,
2847            kem_id: HpkeKem::read(r)?,
2848            public_key: PayloadU16::read(r)?,
2849            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
2850        })
2851    }
2852}
2853
2854#[derive(Clone, Debug, PartialEq)]
2855pub struct EchConfigContents {
2856    pub key_config: HpkeKeyConfig,
2857    pub maximum_name_length: u8,
2858    pub public_name: DnsName<'static>,
2859    pub extensions: Vec<EchConfigExtension>,
2860}
2861
2862impl EchConfigContents {
2863    /// Returns true if there is more than one extension of a given
2864    /// type.
2865    pub(crate) fn has_duplicate_extension(&self) -> bool {
2866        has_duplicates::<_, _, u16>(
2867            self.extensions
2868                .iter()
2869                .map(|ext| ext.ext_type()),
2870        )
2871    }
2872
2873    /// Returns true if there is at least one mandatory unsupported extension.
2874    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
2875        self.extensions
2876            .iter()
2877            // An extension is considered mandatory if the high bit of its type is set.
2878            .any(|ext| {
2879                matches!(ext.ext_type(), ExtensionType::Unknown(_))
2880                    && u16::from(ext.ext_type()) & 0x8000 != 0
2881            })
2882    }
2883}
2884
2885impl Codec<'_> for EchConfigContents {
2886    fn encode(&self, bytes: &mut Vec<u8>) {
2887        self.key_config.encode(bytes);
2888        self.maximum_name_length.encode(bytes);
2889        let dns_name = &self.public_name.borrow();
2890        PayloadU8::<MaybeEmpty>::encode_slice(dns_name.as_ref().as_ref(), bytes);
2891        self.extensions.encode(bytes);
2892    }
2893
2894    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2895        Ok(Self {
2896            key_config: HpkeKeyConfig::read(r)?,
2897            maximum_name_length: u8::read(r)?,
2898            public_name: {
2899                DnsName::try_from(
2900                    PayloadU8::<MaybeEmpty>::read(r)?
2901                        .0
2902                        .as_slice(),
2903                )
2904                .map_err(|_| InvalidMessage::InvalidServerName)?
2905                .to_owned()
2906            },
2907            extensions: Vec::read(r)?,
2908        })
2909    }
2910}
2911
2912/// An encrypted client hello (ECH) config.
2913#[derive(Clone, Debug, PartialEq)]
2914pub enum EchConfigPayload {
2915    /// A recognized V18 ECH configuration.
2916    V18(EchConfigContents),
2917    /// An unknown version ECH configuration.
2918    Unknown {
2919        version: EchVersion,
2920        contents: PayloadU16,
2921    },
2922}
2923
2924impl TlsListElement for EchConfigPayload {
2925    const SIZE_LEN: ListLength = ListLength::U16;
2926}
2927
2928impl Codec<'_> for EchConfigPayload {
2929    fn encode(&self, bytes: &mut Vec<u8>) {
2930        match self {
2931            Self::V18(c) => {
2932                // Write the version, the length, and the contents.
2933                EchVersion::V18.encode(bytes);
2934                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2935                c.encode(inner.buf);
2936            }
2937            Self::Unknown { version, contents } => {
2938                // Unknown configuration versions are opaque.
2939                version.encode(bytes);
2940                contents.encode(bytes);
2941            }
2942        }
2943    }
2944
2945    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2946        let version = EchVersion::read(r)?;
2947        let length = u16::read(r)?;
2948        let mut contents = r.sub(length as usize)?;
2949
2950        Ok(match version {
2951            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
2952            _ => {
2953                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
2954                let data = PayloadU16::new(contents.rest().into());
2955                Self::Unknown {
2956                    version,
2957                    contents: data,
2958                }
2959            }
2960        })
2961    }
2962}
2963
2964#[derive(Clone, Debug, PartialEq)]
2965pub enum EchConfigExtension {
2966    Unknown(UnknownExtension),
2967}
2968
2969impl EchConfigExtension {
2970    pub(crate) fn ext_type(&self) -> ExtensionType {
2971        match self {
2972            Self::Unknown(r) => r.typ,
2973        }
2974    }
2975}
2976
2977impl Codec<'_> for EchConfigExtension {
2978    fn encode(&self, bytes: &mut Vec<u8>) {
2979        self.ext_type().encode(bytes);
2980
2981        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2982        match self {
2983            Self::Unknown(r) => r.encode(nested.buf),
2984        }
2985    }
2986
2987    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2988        let typ = ExtensionType::read(r)?;
2989        let len = u16::read(r)? as usize;
2990        let mut sub = r.sub(len)?;
2991
2992        #[allow(clippy::match_single_binding)] // Future-proofing.
2993        let ext = match typ {
2994            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2995        };
2996
2997        sub.expect_empty("EchConfigExtension")
2998            .map(|_| ext)
2999    }
3000}
3001
3002impl TlsListElement for EchConfigExtension {
3003    const SIZE_LEN: ListLength = ListLength::U16;
3004}
3005
3006/// Representation of the `ECHClientHello` client extension specified in
3007/// [draft-ietf-tls-esni Section 5].
3008///
3009/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3010#[derive(Clone, Debug)]
3011pub(crate) enum EncryptedClientHello {
3012    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
3013    Outer(EncryptedClientHelloOuter),
3014    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
3015    ///
3016    /// This variant has no payload.
3017    Inner,
3018}
3019
3020impl Codec<'_> for EncryptedClientHello {
3021    fn encode(&self, bytes: &mut Vec<u8>) {
3022        match self {
3023            Self::Outer(payload) => {
3024                EchClientHelloType::ClientHelloOuter.encode(bytes);
3025                payload.encode(bytes);
3026            }
3027            Self::Inner => {
3028                EchClientHelloType::ClientHelloInner.encode(bytes);
3029                // Empty payload.
3030            }
3031        }
3032    }
3033
3034    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3035        match EchClientHelloType::read(r)? {
3036            EchClientHelloType::ClientHelloOuter => {
3037                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
3038            }
3039            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
3040            _ => Err(InvalidMessage::InvalidContentType),
3041        }
3042    }
3043}
3044
3045/// Representation of the ECHClientHello extension with type outer specified in
3046/// [draft-ietf-tls-esni Section 5].
3047///
3048/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3049#[derive(Clone, Debug)]
3050pub(crate) struct EncryptedClientHelloOuter {
3051    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3052    /// ECHConfigContents.cipher_suites list.
3053    pub cipher_suite: HpkeSymmetricCipherSuite,
3054    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3055    pub config_id: u8,
3056    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3057    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3058    pub enc: PayloadU16,
3059    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3060    pub payload: PayloadU16<NonEmpty>,
3061}
3062
3063impl Codec<'_> for EncryptedClientHelloOuter {
3064    fn encode(&self, bytes: &mut Vec<u8>) {
3065        self.cipher_suite.encode(bytes);
3066        self.config_id.encode(bytes);
3067        self.enc.encode(bytes);
3068        self.payload.encode(bytes);
3069    }
3070
3071    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3072        Ok(Self {
3073            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3074            config_id: u8::read(r)?,
3075            enc: PayloadU16::read(r)?,
3076            payload: PayloadU16::read(r)?,
3077        })
3078    }
3079}
3080
3081/// Representation of the ECHEncryptedExtensions extension specified in
3082/// [draft-ietf-tls-esni Section 5].
3083///
3084/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3085#[derive(Clone, Debug)]
3086pub(crate) struct ServerEncryptedClientHello {
3087    pub(crate) retry_configs: Vec<EchConfigPayload>,
3088}
3089
3090impl Codec<'_> for ServerEncryptedClientHello {
3091    fn encode(&self, bytes: &mut Vec<u8>) {
3092        self.retry_configs.encode(bytes);
3093    }
3094
3095    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3096        Ok(Self {
3097            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3098        })
3099    }
3100}
3101
3102/// The method of encoding to use for a handshake message.
3103///
3104/// In some cases a handshake message may be encoded differently depending on the purpose
3105/// the encoded message is being used for.
3106pub(crate) enum Encoding {
3107    /// Standard RFC 8446 encoding.
3108    Standard,
3109    /// Encoding for ECH confirmation for HRR.
3110    EchConfirmation,
3111    /// Encoding for ECH inner client hello.
3112    EchInnerHello { to_compress: Vec<ExtensionType> },
3113}
3114
3115fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3116    let mut seen = BTreeSet::new();
3117
3118    for x in iter {
3119        if !seen.insert(x.into()) {
3120            return true;
3121        }
3122    }
3123
3124    false
3125}
3126
3127struct DuplicateExtensionChecker(BTreeSet<u16>);
3128
3129impl DuplicateExtensionChecker {
3130    fn new() -> Self {
3131        Self(BTreeSet::new())
3132    }
3133
3134    fn check(&mut self, typ: ExtensionType) -> Result<(), InvalidMessage> {
3135        let u = u16::from(typ);
3136        match self.0.insert(u) {
3137            true => Ok(()),
3138            false => Err(InvalidMessage::DuplicateExtension(u)),
3139        }
3140    }
3141}
3142
3143fn low_quality_integer_hash(mut x: u32) -> u32 {
3144    x = x
3145        .wrapping_add(0x7ed55d16)
3146        .wrapping_add(x << 12);
3147    x = (x ^ 0xc761c23c) ^ (x >> 19);
3148    x = x
3149        .wrapping_add(0x165667b1)
3150        .wrapping_add(x << 5);
3151    x = x.wrapping_add(0xd3a2646c) ^ (x << 9);
3152    x = x
3153        .wrapping_add(0xfd7046c5)
3154        .wrapping_add(x << 3);
3155    x = (x ^ 0xb55a4f09) ^ (x >> 16);
3156    x
3157}
3158
3159#[cfg(test)]
3160mod tests {
3161    use super::*;
3162
3163    #[test]
3164    fn test_ech_config_dupe_exts() {
3165        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3166            typ: ExtensionType::Unknown(0x42),
3167            payload: Payload::new(vec![0x42]),
3168        });
3169        let mut config = config_template();
3170        config
3171            .extensions
3172            .push(unknown_ext.clone());
3173        config.extensions.push(unknown_ext);
3174
3175        assert!(config.has_duplicate_extension());
3176        assert!(!config.has_unknown_mandatory_extension());
3177    }
3178
3179    #[test]
3180    fn test_ech_config_mandatory_exts() {
3181        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3182            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3183            payload: Payload::new(vec![0x42]),
3184        });
3185        let mut config = config_template();
3186        config
3187            .extensions
3188            .push(mandatory_unknown_ext);
3189
3190        assert!(!config.has_duplicate_extension());
3191        assert!(config.has_unknown_mandatory_extension());
3192    }
3193
3194    fn config_template() -> EchConfigContents {
3195        EchConfigContents {
3196            key_config: HpkeKeyConfig {
3197                config_id: 0,
3198                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3199                public_key: PayloadU16::new(b"xxx".into()),
3200                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3201                    kdf_id: HpkeKdf::HKDF_SHA256,
3202                    aead_id: HpkeAead::AES_128_GCM,
3203                }],
3204            },
3205            maximum_name_length: 0,
3206            public_name: DnsName::try_from("example.com").unwrap(),
3207            extensions: vec![],
3208        }
3209    }
3210}