quinn_proto/
cid_generator.rs1use std::hash::Hasher;
2
3use rand::{Rng, RngCore};
4
5use crate::Duration;
6use crate::MAX_CID_SIZE;
7use crate::shared::ConnectionId;
8
9pub trait ConnectionIdGenerator: Send + Sync {
11 fn generate_cid(&mut self) -> ConnectionId;
19
20 fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> {
24 Ok(())
25 }
26
27 fn cid_len(&self) -> usize;
29 fn cid_lifetime(&self) -> Option<Duration>;
33}
34
35#[derive(Debug, Copy, Clone)]
37pub struct InvalidCid;
38
39#[derive(Debug, Clone, Copy)]
44pub struct RandomConnectionIdGenerator {
45 cid_len: usize,
46 lifetime: Option<Duration>,
47}
48
49impl Default for RandomConnectionIdGenerator {
50 fn default() -> Self {
51 Self {
52 cid_len: 8,
53 lifetime: None,
54 }
55 }
56}
57
58impl RandomConnectionIdGenerator {
59 pub fn new(cid_len: usize) -> Self {
63 debug_assert!(cid_len <= MAX_CID_SIZE);
64 Self {
65 cid_len,
66 ..Self::default()
67 }
68 }
69
70 pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
72 self.lifetime = Some(d);
73 self
74 }
75}
76
77impl ConnectionIdGenerator for RandomConnectionIdGenerator {
78 fn generate_cid(&mut self) -> ConnectionId {
79 let mut bytes_arr = [0; MAX_CID_SIZE];
80 rand::rng().fill_bytes(&mut bytes_arr[..self.cid_len]);
81
82 ConnectionId::new(&bytes_arr[..self.cid_len])
83 }
84
85 fn cid_len(&self) -> usize {
87 self.cid_len
88 }
89
90 fn cid_lifetime(&self) -> Option<Duration> {
91 self.lifetime
92 }
93}
94
95pub struct HashedConnectionIdGenerator {
101 key: u64,
102 lifetime: Option<Duration>,
103}
104
105impl HashedConnectionIdGenerator {
106 pub fn new() -> Self {
108 Self::from_key(rand::rng().random())
109 }
110
111 pub fn from_key(key: u64) -> Self {
116 Self {
117 key,
118 lifetime: None,
119 }
120 }
121
122 pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
124 self.lifetime = Some(d);
125 self
126 }
127}
128
129impl Default for HashedConnectionIdGenerator {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135impl ConnectionIdGenerator for HashedConnectionIdGenerator {
136 fn generate_cid(&mut self) -> ConnectionId {
137 let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN];
138 rand::rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]);
139 let mut hasher = rustc_hash::FxHasher::default();
140 hasher.write_u64(self.key);
141 hasher.write(&bytes_arr[..NONCE_LEN]);
142 bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]);
143 ConnectionId::new(&bytes_arr)
144 }
145
146 fn validate(&self, cid: &ConnectionId) -> Result<(), InvalidCid> {
147 let (nonce, signature) = cid.split_at(NONCE_LEN);
148 let mut hasher = rustc_hash::FxHasher::default();
149 hasher.write_u64(self.key);
150 hasher.write(nonce);
151 let expected = hasher.finish().to_le_bytes();
152 match expected[..SIGNATURE_LEN] == signature[..] {
153 true => Ok(()),
154 false => Err(InvalidCid),
155 }
156 }
157
158 fn cid_len(&self) -> usize {
159 NONCE_LEN + SIGNATURE_LEN
160 }
161
162 fn cid_lifetime(&self) -> Option<Duration> {
163 self.lifetime
164 }
165}
166
167const NONCE_LEN: usize = 3; const SIGNATURE_LEN: usize = 8 - NONCE_LEN; #[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn validate_keyed_cid() {
176 let mut generator = HashedConnectionIdGenerator::new();
177 let cid = generator.generate_cid();
178 generator.validate(&cid).unwrap();
179 }
180}