rustls/client/
handy.rs

1use pki_types::ServerName;
2
3use crate::enums::SignatureScheme;
4use crate::msgs::persist;
5use crate::sync::Arc;
6use crate::{NamedGroup, client, sign};
7
8/// An implementer of `ClientSessionStore` which does nothing.
9#[derive(Debug)]
10pub(super) struct NoClientSessionStorage;
11
12impl client::ClientSessionStore for NoClientSessionStorage {
13    fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {}
14
15    fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> {
16        None
17    }
18
19    fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {}
20
21    fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> {
22        None
23    }
24
25    fn remove_tls12_session(&self, _: &ServerName<'_>) {}
26
27    fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {}
28
29    fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> {
30        None
31    }
32}
33
34#[cfg(any(feature = "std", feature = "hashbrown"))]
35mod cache {
36    use alloc::collections::VecDeque;
37    use core::fmt;
38
39    use pki_types::ServerName;
40
41    use crate::lock::Mutex;
42    use crate::msgs::persist;
43    use crate::{NamedGroup, limited_cache};
44
45    const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
46
47    struct ServerData {
48        kx_hint: Option<NamedGroup>,
49
50        // Zero or one TLS1.2 sessions.
51        #[cfg(feature = "tls12")]
52        tls12: Option<persist::Tls12ClientSessionValue>,
53
54        // Up to MAX_TLS13_TICKETS_PER_SERVER TLS1.3 tickets, oldest first.
55        tls13: VecDeque<persist::Tls13ClientSessionValue>,
56    }
57
58    impl Default for ServerData {
59        fn default() -> Self {
60            Self {
61                kx_hint: None,
62                #[cfg(feature = "tls12")]
63                tls12: None,
64                tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
65            }
66        }
67    }
68
69    /// An implementer of `ClientSessionStore` that stores everything
70    /// in memory.
71    ///
72    /// It enforces a limit on the number of entries to bound memory usage.
73    pub struct ClientSessionMemoryCache {
74        servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>,
75    }
76
77    impl ClientSessionMemoryCache {
78        /// Make a new ClientSessionMemoryCache.  `size` is the
79        /// maximum number of stored sessions.
80        #[cfg(feature = "std")]
81        pub fn new(size: usize) -> Self {
82            let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
83                / MAX_TLS13_TICKETS_PER_SERVER;
84            Self {
85                servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
86            }
87        }
88
89        /// Make a new ClientSessionMemoryCache.  `size` is the
90        /// maximum number of stored sessions.
91        #[cfg(not(feature = "std"))]
92        pub fn new<M: crate::lock::MakeMutex>(size: usize) -> Self {
93            let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
94                / MAX_TLS13_TICKETS_PER_SERVER;
95            Self {
96                servers: Mutex::new::<M>(limited_cache::LimitedCache::new(max_servers)),
97            }
98        }
99    }
100
101    impl super::client::ClientSessionStore for ClientSessionMemoryCache {
102        fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) {
103            self.servers
104                .lock()
105                .unwrap()
106                .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group));
107        }
108
109        fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> {
110            self.servers
111                .lock()
112                .unwrap()
113                .get(server_name)
114                .and_then(|sd| sd.kx_hint)
115        }
116
117        fn set_tls12_session(
118            &self,
119            _server_name: ServerName<'static>,
120            _value: persist::Tls12ClientSessionValue,
121        ) {
122            #[cfg(feature = "tls12")]
123            self.servers
124                .lock()
125                .unwrap()
126                .get_or_insert_default_and_edit(_server_name.clone(), |data| {
127                    data.tls12 = Some(_value)
128                });
129        }
130
131        fn tls12_session(
132            &self,
133            _server_name: &ServerName<'_>,
134        ) -> Option<persist::Tls12ClientSessionValue> {
135            #[cfg(not(feature = "tls12"))]
136            return None;
137
138            #[cfg(feature = "tls12")]
139            self.servers
140                .lock()
141                .unwrap()
142                .get(_server_name)
143                .and_then(|sd| sd.tls12.as_ref().cloned())
144        }
145
146        fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {
147            #[cfg(feature = "tls12")]
148            self.servers
149                .lock()
150                .unwrap()
151                .get_mut(_server_name)
152                .and_then(|data| data.tls12.take());
153        }
154
155        fn insert_tls13_ticket(
156            &self,
157            server_name: ServerName<'static>,
158            value: persist::Tls13ClientSessionValue,
159        ) {
160            self.servers
161                .lock()
162                .unwrap()
163                .get_or_insert_default_and_edit(server_name.clone(), |data| {
164                    if data.tls13.len() == data.tls13.capacity() {
165                        data.tls13.pop_front();
166                    }
167                    data.tls13.push_back(value);
168                });
169        }
170
171        fn take_tls13_ticket(
172            &self,
173            server_name: &ServerName<'static>,
174        ) -> Option<persist::Tls13ClientSessionValue> {
175            self.servers
176                .lock()
177                .unwrap()
178                .get_mut(server_name)
179                .and_then(|data| data.tls13.pop_back())
180        }
181    }
182
183    impl fmt::Debug for ClientSessionMemoryCache {
184        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185            // Note: we omit self.servers as it may contain sensitive data.
186            f.debug_struct("ClientSessionMemoryCache")
187                .finish()
188        }
189    }
190}
191
192#[cfg(any(feature = "std", feature = "hashbrown"))]
193pub use cache::ClientSessionMemoryCache;
194
195#[derive(Debug)]
196pub(super) struct FailResolveClientCert {}
197
198impl client::ResolvesClientCert for FailResolveClientCert {
199    fn resolve(
200        &self,
201        _root_hint_subjects: &[&[u8]],
202        _sigschemes: &[SignatureScheme],
203    ) -> Option<Arc<sign::CertifiedKey>> {
204        None
205    }
206
207    fn has_certs(&self) -> bool {
208        false
209    }
210}
211
212/// An exemplar `ResolvesClientCert` implementation that always resolves to a single
213/// [RFC 7250] raw public key.
214///
215/// [RFC 7250]: https://tools.ietf.org/html/rfc7250
216#[derive(Clone, Debug)]
217pub struct AlwaysResolvesClientRawPublicKeys(Arc<sign::CertifiedKey>);
218impl AlwaysResolvesClientRawPublicKeys {
219    /// Create a new `AlwaysResolvesClientRawPublicKeys` instance.
220    pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Self {
221        Self(certified_key)
222    }
223}
224
225impl client::ResolvesClientCert for AlwaysResolvesClientRawPublicKeys {
226    fn resolve(
227        &self,
228        _root_hint_subjects: &[&[u8]],
229        _sigschemes: &[SignatureScheme],
230    ) -> Option<Arc<sign::CertifiedKey>> {
231        Some(Arc::clone(&self.0))
232    }
233
234    fn only_raw_public_keys(&self) -> bool {
235        true
236    }
237
238    /// Returns true if the resolver is ready to present an identity.
239    ///
240    /// Even though the function is called `has_certs`, it returns true
241    /// although only an RPK (Raw Public Key) is available, not an actual certificate.
242    fn has_certs(&self) -> bool {
243        true
244    }
245}
246
247#[cfg(test)]
248#[macro_rules_attribute::apply(test_for_each_provider)]
249mod tests {
250    use std::prelude::v1::*;
251
252    use pki_types::{ServerName, UnixTime};
253
254    use super::NoClientSessionStorage;
255    use super::provider::cipher_suite;
256    use crate::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
257    use crate::client::{ClientSessionStore, ResolvesClientCert};
258    use crate::msgs::base::PayloadU16;
259    use crate::msgs::enums::NamedGroup;
260    use crate::msgs::handshake::CertificateChain;
261    #[cfg(feature = "tls12")]
262    use crate::msgs::handshake::SessionId;
263    use crate::msgs::persist::Tls13ClientSessionValue;
264    use crate::pki_types::CertificateDer;
265    use crate::suites::SupportedCipherSuite;
266    use crate::sync::Arc;
267    use crate::{DigitallySignedStruct, Error, SignatureScheme, sign};
268
269    #[test]
270    fn test_noclientsessionstorage_does_nothing() {
271        let c = NoClientSessionStorage {};
272        let name = ServerName::try_from("example.com").unwrap();
273        let now = UnixTime::now();
274        let server_cert_verifier: Arc<dyn ServerCertVerifier> = Arc::new(DummyServerCertVerifier);
275        let resolves_client_cert: Arc<dyn ResolvesClientCert> = Arc::new(DummyResolvesClientCert);
276
277        c.set_kx_hint(name.clone(), NamedGroup::X25519);
278        assert_eq!(None, c.kx_hint(&name));
279
280        #[cfg(feature = "tls12")]
281        {
282            use crate::msgs::persist::Tls12ClientSessionValue;
283            let SupportedCipherSuite::Tls12(tls12_suite) =
284                cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
285            else {
286                unreachable!()
287            };
288
289            c.set_tls12_session(
290                name.clone(),
291                Tls12ClientSessionValue::new(
292                    tls12_suite,
293                    SessionId::empty(),
294                    Arc::new(PayloadU16::empty()),
295                    &[],
296                    CertificateChain::default(),
297                    &server_cert_verifier,
298                    &resolves_client_cert,
299                    now,
300                    0,
301                    true,
302                ),
303            );
304            assert!(c.tls12_session(&name).is_none());
305            c.remove_tls12_session(&name);
306        }
307
308        let SupportedCipherSuite::Tls13(tls13_suite) = cipher_suite::TLS13_AES_256_GCM_SHA384
309        else {
310            unreachable!();
311        };
312        c.insert_tls13_ticket(
313            name.clone(),
314            Tls13ClientSessionValue::new(
315                tls13_suite,
316                Arc::new(PayloadU16::empty()),
317                &[],
318                CertificateChain::default(),
319                &server_cert_verifier,
320                &resolves_client_cert,
321                now,
322                0,
323                0,
324                0,
325            ),
326        );
327        assert!(c.take_tls13_ticket(&name).is_none());
328    }
329
330    #[derive(Debug)]
331    struct DummyServerCertVerifier;
332
333    impl ServerCertVerifier for DummyServerCertVerifier {
334        #[cfg_attr(coverage_nightly, coverage(off))]
335        fn verify_server_cert(
336            &self,
337            _end_entity: &CertificateDer<'_>,
338            _intermediates: &[CertificateDer<'_>],
339            _server_name: &ServerName<'_>,
340            _ocsp_response: &[u8],
341            _now: UnixTime,
342        ) -> Result<ServerCertVerified, Error> {
343            unreachable!()
344        }
345
346        #[cfg_attr(coverage_nightly, coverage(off))]
347        fn verify_tls12_signature(
348            &self,
349            _message: &[u8],
350            _cert: &CertificateDer<'_>,
351            _dss: &DigitallySignedStruct,
352        ) -> Result<HandshakeSignatureValid, Error> {
353            unreachable!()
354        }
355
356        #[cfg_attr(coverage_nightly, coverage(off))]
357        fn verify_tls13_signature(
358            &self,
359            _message: &[u8],
360            _cert: &CertificateDer<'_>,
361            _dss: &DigitallySignedStruct,
362        ) -> Result<HandshakeSignatureValid, Error> {
363            unreachable!()
364        }
365
366        #[cfg_attr(coverage_nightly, coverage(off))]
367        fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
368            unreachable!()
369        }
370    }
371
372    #[derive(Debug)]
373    struct DummyResolvesClientCert;
374
375    impl ResolvesClientCert for DummyResolvesClientCert {
376        #[cfg_attr(coverage_nightly, coverage(off))]
377        fn resolve(
378            &self,
379            _root_hint_subjects: &[&[u8]],
380            _sigschemes: &[SignatureScheme],
381        ) -> Option<Arc<sign::CertifiedKey>> {
382            unreachable!()
383        }
384
385        #[cfg_attr(coverage_nightly, coverage(off))]
386        fn has_certs(&self) -> bool {
387            unreachable!()
388        }
389    }
390}