quinn_proto/
token_memory_cache.rs

1//! Storing tokens sent from servers in NEW_TOKEN frames and using them in subsequent connections
2
3use std::{
4    collections::{HashMap, VecDeque, hash_map},
5    sync::{Arc, Mutex},
6};
7
8use bytes::Bytes;
9use lru_slab::LruSlab;
10use tracing::trace;
11
12use crate::token::TokenStore;
13
14/// `TokenStore` implementation that stores up to `N` tokens per server name for up to a
15/// limited number of server names, in-memory
16#[derive(Debug)]
17pub struct TokenMemoryCache(Mutex<State>);
18
19impl TokenMemoryCache {
20    /// Construct empty
21    pub fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
22        Self(Mutex::new(State::new(
23            max_server_names,
24            max_tokens_per_server,
25        )))
26    }
27}
28
29impl TokenStore for TokenMemoryCache {
30    fn insert(&self, server_name: &str, token: Bytes) {
31        trace!(%server_name, "storing token");
32        self.0.lock().unwrap().store(server_name, token)
33    }
34
35    fn take(&self, server_name: &str) -> Option<Bytes> {
36        let token = self.0.lock().unwrap().take(server_name);
37        trace!(%server_name, found=%token.is_some(), "taking token");
38        token
39    }
40}
41
42/// Defaults to a maximum of 256 servers and 2 tokens per server
43impl Default for TokenMemoryCache {
44    fn default() -> Self {
45        Self::new(256, 2)
46    }
47}
48
49/// Lockable inner state of `TokenMemoryCache`
50#[derive(Debug)]
51struct State {
52    max_server_names: u32,
53    max_tokens_per_server: usize,
54    // map from server name to index in lru
55    lookup: HashMap<Arc<str>, u32>,
56    lru: LruSlab<CacheEntry>,
57}
58
59impl State {
60    fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
61        Self {
62            max_server_names,
63            max_tokens_per_server,
64            lookup: HashMap::new(),
65            lru: LruSlab::default(),
66        }
67    }
68
69    fn store(&mut self, server_name: &str, token: Bytes) {
70        if self.max_server_names == 0 {
71            // the rest of this method assumes that we can always insert a new entry so long as
72            // we're willing to evict a pre-existing entry. thus, an entry limit of 0 is an edge
73            // case we must short-circuit on now.
74            return;
75        }
76        if self.max_tokens_per_server == 0 {
77            // similarly to above, the rest of this method assumes that we can always push a new
78            // token to a queue so long as we're willing to evict a pre-existing token, so we
79            // short-circuit on the edge case of a token limit of 0.
80            return;
81        }
82
83        let server_name = Arc::<str>::from(server_name);
84        match self.lookup.entry(server_name.clone()) {
85            hash_map::Entry::Occupied(hmap_entry) => {
86                // key already exists, push the new token to its token queue
87                let tokens = &mut self.lru.get_mut(*hmap_entry.get()).tokens;
88                if tokens.len() >= self.max_tokens_per_server {
89                    debug_assert!(tokens.len() == self.max_tokens_per_server);
90                    tokens.pop_front().unwrap();
91                }
92                tokens.push_back(token);
93            }
94            hash_map::Entry::Vacant(hmap_entry) => {
95                // key does not yet exist, create a new one, evicting the oldest if necessary
96                let removed_key = if self.lru.len() >= self.max_server_names {
97                    // unwrap safety: max_server_names is > 0, so there's at least one entry, so
98                    //                lru() is some
99                    Some(self.lru.remove(self.lru.lru().unwrap()).server_name)
100                } else {
101                    None
102                };
103
104                hmap_entry.insert(self.lru.insert(CacheEntry::new(server_name, token)));
105
106                // for borrowing reasons, we must defer removing the evicted hmap entry to here
107                if let Some(removed_slot) = removed_key {
108                    let removed = self.lookup.remove(&removed_slot);
109                    debug_assert!(removed.is_some());
110                }
111            }
112        };
113    }
114
115    fn take(&mut self, server_name: &str) -> Option<Bytes> {
116        let slab_key = *self.lookup.get(server_name)?;
117
118        // pop from entry's token queue
119        let entry = self.lru.get_mut(slab_key);
120        // unwrap safety: we never leave tokens empty
121        let token = entry.tokens.pop_front().unwrap();
122
123        if entry.tokens.is_empty() {
124            // token stack emptied, remove entry
125            self.lru.remove(slab_key);
126            self.lookup.remove(server_name);
127        }
128
129        Some(token)
130    }
131}
132
133/// Cache entry within `TokenMemoryCache`'s LRU slab
134#[derive(Debug)]
135struct CacheEntry {
136    server_name: Arc<str>,
137    // invariant: tokens is never empty
138    tokens: VecDeque<Bytes>,
139}
140
141impl CacheEntry {
142    /// Construct with a single token
143    fn new(server_name: Arc<str>, token: Bytes) -> Self {
144        let mut tokens = VecDeque::new();
145        tokens.push_back(token);
146        Self {
147            server_name,
148            tokens,
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use std::collections::VecDeque;
156
157    use super::*;
158    use rand::prelude::*;
159    use rand_pcg::Pcg32;
160
161    fn new_rng() -> impl Rng {
162        Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeefu128.to_le_bytes())
163    }
164
165    #[test]
166    fn cache_test() {
167        let mut rng = new_rng();
168        const N: usize = 2;
169
170        for _ in 0..10 {
171            let mut cache_1: Vec<(u32, VecDeque<Bytes>)> = Vec::new(); // keep it sorted oldest to newest
172            let cache_2 = TokenMemoryCache::new(20, 2);
173
174            for i in 0..200 {
175                let server_name = rng.random::<u32>() % 10;
176                if rng.random_bool(0.666) {
177                    // store
178                    let token = Bytes::from(vec![i]);
179                    println!("STORE {} {:?}", server_name, token);
180                    if let Some((j, _)) = cache_1
181                        .iter()
182                        .enumerate()
183                        .find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
184                    {
185                        let (_, mut queue) = cache_1.remove(j);
186                        queue.push_back(token.clone());
187                        if queue.len() > N {
188                            queue.pop_front();
189                        }
190                        cache_1.push((server_name, queue));
191                    } else {
192                        let mut queue = VecDeque::new();
193                        queue.push_back(token.clone());
194                        cache_1.push((server_name, queue));
195                        if cache_1.len() > 20 {
196                            cache_1.remove(0);
197                        }
198                    }
199                    cache_2.insert(&server_name.to_string(), token);
200                } else {
201                    // take
202                    println!("TAKE {}", server_name);
203                    let expecting = cache_1
204                        .iter()
205                        .enumerate()
206                        .find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
207                        .map(|(j, _)| j)
208                        .map(|j| {
209                            let (_, mut queue) = cache_1.remove(j);
210                            let token = queue.pop_front().unwrap();
211                            if !queue.is_empty() {
212                                cache_1.push((server_name, queue));
213                            }
214                            token
215                        });
216                    println!("EXPECTING {:?}", expecting);
217                    assert_eq!(cache_2.take(&server_name.to_string()), expecting);
218                }
219            }
220        }
221    }
222
223    #[test]
224    fn zero_max_server_names() {
225        // test that this edge case doesn't panic
226        let cache = TokenMemoryCache::new(0, 2);
227        for i in 0..10 {
228            cache.insert(&i.to_string(), Bytes::from(vec![i]));
229            for j in 0..10 {
230                assert!(cache.take(&j.to_string()).is_none());
231            }
232        }
233    }
234
235    #[test]
236    fn zero_queue_length() {
237        // test that this edge case doesn't panic
238        let cache = TokenMemoryCache::new(256, 0);
239        for i in 0..10 {
240            cache.insert(&i.to_string(), Bytes::from(vec![i]));
241            for j in 0..10 {
242                assert!(cache.take(&j.to_string()).is_none());
243            }
244        }
245    }
246}