1use crate::ferron_util::match_hostname::match_hostname;
2use rustls::{server::ResolvesServerCert, sign::CertifiedKey};
3use std::{collections::HashMap, sync::Arc};
4
5#[derive(Debug)]
6pub struct CustomSniResolver {
7 fallback_cert_key: Option<Arc<CertifiedKey>>,
8 cert_keys: HashMap<String, Arc<CertifiedKey>>,
9}
10
11impl CustomSniResolver {
12 pub fn new() -> Self {
13 Self {
14 fallback_cert_key: None,
15 cert_keys: HashMap::new(),
16 }
17 }
18
19 pub fn load_fallback_cert_key(&mut self, fallback_cert_key: Arc<CertifiedKey>) {
20 self.fallback_cert_key = Some(fallback_cert_key);
21 }
22
23 pub fn load_host_cert_key(&mut self, host: &str, cert_key: Arc<CertifiedKey>) {
24 self.cert_keys.insert(String::from(host), cert_key);
25 }
26}
27
28impl ResolvesServerCert for CustomSniResolver {
29 fn resolve(
30 &self,
31 client_hello: rustls::server::ClientHello<'_>,
32 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
33 let hostname = client_hello.server_name();
34 if let Some(hostname) = hostname {
35 let keys_iterator = self.cert_keys.keys();
36 for configured_hostname in keys_iterator {
37 if match_hostname(Some(configured_hostname), Some(hostname)) {
38 return self.cert_keys.get(configured_hostname).cloned();
39 }
40 }
41 self.fallback_cert_key.clone()
42 } else {
43 self.fallback_cert_key.clone()
44 }
45 }
46}