quinn_proto/
token_memory_cache.rs1use std::{
4 collections::{HashMap, VecDeque, hash_map},
5 sync::{Arc, Mutex},
6};
7
8use bytes::Bytes;
9use lru_slab::LruSlab;
10use tracing::trace;
11
12use crate::token::TokenStore;
13
14#[derive(Debug)]
17pub struct TokenMemoryCache(Mutex<State>);
18
19impl TokenMemoryCache {
20 pub fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
22 Self(Mutex::new(State::new(
23 max_server_names,
24 max_tokens_per_server,
25 )))
26 }
27}
28
29impl TokenStore for TokenMemoryCache {
30 fn insert(&self, server_name: &str, token: Bytes) {
31 trace!(%server_name, "storing token");
32 self.0.lock().unwrap().store(server_name, token)
33 }
34
35 fn take(&self, server_name: &str) -> Option<Bytes> {
36 let token = self.0.lock().unwrap().take(server_name);
37 trace!(%server_name, found=%token.is_some(), "taking token");
38 token
39 }
40}
41
42impl Default for TokenMemoryCache {
44 fn default() -> Self {
45 Self::new(256, 2)
46 }
47}
48
49#[derive(Debug)]
51struct State {
52 max_server_names: u32,
53 max_tokens_per_server: usize,
54 lookup: HashMap<Arc<str>, u32>,
56 lru: LruSlab<CacheEntry>,
57}
58
59impl State {
60 fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
61 Self {
62 max_server_names,
63 max_tokens_per_server,
64 lookup: HashMap::new(),
65 lru: LruSlab::default(),
66 }
67 }
68
69 fn store(&mut self, server_name: &str, token: Bytes) {
70 if self.max_server_names == 0 {
71 return;
75 }
76 if self.max_tokens_per_server == 0 {
77 return;
81 }
82
83 let server_name = Arc::<str>::from(server_name);
84 match self.lookup.entry(server_name.clone()) {
85 hash_map::Entry::Occupied(hmap_entry) => {
86 let tokens = &mut self.lru.get_mut(*hmap_entry.get()).tokens;
88 if tokens.len() >= self.max_tokens_per_server {
89 debug_assert!(tokens.len() == self.max_tokens_per_server);
90 tokens.pop_front().unwrap();
91 }
92 tokens.push_back(token);
93 }
94 hash_map::Entry::Vacant(hmap_entry) => {
95 let removed_key = if self.lru.len() >= self.max_server_names {
97 Some(self.lru.remove(self.lru.lru().unwrap()).server_name)
100 } else {
101 None
102 };
103
104 hmap_entry.insert(self.lru.insert(CacheEntry::new(server_name, token)));
105
106 if let Some(removed_slot) = removed_key {
108 let removed = self.lookup.remove(&removed_slot);
109 debug_assert!(removed.is_some());
110 }
111 }
112 };
113 }
114
115 fn take(&mut self, server_name: &str) -> Option<Bytes> {
116 let slab_key = *self.lookup.get(server_name)?;
117
118 let entry = self.lru.get_mut(slab_key);
120 let token = entry.tokens.pop_front().unwrap();
122
123 if entry.tokens.is_empty() {
124 self.lru.remove(slab_key);
126 self.lookup.remove(server_name);
127 }
128
129 Some(token)
130 }
131}
132
133#[derive(Debug)]
135struct CacheEntry {
136 server_name: Arc<str>,
137 tokens: VecDeque<Bytes>,
139}
140
141impl CacheEntry {
142 fn new(server_name: Arc<str>, token: Bytes) -> Self {
144 let mut tokens = VecDeque::new();
145 tokens.push_back(token);
146 Self {
147 server_name,
148 tokens,
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use std::collections::VecDeque;
156
157 use super::*;
158 use rand::prelude::*;
159 use rand_pcg::Pcg32;
160
161 fn new_rng() -> impl Rng {
162 Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeefu128.to_le_bytes())
163 }
164
165 #[test]
166 fn cache_test() {
167 let mut rng = new_rng();
168 const N: usize = 2;
169
170 for _ in 0..10 {
171 let mut cache_1: Vec<(u32, VecDeque<Bytes>)> = Vec::new(); let cache_2 = TokenMemoryCache::new(20, 2);
173
174 for i in 0..200 {
175 let server_name = rng.random::<u32>() % 10;
176 if rng.random_bool(0.666) {
177 let token = Bytes::from(vec![i]);
179 println!("STORE {} {:?}", server_name, token);
180 if let Some((j, _)) = cache_1
181 .iter()
182 .enumerate()
183 .find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
184 {
185 let (_, mut queue) = cache_1.remove(j);
186 queue.push_back(token.clone());
187 if queue.len() > N {
188 queue.pop_front();
189 }
190 cache_1.push((server_name, queue));
191 } else {
192 let mut queue = VecDeque::new();
193 queue.push_back(token.clone());
194 cache_1.push((server_name, queue));
195 if cache_1.len() > 20 {
196 cache_1.remove(0);
197 }
198 }
199 cache_2.insert(&server_name.to_string(), token);
200 } else {
201 println!("TAKE {}", server_name);
203 let expecting = cache_1
204 .iter()
205 .enumerate()
206 .find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
207 .map(|(j, _)| j)
208 .map(|j| {
209 let (_, mut queue) = cache_1.remove(j);
210 let token = queue.pop_front().unwrap();
211 if !queue.is_empty() {
212 cache_1.push((server_name, queue));
213 }
214 token
215 });
216 println!("EXPECTING {:?}", expecting);
217 assert_eq!(cache_2.take(&server_name.to_string()), expecting);
218 }
219 }
220 }
221 }
222
223 #[test]
224 fn zero_max_server_names() {
225 let cache = TokenMemoryCache::new(0, 2);
227 for i in 0..10 {
228 cache.insert(&i.to_string(), Bytes::from(vec![i]));
229 for j in 0..10 {
230 assert!(cache.take(&j.to_string()).is_none());
231 }
232 }
233 }
234
235 #[test]
236 fn zero_queue_length() {
237 let cache = TokenMemoryCache::new(256, 0);
239 for i in 0..10 {
240 cache.insert(&i.to_string(), Bytes::from(vec![i]));
241 for j in 0..10 {
242 assert!(cache.take(&j.to_string()).is_none());
243 }
244 }
245 }
246}