rustls/msgs/
handshake.rs

1use alloc::collections::BTreeSet;
2#[cfg(feature = "logging")]
3use alloc::string::String;
4use alloc::vec;
5use alloc::vec::Vec;
6use core::ops::Deref;
7use core::{fmt, iter};
8
9use pki_types::{CertificateDer, DnsName};
10
11#[cfg(feature = "tls12")]
12use crate::crypto::ActiveKeyExchange;
13use crate::crypto::SecureRandom;
14use crate::enums::{
15    CertificateCompressionAlgorithm, CipherSuite, EchClientHelloType, HandshakeType,
16    ProtocolVersion, SignatureScheme,
17};
18use crate::error::InvalidMessage;
19#[cfg(feature = "tls12")]
20use crate::ffdhe_groups::FfdheGroup;
21use crate::log::warn;
22use crate::msgs::base::{Payload, PayloadU8, PayloadU16, PayloadU24};
23use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement};
24use crate::msgs::enums::{
25    CertificateStatusType, CertificateType, ClientCertificateType, Compression, ECCurveType,
26    ECPointFormat, EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest,
27    NamedGroup, PSKKeyExchangeMode, ServerNameType,
28};
29use crate::rand;
30use crate::sync::Arc;
31use crate::verify::DigitallySignedStruct;
32use crate::x509::wrap_in_sequence;
33
34/// Create a newtype wrapper around a given type.
35///
36/// This is used to create newtypes for the various TLS message types which is used to wrap
37/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
38/// anything other than access to the underlying bytes.
39macro_rules! wrapped_payload(
40  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident,) => {
41    $(#[$comment])*
42    #[derive(Clone, Debug)]
43    $vis struct $name($inner);
44
45    impl From<Vec<u8>> for $name {
46        fn from(v: Vec<u8>) -> Self {
47            Self($inner::new(v))
48        }
49    }
50
51    impl AsRef<[u8]> for $name {
52        fn as_ref(&self) -> &[u8] {
53            self.0.0.as_slice()
54        }
55    }
56
57    impl Codec<'_> for $name {
58        fn encode(&self, bytes: &mut Vec<u8>) {
59            self.0.encode(bytes);
60        }
61
62        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
63            Ok(Self($inner::read(r)?))
64        }
65    }
66  }
67);
68
69#[derive(Clone, Copy, Eq, PartialEq)]
70pub struct Random(pub(crate) [u8; 32]);
71
72impl fmt::Debug for Random {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        super::base::hex(f, &self.0)
75    }
76}
77
78static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
79    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
80    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
81]);
82
83static ZERO_RANDOM: Random = Random([0u8; 32]);
84
85impl Codec<'_> for Random {
86    fn encode(&self, bytes: &mut Vec<u8>) {
87        bytes.extend_from_slice(&self.0);
88    }
89
90    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
91        let Some(bytes) = r.take(32) else {
92            return Err(InvalidMessage::MissingData("Random"));
93        };
94
95        let mut opaque = [0; 32];
96        opaque.clone_from_slice(bytes);
97        Ok(Self(opaque))
98    }
99}
100
101impl Random {
102    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
103        let mut data = [0u8; 32];
104        secure_random.fill(&mut data)?;
105        Ok(Self(data))
106    }
107}
108
109impl From<[u8; 32]> for Random {
110    #[inline]
111    fn from(bytes: [u8; 32]) -> Self {
112        Self(bytes)
113    }
114}
115
116#[derive(Copy, Clone)]
117pub struct SessionId {
118    len: usize,
119    data: [u8; 32],
120}
121
122impl fmt::Debug for SessionId {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        super::base::hex(f, &self.data[..self.len])
125    }
126}
127
128impl PartialEq for SessionId {
129    fn eq(&self, other: &Self) -> bool {
130        if self.len != other.len {
131            return false;
132        }
133
134        let mut diff = 0u8;
135        for i in 0..self.len {
136            diff |= self.data[i] ^ other.data[i];
137        }
138
139        diff == 0u8
140    }
141}
142
143impl Codec<'_> for SessionId {
144    fn encode(&self, bytes: &mut Vec<u8>) {
145        debug_assert!(self.len <= 32);
146        bytes.push(self.len as u8);
147        bytes.extend_from_slice(self.as_ref());
148    }
149
150    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
151        let len = u8::read(r)? as usize;
152        if len > 32 {
153            return Err(InvalidMessage::TrailingData("SessionID"));
154        }
155
156        let Some(bytes) = r.take(len) else {
157            return Err(InvalidMessage::MissingData("SessionID"));
158        };
159
160        let mut out = [0u8; 32];
161        out[..len].clone_from_slice(&bytes[..len]);
162        Ok(Self { data: out, len })
163    }
164}
165
166impl SessionId {
167    pub fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
168        let mut data = [0u8; 32];
169        secure_random.fill(&mut data)?;
170        Ok(Self { data, len: 32 })
171    }
172
173    pub(crate) fn empty() -> Self {
174        Self {
175            data: [0u8; 32],
176            len: 0,
177        }
178    }
179
180    #[cfg(feature = "tls12")]
181    pub(crate) fn is_empty(&self) -> bool {
182        self.len == 0
183    }
184}
185
186impl AsRef<[u8]> for SessionId {
187    fn as_ref(&self) -> &[u8] {
188        &self.data[..self.len]
189    }
190}
191
192#[derive(Clone, Debug, PartialEq)]
193pub struct UnknownExtension {
194    pub(crate) typ: ExtensionType,
195    pub(crate) payload: Payload<'static>,
196}
197
198impl UnknownExtension {
199    fn encode(&self, bytes: &mut Vec<u8>) {
200        self.payload.encode(bytes);
201    }
202
203    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
204        let payload = Payload::read(r).into_owned();
205        Self { typ, payload }
206    }
207}
208
209impl TlsListElement for ECPointFormat {
210    const SIZE_LEN: ListLength = ListLength::U8;
211}
212
213impl TlsListElement for NamedGroup {
214    const SIZE_LEN: ListLength = ListLength::U16;
215}
216
217impl TlsListElement for SignatureScheme {
218    const SIZE_LEN: ListLength = ListLength::U16;
219}
220
221#[derive(Clone, Debug)]
222pub(crate) enum ServerNamePayload {
223    HostName(DnsName<'static>),
224    IpAddress(PayloadU16),
225    Unknown(Payload<'static>),
226}
227
228impl ServerNamePayload {
229    pub(crate) fn new_hostname(hostname: DnsName<'static>) -> Self {
230        Self::HostName(hostname)
231    }
232
233    fn read_hostname(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
234        use pki_types::ServerName;
235        let raw = PayloadU16::read(r)?;
236
237        match ServerName::try_from(raw.0.as_slice()) {
238            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
239            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
240            Ok(_) | Err(_) => {
241                warn!(
242                    "Illegal SNI hostname received {:?}",
243                    String::from_utf8_lossy(&raw.0)
244                );
245                Err(InvalidMessage::InvalidServerName)
246            }
247        }
248    }
249
250    fn encode(&self, bytes: &mut Vec<u8>) {
251        match self {
252            Self::HostName(name) => {
253                (name.as_ref().len() as u16).encode(bytes);
254                bytes.extend_from_slice(name.as_ref().as_bytes());
255            }
256            Self::IpAddress(r) => r.encode(bytes),
257            Self::Unknown(r) => r.encode(bytes),
258        }
259    }
260}
261
262#[derive(Clone, Debug)]
263pub struct ServerName {
264    pub(crate) typ: ServerNameType,
265    pub(crate) payload: ServerNamePayload,
266}
267
268impl Codec<'_> for ServerName {
269    fn encode(&self, bytes: &mut Vec<u8>) {
270        self.typ.encode(bytes);
271        self.payload.encode(bytes);
272    }
273
274    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
275        let typ = ServerNameType::read(r)?;
276
277        let payload = match typ {
278            ServerNameType::HostName => ServerNamePayload::read_hostname(r)?,
279            _ => ServerNamePayload::Unknown(Payload::read(r).into_owned()),
280        };
281
282        Ok(Self { typ, payload })
283    }
284}
285
286impl TlsListElement for ServerName {
287    const SIZE_LEN: ListLength = ListLength::U16;
288}
289
290pub(crate) trait ConvertServerNameList {
291    fn has_duplicate_names_for_type(&self) -> bool;
292    fn single_hostname(&self) -> Option<DnsName<'_>>;
293}
294
295impl ConvertServerNameList for [ServerName] {
296    /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type."
297    fn has_duplicate_names_for_type(&self) -> bool {
298        has_duplicates::<_, _, u8>(self.iter().map(|name| name.typ))
299    }
300
301    fn single_hostname(&self) -> Option<DnsName<'_>> {
302        fn only_dns_hostnames(name: &ServerName) -> Option<DnsName<'_>> {
303            if let ServerNamePayload::HostName(dns) = &name.payload {
304                Some(dns.borrow())
305            } else {
306                None
307            }
308        }
309
310        self.iter()
311            .filter_map(only_dns_hostnames)
312            .next()
313    }
314}
315
316wrapped_payload!(pub struct ProtocolName, PayloadU8,);
317
318impl TlsListElement for ProtocolName {
319    const SIZE_LEN: ListLength = ListLength::U16;
320}
321
322pub(crate) trait ConvertProtocolNameList {
323    fn from_slices(names: &[&[u8]]) -> Self;
324    fn to_slices(&self) -> Vec<&[u8]>;
325    fn as_single_slice(&self) -> Option<&[u8]>;
326}
327
328impl ConvertProtocolNameList for Vec<ProtocolName> {
329    fn from_slices(names: &[&[u8]]) -> Self {
330        let mut ret = Self::new();
331
332        for name in names {
333            ret.push(ProtocolName::from(name.to_vec()));
334        }
335
336        ret
337    }
338
339    fn to_slices(&self) -> Vec<&[u8]> {
340        self.iter()
341            .map(|proto| proto.as_ref())
342            .collect::<Vec<&[u8]>>()
343    }
344
345    fn as_single_slice(&self) -> Option<&[u8]> {
346        if self.len() == 1 {
347            Some(self[0].as_ref())
348        } else {
349            None
350        }
351    }
352}
353
354// --- TLS 1.3 Key shares ---
355#[derive(Clone, Debug)]
356pub struct KeyShareEntry {
357    pub(crate) group: NamedGroup,
358    pub(crate) payload: PayloadU16,
359}
360
361impl KeyShareEntry {
362    pub fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
363        Self {
364            group,
365            payload: PayloadU16::new(payload.into()),
366        }
367    }
368
369    pub fn group(&self) -> NamedGroup {
370        self.group
371    }
372}
373
374impl Codec<'_> for KeyShareEntry {
375    fn encode(&self, bytes: &mut Vec<u8>) {
376        self.group.encode(bytes);
377        self.payload.encode(bytes);
378    }
379
380    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
381        let group = NamedGroup::read(r)?;
382        let payload = PayloadU16::read(r)?;
383
384        Ok(Self { group, payload })
385    }
386}
387
388// --- TLS 1.3 PresharedKey offers ---
389#[derive(Clone, Debug)]
390pub(crate) struct PresharedKeyIdentity {
391    pub(crate) identity: PayloadU16,
392    pub(crate) obfuscated_ticket_age: u32,
393}
394
395impl PresharedKeyIdentity {
396    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
397        Self {
398            identity: PayloadU16::new(id),
399            obfuscated_ticket_age: age,
400        }
401    }
402}
403
404impl Codec<'_> for PresharedKeyIdentity {
405    fn encode(&self, bytes: &mut Vec<u8>) {
406        self.identity.encode(bytes);
407        self.obfuscated_ticket_age.encode(bytes);
408    }
409
410    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
411        Ok(Self {
412            identity: PayloadU16::read(r)?,
413            obfuscated_ticket_age: u32::read(r)?,
414        })
415    }
416}
417
418impl TlsListElement for PresharedKeyIdentity {
419    const SIZE_LEN: ListLength = ListLength::U16;
420}
421
422wrapped_payload!(pub(crate) struct PresharedKeyBinder, PayloadU8,);
423
424impl TlsListElement for PresharedKeyBinder {
425    const SIZE_LEN: ListLength = ListLength::U16;
426}
427
428#[derive(Clone, Debug)]
429pub struct PresharedKeyOffer {
430    pub(crate) identities: Vec<PresharedKeyIdentity>,
431    pub(crate) binders: Vec<PresharedKeyBinder>,
432}
433
434impl PresharedKeyOffer {
435    /// Make a new one with one entry.
436    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
437        Self {
438            identities: vec![id],
439            binders: vec![PresharedKeyBinder::from(binder)],
440        }
441    }
442}
443
444impl Codec<'_> for PresharedKeyOffer {
445    fn encode(&self, bytes: &mut Vec<u8>) {
446        self.identities.encode(bytes);
447        self.binders.encode(bytes);
448    }
449
450    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
451        Ok(Self {
452            identities: Vec::read(r)?,
453            binders: Vec::read(r)?,
454        })
455    }
456}
457
458// --- RFC6066 certificate status request ---
459wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
460
461impl TlsListElement for ResponderId {
462    const SIZE_LEN: ListLength = ListLength::U16;
463}
464
465#[derive(Clone, Debug)]
466pub struct OcspCertificateStatusRequest {
467    pub(crate) responder_ids: Vec<ResponderId>,
468    pub(crate) extensions: PayloadU16,
469}
470
471impl Codec<'_> for OcspCertificateStatusRequest {
472    fn encode(&self, bytes: &mut Vec<u8>) {
473        CertificateStatusType::OCSP.encode(bytes);
474        self.responder_ids.encode(bytes);
475        self.extensions.encode(bytes);
476    }
477
478    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
479        Ok(Self {
480            responder_ids: Vec::read(r)?,
481            extensions: PayloadU16::read(r)?,
482        })
483    }
484}
485
486#[derive(Clone, Debug)]
487pub enum CertificateStatusRequest {
488    Ocsp(OcspCertificateStatusRequest),
489    Unknown((CertificateStatusType, Payload<'static>)),
490}
491
492impl Codec<'_> for CertificateStatusRequest {
493    fn encode(&self, bytes: &mut Vec<u8>) {
494        match self {
495            Self::Ocsp(r) => r.encode(bytes),
496            Self::Unknown((typ, payload)) => {
497                typ.encode(bytes);
498                payload.encode(bytes);
499            }
500        }
501    }
502
503    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
504        let typ = CertificateStatusType::read(r)?;
505
506        match typ {
507            CertificateStatusType::OCSP => {
508                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
509                Ok(Self::Ocsp(ocsp_req))
510            }
511            _ => {
512                let data = Payload::read(r).into_owned();
513                Ok(Self::Unknown((typ, data)))
514            }
515        }
516    }
517}
518
519impl CertificateStatusRequest {
520    pub(crate) fn build_ocsp() -> Self {
521        let ocsp = OcspCertificateStatusRequest {
522            responder_ids: Vec::new(),
523            extensions: PayloadU16::empty(),
524        };
525        Self::Ocsp(ocsp)
526    }
527}
528
529// ---
530
531impl TlsListElement for PSKKeyExchangeMode {
532    const SIZE_LEN: ListLength = ListLength::U8;
533}
534
535impl TlsListElement for KeyShareEntry {
536    const SIZE_LEN: ListLength = ListLength::U16;
537}
538
539impl TlsListElement for ProtocolVersion {
540    const SIZE_LEN: ListLength = ListLength::U8;
541}
542
543impl TlsListElement for CertificateType {
544    const SIZE_LEN: ListLength = ListLength::U8;
545}
546
547impl TlsListElement for CertificateCompressionAlgorithm {
548    const SIZE_LEN: ListLength = ListLength::U8;
549}
550
551#[derive(Clone, Debug)]
552pub enum ClientExtension {
553    EcPointFormats(Vec<ECPointFormat>),
554    NamedGroups(Vec<NamedGroup>),
555    SignatureAlgorithms(Vec<SignatureScheme>),
556    ServerName(Vec<ServerName>),
557    SessionTicket(ClientSessionTicket),
558    Protocols(Vec<ProtocolName>),
559    SupportedVersions(Vec<ProtocolVersion>),
560    KeyShare(Vec<KeyShareEntry>),
561    PresharedKeyModes(Vec<PSKKeyExchangeMode>),
562    PresharedKey(PresharedKeyOffer),
563    Cookie(PayloadU16),
564    ExtendedMasterSecretRequest,
565    CertificateStatusRequest(CertificateStatusRequest),
566    ServerCertTypes(Vec<CertificateType>),
567    ClientCertTypes(Vec<CertificateType>),
568    TransportParameters(Vec<u8>),
569    TransportParametersDraft(Vec<u8>),
570    EarlyData,
571    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
572    EncryptedClientHello(EncryptedClientHello),
573    EncryptedClientHelloOuterExtensions(Vec<ExtensionType>),
574    AuthorityNames(Vec<DistinguishedName>),
575    Unknown(UnknownExtension),
576}
577
578impl ClientExtension {
579    pub(crate) fn ext_type(&self) -> ExtensionType {
580        match self {
581            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
582            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
583            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
584            Self::ServerName(_) => ExtensionType::ServerName,
585            Self::SessionTicket(_) => ExtensionType::SessionTicket,
586            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
587            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
588            Self::KeyShare(_) => ExtensionType::KeyShare,
589            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
590            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
591            Self::Cookie(_) => ExtensionType::Cookie,
592            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
593            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
594            Self::ClientCertTypes(_) => ExtensionType::ClientCertificateType,
595            Self::ServerCertTypes(_) => ExtensionType::ServerCertificateType,
596            Self::TransportParameters(_) => ExtensionType::TransportParameters,
597            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
598            Self::EarlyData => ExtensionType::EarlyData,
599            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
600            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
601            Self::EncryptedClientHelloOuterExtensions(_) => {
602                ExtensionType::EncryptedClientHelloOuterExtensions
603            }
604            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
605            Self::Unknown(r) => r.typ,
606        }
607    }
608}
609
610impl Codec<'_> for ClientExtension {
611    fn encode(&self, bytes: &mut Vec<u8>) {
612        self.ext_type().encode(bytes);
613
614        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
615        match self {
616            Self::EcPointFormats(r) => r.encode(nested.buf),
617            Self::NamedGroups(r) => r.encode(nested.buf),
618            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
619            Self::ServerName(r) => r.encode(nested.buf),
620            Self::SessionTicket(ClientSessionTicket::Request)
621            | Self::ExtendedMasterSecretRequest
622            | Self::EarlyData => {}
623            Self::SessionTicket(ClientSessionTicket::Offer(r)) => r.encode(nested.buf),
624            Self::Protocols(r) => r.encode(nested.buf),
625            Self::SupportedVersions(r) => r.encode(nested.buf),
626            Self::KeyShare(r) => r.encode(nested.buf),
627            Self::PresharedKeyModes(r) => r.encode(nested.buf),
628            Self::PresharedKey(r) => r.encode(nested.buf),
629            Self::Cookie(r) => r.encode(nested.buf),
630            Self::CertificateStatusRequest(r) => r.encode(nested.buf),
631            Self::ClientCertTypes(r) => r.encode(nested.buf),
632            Self::ServerCertTypes(r) => r.encode(nested.buf),
633            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
634                nested.buf.extend_from_slice(r);
635            }
636            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
637            Self::EncryptedClientHello(r) => r.encode(nested.buf),
638            Self::EncryptedClientHelloOuterExtensions(r) => r.encode(nested.buf),
639            Self::AuthorityNames(r) => r.encode(nested.buf),
640            Self::Unknown(r) => r.encode(nested.buf),
641        }
642    }
643
644    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
645        let typ = ExtensionType::read(r)?;
646        let len = u16::read(r)? as usize;
647        let mut sub = r.sub(len)?;
648
649        let ext = match typ {
650            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
651            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
652            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
653            ExtensionType::ServerName => Self::ServerName(Vec::read(&mut sub)?),
654            ExtensionType::SessionTicket => {
655                if sub.any_left() {
656                    let contents = Payload::read(&mut sub).into_owned();
657                    Self::SessionTicket(ClientSessionTicket::Offer(contents))
658                } else {
659                    Self::SessionTicket(ClientSessionTicket::Request)
660                }
661            }
662            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
663            ExtensionType::SupportedVersions => Self::SupportedVersions(Vec::read(&mut sub)?),
664            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
665            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
666            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
667            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
668            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
669                Self::ExtendedMasterSecretRequest
670            }
671            ExtensionType::ClientCertificateType => Self::ClientCertTypes(Vec::read(&mut sub)?),
672            ExtensionType::ServerCertificateType => Self::ServerCertTypes(Vec::read(&mut sub)?),
673            ExtensionType::StatusRequest => {
674                let csr = CertificateStatusRequest::read(&mut sub)?;
675                Self::CertificateStatusRequest(csr)
676            }
677            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
678            ExtensionType::TransportParametersDraft => {
679                Self::TransportParametersDraft(sub.rest().to_vec())
680            }
681            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
682            ExtensionType::CompressCertificate => {
683                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
684            }
685            ExtensionType::EncryptedClientHelloOuterExtensions => {
686                Self::EncryptedClientHelloOuterExtensions(Vec::read(&mut sub)?)
687            }
688            ExtensionType::CertificateAuthorities => Self::AuthorityNames(Vec::read(&mut sub)?),
689            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
690        };
691
692        sub.expect_empty("ClientExtension")
693            .map(|_| ext)
694    }
695}
696
697fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
698    let dns_name_str = dns_name.as_ref();
699
700    // RFC6066: "The hostname is represented as a byte string using
701    // ASCII encoding without a trailing dot"
702    if dns_name_str.ends_with('.') {
703        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
704        DnsName::try_from(trimmed)
705            .unwrap()
706            .to_owned()
707    } else {
708        dns_name.to_owned()
709    }
710}
711
712impl ClientExtension {
713    /// Make a basic SNI ServerNameRequest quoting `hostname`.
714    pub(crate) fn make_sni(dns_name: &DnsName<'_>) -> Self {
715        let name = ServerName {
716            typ: ServerNameType::HostName,
717            payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)),
718        };
719
720        Self::ServerName(vec![name])
721    }
722}
723
724#[derive(Clone, Debug)]
725pub enum ClientSessionTicket {
726    Request,
727    Offer(Payload<'static>),
728}
729
730#[derive(Clone, Debug)]
731pub enum ServerExtension {
732    EcPointFormats(Vec<ECPointFormat>),
733    ServerNameAck,
734    SessionTicketAck,
735    RenegotiationInfo(PayloadU8),
736    Protocols(Vec<ProtocolName>),
737    KeyShare(KeyShareEntry),
738    PresharedKey(u16),
739    ExtendedMasterSecretAck,
740    CertificateStatusAck,
741    ServerCertType(CertificateType),
742    ClientCertType(CertificateType),
743    SupportedVersions(ProtocolVersion),
744    TransportParameters(Vec<u8>),
745    TransportParametersDraft(Vec<u8>),
746    EarlyData,
747    EncryptedClientHello(ServerEncryptedClientHello),
748    Unknown(UnknownExtension),
749}
750
751impl ServerExtension {
752    pub(crate) fn ext_type(&self) -> ExtensionType {
753        match self {
754            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
755            Self::ServerNameAck => ExtensionType::ServerName,
756            Self::SessionTicketAck => ExtensionType::SessionTicket,
757            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
758            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
759            Self::KeyShare(_) => ExtensionType::KeyShare,
760            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
761            Self::ClientCertType(_) => ExtensionType::ClientCertificateType,
762            Self::ServerCertType(_) => ExtensionType::ServerCertificateType,
763            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
764            Self::CertificateStatusAck => ExtensionType::StatusRequest,
765            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
766            Self::TransportParameters(_) => ExtensionType::TransportParameters,
767            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
768            Self::EarlyData => ExtensionType::EarlyData,
769            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
770            Self::Unknown(r) => r.typ,
771        }
772    }
773}
774
775impl Codec<'_> for ServerExtension {
776    fn encode(&self, bytes: &mut Vec<u8>) {
777        self.ext_type().encode(bytes);
778
779        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
780        match self {
781            Self::EcPointFormats(r) => r.encode(nested.buf),
782            Self::ServerNameAck
783            | Self::SessionTicketAck
784            | Self::ExtendedMasterSecretAck
785            | Self::CertificateStatusAck
786            | Self::EarlyData => {}
787            Self::RenegotiationInfo(r) => r.encode(nested.buf),
788            Self::Protocols(r) => r.encode(nested.buf),
789            Self::KeyShare(r) => r.encode(nested.buf),
790            Self::PresharedKey(r) => r.encode(nested.buf),
791            Self::ClientCertType(r) => r.encode(nested.buf),
792            Self::ServerCertType(r) => r.encode(nested.buf),
793            Self::SupportedVersions(r) => r.encode(nested.buf),
794            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
795                nested.buf.extend_from_slice(r);
796            }
797            Self::EncryptedClientHello(r) => r.encode(nested.buf),
798            Self::Unknown(r) => r.encode(nested.buf),
799        }
800    }
801
802    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
803        let typ = ExtensionType::read(r)?;
804        let len = u16::read(r)? as usize;
805        let mut sub = r.sub(len)?;
806
807        let ext = match typ {
808            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
809            ExtensionType::ServerName => Self::ServerNameAck,
810            ExtensionType::SessionTicket => Self::SessionTicketAck,
811            ExtensionType::StatusRequest => Self::CertificateStatusAck,
812            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
813            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
814            ExtensionType::ClientCertificateType => {
815                Self::ClientCertType(CertificateType::read(&mut sub)?)
816            }
817            ExtensionType::ServerCertificateType => {
818                Self::ServerCertType(CertificateType::read(&mut sub)?)
819            }
820            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
821            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
822            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
823            ExtensionType::SupportedVersions => {
824                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
825            }
826            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
827            ExtensionType::TransportParametersDraft => {
828                Self::TransportParametersDraft(sub.rest().to_vec())
829            }
830            ExtensionType::EarlyData => Self::EarlyData,
831            ExtensionType::EncryptedClientHello => {
832                Self::EncryptedClientHello(ServerEncryptedClientHello::read(&mut sub)?)
833            }
834            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
835        };
836
837        sub.expect_empty("ServerExtension")
838            .map(|_| ext)
839    }
840}
841
842impl ServerExtension {
843    pub(crate) fn make_alpn(proto: &[&[u8]]) -> Self {
844        Self::Protocols(Vec::from_slices(proto))
845    }
846
847    #[cfg(feature = "tls12")]
848    pub(crate) fn make_empty_renegotiation_info() -> Self {
849        let empty = Vec::new();
850        Self::RenegotiationInfo(PayloadU8::new(empty))
851    }
852}
853
854#[derive(Clone, Debug)]
855pub struct ClientHelloPayload {
856    pub client_version: ProtocolVersion,
857    pub random: Random,
858    pub session_id: SessionId,
859    pub cipher_suites: Vec<CipherSuite>,
860    pub compression_methods: Vec<Compression>,
861    pub extensions: Vec<ClientExtension>,
862}
863
864impl Codec<'_> for ClientHelloPayload {
865    fn encode(&self, bytes: &mut Vec<u8>) {
866        self.payload_encode(bytes, Encoding::Standard)
867    }
868
869    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
870        let mut ret = Self {
871            client_version: ProtocolVersion::read(r)?,
872            random: Random::read(r)?,
873            session_id: SessionId::read(r)?,
874            cipher_suites: Vec::read(r)?,
875            compression_methods: Vec::read(r)?,
876            extensions: Vec::new(),
877        };
878
879        if r.any_left() {
880            ret.extensions = Vec::read(r)?;
881        }
882
883        match (r.any_left(), ret.extensions.is_empty()) {
884            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
885            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
886            _ => Ok(ret),
887        }
888    }
889}
890
891impl TlsListElement for CipherSuite {
892    const SIZE_LEN: ListLength = ListLength::U16;
893}
894
895impl TlsListElement for Compression {
896    const SIZE_LEN: ListLength = ListLength::U8;
897}
898
899impl TlsListElement for ClientExtension {
900    const SIZE_LEN: ListLength = ListLength::U16;
901}
902
903impl TlsListElement for ExtensionType {
904    const SIZE_LEN: ListLength = ListLength::U8;
905}
906
907impl ClientHelloPayload {
908    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
909        let mut bytes = Vec::new();
910        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
911        bytes
912    }
913
914    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
915        self.client_version.encode(bytes);
916        self.random.encode(bytes);
917
918        match purpose {
919            // SessionID is required to be empty in the encoded inner client hello.
920            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
921            _ => self.session_id.encode(bytes),
922        }
923
924        self.cipher_suites.encode(bytes);
925        self.compression_methods.encode(bytes);
926
927        let to_compress = match purpose {
928            // Compressed extensions must be replaced in the encoded inner client hello.
929            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
930            _ => {
931                if !self.extensions.is_empty() {
932                    self.extensions.encode(bytes);
933                }
934                return;
935            }
936        };
937
938        // Safety: not empty check in match guard.
939        let first_compressed_type = *to_compress.first().unwrap();
940
941        // Compressed extensions are in a contiguous range and must be replaced
942        // with a marker extension.
943        let compressed_start_idx = self
944            .extensions
945            .iter()
946            .position(|ext| ext.ext_type() == first_compressed_type);
947        let compressed_end_idx = compressed_start_idx.map(|start| start + to_compress.len());
948        let marker_ext = ClientExtension::EncryptedClientHelloOuterExtensions(to_compress);
949
950        let exts = self
951            .extensions
952            .iter()
953            .enumerate()
954            .filter_map(|(i, ext)| {
955                if Some(i) == compressed_start_idx {
956                    Some(&marker_ext)
957                } else if Some(i) > compressed_start_idx && Some(i) < compressed_end_idx {
958                    None
959                } else {
960                    Some(ext)
961                }
962            });
963
964        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
965        for ext in exts {
966            ext.encode(nested.buf);
967        }
968    }
969
970    /// Returns true if there is more than one extension of a given
971    /// type.
972    pub(crate) fn has_duplicate_extension(&self) -> bool {
973        has_duplicates::<_, _, u16>(
974            self.extensions
975                .iter()
976                .map(|ext| ext.ext_type()),
977        )
978    }
979
980    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
981        self.extensions
982            .iter()
983            .find(|x| x.ext_type() == ext)
984    }
985
986    pub(crate) fn sni_extension(&self) -> Option<&[ServerName]> {
987        let ext = self.find_extension(ExtensionType::ServerName)?;
988        match ext {
989            // Does this comply with RFC6066?
990            //
991            // [RFC6066][] specifies that literal IP addresses are illegal in
992            // `ServerName`s with a `name_type` of `host_name`.
993            //
994            // Some clients incorrectly send such extensions: we choose to
995            // successfully parse these (into `ServerNamePayload::IpAddress`)
996            // but then act like the client sent no `server_name` extension.
997            //
998            // [RFC6066]: https://datatracker.ietf.org/doc/html/rfc6066#section-3
999            ClientExtension::ServerName(req)
1000                if !req
1001                    .iter()
1002                    .any(|name| matches!(name.payload, ServerNamePayload::IpAddress(_))) =>
1003            {
1004                Some(req)
1005            }
1006            _ => None,
1007        }
1008    }
1009
1010    pub fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
1011        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
1012        match ext {
1013            ClientExtension::SignatureAlgorithms(req) => Some(req),
1014            _ => None,
1015        }
1016    }
1017
1018    pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> {
1019        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
1020        match ext {
1021            ClientExtension::NamedGroups(req) => Some(req),
1022            _ => None,
1023        }
1024    }
1025
1026    #[cfg(feature = "tls12")]
1027    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1028        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1029        match ext {
1030            ClientExtension::EcPointFormats(req) => Some(req),
1031            _ => None,
1032        }
1033    }
1034
1035    pub(crate) fn server_certificate_extension(&self) -> Option<&[CertificateType]> {
1036        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
1037        match ext {
1038            ClientExtension::ServerCertTypes(req) => Some(req),
1039            _ => None,
1040        }
1041    }
1042
1043    pub(crate) fn client_certificate_extension(&self) -> Option<&[CertificateType]> {
1044        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
1045        match ext {
1046            ClientExtension::ClientCertTypes(req) => Some(req),
1047            _ => None,
1048        }
1049    }
1050
1051    pub(crate) fn alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
1052        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1053        match ext {
1054            ClientExtension::Protocols(req) => Some(req),
1055            _ => None,
1056        }
1057    }
1058
1059    pub(crate) fn quic_params_extension(&self) -> Option<Vec<u8>> {
1060        let ext = self
1061            .find_extension(ExtensionType::TransportParameters)
1062            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1063        match ext {
1064            ClientExtension::TransportParameters(bytes)
1065            | ClientExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
1066            _ => None,
1067        }
1068    }
1069
1070    #[cfg(feature = "tls12")]
1071    pub(crate) fn ticket_extension(&self) -> Option<&ClientExtension> {
1072        self.find_extension(ExtensionType::SessionTicket)
1073    }
1074
1075    pub(crate) fn versions_extension(&self) -> Option<&[ProtocolVersion]> {
1076        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1077        match ext {
1078            ClientExtension::SupportedVersions(vers) => Some(vers),
1079            _ => None,
1080        }
1081    }
1082
1083    pub fn keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
1084        let ext = self.find_extension(ExtensionType::KeyShare)?;
1085        match ext {
1086            ClientExtension::KeyShare(shares) => Some(shares),
1087            _ => None,
1088        }
1089    }
1090
1091    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1092        self.keyshare_extension()
1093            .map(|entries| {
1094                has_duplicates::<_, _, u16>(
1095                    entries
1096                        .iter()
1097                        .map(|kse| u16::from(kse.group)),
1098                )
1099            })
1100            .unwrap_or_default()
1101    }
1102
1103    pub(crate) fn psk(&self) -> Option<&PresharedKeyOffer> {
1104        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1105        match ext {
1106            ClientExtension::PresharedKey(psk) => Some(psk),
1107            _ => None,
1108        }
1109    }
1110
1111    pub(crate) fn check_psk_ext_is_last(&self) -> bool {
1112        self.extensions
1113            .last()
1114            .is_some_and(|ext| ext.ext_type() == ExtensionType::PreSharedKey)
1115    }
1116
1117    pub(crate) fn psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> {
1118        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
1119        match ext {
1120            ClientExtension::PresharedKeyModes(psk_modes) => Some(psk_modes),
1121            _ => None,
1122        }
1123    }
1124
1125    pub(crate) fn psk_mode_offered(&self, mode: PSKKeyExchangeMode) -> bool {
1126        self.psk_modes()
1127            .map(|modes| modes.contains(&mode))
1128            .unwrap_or(false)
1129    }
1130
1131    pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
1132        let last_extension = self.extensions.last_mut();
1133        if let Some(ClientExtension::PresharedKey(offer)) = last_extension {
1134            offer.binders[0] = PresharedKeyBinder::from(binder.into());
1135        }
1136    }
1137
1138    #[cfg(feature = "tls12")]
1139    pub(crate) fn ems_support_offered(&self) -> bool {
1140        self.find_extension(ExtensionType::ExtendedMasterSecret)
1141            .is_some()
1142    }
1143
1144    pub(crate) fn early_data_extension_offered(&self) -> bool {
1145        self.find_extension(ExtensionType::EarlyData)
1146            .is_some()
1147    }
1148
1149    pub(crate) fn certificate_compression_extension(
1150        &self,
1151    ) -> Option<&[CertificateCompressionAlgorithm]> {
1152        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
1153        match ext {
1154            ClientExtension::CertificateCompressionAlgorithms(algs) => Some(algs),
1155            _ => None,
1156        }
1157    }
1158
1159    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1160        if let Some(algs) = self.certificate_compression_extension() {
1161            has_duplicates::<_, _, u16>(algs.iter().cloned())
1162        } else {
1163            false
1164        }
1165    }
1166
1167    pub(crate) fn certificate_authorities_extension(&self) -> Option<&[DistinguishedName]> {
1168        match self.find_extension(ExtensionType::CertificateAuthorities)? {
1169            ClientExtension::AuthorityNames(ext) => Some(ext),
1170            _ => unreachable!("extension type checked"),
1171        }
1172    }
1173}
1174
1175#[derive(Clone, Debug)]
1176pub(crate) enum HelloRetryExtension {
1177    KeyShare(NamedGroup),
1178    Cookie(PayloadU16),
1179    SupportedVersions(ProtocolVersion),
1180    EchHelloRetryRequest(Vec<u8>),
1181    Unknown(UnknownExtension),
1182}
1183
1184impl HelloRetryExtension {
1185    pub(crate) fn ext_type(&self) -> ExtensionType {
1186        match self {
1187            Self::KeyShare(_) => ExtensionType::KeyShare,
1188            Self::Cookie(_) => ExtensionType::Cookie,
1189            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1190            Self::EchHelloRetryRequest(_) => ExtensionType::EncryptedClientHello,
1191            Self::Unknown(r) => r.typ,
1192        }
1193    }
1194}
1195
1196impl Codec<'_> for HelloRetryExtension {
1197    fn encode(&self, bytes: &mut Vec<u8>) {
1198        self.ext_type().encode(bytes);
1199
1200        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1201        match self {
1202            Self::KeyShare(r) => r.encode(nested.buf),
1203            Self::Cookie(r) => r.encode(nested.buf),
1204            Self::SupportedVersions(r) => r.encode(nested.buf),
1205            Self::EchHelloRetryRequest(r) => {
1206                nested.buf.extend_from_slice(r);
1207            }
1208            Self::Unknown(r) => r.encode(nested.buf),
1209        }
1210    }
1211
1212    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1213        let typ = ExtensionType::read(r)?;
1214        let len = u16::read(r)? as usize;
1215        let mut sub = r.sub(len)?;
1216
1217        let ext = match typ {
1218            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1219            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1220            ExtensionType::SupportedVersions => {
1221                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1222            }
1223            ExtensionType::EncryptedClientHello => Self::EchHelloRetryRequest(sub.rest().to_vec()),
1224            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1225        };
1226
1227        sub.expect_empty("HelloRetryExtension")
1228            .map(|_| ext)
1229    }
1230}
1231
1232impl TlsListElement for HelloRetryExtension {
1233    const SIZE_LEN: ListLength = ListLength::U16;
1234}
1235
1236#[derive(Clone, Debug)]
1237pub struct HelloRetryRequest {
1238    pub(crate) legacy_version: ProtocolVersion,
1239    pub session_id: SessionId,
1240    pub(crate) cipher_suite: CipherSuite,
1241    pub(crate) extensions: Vec<HelloRetryExtension>,
1242}
1243
1244impl Codec<'_> for HelloRetryRequest {
1245    fn encode(&self, bytes: &mut Vec<u8>) {
1246        self.payload_encode(bytes, Encoding::Standard)
1247    }
1248
1249    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1250        let session_id = SessionId::read(r)?;
1251        let cipher_suite = CipherSuite::read(r)?;
1252        let compression = Compression::read(r)?;
1253
1254        if compression != Compression::Null {
1255            return Err(InvalidMessage::UnsupportedCompression);
1256        }
1257
1258        Ok(Self {
1259            legacy_version: ProtocolVersion::Unknown(0),
1260            session_id,
1261            cipher_suite,
1262            extensions: Vec::read(r)?,
1263        })
1264    }
1265}
1266
1267impl HelloRetryRequest {
1268    /// Returns true if there is more than one extension of a given
1269    /// type.
1270    pub(crate) fn has_duplicate_extension(&self) -> bool {
1271        has_duplicates::<_, _, u16>(
1272            self.extensions
1273                .iter()
1274                .map(|ext| ext.ext_type()),
1275        )
1276    }
1277
1278    pub(crate) fn has_unknown_extension(&self) -> bool {
1279        self.extensions.iter().any(|ext| {
1280            ext.ext_type() != ExtensionType::KeyShare
1281                && ext.ext_type() != ExtensionType::SupportedVersions
1282                && ext.ext_type() != ExtensionType::Cookie
1283                && ext.ext_type() != ExtensionType::EncryptedClientHello
1284        })
1285    }
1286
1287    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1288        self.extensions
1289            .iter()
1290            .find(|x| x.ext_type() == ext)
1291    }
1292
1293    pub fn requested_key_share_group(&self) -> Option<NamedGroup> {
1294        let ext = self.find_extension(ExtensionType::KeyShare)?;
1295        match ext {
1296            HelloRetryExtension::KeyShare(grp) => Some(*grp),
1297            _ => None,
1298        }
1299    }
1300
1301    pub(crate) fn cookie(&self) -> Option<&PayloadU16> {
1302        let ext = self.find_extension(ExtensionType::Cookie)?;
1303        match ext {
1304            HelloRetryExtension::Cookie(ck) => Some(ck),
1305            _ => None,
1306        }
1307    }
1308
1309    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1310        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1311        match ext {
1312            HelloRetryExtension::SupportedVersions(ver) => Some(*ver),
1313            _ => None,
1314        }
1315    }
1316
1317    pub(crate) fn ech(&self) -> Option<&Vec<u8>> {
1318        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
1319        match ext {
1320            HelloRetryExtension::EchHelloRetryRequest(ech) => Some(ech),
1321            _ => None,
1322        }
1323    }
1324
1325    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1326        self.legacy_version.encode(bytes);
1327        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1328        self.session_id.encode(bytes);
1329        self.cipher_suite.encode(bytes);
1330        Compression::Null.encode(bytes);
1331
1332        match purpose {
1333            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1334            // must have its payload replaced by 8 zero bytes.
1335            //
1336            // See draft-ietf-tls-esni-18 7.2.1:
1337            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1338            Encoding::EchConfirmation => {
1339                let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1340                for ext in &self.extensions {
1341                    match ext.ext_type() {
1342                        ExtensionType::EncryptedClientHello => {
1343                            HelloRetryExtension::EchHelloRetryRequest(vec![0u8; 8])
1344                                .encode(extensions.buf);
1345                        }
1346                        _ => {
1347                            ext.encode(extensions.buf);
1348                        }
1349                    }
1350                }
1351            }
1352            _ => {
1353                self.extensions.encode(bytes);
1354            }
1355        }
1356    }
1357}
1358
1359#[derive(Clone, Debug)]
1360pub struct ServerHelloPayload {
1361    pub extensions: Vec<ServerExtension>,
1362    pub(crate) legacy_version: ProtocolVersion,
1363    pub(crate) random: Random,
1364    pub(crate) session_id: SessionId,
1365    pub(crate) cipher_suite: CipherSuite,
1366    pub(crate) compression_method: Compression,
1367}
1368
1369impl Codec<'_> for ServerHelloPayload {
1370    fn encode(&self, bytes: &mut Vec<u8>) {
1371        self.payload_encode(bytes, Encoding::Standard)
1372    }
1373
1374    // minus version and random, which have already been read.
1375    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1376        let session_id = SessionId::read(r)?;
1377        let suite = CipherSuite::read(r)?;
1378        let compression = Compression::read(r)?;
1379
1380        // RFC5246:
1381        // "The presence of extensions can be detected by determining whether
1382        //  there are bytes following the compression_method field at the end of
1383        //  the ServerHello."
1384        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1385
1386        let ret = Self {
1387            legacy_version: ProtocolVersion::Unknown(0),
1388            random: ZERO_RANDOM,
1389            session_id,
1390            cipher_suite: suite,
1391            compression_method: compression,
1392            extensions,
1393        };
1394
1395        r.expect_empty("ServerHelloPayload")
1396            .map(|_| ret)
1397    }
1398}
1399
1400impl HasServerExtensions for ServerHelloPayload {
1401    fn extensions(&self) -> &[ServerExtension] {
1402        &self.extensions
1403    }
1404}
1405
1406impl ServerHelloPayload {
1407    pub(crate) fn key_share(&self) -> Option<&KeyShareEntry> {
1408        let ext = self.find_extension(ExtensionType::KeyShare)?;
1409        match ext {
1410            ServerExtension::KeyShare(share) => Some(share),
1411            _ => None,
1412        }
1413    }
1414
1415    pub(crate) fn psk_index(&self) -> Option<u16> {
1416        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1417        match ext {
1418            ServerExtension::PresharedKey(index) => Some(*index),
1419            _ => None,
1420        }
1421    }
1422
1423    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1424        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1425        match ext {
1426            ServerExtension::EcPointFormats(fmts) => Some(fmts),
1427            _ => None,
1428        }
1429    }
1430
1431    #[cfg(feature = "tls12")]
1432    pub(crate) fn ems_support_acked(&self) -> bool {
1433        self.find_extension(ExtensionType::ExtendedMasterSecret)
1434            .is_some()
1435    }
1436
1437    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1438        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1439        match ext {
1440            ServerExtension::SupportedVersions(vers) => Some(*vers),
1441            _ => None,
1442        }
1443    }
1444
1445    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1446        self.legacy_version.encode(bytes);
1447
1448        match encoding {
1449            // When encoding a ServerHello for ECH confirmation, the random value
1450            // has the last 8 bytes zeroed out.
1451            Encoding::EchConfirmation => {
1452                // Indexing safety: self.random is 32 bytes long by definition.
1453                let rand_vec = self.random.get_encoding();
1454                bytes.extend_from_slice(&rand_vec.as_slice()[..24]);
1455                bytes.extend_from_slice(&[0u8; 8]);
1456            }
1457            _ => self.random.encode(bytes),
1458        }
1459
1460        self.session_id.encode(bytes);
1461        self.cipher_suite.encode(bytes);
1462        self.compression_method.encode(bytes);
1463
1464        if !self.extensions.is_empty() {
1465            self.extensions.encode(bytes);
1466        }
1467    }
1468}
1469
1470#[derive(Clone, Default, Debug)]
1471pub struct CertificateChain<'a>(pub Vec<CertificateDer<'a>>);
1472
1473impl CertificateChain<'_> {
1474    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1475        CertificateChain(
1476            self.0
1477                .into_iter()
1478                .map(|c| c.into_owned())
1479                .collect(),
1480        )
1481    }
1482}
1483
1484impl<'a> Codec<'a> for CertificateChain<'a> {
1485    fn encode(&self, bytes: &mut Vec<u8>) {
1486        Vec::encode(&self.0, bytes)
1487    }
1488
1489    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1490        Vec::read(r).map(Self)
1491    }
1492}
1493
1494impl<'a> Deref for CertificateChain<'a> {
1495    type Target = [CertificateDer<'a>];
1496
1497    fn deref(&self) -> &[CertificateDer<'a>] {
1498        &self.0
1499    }
1500}
1501
1502impl TlsListElement for CertificateDer<'_> {
1503    const SIZE_LEN: ListLength = ListLength::U24 {
1504        max: CERTIFICATE_MAX_SIZE_LIMIT,
1505        error: InvalidMessage::CertificatePayloadTooLarge,
1506    };
1507}
1508
1509/// TLS has a 16MB size limit on any handshake message,
1510/// plus a 16MB limit on any given certificate.
1511///
1512/// We contract that to 64KB to limit the amount of memory allocation
1513/// that is directly controllable by the peer.
1514pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1515
1516#[derive(Debug)]
1517pub(crate) enum CertificateExtension<'a> {
1518    CertificateStatus(CertificateStatus<'a>),
1519    Unknown(UnknownExtension),
1520}
1521
1522impl CertificateExtension<'_> {
1523    pub(crate) fn ext_type(&self) -> ExtensionType {
1524        match self {
1525            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1526            Self::Unknown(r) => r.typ,
1527        }
1528    }
1529
1530    pub(crate) fn cert_status(&self) -> Option<&[u8]> {
1531        match self {
1532            Self::CertificateStatus(cs) => Some(cs.ocsp_response.0.bytes()),
1533            _ => None,
1534        }
1535    }
1536
1537    pub(crate) fn into_owned(self) -> CertificateExtension<'static> {
1538        match self {
1539            Self::CertificateStatus(st) => CertificateExtension::CertificateStatus(st.into_owned()),
1540            Self::Unknown(unk) => CertificateExtension::Unknown(unk),
1541        }
1542    }
1543}
1544
1545impl<'a> Codec<'a> for CertificateExtension<'a> {
1546    fn encode(&self, bytes: &mut Vec<u8>) {
1547        self.ext_type().encode(bytes);
1548
1549        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1550        match self {
1551            Self::CertificateStatus(r) => r.encode(nested.buf),
1552            Self::Unknown(r) => r.encode(nested.buf),
1553        }
1554    }
1555
1556    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1557        let typ = ExtensionType::read(r)?;
1558        let len = u16::read(r)? as usize;
1559        let mut sub = r.sub(len)?;
1560
1561        let ext = match typ {
1562            ExtensionType::StatusRequest => {
1563                let st = CertificateStatus::read(&mut sub)?;
1564                Self::CertificateStatus(st)
1565            }
1566            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1567        };
1568
1569        sub.expect_empty("CertificateExtension")
1570            .map(|_| ext)
1571    }
1572}
1573
1574impl TlsListElement for CertificateExtension<'_> {
1575    const SIZE_LEN: ListLength = ListLength::U16;
1576}
1577
1578#[derive(Debug)]
1579pub(crate) struct CertificateEntry<'a> {
1580    pub(crate) cert: CertificateDer<'a>,
1581    pub(crate) exts: Vec<CertificateExtension<'a>>,
1582}
1583
1584impl<'a> Codec<'a> for CertificateEntry<'a> {
1585    fn encode(&self, bytes: &mut Vec<u8>) {
1586        self.cert.encode(bytes);
1587        self.exts.encode(bytes);
1588    }
1589
1590    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1591        Ok(Self {
1592            cert: CertificateDer::read(r)?,
1593            exts: Vec::read(r)?,
1594        })
1595    }
1596}
1597
1598impl<'a> CertificateEntry<'a> {
1599    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1600        Self {
1601            cert,
1602            exts: Vec::new(),
1603        }
1604    }
1605
1606    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1607        CertificateEntry {
1608            cert: self.cert.into_owned(),
1609            exts: self
1610                .exts
1611                .into_iter()
1612                .map(CertificateExtension::into_owned)
1613                .collect(),
1614        }
1615    }
1616
1617    pub(crate) fn has_duplicate_extension(&self) -> bool {
1618        has_duplicates::<_, _, u16>(
1619            self.exts
1620                .iter()
1621                .map(|ext| ext.ext_type()),
1622        )
1623    }
1624
1625    pub(crate) fn has_unknown_extension(&self) -> bool {
1626        self.exts
1627            .iter()
1628            .any(|ext| ext.ext_type() != ExtensionType::StatusRequest)
1629    }
1630
1631    pub(crate) fn ocsp_response(&self) -> Option<&[u8]> {
1632        self.exts
1633            .iter()
1634            .find(|ext| ext.ext_type() == ExtensionType::StatusRequest)
1635            .and_then(CertificateExtension::cert_status)
1636    }
1637}
1638
1639impl TlsListElement for CertificateEntry<'_> {
1640    const SIZE_LEN: ListLength = ListLength::U24 {
1641        max: CERTIFICATE_MAX_SIZE_LIMIT,
1642        error: InvalidMessage::CertificatePayloadTooLarge,
1643    };
1644}
1645
1646#[derive(Debug)]
1647pub struct CertificatePayloadTls13<'a> {
1648    pub(crate) context: PayloadU8,
1649    pub(crate) entries: Vec<CertificateEntry<'a>>,
1650}
1651
1652impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1653    fn encode(&self, bytes: &mut Vec<u8>) {
1654        self.context.encode(bytes);
1655        self.entries.encode(bytes);
1656    }
1657
1658    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1659        Ok(Self {
1660            context: PayloadU8::read(r)?,
1661            entries: Vec::read(r)?,
1662        })
1663    }
1664}
1665
1666impl<'a> CertificatePayloadTls13<'a> {
1667    pub(crate) fn new(
1668        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1669        ocsp_response: Option<&'a [u8]>,
1670    ) -> Self {
1671        Self {
1672            context: PayloadU8::empty(),
1673            entries: certs
1674                // zip certificate iterator with `ocsp_response` followed by
1675                // an infinite-length iterator of `None`.
1676                .zip(
1677                    ocsp_response
1678                        .into_iter()
1679                        .map(Some)
1680                        .chain(iter::repeat(None)),
1681                )
1682                .map(|(cert, ocsp)| {
1683                    let mut e = CertificateEntry::new(cert.clone());
1684                    if let Some(ocsp) = ocsp {
1685                        e.exts
1686                            .push(CertificateExtension::CertificateStatus(
1687                                CertificateStatus::new(ocsp),
1688                            ));
1689                    }
1690                    e
1691                })
1692                .collect(),
1693        }
1694    }
1695
1696    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1697        CertificatePayloadTls13 {
1698            context: self.context,
1699            entries: self
1700                .entries
1701                .into_iter()
1702                .map(CertificateEntry::into_owned)
1703                .collect(),
1704        }
1705    }
1706
1707    pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool {
1708        for entry in &self.entries {
1709            if entry.has_duplicate_extension() {
1710                return true;
1711            }
1712        }
1713
1714        false
1715    }
1716
1717    pub(crate) fn any_entry_has_unknown_extension(&self) -> bool {
1718        for entry in &self.entries {
1719            if entry.has_unknown_extension() {
1720                return true;
1721            }
1722        }
1723
1724        false
1725    }
1726
1727    pub(crate) fn any_entry_has_extension(&self) -> bool {
1728        for entry in &self.entries {
1729            if !entry.exts.is_empty() {
1730                return true;
1731            }
1732        }
1733
1734        false
1735    }
1736
1737    pub(crate) fn end_entity_ocsp(&self) -> Vec<u8> {
1738        self.entries
1739            .first()
1740            .and_then(CertificateEntry::ocsp_response)
1741            .map(|resp| resp.to_vec())
1742            .unwrap_or_default()
1743    }
1744
1745    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1746        CertificateChain(
1747            self.entries
1748                .into_iter()
1749                .map(|e| e.cert)
1750                .collect(),
1751        )
1752    }
1753}
1754
1755/// Describes supported key exchange mechanisms.
1756#[derive(Clone, Copy, Debug, PartialEq)]
1757#[non_exhaustive]
1758pub enum KeyExchangeAlgorithm {
1759    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1760    ///
1761    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1762    DHE,
1763    /// Key exchange performed via elliptic curve Diffie-Hellman.
1764    ECDHE,
1765}
1766
1767pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1768    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1769
1770// We don't support arbitrary curves.  It's a terrible
1771// idea and unnecessary attack surface.  Please,
1772// get a grip.
1773#[derive(Debug)]
1774pub(crate) struct EcParameters {
1775    pub(crate) curve_type: ECCurveType,
1776    pub(crate) named_group: NamedGroup,
1777}
1778
1779impl Codec<'_> for EcParameters {
1780    fn encode(&self, bytes: &mut Vec<u8>) {
1781        self.curve_type.encode(bytes);
1782        self.named_group.encode(bytes);
1783    }
1784
1785    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1786        let ct = ECCurveType::read(r)?;
1787        if ct != ECCurveType::NamedCurve {
1788            return Err(InvalidMessage::UnsupportedCurveType);
1789        }
1790
1791        let grp = NamedGroup::read(r)?;
1792
1793        Ok(Self {
1794            curve_type: ct,
1795            named_group: grp,
1796        })
1797    }
1798}
1799
1800#[cfg(feature = "tls12")]
1801pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1802    /// Decode a key exchange message given the key_exchange `algo`
1803    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1804}
1805
1806#[cfg(feature = "tls12")]
1807#[derive(Debug)]
1808pub(crate) enum ClientKeyExchangeParams {
1809    Ecdh(ClientEcdhParams),
1810    Dh(ClientDhParams),
1811}
1812
1813#[cfg(feature = "tls12")]
1814impl ClientKeyExchangeParams {
1815    pub(crate) fn pub_key(&self) -> &[u8] {
1816        match self {
1817            Self::Ecdh(ecdh) => &ecdh.public.0,
1818            Self::Dh(dh) => &dh.public.0,
1819        }
1820    }
1821
1822    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1823        match self {
1824            Self::Ecdh(ecdh) => ecdh.encode(buf),
1825            Self::Dh(dh) => dh.encode(buf),
1826        }
1827    }
1828}
1829
1830#[cfg(feature = "tls12")]
1831impl KxDecode<'_> for ClientKeyExchangeParams {
1832    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1833        use KeyExchangeAlgorithm::*;
1834        Ok(match algo {
1835            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1836            DHE => Self::Dh(ClientDhParams::read(r)?),
1837        })
1838    }
1839}
1840
1841#[cfg(feature = "tls12")]
1842#[derive(Debug)]
1843pub(crate) struct ClientEcdhParams {
1844    pub(crate) public: PayloadU8,
1845}
1846
1847#[cfg(feature = "tls12")]
1848impl Codec<'_> for ClientEcdhParams {
1849    fn encode(&self, bytes: &mut Vec<u8>) {
1850        self.public.encode(bytes);
1851    }
1852
1853    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1854        let pb = PayloadU8::read(r)?;
1855        Ok(Self { public: pb })
1856    }
1857}
1858
1859#[cfg(feature = "tls12")]
1860#[derive(Debug)]
1861pub(crate) struct ClientDhParams {
1862    pub(crate) public: PayloadU16,
1863}
1864
1865#[cfg(feature = "tls12")]
1866impl Codec<'_> for ClientDhParams {
1867    fn encode(&self, bytes: &mut Vec<u8>) {
1868        self.public.encode(bytes);
1869    }
1870
1871    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1872        Ok(Self {
1873            public: PayloadU16::read(r)?,
1874        })
1875    }
1876}
1877
1878#[derive(Debug)]
1879pub(crate) struct ServerEcdhParams {
1880    pub(crate) curve_params: EcParameters,
1881    pub(crate) public: PayloadU8,
1882}
1883
1884impl ServerEcdhParams {
1885    #[cfg(feature = "tls12")]
1886    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1887        Self {
1888            curve_params: EcParameters {
1889                curve_type: ECCurveType::NamedCurve,
1890                named_group: kx.group(),
1891            },
1892            public: PayloadU8::new(kx.pub_key().to_vec()),
1893        }
1894    }
1895}
1896
1897impl Codec<'_> for ServerEcdhParams {
1898    fn encode(&self, bytes: &mut Vec<u8>) {
1899        self.curve_params.encode(bytes);
1900        self.public.encode(bytes);
1901    }
1902
1903    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1904        let cp = EcParameters::read(r)?;
1905        let pb = PayloadU8::read(r)?;
1906
1907        Ok(Self {
1908            curve_params: cp,
1909            public: pb,
1910        })
1911    }
1912}
1913
1914#[derive(Debug)]
1915#[allow(non_snake_case)]
1916pub(crate) struct ServerDhParams {
1917    pub(crate) dh_p: PayloadU16,
1918    pub(crate) dh_g: PayloadU16,
1919    pub(crate) dh_Ys: PayloadU16,
1920}
1921
1922impl ServerDhParams {
1923    #[cfg(feature = "tls12")]
1924    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1925        let Some(params) = kx.ffdhe_group() else {
1926            panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group());
1927        };
1928
1929        Self {
1930            dh_p: PayloadU16::new(params.p.to_vec()),
1931            dh_g: PayloadU16::new(params.g.to_vec()),
1932            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
1933        }
1934    }
1935
1936    #[cfg(feature = "tls12")]
1937    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
1938        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
1939    }
1940}
1941
1942impl Codec<'_> for ServerDhParams {
1943    fn encode(&self, bytes: &mut Vec<u8>) {
1944        self.dh_p.encode(bytes);
1945        self.dh_g.encode(bytes);
1946        self.dh_Ys.encode(bytes);
1947    }
1948
1949    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1950        Ok(Self {
1951            dh_p: PayloadU16::read(r)?,
1952            dh_g: PayloadU16::read(r)?,
1953            dh_Ys: PayloadU16::read(r)?,
1954        })
1955    }
1956}
1957
1958#[allow(dead_code)]
1959#[derive(Debug)]
1960pub(crate) enum ServerKeyExchangeParams {
1961    Ecdh(ServerEcdhParams),
1962    Dh(ServerDhParams),
1963}
1964
1965impl ServerKeyExchangeParams {
1966    #[cfg(feature = "tls12")]
1967    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1968        match kx.group().key_exchange_algorithm() {
1969            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
1970            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
1971        }
1972    }
1973
1974    #[cfg(feature = "tls12")]
1975    pub(crate) fn pub_key(&self) -> &[u8] {
1976        match self {
1977            Self::Ecdh(ecdh) => &ecdh.public.0,
1978            Self::Dh(dh) => &dh.dh_Ys.0,
1979        }
1980    }
1981
1982    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1983        match self {
1984            Self::Ecdh(ecdh) => ecdh.encode(buf),
1985            Self::Dh(dh) => dh.encode(buf),
1986        }
1987    }
1988}
1989
1990#[cfg(feature = "tls12")]
1991impl KxDecode<'_> for ServerKeyExchangeParams {
1992    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1993        use KeyExchangeAlgorithm::*;
1994        Ok(match algo {
1995            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
1996            DHE => Self::Dh(ServerDhParams::read(r)?),
1997        })
1998    }
1999}
2000
2001#[derive(Debug)]
2002pub struct ServerKeyExchange {
2003    pub(crate) params: ServerKeyExchangeParams,
2004    pub(crate) dss: DigitallySignedStruct,
2005}
2006
2007impl ServerKeyExchange {
2008    pub fn encode(&self, buf: &mut Vec<u8>) {
2009        self.params.encode(buf);
2010        self.dss.encode(buf);
2011    }
2012}
2013
2014#[derive(Debug)]
2015pub enum ServerKeyExchangePayload {
2016    Known(ServerKeyExchange),
2017    Unknown(Payload<'static>),
2018}
2019
2020impl From<ServerKeyExchange> for ServerKeyExchangePayload {
2021    fn from(value: ServerKeyExchange) -> Self {
2022        Self::Known(value)
2023    }
2024}
2025
2026impl Codec<'_> for ServerKeyExchangePayload {
2027    fn encode(&self, bytes: &mut Vec<u8>) {
2028        match self {
2029            Self::Known(x) => x.encode(bytes),
2030            Self::Unknown(x) => x.encode(bytes),
2031        }
2032    }
2033
2034    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2035        // read as Unknown, fully parse when we know the
2036        // KeyExchangeAlgorithm
2037        Ok(Self::Unknown(Payload::read(r).into_owned()))
2038    }
2039}
2040
2041impl ServerKeyExchangePayload {
2042    #[cfg(feature = "tls12")]
2043    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
2044        if let Self::Unknown(unk) = self {
2045            let mut rd = Reader::init(unk.bytes());
2046
2047            let result = ServerKeyExchange {
2048                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
2049                dss: DigitallySignedStruct::read(&mut rd).ok()?,
2050            };
2051
2052            if !rd.any_left() {
2053                return Some(result);
2054            };
2055        }
2056
2057        None
2058    }
2059}
2060
2061// -- EncryptedExtensions (TLS1.3 only) --
2062
2063impl TlsListElement for ServerExtension {
2064    const SIZE_LEN: ListLength = ListLength::U16;
2065}
2066
2067pub(crate) trait HasServerExtensions {
2068    fn extensions(&self) -> &[ServerExtension];
2069
2070    /// Returns true if there is more than one extension of a given
2071    /// type.
2072    fn has_duplicate_extension(&self) -> bool {
2073        has_duplicates::<_, _, u16>(
2074            self.extensions()
2075                .iter()
2076                .map(|ext| ext.ext_type()),
2077        )
2078    }
2079
2080    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
2081        self.extensions()
2082            .iter()
2083            .find(|x| x.ext_type() == ext)
2084    }
2085
2086    fn alpn_protocol(&self) -> Option<&[u8]> {
2087        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
2088        match ext {
2089            ServerExtension::Protocols(protos) => protos.as_single_slice(),
2090            _ => None,
2091        }
2092    }
2093
2094    fn server_cert_type(&self) -> Option<&CertificateType> {
2095        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
2096        match ext {
2097            ServerExtension::ServerCertType(req) => Some(req),
2098            _ => None,
2099        }
2100    }
2101
2102    fn client_cert_type(&self) -> Option<&CertificateType> {
2103        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
2104        match ext {
2105            ServerExtension::ClientCertType(req) => Some(req),
2106            _ => None,
2107        }
2108    }
2109
2110    fn quic_params_extension(&self) -> Option<Vec<u8>> {
2111        let ext = self
2112            .find_extension(ExtensionType::TransportParameters)
2113            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
2114        match ext {
2115            ServerExtension::TransportParameters(bytes)
2116            | ServerExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
2117            _ => None,
2118        }
2119    }
2120
2121    fn server_ech_extension(&self) -> Option<ServerEncryptedClientHello> {
2122        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
2123        match ext {
2124            ServerExtension::EncryptedClientHello(ech) => Some(ech.clone()),
2125            _ => None,
2126        }
2127    }
2128
2129    fn early_data_extension_offered(&self) -> bool {
2130        self.find_extension(ExtensionType::EarlyData)
2131            .is_some()
2132    }
2133}
2134
2135impl HasServerExtensions for Vec<ServerExtension> {
2136    fn extensions(&self) -> &[ServerExtension] {
2137        self
2138    }
2139}
2140
2141impl TlsListElement for ClientCertificateType {
2142    const SIZE_LEN: ListLength = ListLength::U8;
2143}
2144
2145wrapped_payload!(
2146    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2147    ///
2148    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2149    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2150    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2151    ///
2152    /// ```ignore
2153    /// for name in distinguished_names {
2154    ///     use x509_parser::prelude::FromDer;
2155    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2156    /// }
2157    /// ```
2158    pub struct DistinguishedName,
2159    PayloadU16,
2160);
2161
2162impl DistinguishedName {
2163    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2164    ///
2165    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2166    ///
2167    /// ```ignore
2168    /// use x509_parser::prelude::FromDer;
2169    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2170    /// ```
2171    pub fn in_sequence(bytes: &[u8]) -> Self {
2172        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2173    }
2174}
2175
2176impl TlsListElement for DistinguishedName {
2177    const SIZE_LEN: ListLength = ListLength::U16;
2178}
2179
2180#[derive(Debug)]
2181pub struct CertificateRequestPayload {
2182    pub(crate) certtypes: Vec<ClientCertificateType>,
2183    pub(crate) sigschemes: Vec<SignatureScheme>,
2184    pub(crate) canames: Vec<DistinguishedName>,
2185}
2186
2187impl Codec<'_> for CertificateRequestPayload {
2188    fn encode(&self, bytes: &mut Vec<u8>) {
2189        self.certtypes.encode(bytes);
2190        self.sigschemes.encode(bytes);
2191        self.canames.encode(bytes);
2192    }
2193
2194    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2195        let certtypes = Vec::read(r)?;
2196        let sigschemes = Vec::read(r)?;
2197        let canames = Vec::read(r)?;
2198
2199        if sigschemes.is_empty() {
2200            warn!("meaningless CertificateRequest message");
2201            Err(InvalidMessage::NoSignatureSchemes)
2202        } else {
2203            Ok(Self {
2204                certtypes,
2205                sigschemes,
2206                canames,
2207            })
2208        }
2209    }
2210}
2211
2212#[derive(Debug)]
2213pub(crate) enum CertReqExtension {
2214    SignatureAlgorithms(Vec<SignatureScheme>),
2215    AuthorityNames(Vec<DistinguishedName>),
2216    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
2217    Unknown(UnknownExtension),
2218}
2219
2220impl CertReqExtension {
2221    pub(crate) fn ext_type(&self) -> ExtensionType {
2222        match self {
2223            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
2224            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
2225            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
2226            Self::Unknown(r) => r.typ,
2227        }
2228    }
2229}
2230
2231impl Codec<'_> for CertReqExtension {
2232    fn encode(&self, bytes: &mut Vec<u8>) {
2233        self.ext_type().encode(bytes);
2234
2235        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2236        match self {
2237            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
2238            Self::AuthorityNames(r) => r.encode(nested.buf),
2239            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
2240            Self::Unknown(r) => r.encode(nested.buf),
2241        }
2242    }
2243
2244    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2245        let typ = ExtensionType::read(r)?;
2246        let len = u16::read(r)? as usize;
2247        let mut sub = r.sub(len)?;
2248
2249        let ext = match typ {
2250            ExtensionType::SignatureAlgorithms => {
2251                let schemes = Vec::read(&mut sub)?;
2252                if schemes.is_empty() {
2253                    return Err(InvalidMessage::NoSignatureSchemes);
2254                }
2255                Self::SignatureAlgorithms(schemes)
2256            }
2257            ExtensionType::CertificateAuthorities => {
2258                let cas = Vec::read(&mut sub)?;
2259                Self::AuthorityNames(cas)
2260            }
2261            ExtensionType::CompressCertificate => {
2262                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
2263            }
2264            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2265        };
2266
2267        sub.expect_empty("CertReqExtension")
2268            .map(|_| ext)
2269    }
2270}
2271
2272impl TlsListElement for CertReqExtension {
2273    const SIZE_LEN: ListLength = ListLength::U16;
2274}
2275
2276#[derive(Debug)]
2277pub struct CertificateRequestPayloadTls13 {
2278    pub(crate) context: PayloadU8,
2279    pub(crate) extensions: Vec<CertReqExtension>,
2280}
2281
2282impl Codec<'_> for CertificateRequestPayloadTls13 {
2283    fn encode(&self, bytes: &mut Vec<u8>) {
2284        self.context.encode(bytes);
2285        self.extensions.encode(bytes);
2286    }
2287
2288    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2289        let context = PayloadU8::read(r)?;
2290        let extensions = Vec::read(r)?;
2291
2292        Ok(Self {
2293            context,
2294            extensions,
2295        })
2296    }
2297}
2298
2299impl CertificateRequestPayloadTls13 {
2300    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
2301        self.extensions
2302            .iter()
2303            .find(|x| x.ext_type() == ext)
2304    }
2305
2306    pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
2307        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
2308        match ext {
2309            CertReqExtension::SignatureAlgorithms(sa) => Some(sa),
2310            _ => None,
2311        }
2312    }
2313
2314    pub(crate) fn authorities_extension(&self) -> Option<&[DistinguishedName]> {
2315        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
2316        match ext {
2317            CertReqExtension::AuthorityNames(an) => Some(an),
2318            _ => None,
2319        }
2320    }
2321
2322    pub(crate) fn certificate_compression_extension(
2323        &self,
2324    ) -> Option<&[CertificateCompressionAlgorithm]> {
2325        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
2326        match ext {
2327            CertReqExtension::CertificateCompressionAlgorithms(comps) => Some(comps),
2328            _ => None,
2329        }
2330    }
2331}
2332
2333// -- NewSessionTicket --
2334#[derive(Debug)]
2335pub struct NewSessionTicketPayload {
2336    pub(crate) lifetime_hint: u32,
2337    // Tickets can be large (KB), so we deserialise this straight
2338    // into an Arc, so it can be passed directly into the client's
2339    // session object without copying.
2340    pub(crate) ticket: Arc<PayloadU16>,
2341}
2342
2343impl NewSessionTicketPayload {
2344    #[cfg(feature = "tls12")]
2345    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2346        Self {
2347            lifetime_hint,
2348            ticket: Arc::new(PayloadU16::new(ticket)),
2349        }
2350    }
2351}
2352
2353impl Codec<'_> for NewSessionTicketPayload {
2354    fn encode(&self, bytes: &mut Vec<u8>) {
2355        self.lifetime_hint.encode(bytes);
2356        self.ticket.encode(bytes);
2357    }
2358
2359    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2360        let lifetime = u32::read(r)?;
2361        let ticket = Arc::new(PayloadU16::read(r)?);
2362
2363        Ok(Self {
2364            lifetime_hint: lifetime,
2365            ticket,
2366        })
2367    }
2368}
2369
2370// -- NewSessionTicket electric boogaloo --
2371#[derive(Debug)]
2372pub(crate) enum NewSessionTicketExtension {
2373    EarlyData(u32),
2374    Unknown(UnknownExtension),
2375}
2376
2377impl NewSessionTicketExtension {
2378    pub(crate) fn ext_type(&self) -> ExtensionType {
2379        match self {
2380            Self::EarlyData(_) => ExtensionType::EarlyData,
2381            Self::Unknown(r) => r.typ,
2382        }
2383    }
2384}
2385
2386impl Codec<'_> for NewSessionTicketExtension {
2387    fn encode(&self, bytes: &mut Vec<u8>) {
2388        self.ext_type().encode(bytes);
2389
2390        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2391        match self {
2392            Self::EarlyData(r) => r.encode(nested.buf),
2393            Self::Unknown(r) => r.encode(nested.buf),
2394        }
2395    }
2396
2397    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2398        let typ = ExtensionType::read(r)?;
2399        let len = u16::read(r)? as usize;
2400        let mut sub = r.sub(len)?;
2401
2402        let ext = match typ {
2403            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
2404            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2405        };
2406
2407        sub.expect_empty("NewSessionTicketExtension")
2408            .map(|_| ext)
2409    }
2410}
2411
2412impl TlsListElement for NewSessionTicketExtension {
2413    const SIZE_LEN: ListLength = ListLength::U16;
2414}
2415
2416#[derive(Debug)]
2417pub struct NewSessionTicketPayloadTls13 {
2418    pub(crate) lifetime: u32,
2419    pub(crate) age_add: u32,
2420    pub(crate) nonce: PayloadU8,
2421    pub(crate) ticket: Arc<PayloadU16>,
2422    pub(crate) exts: Vec<NewSessionTicketExtension>,
2423}
2424
2425impl NewSessionTicketPayloadTls13 {
2426    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2427        Self {
2428            lifetime,
2429            age_add,
2430            nonce: PayloadU8::new(nonce),
2431            ticket: Arc::new(PayloadU16::new(ticket)),
2432            exts: vec![],
2433        }
2434    }
2435
2436    pub(crate) fn has_duplicate_extension(&self) -> bool {
2437        has_duplicates::<_, _, u16>(
2438            self.exts
2439                .iter()
2440                .map(|ext| ext.ext_type()),
2441        )
2442    }
2443
2444    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
2445        self.exts
2446            .iter()
2447            .find(|x| x.ext_type() == ext)
2448    }
2449
2450    pub(crate) fn max_early_data_size(&self) -> Option<u32> {
2451        let ext = self.find_extension(ExtensionType::EarlyData)?;
2452        match ext {
2453            NewSessionTicketExtension::EarlyData(sz) => Some(*sz),
2454            _ => None,
2455        }
2456    }
2457}
2458
2459impl Codec<'_> for NewSessionTicketPayloadTls13 {
2460    fn encode(&self, bytes: &mut Vec<u8>) {
2461        self.lifetime.encode(bytes);
2462        self.age_add.encode(bytes);
2463        self.nonce.encode(bytes);
2464        self.ticket.encode(bytes);
2465        self.exts.encode(bytes);
2466    }
2467
2468    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2469        let lifetime = u32::read(r)?;
2470        let age_add = u32::read(r)?;
2471        let nonce = PayloadU8::read(r)?;
2472        let ticket = Arc::new(PayloadU16::read(r)?);
2473        let exts = Vec::read(r)?;
2474
2475        Ok(Self {
2476            lifetime,
2477            age_add,
2478            nonce,
2479            ticket,
2480            exts,
2481        })
2482    }
2483}
2484
2485// -- RFC6066 certificate status types
2486
2487/// Only supports OCSP
2488#[derive(Debug)]
2489pub struct CertificateStatus<'a> {
2490    pub(crate) ocsp_response: PayloadU24<'a>,
2491}
2492
2493impl<'a> Codec<'a> for CertificateStatus<'a> {
2494    fn encode(&self, bytes: &mut Vec<u8>) {
2495        CertificateStatusType::OCSP.encode(bytes);
2496        self.ocsp_response.encode(bytes);
2497    }
2498
2499    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2500        let typ = CertificateStatusType::read(r)?;
2501
2502        match typ {
2503            CertificateStatusType::OCSP => Ok(Self {
2504                ocsp_response: PayloadU24::read(r)?,
2505            }),
2506            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2507        }
2508    }
2509}
2510
2511impl<'a> CertificateStatus<'a> {
2512    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2513        CertificateStatus {
2514            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2515        }
2516    }
2517
2518    #[cfg(feature = "tls12")]
2519    pub(crate) fn into_inner(self) -> Vec<u8> {
2520        self.ocsp_response.0.into_vec()
2521    }
2522
2523    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2524        CertificateStatus {
2525            ocsp_response: self.ocsp_response.into_owned(),
2526        }
2527    }
2528}
2529
2530// -- RFC8879 compressed certificates
2531
2532#[derive(Debug)]
2533pub struct CompressedCertificatePayload<'a> {
2534    pub(crate) alg: CertificateCompressionAlgorithm,
2535    pub(crate) uncompressed_len: u32,
2536    pub(crate) compressed: PayloadU24<'a>,
2537}
2538
2539impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2540    fn encode(&self, bytes: &mut Vec<u8>) {
2541        self.alg.encode(bytes);
2542        codec::u24(self.uncompressed_len).encode(bytes);
2543        self.compressed.encode(bytes);
2544    }
2545
2546    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2547        Ok(Self {
2548            alg: CertificateCompressionAlgorithm::read(r)?,
2549            uncompressed_len: codec::u24::read(r)?.0,
2550            compressed: PayloadU24::read(r)?,
2551        })
2552    }
2553}
2554
2555impl CompressedCertificatePayload<'_> {
2556    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2557        CompressedCertificatePayload {
2558            compressed: self.compressed.into_owned(),
2559            ..self
2560        }
2561    }
2562
2563    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2564        CompressedCertificatePayload {
2565            alg: self.alg,
2566            uncompressed_len: self.uncompressed_len,
2567            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2568        }
2569    }
2570}
2571
2572#[derive(Debug)]
2573pub enum HandshakePayload<'a> {
2574    HelloRequest,
2575    ClientHello(ClientHelloPayload),
2576    ServerHello(ServerHelloPayload),
2577    HelloRetryRequest(HelloRetryRequest),
2578    Certificate(CertificateChain<'a>),
2579    CertificateTls13(CertificatePayloadTls13<'a>),
2580    CompressedCertificate(CompressedCertificatePayload<'a>),
2581    ServerKeyExchange(ServerKeyExchangePayload),
2582    CertificateRequest(CertificateRequestPayload),
2583    CertificateRequestTls13(CertificateRequestPayloadTls13),
2584    CertificateVerify(DigitallySignedStruct),
2585    ServerHelloDone,
2586    EndOfEarlyData,
2587    ClientKeyExchange(Payload<'a>),
2588    NewSessionTicket(NewSessionTicketPayload),
2589    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2590    EncryptedExtensions(Vec<ServerExtension>),
2591    KeyUpdate(KeyUpdateRequest),
2592    Finished(Payload<'a>),
2593    CertificateStatus(CertificateStatus<'a>),
2594    MessageHash(Payload<'a>),
2595    Unknown(Payload<'a>),
2596}
2597
2598impl HandshakePayload<'_> {
2599    fn encode(&self, bytes: &mut Vec<u8>) {
2600        use self::HandshakePayload::*;
2601        match self {
2602            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2603            ClientHello(x) => x.encode(bytes),
2604            ServerHello(x) => x.encode(bytes),
2605            HelloRetryRequest(x) => x.encode(bytes),
2606            Certificate(x) => x.encode(bytes),
2607            CertificateTls13(x) => x.encode(bytes),
2608            CompressedCertificate(x) => x.encode(bytes),
2609            ServerKeyExchange(x) => x.encode(bytes),
2610            ClientKeyExchange(x) => x.encode(bytes),
2611            CertificateRequest(x) => x.encode(bytes),
2612            CertificateRequestTls13(x) => x.encode(bytes),
2613            CertificateVerify(x) => x.encode(bytes),
2614            NewSessionTicket(x) => x.encode(bytes),
2615            NewSessionTicketTls13(x) => x.encode(bytes),
2616            EncryptedExtensions(x) => x.encode(bytes),
2617            KeyUpdate(x) => x.encode(bytes),
2618            Finished(x) => x.encode(bytes),
2619            CertificateStatus(x) => x.encode(bytes),
2620            MessageHash(x) => x.encode(bytes),
2621            Unknown(x) => x.encode(bytes),
2622        }
2623    }
2624
2625    fn into_owned(self) -> HandshakePayload<'static> {
2626        use HandshakePayload::*;
2627
2628        match self {
2629            HelloRequest => HelloRequest,
2630            ClientHello(x) => ClientHello(x),
2631            ServerHello(x) => ServerHello(x),
2632            HelloRetryRequest(x) => HelloRetryRequest(x),
2633            Certificate(x) => Certificate(x.into_owned()),
2634            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2635            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2636            ServerKeyExchange(x) => ServerKeyExchange(x),
2637            CertificateRequest(x) => CertificateRequest(x),
2638            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2639            CertificateVerify(x) => CertificateVerify(x),
2640            ServerHelloDone => ServerHelloDone,
2641            EndOfEarlyData => EndOfEarlyData,
2642            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2643            NewSessionTicket(x) => NewSessionTicket(x),
2644            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2645            EncryptedExtensions(x) => EncryptedExtensions(x),
2646            KeyUpdate(x) => KeyUpdate(x),
2647            Finished(x) => Finished(x.into_owned()),
2648            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2649            MessageHash(x) => MessageHash(x.into_owned()),
2650            Unknown(x) => Unknown(x.into_owned()),
2651        }
2652    }
2653}
2654
2655#[derive(Debug)]
2656pub struct HandshakeMessagePayload<'a> {
2657    pub typ: HandshakeType,
2658    pub payload: HandshakePayload<'a>,
2659}
2660
2661impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2662    fn encode(&self, bytes: &mut Vec<u8>) {
2663        self.payload_encode(bytes, Encoding::Standard);
2664    }
2665
2666    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2667        Self::read_version(r, ProtocolVersion::TLSv1_2)
2668    }
2669}
2670
2671impl<'a> HandshakeMessagePayload<'a> {
2672    pub(crate) fn read_version(
2673        r: &mut Reader<'a>,
2674        vers: ProtocolVersion,
2675    ) -> Result<Self, InvalidMessage> {
2676        let mut typ = HandshakeType::read(r)?;
2677        let len = codec::u24::read(r)?.0 as usize;
2678        let mut sub = r.sub(len)?;
2679
2680        let payload = match typ {
2681            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2682            HandshakeType::ClientHello => {
2683                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2684            }
2685            HandshakeType::ServerHello => {
2686                let version = ProtocolVersion::read(&mut sub)?;
2687                let random = Random::read(&mut sub)?;
2688
2689                if random == HELLO_RETRY_REQUEST_RANDOM {
2690                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2691                    hrr.legacy_version = version;
2692                    typ = HandshakeType::HelloRetryRequest;
2693                    HandshakePayload::HelloRetryRequest(hrr)
2694                } else {
2695                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2696                    shp.legacy_version = version;
2697                    shp.random = random;
2698                    HandshakePayload::ServerHello(shp)
2699                }
2700            }
2701            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2702                let p = CertificatePayloadTls13::read(&mut sub)?;
2703                HandshakePayload::CertificateTls13(p)
2704            }
2705            HandshakeType::Certificate => {
2706                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2707            }
2708            HandshakeType::ServerKeyExchange => {
2709                let p = ServerKeyExchangePayload::read(&mut sub)?;
2710                HandshakePayload::ServerKeyExchange(p)
2711            }
2712            HandshakeType::ServerHelloDone => {
2713                sub.expect_empty("ServerHelloDone")?;
2714                HandshakePayload::ServerHelloDone
2715            }
2716            HandshakeType::ClientKeyExchange => {
2717                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2718            }
2719            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2720                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2721                HandshakePayload::CertificateRequestTls13(p)
2722            }
2723            HandshakeType::CertificateRequest => {
2724                let p = CertificateRequestPayload::read(&mut sub)?;
2725                HandshakePayload::CertificateRequest(p)
2726            }
2727            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2728                CompressedCertificatePayload::read(&mut sub)?,
2729            ),
2730            HandshakeType::CertificateVerify => {
2731                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2732            }
2733            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2734                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2735                HandshakePayload::NewSessionTicketTls13(p)
2736            }
2737            HandshakeType::NewSessionTicket => {
2738                let p = NewSessionTicketPayload::read(&mut sub)?;
2739                HandshakePayload::NewSessionTicket(p)
2740            }
2741            HandshakeType::EncryptedExtensions => {
2742                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2743            }
2744            HandshakeType::KeyUpdate => {
2745                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2746            }
2747            HandshakeType::EndOfEarlyData => {
2748                sub.expect_empty("EndOfEarlyData")?;
2749                HandshakePayload::EndOfEarlyData
2750            }
2751            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2752            HandshakeType::CertificateStatus => {
2753                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2754            }
2755            HandshakeType::MessageHash => {
2756                // does not appear on the wire
2757                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2758            }
2759            HandshakeType::HelloRetryRequest => {
2760                // not legal on wire
2761                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2762            }
2763            _ => HandshakePayload::Unknown(Payload::read(&mut sub)),
2764        };
2765
2766        sub.expect_empty("HandshakeMessagePayload")
2767            .map(|_| Self { typ, payload })
2768    }
2769
2770    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2771        let mut ret = self.get_encoding();
2772        let ret_len = ret.len() - self.total_binder_length();
2773        ret.truncate(ret_len);
2774        ret
2775    }
2776
2777    pub(crate) fn total_binder_length(&self) -> usize {
2778        match &self.payload {
2779            HandshakePayload::ClientHello(ch) => match ch.extensions.last() {
2780                Some(ClientExtension::PresharedKey(offer)) => {
2781                    let mut binders_encoding = Vec::new();
2782                    offer
2783                        .binders
2784                        .encode(&mut binders_encoding);
2785                    binders_encoding.len()
2786                }
2787                _ => 0,
2788            },
2789            _ => 0,
2790        }
2791    }
2792
2793    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2794        // output type, length, and encoded payload
2795        match self.typ {
2796            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2797            _ => self.typ,
2798        }
2799        .encode(bytes);
2800
2801        let nested = LengthPrefixedBuffer::new(
2802            ListLength::U24 {
2803                max: usize::MAX,
2804                error: InvalidMessage::MessageTooLarge,
2805            },
2806            bytes,
2807        );
2808
2809        match &self.payload {
2810            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
2811            // differently based on the purpose of the encoding.
2812            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
2813            HandshakePayload::HelloRetryRequest(payload) => {
2814                payload.payload_encode(nested.buf, encoding)
2815            }
2816
2817            // All other payload types are encoded the same regardless of purpose.
2818            _ => self.payload.encode(nested.buf),
2819        }
2820    }
2821
2822    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2823        Self {
2824            typ: HandshakeType::MessageHash,
2825            payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())),
2826        }
2827    }
2828
2829    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
2830        let Self { typ, payload } = self;
2831        HandshakeMessagePayload {
2832            typ,
2833            payload: payload.into_owned(),
2834        }
2835    }
2836}
2837
2838#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
2839pub struct HpkeSymmetricCipherSuite {
2840    pub kdf_id: HpkeKdf,
2841    pub aead_id: HpkeAead,
2842}
2843
2844impl Codec<'_> for HpkeSymmetricCipherSuite {
2845    fn encode(&self, bytes: &mut Vec<u8>) {
2846        self.kdf_id.encode(bytes);
2847        self.aead_id.encode(bytes);
2848    }
2849
2850    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2851        Ok(Self {
2852            kdf_id: HpkeKdf::read(r)?,
2853            aead_id: HpkeAead::read(r)?,
2854        })
2855    }
2856}
2857
2858impl TlsListElement for HpkeSymmetricCipherSuite {
2859    const SIZE_LEN: ListLength = ListLength::U16;
2860}
2861
2862#[derive(Clone, Debug, PartialEq)]
2863pub struct HpkeKeyConfig {
2864    pub config_id: u8,
2865    pub kem_id: HpkeKem,
2866    pub public_key: PayloadU16,
2867    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
2868}
2869
2870impl Codec<'_> for HpkeKeyConfig {
2871    fn encode(&self, bytes: &mut Vec<u8>) {
2872        self.config_id.encode(bytes);
2873        self.kem_id.encode(bytes);
2874        self.public_key.encode(bytes);
2875        self.symmetric_cipher_suites
2876            .encode(bytes);
2877    }
2878
2879    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2880        Ok(Self {
2881            config_id: u8::read(r)?,
2882            kem_id: HpkeKem::read(r)?,
2883            public_key: PayloadU16::read(r)?,
2884            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
2885        })
2886    }
2887}
2888
2889#[derive(Clone, Debug, PartialEq)]
2890pub struct EchConfigContents {
2891    pub key_config: HpkeKeyConfig,
2892    pub maximum_name_length: u8,
2893    pub public_name: DnsName<'static>,
2894    pub extensions: Vec<EchConfigExtension>,
2895}
2896
2897impl EchConfigContents {
2898    /// Returns true if there is more than one extension of a given
2899    /// type.
2900    pub(crate) fn has_duplicate_extension(&self) -> bool {
2901        has_duplicates::<_, _, u16>(
2902            self.extensions
2903                .iter()
2904                .map(|ext| ext.ext_type()),
2905        )
2906    }
2907
2908    /// Returns true if there is at least one mandatory unsupported extension.
2909    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
2910        self.extensions
2911            .iter()
2912            // An extension is considered mandatory if the high bit of its type is set.
2913            .any(|ext| {
2914                matches!(ext.ext_type(), ExtensionType::Unknown(_))
2915                    && u16::from(ext.ext_type()) & 0x8000 != 0
2916            })
2917    }
2918}
2919
2920impl Codec<'_> for EchConfigContents {
2921    fn encode(&self, bytes: &mut Vec<u8>) {
2922        self.key_config.encode(bytes);
2923        self.maximum_name_length.encode(bytes);
2924        let dns_name = &self.public_name.borrow();
2925        PayloadU8::encode_slice(dns_name.as_ref().as_ref(), bytes);
2926        self.extensions.encode(bytes);
2927    }
2928
2929    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2930        Ok(Self {
2931            key_config: HpkeKeyConfig::read(r)?,
2932            maximum_name_length: u8::read(r)?,
2933            public_name: {
2934                DnsName::try_from(PayloadU8::read(r)?.0.as_slice())
2935                    .map_err(|_| InvalidMessage::InvalidServerName)?
2936                    .to_owned()
2937            },
2938            extensions: Vec::read(r)?,
2939        })
2940    }
2941}
2942
2943/// An encrypted client hello (ECH) config.
2944#[derive(Clone, Debug, PartialEq)]
2945pub enum EchConfigPayload {
2946    /// A recognized V18 ECH configuration.
2947    V18(EchConfigContents),
2948    /// An unknown version ECH configuration.
2949    Unknown {
2950        version: EchVersion,
2951        contents: PayloadU16,
2952    },
2953}
2954
2955impl TlsListElement for EchConfigPayload {
2956    const SIZE_LEN: ListLength = ListLength::U16;
2957}
2958
2959impl Codec<'_> for EchConfigPayload {
2960    fn encode(&self, bytes: &mut Vec<u8>) {
2961        match self {
2962            Self::V18(c) => {
2963                // Write the version, the length, and the contents.
2964                EchVersion::V18.encode(bytes);
2965                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2966                c.encode(inner.buf);
2967            }
2968            Self::Unknown { version, contents } => {
2969                // Unknown configuration versions are opaque.
2970                version.encode(bytes);
2971                contents.encode(bytes);
2972            }
2973        }
2974    }
2975
2976    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2977        let version = EchVersion::read(r)?;
2978        let length = u16::read(r)?;
2979        let mut contents = r.sub(length as usize)?;
2980
2981        Ok(match version {
2982            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
2983            _ => {
2984                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
2985                let data = PayloadU16::new(contents.rest().into());
2986                Self::Unknown {
2987                    version,
2988                    contents: data,
2989                }
2990            }
2991        })
2992    }
2993}
2994
2995#[derive(Clone, Debug, PartialEq)]
2996pub enum EchConfigExtension {
2997    Unknown(UnknownExtension),
2998}
2999
3000impl EchConfigExtension {
3001    pub(crate) fn ext_type(&self) -> ExtensionType {
3002        match self {
3003            Self::Unknown(r) => r.typ,
3004        }
3005    }
3006}
3007
3008impl Codec<'_> for EchConfigExtension {
3009    fn encode(&self, bytes: &mut Vec<u8>) {
3010        self.ext_type().encode(bytes);
3011
3012        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
3013        match self {
3014            Self::Unknown(r) => r.encode(nested.buf),
3015        }
3016    }
3017
3018    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3019        let typ = ExtensionType::read(r)?;
3020        let len = u16::read(r)? as usize;
3021        let mut sub = r.sub(len)?;
3022
3023        #[allow(clippy::match_single_binding)] // Future-proofing.
3024        let ext = match typ {
3025            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
3026        };
3027
3028        sub.expect_empty("EchConfigExtension")
3029            .map(|_| ext)
3030    }
3031}
3032
3033impl TlsListElement for EchConfigExtension {
3034    const SIZE_LEN: ListLength = ListLength::U16;
3035}
3036
3037/// Representation of the `ECHClientHello` client extension specified in
3038/// [draft-ietf-tls-esni Section 5].
3039///
3040/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3041#[derive(Clone, Debug)]
3042pub enum EncryptedClientHello {
3043    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
3044    Outer(EncryptedClientHelloOuter),
3045    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
3046    ///
3047    /// This variant has no payload.
3048    Inner,
3049}
3050
3051impl Codec<'_> for EncryptedClientHello {
3052    fn encode(&self, bytes: &mut Vec<u8>) {
3053        match self {
3054            Self::Outer(payload) => {
3055                EchClientHelloType::ClientHelloOuter.encode(bytes);
3056                payload.encode(bytes);
3057            }
3058            Self::Inner => {
3059                EchClientHelloType::ClientHelloInner.encode(bytes);
3060                // Empty payload.
3061            }
3062        }
3063    }
3064
3065    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3066        match EchClientHelloType::read(r)? {
3067            EchClientHelloType::ClientHelloOuter => {
3068                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
3069            }
3070            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
3071            _ => Err(InvalidMessage::InvalidContentType),
3072        }
3073    }
3074}
3075
3076/// Representation of the ECHClientHello extension with type outer specified in
3077/// [draft-ietf-tls-esni Section 5].
3078///
3079/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3080#[derive(Clone, Debug)]
3081pub struct EncryptedClientHelloOuter {
3082    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3083    /// ECHConfigContents.cipher_suites list.
3084    pub cipher_suite: HpkeSymmetricCipherSuite,
3085    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3086    pub config_id: u8,
3087    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3088    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3089    pub enc: PayloadU16,
3090    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3091    pub payload: PayloadU16,
3092}
3093
3094impl Codec<'_> for EncryptedClientHelloOuter {
3095    fn encode(&self, bytes: &mut Vec<u8>) {
3096        self.cipher_suite.encode(bytes);
3097        self.config_id.encode(bytes);
3098        self.enc.encode(bytes);
3099        self.payload.encode(bytes);
3100    }
3101
3102    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3103        Ok(Self {
3104            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3105            config_id: u8::read(r)?,
3106            enc: PayloadU16::read(r)?,
3107            payload: PayloadU16::read(r)?,
3108        })
3109    }
3110}
3111
3112/// Representation of the ECHEncryptedExtensions extension specified in
3113/// [draft-ietf-tls-esni Section 5].
3114///
3115/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3116#[derive(Clone, Debug)]
3117pub struct ServerEncryptedClientHello {
3118    pub(crate) retry_configs: Vec<EchConfigPayload>,
3119}
3120
3121impl Codec<'_> for ServerEncryptedClientHello {
3122    fn encode(&self, bytes: &mut Vec<u8>) {
3123        self.retry_configs.encode(bytes);
3124    }
3125
3126    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3127        Ok(Self {
3128            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3129        })
3130    }
3131}
3132
3133/// The method of encoding to use for a handshake message.
3134///
3135/// In some cases a handshake message may be encoded differently depending on the purpose
3136/// the encoded message is being used for. For example, a [ServerHelloPayload] may be encoded
3137/// with the last 8 bytes of the random zeroed out when being encoded for ECH confirmation.
3138pub(crate) enum Encoding {
3139    /// Standard RFC 8446 encoding.
3140    Standard,
3141    /// Encoding for ECH confirmation.
3142    EchConfirmation,
3143    /// Encoding for ECH inner client hello.
3144    EchInnerHello { to_compress: Vec<ExtensionType> },
3145}
3146
3147fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3148    let mut seen = BTreeSet::new();
3149
3150    for x in iter {
3151        if !seen.insert(x.into()) {
3152            return true;
3153        }
3154    }
3155
3156    false
3157}
3158
3159#[cfg(test)]
3160mod tests {
3161    use super::*;
3162
3163    #[test]
3164    fn test_ech_config_dupe_exts() {
3165        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3166            typ: ExtensionType::Unknown(0x42),
3167            payload: Payload::new(vec![0x42]),
3168        });
3169        let mut config = config_template();
3170        config
3171            .extensions
3172            .push(unknown_ext.clone());
3173        config.extensions.push(unknown_ext);
3174
3175        assert!(config.has_duplicate_extension());
3176        assert!(!config.has_unknown_mandatory_extension());
3177    }
3178
3179    #[test]
3180    fn test_ech_config_mandatory_exts() {
3181        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3182            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3183            payload: Payload::new(vec![0x42]),
3184        });
3185        let mut config = config_template();
3186        config
3187            .extensions
3188            .push(mandatory_unknown_ext);
3189
3190        assert!(!config.has_duplicate_extension());
3191        assert!(config.has_unknown_mandatory_extension());
3192    }
3193
3194    fn config_template() -> EchConfigContents {
3195        EchConfigContents {
3196            key_config: HpkeKeyConfig {
3197                config_id: 0,
3198                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3199                public_key: PayloadU16(b"xxx".into()),
3200                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3201                    kdf_id: HpkeKdf::HKDF_SHA256,
3202                    aead_id: HpkeAead::AES_128_GCM,
3203                }],
3204            },
3205            maximum_name_length: 0,
3206            public_name: DnsName::try_from("example.com").unwrap(),
3207            extensions: vec![],
3208        }
3209    }
3210}