1use alloc::vec::Vec;
2use core::cmp;
3
4use pki_types::{DnsName, UnixTime};
5use zeroize::Zeroizing;
6
7use crate::client::ResolvesClientCert;
8use crate::enums::{CipherSuite, ProtocolVersion};
9use crate::error::InvalidMessage;
10use crate::msgs::base::{PayloadU8, PayloadU16};
11use crate::msgs::codec::{Codec, Reader};
12use crate::msgs::handshake::CertificateChain;
13#[cfg(feature = "tls12")]
14use crate::msgs::handshake::SessionId;
15use crate::sync::{Arc, Weak};
16#[cfg(feature = "tls12")]
17use crate::tls12::Tls12CipherSuite;
18use crate::tls13::Tls13CipherSuite;
19use crate::verify::ServerCertVerifier;
20
21pub(crate) struct Retrieved<T> {
22 pub(crate) value: T,
23 retrieved_at: UnixTime,
24}
25
26impl<T> Retrieved<T> {
27 pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
28 Self {
29 value,
30 retrieved_at,
31 }
32 }
33
34 pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
35 Some(Retrieved {
36 value: f(&self.value)?,
37 retrieved_at: self.retrieved_at,
38 })
39 }
40}
41
42impl Retrieved<&Tls13ClientSessionValue> {
43 pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
44 let age_secs = self
45 .retrieved_at
46 .as_secs()
47 .saturating_sub(self.value.common.epoch);
48 let age_millis = age_secs as u32 * 1000;
49 age_millis.wrapping_add(self.value.age_add)
50 }
51}
52
53impl<T: core::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
54 pub(crate) fn has_expired(&self) -> bool {
55 let common = &*self.value;
56 common.lifetime_secs != 0
57 && common
58 .epoch
59 .saturating_add(u64::from(common.lifetime_secs))
60 < self.retrieved_at.as_secs()
61 }
62}
63
64impl<T> core::ops::Deref for Retrieved<T> {
65 type Target = T;
66
67 fn deref(&self) -> &Self::Target {
68 &self.value
69 }
70}
71
72#[derive(Debug)]
73pub struct Tls13ClientSessionValue {
74 suite: &'static Tls13CipherSuite,
75 age_add: u32,
76 max_early_data_size: u32,
77 pub(crate) common: ClientSessionCommon,
78 quic_params: PayloadU16,
79}
80
81impl Tls13ClientSessionValue {
82 pub(crate) fn new(
83 suite: &'static Tls13CipherSuite,
84 ticket: Arc<PayloadU16>,
85 secret: &[u8],
86 server_cert_chain: CertificateChain<'static>,
87 server_cert_verifier: &Arc<dyn ServerCertVerifier>,
88 client_creds: &Arc<dyn ResolvesClientCert>,
89 time_now: UnixTime,
90 lifetime_secs: u32,
91 age_add: u32,
92 max_early_data_size: u32,
93 ) -> Self {
94 Self {
95 suite,
96 age_add,
97 max_early_data_size,
98 common: ClientSessionCommon::new(
99 ticket,
100 secret,
101 time_now,
102 lifetime_secs,
103 server_cert_chain,
104 server_cert_verifier,
105 client_creds,
106 ),
107 quic_params: PayloadU16(Vec::new()),
108 }
109 }
110
111 pub fn max_early_data_size(&self) -> u32 {
112 self.max_early_data_size
113 }
114
115 pub fn suite(&self) -> &'static Tls13CipherSuite {
116 self.suite
117 }
118
119 #[doc(hidden)]
120 pub fn rewind_epoch(&mut self, delta: u32) {
122 self.common.epoch -= delta as u64;
123 }
124
125 #[doc(hidden)]
126 pub fn _private_set_max_early_data_size(&mut self, new: u32) {
128 self.max_early_data_size = new;
129 }
130
131 pub fn set_quic_params(&mut self, quic_params: &[u8]) {
132 self.quic_params = PayloadU16(quic_params.to_vec());
133 }
134
135 pub fn quic_params(&self) -> Vec<u8> {
136 self.quic_params.0.clone()
137 }
138}
139
140impl core::ops::Deref for Tls13ClientSessionValue {
141 type Target = ClientSessionCommon;
142
143 fn deref(&self) -> &Self::Target {
144 &self.common
145 }
146}
147
148#[derive(Debug, Clone)]
149pub struct Tls12ClientSessionValue {
150 #[cfg(feature = "tls12")]
151 suite: &'static Tls12CipherSuite,
152 #[cfg(feature = "tls12")]
153 pub(crate) session_id: SessionId,
154 #[cfg(feature = "tls12")]
155 extended_ms: bool,
156 #[doc(hidden)]
157 #[cfg(feature = "tls12")]
158 pub(crate) common: ClientSessionCommon,
159}
160
161#[cfg(feature = "tls12")]
162impl Tls12ClientSessionValue {
163 pub(crate) fn new(
164 suite: &'static Tls12CipherSuite,
165 session_id: SessionId,
166 ticket: Arc<PayloadU16>,
167 master_secret: &[u8],
168 server_cert_chain: CertificateChain<'static>,
169 server_cert_verifier: &Arc<dyn ServerCertVerifier>,
170 client_creds: &Arc<dyn ResolvesClientCert>,
171 time_now: UnixTime,
172 lifetime_secs: u32,
173 extended_ms: bool,
174 ) -> Self {
175 Self {
176 suite,
177 session_id,
178 extended_ms,
179 common: ClientSessionCommon::new(
180 ticket,
181 master_secret,
182 time_now,
183 lifetime_secs,
184 server_cert_chain,
185 server_cert_verifier,
186 client_creds,
187 ),
188 }
189 }
190
191 pub(crate) fn ticket(&mut self) -> Arc<PayloadU16> {
192 Arc::clone(&self.common.ticket)
193 }
194
195 pub(crate) fn extended_ms(&self) -> bool {
196 self.extended_ms
197 }
198
199 pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
200 self.suite
201 }
202
203 #[doc(hidden)]
204 pub fn rewind_epoch(&mut self, delta: u32) {
206 self.common.epoch -= delta as u64;
207 }
208}
209
210#[cfg(feature = "tls12")]
211impl core::ops::Deref for Tls12ClientSessionValue {
212 type Target = ClientSessionCommon;
213
214 fn deref(&self) -> &Self::Target {
215 &self.common
216 }
217}
218
219#[derive(Debug, Clone)]
220pub struct ClientSessionCommon {
221 ticket: Arc<PayloadU16>,
222 secret: Zeroizing<PayloadU8>,
223 epoch: u64,
224 lifetime_secs: u32,
225 server_cert_chain: Arc<CertificateChain<'static>>,
226 server_cert_verifier: Weak<dyn ServerCertVerifier>,
227 client_creds: Weak<dyn ResolvesClientCert>,
228}
229
230impl ClientSessionCommon {
231 fn new(
232 ticket: Arc<PayloadU16>,
233 secret: &[u8],
234 time_now: UnixTime,
235 lifetime_secs: u32,
236 server_cert_chain: CertificateChain<'static>,
237 server_cert_verifier: &Arc<dyn ServerCertVerifier>,
238 client_creds: &Arc<dyn ResolvesClientCert>,
239 ) -> Self {
240 Self {
241 ticket,
242 secret: Zeroizing::new(PayloadU8(secret.to_vec())),
243 epoch: time_now.as_secs(),
244 lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
245 server_cert_chain: Arc::new(server_cert_chain),
246 server_cert_verifier: Arc::downgrade(server_cert_verifier),
247 client_creds: Arc::downgrade(client_creds),
248 }
249 }
250
251 pub(crate) fn compatible_config(
252 &self,
253 server_cert_verifier: &Arc<dyn ServerCertVerifier>,
254 client_creds: &Arc<dyn ResolvesClientCert>,
255 ) -> bool {
256 let same_verifier = Weak::ptr_eq(
257 &Arc::downgrade(server_cert_verifier),
258 &self.server_cert_verifier,
259 );
260 let same_creds = Weak::ptr_eq(&Arc::downgrade(client_creds), &self.client_creds);
261
262 match (same_verifier, same_creds) {
263 (true, true) => true,
264 (false, _) => {
265 crate::log::trace!("resumption not allowed between different ServerCertVerifiers");
266 false
267 }
268 (_, _) => {
269 crate::log::trace!(
270 "resumption not allowed between different ResolvesClientCert values"
271 );
272 false
273 }
274 }
275 }
276
277 pub(crate) fn server_cert_chain(&self) -> &CertificateChain<'static> {
278 &self.server_cert_chain
279 }
280
281 pub(crate) fn secret(&self) -> &[u8] {
282 self.secret.0.as_ref()
283 }
284
285 pub(crate) fn ticket(&self) -> &[u8] {
286 self.ticket.0.as_ref()
287 }
288}
289
290static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
291
292static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
297
298#[derive(Debug)]
300pub struct ServerSessionValue {
301 pub(crate) sni: Option<DnsName<'static>>,
302 pub(crate) version: ProtocolVersion,
303 pub(crate) cipher_suite: CipherSuite,
304 pub(crate) master_secret: Zeroizing<PayloadU8>,
305 pub(crate) extended_ms: bool,
306 pub(crate) client_cert_chain: Option<CertificateChain<'static>>,
307 pub(crate) alpn: Option<PayloadU8>,
308 pub(crate) application_data: PayloadU16,
309 pub creation_time_sec: u64,
310 pub(crate) age_obfuscation_offset: u32,
311 freshness: Option<bool>,
312}
313
314impl Codec<'_> for ServerSessionValue {
315 fn encode(&self, bytes: &mut Vec<u8>) {
316 if let Some(sni) = &self.sni {
317 1u8.encode(bytes);
318 let sni_bytes: &str = sni.as_ref();
319 PayloadU8::new(Vec::from(sni_bytes)).encode(bytes);
320 } else {
321 0u8.encode(bytes);
322 }
323 self.version.encode(bytes);
324 self.cipher_suite.encode(bytes);
325 self.master_secret.encode(bytes);
326 (u8::from(self.extended_ms)).encode(bytes);
327 if let Some(chain) = &self.client_cert_chain {
328 1u8.encode(bytes);
329 chain.encode(bytes);
330 } else {
331 0u8.encode(bytes);
332 }
333 if let Some(alpn) = &self.alpn {
334 1u8.encode(bytes);
335 alpn.encode(bytes);
336 } else {
337 0u8.encode(bytes);
338 }
339 self.application_data.encode(bytes);
340 self.creation_time_sec.encode(bytes);
341 self.age_obfuscation_offset
342 .encode(bytes);
343 }
344
345 fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
346 let has_sni = u8::read(r)?;
347 let sni = if has_sni == 1 {
348 let dns_name = PayloadU8::read(r)?;
349 let dns_name = match DnsName::try_from(dns_name.0.as_slice()) {
350 Ok(dns_name) => dns_name.to_owned(),
351 Err(_) => return Err(InvalidMessage::InvalidServerName),
352 };
353
354 Some(dns_name)
355 } else {
356 None
357 };
358
359 let v = ProtocolVersion::read(r)?;
360 let cs = CipherSuite::read(r)?;
361 let ms = Zeroizing::new(PayloadU8::read(r)?);
362 let ems = u8::read(r)?;
363 let has_ccert = u8::read(r)? == 1;
364 let ccert = if has_ccert {
365 Some(CertificateChain::read(r)?.into_owned())
366 } else {
367 None
368 };
369 let has_alpn = u8::read(r)? == 1;
370 let alpn = if has_alpn {
371 Some(PayloadU8::read(r)?)
372 } else {
373 None
374 };
375 let application_data = PayloadU16::read(r)?;
376 let creation_time_sec = u64::read(r)?;
377 let age_obfuscation_offset = u32::read(r)?;
378
379 Ok(Self {
380 sni,
381 version: v,
382 cipher_suite: cs,
383 master_secret: ms,
384 extended_ms: ems == 1u8,
385 client_cert_chain: ccert,
386 alpn,
387 application_data,
388 creation_time_sec,
389 age_obfuscation_offset,
390 freshness: None,
391 })
392 }
393}
394
395impl ServerSessionValue {
396 pub(crate) fn new(
397 sni: Option<&DnsName<'_>>,
398 v: ProtocolVersion,
399 cs: CipherSuite,
400 ms: &[u8],
401 client_cert_chain: Option<CertificateChain<'static>>,
402 alpn: Option<Vec<u8>>,
403 application_data: Vec<u8>,
404 creation_time: UnixTime,
405 age_obfuscation_offset: u32,
406 ) -> Self {
407 Self {
408 sni: sni.map(|dns| dns.to_owned()),
409 version: v,
410 cipher_suite: cs,
411 master_secret: Zeroizing::new(PayloadU8::new(ms.to_vec())),
412 extended_ms: false,
413 client_cert_chain,
414 alpn: alpn.map(PayloadU8::new),
415 application_data: PayloadU16::new(application_data),
416 creation_time_sec: creation_time.as_secs(),
417 age_obfuscation_offset,
418 freshness: None,
419 }
420 }
421
422 #[cfg(feature = "tls12")]
423 pub(crate) fn set_extended_ms_used(&mut self) {
424 self.extended_ms = true;
425 }
426
427 pub(crate) fn set_freshness(
428 mut self,
429 obfuscated_client_age_ms: u32,
430 time_now: UnixTime,
431 ) -> Self {
432 let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
433 let server_age_ms = (time_now
434 .as_secs()
435 .saturating_sub(self.creation_time_sec) as u32)
436 .saturating_mul(1000);
437
438 let age_difference = if client_age_ms < server_age_ms {
439 server_age_ms - client_age_ms
440 } else {
441 client_age_ms - server_age_ms
442 };
443
444 self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
445 self
446 }
447
448 pub(crate) fn is_fresh(&self) -> bool {
449 self.freshness.unwrap_or_default()
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[cfg(feature = "std")] #[test]
459 fn serversessionvalue_is_debug() {
460 use std::{println, vec};
461 let ssv = ServerSessionValue::new(
462 None,
463 ProtocolVersion::TLSv1_3,
464 CipherSuite::TLS13_AES_128_GCM_SHA256,
465 &[1, 2, 3],
466 None,
467 None,
468 vec![4, 5, 6],
469 UnixTime::now(),
470 0x12345678,
471 );
472 println!("{:?}", ssv);
473 }
474
475 #[test]
476 fn serversessionvalue_no_sni() {
477 let bytes = [
478 0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
479 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
480 ];
481 let mut rd = Reader::init(&bytes);
482 let ssv = ServerSessionValue::read(&mut rd).unwrap();
483 assert_eq!(ssv.get_encoding(), bytes);
484 }
485
486 #[test]
487 fn serversessionvalue_with_cert() {
488 let bytes = [
489 0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
490 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
491 ];
492 let mut rd = Reader::init(&bytes);
493 let ssv = ServerSessionValue::read(&mut rd).unwrap();
494 assert_eq!(ssv.get_encoding(), bytes);
495 }
496}