rustls/msgs/
persist.rs

1use alloc::vec::Vec;
2use core::cmp;
3
4use pki_types::{DnsName, UnixTime};
5use zeroize::Zeroizing;
6
7use crate::client::ResolvesClientCert;
8use crate::enums::{CipherSuite, ProtocolVersion};
9use crate::error::InvalidMessage;
10use crate::msgs::base::{PayloadU8, PayloadU16};
11use crate::msgs::codec::{Codec, Reader};
12use crate::msgs::handshake::CertificateChain;
13#[cfg(feature = "tls12")]
14use crate::msgs::handshake::SessionId;
15use crate::sync::{Arc, Weak};
16#[cfg(feature = "tls12")]
17use crate::tls12::Tls12CipherSuite;
18use crate::tls13::Tls13CipherSuite;
19use crate::verify::ServerCertVerifier;
20
21pub(crate) struct Retrieved<T> {
22    pub(crate) value: T,
23    retrieved_at: UnixTime,
24}
25
26impl<T> Retrieved<T> {
27    pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
28        Self {
29            value,
30            retrieved_at,
31        }
32    }
33
34    pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
35        Some(Retrieved {
36            value: f(&self.value)?,
37            retrieved_at: self.retrieved_at,
38        })
39    }
40}
41
42impl Retrieved<&Tls13ClientSessionValue> {
43    pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
44        let age_secs = self
45            .retrieved_at
46            .as_secs()
47            .saturating_sub(self.value.common.epoch);
48        let age_millis = age_secs as u32 * 1000;
49        age_millis.wrapping_add(self.value.age_add)
50    }
51}
52
53impl<T: core::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
54    pub(crate) fn has_expired(&self) -> bool {
55        let common = &*self.value;
56        common.lifetime_secs != 0
57            && common
58                .epoch
59                .saturating_add(u64::from(common.lifetime_secs))
60                < self.retrieved_at.as_secs()
61    }
62}
63
64impl<T> core::ops::Deref for Retrieved<T> {
65    type Target = T;
66
67    fn deref(&self) -> &Self::Target {
68        &self.value
69    }
70}
71
72#[derive(Debug)]
73pub struct Tls13ClientSessionValue {
74    suite: &'static Tls13CipherSuite,
75    age_add: u32,
76    max_early_data_size: u32,
77    pub(crate) common: ClientSessionCommon,
78    quic_params: PayloadU16,
79}
80
81impl Tls13ClientSessionValue {
82    pub(crate) fn new(
83        suite: &'static Tls13CipherSuite,
84        ticket: Arc<PayloadU16>,
85        secret: &[u8],
86        server_cert_chain: CertificateChain<'static>,
87        server_cert_verifier: &Arc<dyn ServerCertVerifier>,
88        client_creds: &Arc<dyn ResolvesClientCert>,
89        time_now: UnixTime,
90        lifetime_secs: u32,
91        age_add: u32,
92        max_early_data_size: u32,
93    ) -> Self {
94        Self {
95            suite,
96            age_add,
97            max_early_data_size,
98            common: ClientSessionCommon::new(
99                ticket,
100                secret,
101                time_now,
102                lifetime_secs,
103                server_cert_chain,
104                server_cert_verifier,
105                client_creds,
106            ),
107            quic_params: PayloadU16(Vec::new()),
108        }
109    }
110
111    pub fn max_early_data_size(&self) -> u32 {
112        self.max_early_data_size
113    }
114
115    pub fn suite(&self) -> &'static Tls13CipherSuite {
116        self.suite
117    }
118
119    #[doc(hidden)]
120    /// Test only: rewind epoch by `delta` seconds.
121    pub fn rewind_epoch(&mut self, delta: u32) {
122        self.common.epoch -= delta as u64;
123    }
124
125    #[doc(hidden)]
126    /// Test only: replace `max_early_data_size` with `new`
127    pub fn _private_set_max_early_data_size(&mut self, new: u32) {
128        self.max_early_data_size = new;
129    }
130
131    pub fn set_quic_params(&mut self, quic_params: &[u8]) {
132        self.quic_params = PayloadU16(quic_params.to_vec());
133    }
134
135    pub fn quic_params(&self) -> Vec<u8> {
136        self.quic_params.0.clone()
137    }
138}
139
140impl core::ops::Deref for Tls13ClientSessionValue {
141    type Target = ClientSessionCommon;
142
143    fn deref(&self) -> &Self::Target {
144        &self.common
145    }
146}
147
148#[derive(Debug, Clone)]
149pub struct Tls12ClientSessionValue {
150    #[cfg(feature = "tls12")]
151    suite: &'static Tls12CipherSuite,
152    #[cfg(feature = "tls12")]
153    pub(crate) session_id: SessionId,
154    #[cfg(feature = "tls12")]
155    extended_ms: bool,
156    #[doc(hidden)]
157    #[cfg(feature = "tls12")]
158    pub(crate) common: ClientSessionCommon,
159}
160
161#[cfg(feature = "tls12")]
162impl Tls12ClientSessionValue {
163    pub(crate) fn new(
164        suite: &'static Tls12CipherSuite,
165        session_id: SessionId,
166        ticket: Arc<PayloadU16>,
167        master_secret: &[u8],
168        server_cert_chain: CertificateChain<'static>,
169        server_cert_verifier: &Arc<dyn ServerCertVerifier>,
170        client_creds: &Arc<dyn ResolvesClientCert>,
171        time_now: UnixTime,
172        lifetime_secs: u32,
173        extended_ms: bool,
174    ) -> Self {
175        Self {
176            suite,
177            session_id,
178            extended_ms,
179            common: ClientSessionCommon::new(
180                ticket,
181                master_secret,
182                time_now,
183                lifetime_secs,
184                server_cert_chain,
185                server_cert_verifier,
186                client_creds,
187            ),
188        }
189    }
190
191    pub(crate) fn ticket(&mut self) -> Arc<PayloadU16> {
192        Arc::clone(&self.common.ticket)
193    }
194
195    pub(crate) fn extended_ms(&self) -> bool {
196        self.extended_ms
197    }
198
199    pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
200        self.suite
201    }
202
203    #[doc(hidden)]
204    /// Test only: rewind epoch by `delta` seconds.
205    pub fn rewind_epoch(&mut self, delta: u32) {
206        self.common.epoch -= delta as u64;
207    }
208}
209
210#[cfg(feature = "tls12")]
211impl core::ops::Deref for Tls12ClientSessionValue {
212    type Target = ClientSessionCommon;
213
214    fn deref(&self) -> &Self::Target {
215        &self.common
216    }
217}
218
219#[derive(Debug, Clone)]
220pub struct ClientSessionCommon {
221    ticket: Arc<PayloadU16>,
222    secret: Zeroizing<PayloadU8>,
223    epoch: u64,
224    lifetime_secs: u32,
225    server_cert_chain: Arc<CertificateChain<'static>>,
226    server_cert_verifier: Weak<dyn ServerCertVerifier>,
227    client_creds: Weak<dyn ResolvesClientCert>,
228}
229
230impl ClientSessionCommon {
231    fn new(
232        ticket: Arc<PayloadU16>,
233        secret: &[u8],
234        time_now: UnixTime,
235        lifetime_secs: u32,
236        server_cert_chain: CertificateChain<'static>,
237        server_cert_verifier: &Arc<dyn ServerCertVerifier>,
238        client_creds: &Arc<dyn ResolvesClientCert>,
239    ) -> Self {
240        Self {
241            ticket,
242            secret: Zeroizing::new(PayloadU8(secret.to_vec())),
243            epoch: time_now.as_secs(),
244            lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
245            server_cert_chain: Arc::new(server_cert_chain),
246            server_cert_verifier: Arc::downgrade(server_cert_verifier),
247            client_creds: Arc::downgrade(client_creds),
248        }
249    }
250
251    pub(crate) fn compatible_config(
252        &self,
253        server_cert_verifier: &Arc<dyn ServerCertVerifier>,
254        client_creds: &Arc<dyn ResolvesClientCert>,
255    ) -> bool {
256        let same_verifier = Weak::ptr_eq(
257            &Arc::downgrade(server_cert_verifier),
258            &self.server_cert_verifier,
259        );
260        let same_creds = Weak::ptr_eq(&Arc::downgrade(client_creds), &self.client_creds);
261
262        match (same_verifier, same_creds) {
263            (true, true) => true,
264            (false, _) => {
265                crate::log::trace!("resumption not allowed between different ServerCertVerifiers");
266                false
267            }
268            (_, _) => {
269                crate::log::trace!(
270                    "resumption not allowed between different ResolvesClientCert values"
271                );
272                false
273            }
274        }
275    }
276
277    pub(crate) fn server_cert_chain(&self) -> &CertificateChain<'static> {
278        &self.server_cert_chain
279    }
280
281    pub(crate) fn secret(&self) -> &[u8] {
282        self.secret.0.as_ref()
283    }
284
285    pub(crate) fn ticket(&self) -> &[u8] {
286        self.ticket.0.as_ref()
287    }
288}
289
290static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
291
292/// This is the maximum allowed skew between server and client clocks, over
293/// the maximum ticket lifetime period.  This encompasses TCP retransmission
294/// times in case packet loss occurs when the client sends the ClientHello
295/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
296static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
297
298// --- Server types ---
299#[derive(Debug)]
300pub struct ServerSessionValue {
301    pub(crate) sni: Option<DnsName<'static>>,
302    pub(crate) version: ProtocolVersion,
303    pub(crate) cipher_suite: CipherSuite,
304    pub(crate) master_secret: Zeroizing<PayloadU8>,
305    pub(crate) extended_ms: bool,
306    pub(crate) client_cert_chain: Option<CertificateChain<'static>>,
307    pub(crate) alpn: Option<PayloadU8>,
308    pub(crate) application_data: PayloadU16,
309    pub creation_time_sec: u64,
310    pub(crate) age_obfuscation_offset: u32,
311    freshness: Option<bool>,
312}
313
314impl Codec<'_> for ServerSessionValue {
315    fn encode(&self, bytes: &mut Vec<u8>) {
316        if let Some(sni) = &self.sni {
317            1u8.encode(bytes);
318            let sni_bytes: &str = sni.as_ref();
319            PayloadU8::new(Vec::from(sni_bytes)).encode(bytes);
320        } else {
321            0u8.encode(bytes);
322        }
323        self.version.encode(bytes);
324        self.cipher_suite.encode(bytes);
325        self.master_secret.encode(bytes);
326        (u8::from(self.extended_ms)).encode(bytes);
327        if let Some(chain) = &self.client_cert_chain {
328            1u8.encode(bytes);
329            chain.encode(bytes);
330        } else {
331            0u8.encode(bytes);
332        }
333        if let Some(alpn) = &self.alpn {
334            1u8.encode(bytes);
335            alpn.encode(bytes);
336        } else {
337            0u8.encode(bytes);
338        }
339        self.application_data.encode(bytes);
340        self.creation_time_sec.encode(bytes);
341        self.age_obfuscation_offset
342            .encode(bytes);
343    }
344
345    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
346        let has_sni = u8::read(r)?;
347        let sni = if has_sni == 1 {
348            let dns_name = PayloadU8::read(r)?;
349            let dns_name = match DnsName::try_from(dns_name.0.as_slice()) {
350                Ok(dns_name) => dns_name.to_owned(),
351                Err(_) => return Err(InvalidMessage::InvalidServerName),
352            };
353
354            Some(dns_name)
355        } else {
356            None
357        };
358
359        let v = ProtocolVersion::read(r)?;
360        let cs = CipherSuite::read(r)?;
361        let ms = Zeroizing::new(PayloadU8::read(r)?);
362        let ems = u8::read(r)?;
363        let has_ccert = u8::read(r)? == 1;
364        let ccert = if has_ccert {
365            Some(CertificateChain::read(r)?.into_owned())
366        } else {
367            None
368        };
369        let has_alpn = u8::read(r)? == 1;
370        let alpn = if has_alpn {
371            Some(PayloadU8::read(r)?)
372        } else {
373            None
374        };
375        let application_data = PayloadU16::read(r)?;
376        let creation_time_sec = u64::read(r)?;
377        let age_obfuscation_offset = u32::read(r)?;
378
379        Ok(Self {
380            sni,
381            version: v,
382            cipher_suite: cs,
383            master_secret: ms,
384            extended_ms: ems == 1u8,
385            client_cert_chain: ccert,
386            alpn,
387            application_data,
388            creation_time_sec,
389            age_obfuscation_offset,
390            freshness: None,
391        })
392    }
393}
394
395impl ServerSessionValue {
396    pub(crate) fn new(
397        sni: Option<&DnsName<'_>>,
398        v: ProtocolVersion,
399        cs: CipherSuite,
400        ms: &[u8],
401        client_cert_chain: Option<CertificateChain<'static>>,
402        alpn: Option<Vec<u8>>,
403        application_data: Vec<u8>,
404        creation_time: UnixTime,
405        age_obfuscation_offset: u32,
406    ) -> Self {
407        Self {
408            sni: sni.map(|dns| dns.to_owned()),
409            version: v,
410            cipher_suite: cs,
411            master_secret: Zeroizing::new(PayloadU8::new(ms.to_vec())),
412            extended_ms: false,
413            client_cert_chain,
414            alpn: alpn.map(PayloadU8::new),
415            application_data: PayloadU16::new(application_data),
416            creation_time_sec: creation_time.as_secs(),
417            age_obfuscation_offset,
418            freshness: None,
419        }
420    }
421
422    #[cfg(feature = "tls12")]
423    pub(crate) fn set_extended_ms_used(&mut self) {
424        self.extended_ms = true;
425    }
426
427    pub(crate) fn set_freshness(
428        mut self,
429        obfuscated_client_age_ms: u32,
430        time_now: UnixTime,
431    ) -> Self {
432        let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
433        let server_age_ms = (time_now
434            .as_secs()
435            .saturating_sub(self.creation_time_sec) as u32)
436            .saturating_mul(1000);
437
438        let age_difference = if client_age_ms < server_age_ms {
439            server_age_ms - client_age_ms
440        } else {
441            client_age_ms - server_age_ms
442        };
443
444        self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
445        self
446    }
447
448    pub(crate) fn is_fresh(&self) -> bool {
449        self.freshness.unwrap_or_default()
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[cfg(feature = "std")] // for UnixTime::now
458    #[test]
459    fn serversessionvalue_is_debug() {
460        use std::{println, vec};
461        let ssv = ServerSessionValue::new(
462            None,
463            ProtocolVersion::TLSv1_3,
464            CipherSuite::TLS13_AES_128_GCM_SHA256,
465            &[1, 2, 3],
466            None,
467            None,
468            vec![4, 5, 6],
469            UnixTime::now(),
470            0x12345678,
471        );
472        println!("{:?}", ssv);
473    }
474
475    #[test]
476    fn serversessionvalue_no_sni() {
477        let bytes = [
478            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
479            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
480        ];
481        let mut rd = Reader::init(&bytes);
482        let ssv = ServerSessionValue::read(&mut rd).unwrap();
483        assert_eq!(ssv.get_encoding(), bytes);
484    }
485
486    #[test]
487    fn serversessionvalue_with_cert() {
488        let bytes = [
489            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
490            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
491        ];
492        let mut rd = Reader::init(&bytes);
493        let ssv = ServerSessionValue::read(&mut rd).unwrap();
494        assert_eq!(ssv.get_encoding(), bytes);
495    }
496}