quinn_proto/
bloom_token_log.rs

1use std::{
2    collections::HashSet,
3    f64::consts::LN_2,
4    hash::{BuildHasher, Hasher},
5    mem::{size_of, take},
6    sync::Mutex,
7};
8
9use fastbloom::BloomFilter;
10use rustc_hash::FxBuildHasher;
11use tracing::{trace, warn};
12
13use crate::{Duration, SystemTime, TokenLog, TokenReuseError, UNIX_EPOCH};
14
15/// Bloom filter-based [`TokenLog`]
16///
17/// Parameterizable over an approximate maximum number of bytes to allocate. Starts out by storing
18/// used tokens in a hash set. Once the hash set becomes too large, converts it to a bloom filter.
19/// This achieves a memory profile of linear growth with an upper bound.
20///
21/// Divides time into periods based on `lifetime` and stores two filters at any given moment, for
22/// each of the two periods currently non-expired tokens could expire in. As such, turns over
23/// filters as time goes on to avoid bloom filter false positive rate increasing infinitely over
24/// time.
25pub struct BloomTokenLog(Mutex<State>);
26
27impl BloomTokenLog {
28    /// Construct with an approximate maximum memory usage and expected number of validation token
29    /// usages per expiration period
30    ///
31    /// Calculates the optimal bloom filter k number automatically.
32    pub fn new_expected_items(max_bytes: usize, expected_hits: u64) -> Self {
33        Self::new(max_bytes, optimal_k_num(max_bytes, expected_hits))
34    }
35
36    /// Construct with an approximate maximum memory usage and a [bloom filter k number][bloom]
37    ///
38    /// [bloom]: https://en.wikipedia.org/wiki/Bloom_filter
39    ///
40    /// If choosing a custom k number, note that `BloomTokenLog` always maintains two filters
41    /// between them and divides the allocation budget of `max_bytes` evenly between them. As such,
42    /// each bloom filter will contain `max_bytes * 4` bits.
43    pub fn new(max_bytes: usize, k_num: u32) -> Self {
44        Self(Mutex::new(State {
45            config: FilterConfig {
46                filter_max_bytes: max_bytes / 2,
47                k_num,
48            },
49            period_1_start: UNIX_EPOCH,
50            filter_1: Filter::default(),
51            filter_2: Filter::default(),
52        }))
53    }
54}
55
56impl TokenLog for BloomTokenLog {
57    fn check_and_insert(
58        &self,
59        nonce: u128,
60        issued: SystemTime,
61        lifetime: Duration,
62    ) -> Result<(), TokenReuseError> {
63        trace!(%nonce, "check_and_insert");
64
65        if lifetime.is_zero() {
66            // avoid divide-by-zero if lifetime is zero
67            return Err(TokenReuseError);
68        }
69
70        let mut guard = self.0.lock().unwrap();
71        let state = &mut *guard;
72
73        // calculate how many periods past period 1 the token expires
74        let expires_at = issued + lifetime;
75        let Ok(periods_forward) = expires_at
76            .duration_since(state.period_1_start)
77            .map(|duration| duration.as_nanos() / lifetime.as_nanos())
78        else {
79            // shouldn't happen unless time travels backwards or lifetime changes or the current
80            // system time is before the Unix epoch
81            warn!("BloomTokenLog presented with token too far in past");
82            return Err(TokenReuseError);
83        };
84
85        // get relevant filter
86        let filter = match periods_forward {
87            0 => &mut state.filter_1,
88            1 => &mut state.filter_2,
89            2 => {
90                // turn over filter 1
91                state.filter_1 = take(&mut state.filter_2);
92                state.period_1_start += lifetime;
93                &mut state.filter_2
94            }
95            _ => {
96                // turn over both filters
97                state.filter_1 = Filter::default();
98                state.filter_2 = Filter::default();
99                state.period_1_start = expires_at;
100                &mut state.filter_1
101            }
102        };
103
104        // insert into the filter
105        //
106        // the token's nonce needs to guarantee uniqueness because of the role it plays in the
107        // encryption of the tokens, so it is 128 bits. but since the token log can tolerate false
108        // positives, we trim it down to 64 bits, which would still only have a small collision
109        // rate even at significant amounts of usage, while allowing us to store twice as many in
110        // the hash set variant.
111        //
112        // token nonce values are uniformly randomly generated server-side and cryptographically
113        // integrity-checked, so we don't need to employ secure hashing to trim it down to 64 bits,
114        // we can simply truncate.
115        //
116        // per the Rust reference, we can truncate by simply casting:
117        // https://doc.rust-lang.org/stable/reference/expressions/operator-expr.html#numeric-cast
118        filter.check_and_insert(nonce as u64, &state.config)
119    }
120}
121
122/// Default to 20 MiB max memory consumption and expected one million hits
123///
124/// With the default validation token lifetime of 2 weeks, this corresponds to one token usage per
125/// 1.21 seconds.
126impl Default for BloomTokenLog {
127    fn default() -> Self {
128        Self::new_expected_items(DEFAULT_MAX_BYTES, DEFAULT_EXPECTED_HITS)
129    }
130}
131
132/// Lockable state of [`BloomTokenLog`]
133struct State {
134    config: FilterConfig,
135    // filter_1 covers tokens that expire in the period starting at period_1_start and extending
136    // lifetime after. filter_2 covers tokens for the next lifetime after that.
137    period_1_start: SystemTime,
138    filter_1: Filter,
139    filter_2: Filter,
140}
141
142/// Unchanging parameters governing [`Filter`] behavior
143struct FilterConfig {
144    filter_max_bytes: usize,
145    k_num: u32,
146}
147
148/// Period filter within [`State`]
149enum Filter {
150    Set(HashSet<u64, IdentityBuildHasher>),
151    Bloom(BloomFilter<512, FxBuildHasher>),
152}
153
154impl Filter {
155    fn check_and_insert(
156        &mut self,
157        fingerprint: u64,
158        config: &FilterConfig,
159    ) -> Result<(), TokenReuseError> {
160        match self {
161            Self::Set(hset) => {
162                if !hset.insert(fingerprint) {
163                    return Err(TokenReuseError);
164                }
165
166                if hset.capacity() * size_of::<u64>() <= config.filter_max_bytes {
167                    return Ok(());
168                }
169
170                // convert to bloom
171                // avoid panicking if user passed in filter_max_bytes of 0. we document that this
172                // limit is approximate, so just fudge it up to 1.
173                let mut bloom = BloomFilter::with_num_bits((config.filter_max_bytes * 8).max(1))
174                    .hasher(FxBuildHasher)
175                    .hashes(config.k_num);
176                for item in &*hset {
177                    bloom.insert(item);
178                }
179                *self = Self::Bloom(bloom);
180            }
181            Self::Bloom(bloom) => {
182                if bloom.insert(&fingerprint) {
183                    return Err(TokenReuseError);
184                }
185            }
186        }
187        Ok(())
188    }
189}
190
191impl Default for Filter {
192    fn default() -> Self {
193        Self::Set(HashSet::default())
194    }
195}
196
197/// `BuildHasher` of `IdentityHasher`
198#[derive(Default)]
199struct IdentityBuildHasher;
200
201impl BuildHasher for IdentityBuildHasher {
202    type Hasher = IdentityHasher;
203
204    fn build_hasher(&self) -> Self::Hasher {
205        IdentityHasher::default()
206    }
207}
208
209/// Hasher that is the identity operation--it assumes that exactly 8 bytes will be hashed, and the
210/// resultant hash is those bytes as a `u64`
211#[derive(Default)]
212struct IdentityHasher {
213    data: [u8; 8],
214    #[cfg(debug_assertions)]
215    wrote_8_byte_slice: bool,
216}
217
218impl Hasher for IdentityHasher {
219    fn write(&mut self, bytes: &[u8]) {
220        #[cfg(debug_assertions)]
221        {
222            assert!(!self.wrote_8_byte_slice);
223            assert_eq!(bytes.len(), 8);
224            self.wrote_8_byte_slice = true;
225        }
226        self.data.copy_from_slice(bytes);
227    }
228
229    fn finish(&self) -> u64 {
230        #[cfg(debug_assertions)]
231        assert!(self.wrote_8_byte_slice);
232        u64::from_ne_bytes(self.data)
233    }
234}
235
236fn optimal_k_num(num_bytes: usize, expected_hits: u64) -> u32 {
237    // be more forgiving rather than panickey here. excessively high num_bits may occur if the user
238    // wishes it to be unbounded, so just saturate. expected_hits of 0 would cause divide-by-zero,
239    // so just fudge it up to 1 in that case.
240    let num_bits = (num_bytes as u64).saturating_mul(8);
241    let expected_hits = expected_hits.max(1);
242    // reference for this formula: https://programming.guide/bloom-filter-calculator.html
243    // optimal k = (m ln 2) / n
244    // wherein m is the number of bits, and n is the number of elements in the set.
245    //
246    // we also impose a minimum return value of 1, to avoid making the bloom filter entirely
247    // useless in the case that the user provided an absurdly high ratio of hits / bytes.
248    (((num_bits as f64 / expected_hits as f64) * LN_2).round() as u32).max(1)
249}
250
251// remember to change the doc comment for `impl Default for BloomTokenLog` if these ever change
252const DEFAULT_MAX_BYTES: usize = 10 << 20;
253const DEFAULT_EXPECTED_HITS: u64 = 1_000_000;
254
255#[cfg(test)]
256mod test {
257    use super::*;
258    use rand::prelude::*;
259    use rand_pcg::Pcg32;
260
261    fn new_rng() -> impl Rng {
262        Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeef_u128.to_le_bytes())
263    }
264
265    #[test]
266    fn identity_hash_test() {
267        let mut rng = new_rng();
268        let builder = IdentityBuildHasher;
269        for _ in 0..100 {
270            let n = rng.random::<u64>();
271            let hash = builder.hash_one(n);
272            assert_eq!(hash, n);
273        }
274    }
275
276    #[test]
277    fn optimal_k_num_test() {
278        assert_eq!(optimal_k_num(10 << 20, 1_000_000), 58);
279        assert_eq!(optimal_k_num(10 << 20, 1_000_000_000_000_000), 1);
280        // assert that these don't panic:
281        optimal_k_num(10 << 20, 0);
282        optimal_k_num(usize::MAX, 1_000_000);
283    }
284
285    #[test]
286    fn bloom_token_log_conversion() {
287        let mut rng = new_rng();
288        let mut log = BloomTokenLog::new_expected_items(800, 200);
289
290        let issued = SystemTime::now();
291        let lifetime = Duration::from_secs(1_000_000);
292
293        for i in 0..200 {
294            let token = rng.random::<u128>();
295            let result = log.check_and_insert(token, issued, lifetime);
296            {
297                let filter = &log.0.lock().unwrap().filter_1;
298                if let Filter::Set(ref hset) = *filter {
299                    assert!(hset.capacity() * size_of::<u64>() <= 800);
300                    assert_eq!(hset.len(), i + 1);
301                    assert!(result.is_ok());
302                } else {
303                    assert!(i > 10, "definitely bloomed too early");
304                }
305            }
306            assert!(log.check_and_insert(token, issued, lifetime).is_err());
307        }
308
309        assert!(
310            matches!(log.0.get_mut().unwrap().filter_1, Filter::Bloom { .. }),
311            "didn't bloom"
312        );
313    }
314
315    #[test]
316    fn turn_over() {
317        let mut rng = new_rng();
318        let log = BloomTokenLog::new_expected_items(800, 200);
319        let lifetime = Duration::from_secs(1_000);
320        let mut old = Vec::default();
321        let mut accepted = 0;
322
323        for i in 0..200 {
324            let token = rng.random::<u128>();
325            let now = UNIX_EPOCH + lifetime * 10 + lifetime * i / 10;
326            let issued = now - lifetime.mul_f32(rng.random_range(0.0..3.0));
327            let result = log.check_and_insert(token, issued, lifetime);
328            if result.is_ok() {
329                accepted += 1;
330            }
331            old.push((token, issued));
332            let old_idx = rng.random_range(0..old.len());
333            let (old_token, old_issued) = old[old_idx];
334            assert!(
335                log.check_and_insert(old_token, old_issued, lifetime)
336                    .is_err()
337            );
338        }
339        assert!(accepted > 0);
340    }
341
342    fn test_doesnt_panic(log: BloomTokenLog) {
343        let mut rng = new_rng();
344
345        let issued = SystemTime::now();
346        let lifetime = Duration::from_secs(1_000_000);
347
348        for _ in 0..200 {
349            let _ = log.check_and_insert(rng.random::<u128>(), issued, lifetime);
350        }
351    }
352
353    #[test]
354    fn max_bytes_zero() {
355        // "max bytes" is documented to be approximate. but make sure it doesn't panic.
356        test_doesnt_panic(BloomTokenLog::new_expected_items(0, 200));
357    }
358
359    #[test]
360    fn expected_hits_zero() {
361        test_doesnt_panic(BloomTokenLog::new_expected_items(100, 0));
362    }
363
364    #[test]
365    fn k_num_zero() {
366        test_doesnt_panic(BloomTokenLog::new(100, 0));
367    }
368}