1use pki_types::ServerName;
2
3use crate::enums::SignatureScheme;
4use crate::msgs::persist;
5use crate::sync::Arc;
6use crate::{NamedGroup, client, sign};
7
8#[derive(Debug)]
10pub(super) struct NoClientSessionStorage;
11
12impl client::ClientSessionStore for NoClientSessionStorage {
13 fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {}
14
15 fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> {
16 None
17 }
18
19 fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {}
20
21 fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> {
22 None
23 }
24
25 fn remove_tls12_session(&self, _: &ServerName<'_>) {}
26
27 fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {}
28
29 fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> {
30 None
31 }
32}
33
34#[cfg(any(feature = "std", feature = "hashbrown"))]
35mod cache {
36 use alloc::collections::VecDeque;
37 use core::fmt;
38
39 use pki_types::ServerName;
40
41 use crate::lock::Mutex;
42 use crate::msgs::persist;
43 use crate::{NamedGroup, limited_cache};
44
45 const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
46
47 struct ServerData {
48 kx_hint: Option<NamedGroup>,
49
50 #[cfg(feature = "tls12")]
52 tls12: Option<persist::Tls12ClientSessionValue>,
53
54 tls13: VecDeque<persist::Tls13ClientSessionValue>,
56 }
57
58 impl Default for ServerData {
59 fn default() -> Self {
60 Self {
61 kx_hint: None,
62 #[cfg(feature = "tls12")]
63 tls12: None,
64 tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
65 }
66 }
67 }
68
69 pub struct ClientSessionMemoryCache {
74 servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>,
75 }
76
77 impl ClientSessionMemoryCache {
78 #[cfg(feature = "std")]
81 pub fn new(size: usize) -> Self {
82 let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
83 / MAX_TLS13_TICKETS_PER_SERVER;
84 Self {
85 servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
86 }
87 }
88
89 #[cfg(not(feature = "std"))]
92 pub fn new<M: crate::lock::MakeMutex>(size: usize) -> Self {
93 let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
94 / MAX_TLS13_TICKETS_PER_SERVER;
95 Self {
96 servers: Mutex::new::<M>(limited_cache::LimitedCache::new(max_servers)),
97 }
98 }
99 }
100
101 impl super::client::ClientSessionStore for ClientSessionMemoryCache {
102 fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) {
103 self.servers
104 .lock()
105 .unwrap()
106 .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group));
107 }
108
109 fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> {
110 self.servers
111 .lock()
112 .unwrap()
113 .get(server_name)
114 .and_then(|sd| sd.kx_hint)
115 }
116
117 fn set_tls12_session(
118 &self,
119 _server_name: ServerName<'static>,
120 _value: persist::Tls12ClientSessionValue,
121 ) {
122 #[cfg(feature = "tls12")]
123 self.servers
124 .lock()
125 .unwrap()
126 .get_or_insert_default_and_edit(_server_name.clone(), |data| {
127 data.tls12 = Some(_value)
128 });
129 }
130
131 fn tls12_session(
132 &self,
133 _server_name: &ServerName<'_>,
134 ) -> Option<persist::Tls12ClientSessionValue> {
135 #[cfg(not(feature = "tls12"))]
136 return None;
137
138 #[cfg(feature = "tls12")]
139 self.servers
140 .lock()
141 .unwrap()
142 .get(_server_name)
143 .and_then(|sd| sd.tls12.as_ref().cloned())
144 }
145
146 fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {
147 #[cfg(feature = "tls12")]
148 self.servers
149 .lock()
150 .unwrap()
151 .get_mut(_server_name)
152 .and_then(|data| data.tls12.take());
153 }
154
155 fn insert_tls13_ticket(
156 &self,
157 server_name: ServerName<'static>,
158 value: persist::Tls13ClientSessionValue,
159 ) {
160 self.servers
161 .lock()
162 .unwrap()
163 .get_or_insert_default_and_edit(server_name.clone(), |data| {
164 if data.tls13.len() == data.tls13.capacity() {
165 data.tls13.pop_front();
166 }
167 data.tls13.push_back(value);
168 });
169 }
170
171 fn take_tls13_ticket(
172 &self,
173 server_name: &ServerName<'static>,
174 ) -> Option<persist::Tls13ClientSessionValue> {
175 self.servers
176 .lock()
177 .unwrap()
178 .get_mut(server_name)
179 .and_then(|data| data.tls13.pop_back())
180 }
181 }
182
183 impl fmt::Debug for ClientSessionMemoryCache {
184 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185 f.debug_struct("ClientSessionMemoryCache")
187 .finish()
188 }
189 }
190}
191
192#[cfg(any(feature = "std", feature = "hashbrown"))]
193pub use cache::ClientSessionMemoryCache;
194
195#[derive(Debug)]
196pub(super) struct FailResolveClientCert {}
197
198impl client::ResolvesClientCert for FailResolveClientCert {
199 fn resolve(
200 &self,
201 _root_hint_subjects: &[&[u8]],
202 _sigschemes: &[SignatureScheme],
203 ) -> Option<Arc<sign::CertifiedKey>> {
204 None
205 }
206
207 fn has_certs(&self) -> bool {
208 false
209 }
210}
211
212#[derive(Clone, Debug)]
217pub struct AlwaysResolvesClientRawPublicKeys(Arc<sign::CertifiedKey>);
218impl AlwaysResolvesClientRawPublicKeys {
219 pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Self {
221 Self(certified_key)
222 }
223}
224
225impl client::ResolvesClientCert for AlwaysResolvesClientRawPublicKeys {
226 fn resolve(
227 &self,
228 _root_hint_subjects: &[&[u8]],
229 _sigschemes: &[SignatureScheme],
230 ) -> Option<Arc<sign::CertifiedKey>> {
231 Some(Arc::clone(&self.0))
232 }
233
234 fn only_raw_public_keys(&self) -> bool {
235 true
236 }
237
238 fn has_certs(&self) -> bool {
243 true
244 }
245}
246
247#[cfg(test)]
248#[macro_rules_attribute::apply(test_for_each_provider)]
249mod tests {
250 use std::prelude::v1::*;
251
252 use pki_types::{ServerName, UnixTime};
253
254 use super::NoClientSessionStorage;
255 use super::provider::cipher_suite;
256 use crate::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
257 use crate::client::{ClientSessionStore, ResolvesClientCert};
258 use crate::msgs::base::PayloadU16;
259 use crate::msgs::enums::NamedGroup;
260 use crate::msgs::handshake::CertificateChain;
261 #[cfg(feature = "tls12")]
262 use crate::msgs::handshake::SessionId;
263 use crate::msgs::persist::Tls13ClientSessionValue;
264 use crate::pki_types::CertificateDer;
265 use crate::suites::SupportedCipherSuite;
266 use crate::sync::Arc;
267 use crate::{DigitallySignedStruct, Error, SignatureScheme, sign};
268
269 #[test]
270 fn test_noclientsessionstorage_does_nothing() {
271 let c = NoClientSessionStorage {};
272 let name = ServerName::try_from("example.com").unwrap();
273 let now = UnixTime::now();
274 let server_cert_verifier: Arc<dyn ServerCertVerifier> = Arc::new(DummyServerCertVerifier);
275 let resolves_client_cert: Arc<dyn ResolvesClientCert> = Arc::new(DummyResolvesClientCert);
276
277 c.set_kx_hint(name.clone(), NamedGroup::X25519);
278 assert_eq!(None, c.kx_hint(&name));
279
280 #[cfg(feature = "tls12")]
281 {
282 use crate::msgs::persist::Tls12ClientSessionValue;
283 let SupportedCipherSuite::Tls12(tls12_suite) =
284 cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
285 else {
286 unreachable!()
287 };
288
289 c.set_tls12_session(
290 name.clone(),
291 Tls12ClientSessionValue::new(
292 tls12_suite,
293 SessionId::empty(),
294 Arc::new(PayloadU16::empty()),
295 &[],
296 CertificateChain::default(),
297 &server_cert_verifier,
298 &resolves_client_cert,
299 now,
300 0,
301 true,
302 ),
303 );
304 assert!(c.tls12_session(&name).is_none());
305 c.remove_tls12_session(&name);
306 }
307
308 let SupportedCipherSuite::Tls13(tls13_suite) = cipher_suite::TLS13_AES_256_GCM_SHA384
309 else {
310 unreachable!();
311 };
312 c.insert_tls13_ticket(
313 name.clone(),
314 Tls13ClientSessionValue::new(
315 tls13_suite,
316 Arc::new(PayloadU16::empty()),
317 &[],
318 CertificateChain::default(),
319 &server_cert_verifier,
320 &resolves_client_cert,
321 now,
322 0,
323 0,
324 0,
325 ),
326 );
327 assert!(c.take_tls13_ticket(&name).is_none());
328 }
329
330 #[derive(Debug)]
331 struct DummyServerCertVerifier;
332
333 impl ServerCertVerifier for DummyServerCertVerifier {
334 #[cfg_attr(coverage_nightly, coverage(off))]
335 fn verify_server_cert(
336 &self,
337 _end_entity: &CertificateDer<'_>,
338 _intermediates: &[CertificateDer<'_>],
339 _server_name: &ServerName<'_>,
340 _ocsp_response: &[u8],
341 _now: UnixTime,
342 ) -> Result<ServerCertVerified, Error> {
343 unreachable!()
344 }
345
346 #[cfg_attr(coverage_nightly, coverage(off))]
347 fn verify_tls12_signature(
348 &self,
349 _message: &[u8],
350 _cert: &CertificateDer<'_>,
351 _dss: &DigitallySignedStruct,
352 ) -> Result<HandshakeSignatureValid, Error> {
353 unreachable!()
354 }
355
356 #[cfg_attr(coverage_nightly, coverage(off))]
357 fn verify_tls13_signature(
358 &self,
359 _message: &[u8],
360 _cert: &CertificateDer<'_>,
361 _dss: &DigitallySignedStruct,
362 ) -> Result<HandshakeSignatureValid, Error> {
363 unreachable!()
364 }
365
366 #[cfg_attr(coverage_nightly, coverage(off))]
367 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
368 unreachable!()
369 }
370 }
371
372 #[derive(Debug)]
373 struct DummyResolvesClientCert;
374
375 impl ResolvesClientCert for DummyResolvesClientCert {
376 #[cfg_attr(coverage_nightly, coverage(off))]
377 fn resolve(
378 &self,
379 _root_hint_subjects: &[&[u8]],
380 _sigschemes: &[SignatureScheme],
381 ) -> Option<Arc<sign::CertifiedKey>> {
382 unreachable!()
383 }
384
385 #[cfg_attr(coverage_nightly, coverage(off))]
386 fn has_certs(&self) -> bool {
387 unreachable!()
388 }
389 }
390}