quinn_proto/
bloom_token_log.rs1use std::{
2 collections::HashSet,
3 f64::consts::LN_2,
4 hash::{BuildHasher, Hasher},
5 mem::{size_of, take},
6 sync::Mutex,
7};
8
9use fastbloom::BloomFilter;
10use rustc_hash::FxBuildHasher;
11use tracing::{trace, warn};
12
13use crate::{Duration, SystemTime, TokenLog, TokenReuseError, UNIX_EPOCH};
14
15pub struct BloomTokenLog(Mutex<State>);
26
27impl BloomTokenLog {
28 pub fn new_expected_items(max_bytes: usize, expected_hits: u64) -> Self {
33 Self::new(max_bytes, optimal_k_num(max_bytes, expected_hits))
34 }
35
36 pub fn new(max_bytes: usize, k_num: u32) -> Self {
44 Self(Mutex::new(State {
45 config: FilterConfig {
46 filter_max_bytes: max_bytes / 2,
47 k_num,
48 },
49 period_1_start: UNIX_EPOCH,
50 filter_1: Filter::default(),
51 filter_2: Filter::default(),
52 }))
53 }
54}
55
56impl TokenLog for BloomTokenLog {
57 fn check_and_insert(
58 &self,
59 nonce: u128,
60 issued: SystemTime,
61 lifetime: Duration,
62 ) -> Result<(), TokenReuseError> {
63 trace!(%nonce, "check_and_insert");
64
65 if lifetime.is_zero() {
66 return Err(TokenReuseError);
68 }
69
70 let mut guard = self.0.lock().unwrap();
71 let state = &mut *guard;
72
73 let expires_at = issued + lifetime;
75 let Ok(periods_forward) = expires_at
76 .duration_since(state.period_1_start)
77 .map(|duration| duration.as_nanos() / lifetime.as_nanos())
78 else {
79 warn!("BloomTokenLog presented with token too far in past");
82 return Err(TokenReuseError);
83 };
84
85 let filter = match periods_forward {
87 0 => &mut state.filter_1,
88 1 => &mut state.filter_2,
89 2 => {
90 state.filter_1 = take(&mut state.filter_2);
92 state.period_1_start += lifetime;
93 &mut state.filter_2
94 }
95 _ => {
96 state.filter_1 = Filter::default();
98 state.filter_2 = Filter::default();
99 state.period_1_start = expires_at;
100 &mut state.filter_1
101 }
102 };
103
104 filter.check_and_insert(nonce as u64, &state.config)
119 }
120}
121
122impl Default for BloomTokenLog {
127 fn default() -> Self {
128 Self::new_expected_items(DEFAULT_MAX_BYTES, DEFAULT_EXPECTED_HITS)
129 }
130}
131
132struct State {
134 config: FilterConfig,
135 period_1_start: SystemTime,
138 filter_1: Filter,
139 filter_2: Filter,
140}
141
142struct FilterConfig {
144 filter_max_bytes: usize,
145 k_num: u32,
146}
147
148enum Filter {
150 Set(HashSet<u64, IdentityBuildHasher>),
151 Bloom(BloomFilter<512, FxBuildHasher>),
152}
153
154impl Filter {
155 fn check_and_insert(
156 &mut self,
157 fingerprint: u64,
158 config: &FilterConfig,
159 ) -> Result<(), TokenReuseError> {
160 match self {
161 Self::Set(hset) => {
162 if !hset.insert(fingerprint) {
163 return Err(TokenReuseError);
164 }
165
166 if hset.capacity() * size_of::<u64>() <= config.filter_max_bytes {
167 return Ok(());
168 }
169
170 let mut bloom = BloomFilter::with_num_bits((config.filter_max_bytes * 8).max(1))
174 .hasher(FxBuildHasher)
175 .hashes(config.k_num);
176 for item in &*hset {
177 bloom.insert(item);
178 }
179 *self = Self::Bloom(bloom);
180 }
181 Self::Bloom(bloom) => {
182 if bloom.insert(&fingerprint) {
183 return Err(TokenReuseError);
184 }
185 }
186 }
187 Ok(())
188 }
189}
190
191impl Default for Filter {
192 fn default() -> Self {
193 Self::Set(HashSet::default())
194 }
195}
196
197#[derive(Default)]
199struct IdentityBuildHasher;
200
201impl BuildHasher for IdentityBuildHasher {
202 type Hasher = IdentityHasher;
203
204 fn build_hasher(&self) -> Self::Hasher {
205 IdentityHasher::default()
206 }
207}
208
209#[derive(Default)]
212struct IdentityHasher {
213 data: [u8; 8],
214 #[cfg(debug_assertions)]
215 wrote_8_byte_slice: bool,
216}
217
218impl Hasher for IdentityHasher {
219 fn write(&mut self, bytes: &[u8]) {
220 #[cfg(debug_assertions)]
221 {
222 assert!(!self.wrote_8_byte_slice);
223 assert_eq!(bytes.len(), 8);
224 self.wrote_8_byte_slice = true;
225 }
226 self.data.copy_from_slice(bytes);
227 }
228
229 fn finish(&self) -> u64 {
230 #[cfg(debug_assertions)]
231 assert!(self.wrote_8_byte_slice);
232 u64::from_ne_bytes(self.data)
233 }
234}
235
236fn optimal_k_num(num_bytes: usize, expected_hits: u64) -> u32 {
237 let num_bits = (num_bytes as u64).saturating_mul(8);
241 let expected_hits = expected_hits.max(1);
242 (((num_bits as f64 / expected_hits as f64) * LN_2).round() as u32).max(1)
249}
250
251const DEFAULT_MAX_BYTES: usize = 10 << 20;
253const DEFAULT_EXPECTED_HITS: u64 = 1_000_000;
254
255#[cfg(test)]
256mod test {
257 use super::*;
258 use rand::prelude::*;
259 use rand_pcg::Pcg32;
260
261 fn new_rng() -> impl Rng {
262 Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeef_u128.to_le_bytes())
263 }
264
265 #[test]
266 fn identity_hash_test() {
267 let mut rng = new_rng();
268 let builder = IdentityBuildHasher;
269 for _ in 0..100 {
270 let n = rng.random::<u64>();
271 let hash = builder.hash_one(n);
272 assert_eq!(hash, n);
273 }
274 }
275
276 #[test]
277 fn optimal_k_num_test() {
278 assert_eq!(optimal_k_num(10 << 20, 1_000_000), 58);
279 assert_eq!(optimal_k_num(10 << 20, 1_000_000_000_000_000), 1);
280 optimal_k_num(10 << 20, 0);
282 optimal_k_num(usize::MAX, 1_000_000);
283 }
284
285 #[test]
286 fn bloom_token_log_conversion() {
287 let mut rng = new_rng();
288 let mut log = BloomTokenLog::new_expected_items(800, 200);
289
290 let issued = SystemTime::now();
291 let lifetime = Duration::from_secs(1_000_000);
292
293 for i in 0..200 {
294 let token = rng.random::<u128>();
295 let result = log.check_and_insert(token, issued, lifetime);
296 {
297 let filter = &log.0.lock().unwrap().filter_1;
298 if let Filter::Set(ref hset) = *filter {
299 assert!(hset.capacity() * size_of::<u64>() <= 800);
300 assert_eq!(hset.len(), i + 1);
301 assert!(result.is_ok());
302 } else {
303 assert!(i > 10, "definitely bloomed too early");
304 }
305 }
306 assert!(log.check_and_insert(token, issued, lifetime).is_err());
307 }
308
309 assert!(
310 matches!(log.0.get_mut().unwrap().filter_1, Filter::Bloom { .. }),
311 "didn't bloom"
312 );
313 }
314
315 #[test]
316 fn turn_over() {
317 let mut rng = new_rng();
318 let log = BloomTokenLog::new_expected_items(800, 200);
319 let lifetime = Duration::from_secs(1_000);
320 let mut old = Vec::default();
321 let mut accepted = 0;
322
323 for i in 0..200 {
324 let token = rng.random::<u128>();
325 let now = UNIX_EPOCH + lifetime * 10 + lifetime * i / 10;
326 let issued = now - lifetime.mul_f32(rng.random_range(0.0..3.0));
327 let result = log.check_and_insert(token, issued, lifetime);
328 if result.is_ok() {
329 accepted += 1;
330 }
331 old.push((token, issued));
332 let old_idx = rng.random_range(0..old.len());
333 let (old_token, old_issued) = old[old_idx];
334 assert!(
335 log.check_and_insert(old_token, old_issued, lifetime)
336 .is_err()
337 );
338 }
339 assert!(accepted > 0);
340 }
341
342 fn test_doesnt_panic(log: BloomTokenLog) {
343 let mut rng = new_rng();
344
345 let issued = SystemTime::now();
346 let lifetime = Duration::from_secs(1_000_000);
347
348 for _ in 0..200 {
349 let _ = log.check_and_insert(rng.random::<u128>(), issued, lifetime);
350 }
351 }
352
353 #[test]
354 fn max_bytes_zero() {
355 test_doesnt_panic(BloomTokenLog::new_expected_items(0, 200));
357 }
358
359 #[test]
360 fn expected_hits_zero() {
361 test_doesnt_panic(BloomTokenLog::new_expected_items(100, 0));
362 }
363
364 #[test]
365 fn k_num_zero() {
366 test_doesnt_panic(BloomTokenLog::new(100, 0));
367 }
368}