1use std::{
2 fmt,
3 mem::size_of,
4 net::{IpAddr, SocketAddr},
5};
6
7use bytes::{Buf, BufMut, Bytes};
8use rand::Rng;
9
10use crate::{
11 Duration, RESET_TOKEN_SIZE, ServerConfig, SystemTime, UNIX_EPOCH,
12 coding::{BufExt, BufMutExt},
13 crypto::{HandshakeTokenKey, HmacKey},
14 packet::InitialHeader,
15 shared::ConnectionId,
16};
17
18pub trait TokenLog: Send + Sync {
32 fn check_and_insert(
59 &self,
60 nonce: u128,
61 issued: SystemTime,
62 lifetime: Duration,
63 ) -> Result<(), TokenReuseError>;
64}
65
66pub struct TokenReuseError;
68
69pub struct NoneTokenLog;
71
72impl TokenLog for NoneTokenLog {
73 fn check_and_insert(&self, _: u128, _: SystemTime, _: Duration) -> Result<(), TokenReuseError> {
74 Err(TokenReuseError)
75 }
76}
77
78pub trait TokenStore: Send + Sync {
81 fn insert(&self, server_name: &str, token: Bytes);
85
86 fn take(&self, server_name: &str) -> Option<Bytes>;
93}
94
95pub struct NoneTokenStore;
97
98impl TokenStore for NoneTokenStore {
99 fn insert(&self, _: &str, _: Bytes) {}
100 fn take(&self, _: &str) -> Option<Bytes> {
101 None
102 }
103}
104
105#[derive(Debug)]
107pub(crate) struct IncomingToken {
108 pub(crate) retry_src_cid: Option<ConnectionId>,
109 pub(crate) orig_dst_cid: ConnectionId,
110 pub(crate) validated: bool,
111}
112
113impl IncomingToken {
114 pub(crate) fn from_header(
117 header: &InitialHeader,
118 server_config: &ServerConfig,
119 remote_address: SocketAddr,
120 ) -> Result<Self, InvalidRetryTokenError> {
121 let unvalidated = Self {
122 retry_src_cid: None,
123 orig_dst_cid: header.dst_cid,
124 validated: false,
125 };
126
127 if header.token.is_empty() {
129 return Ok(unvalidated);
130 }
131
132 let Some(retry) = Token::decode(&*server_config.token_key, &header.token) else {
142 return Ok(unvalidated);
143 };
144
145 match retry.payload {
147 TokenPayload::Retry {
148 address,
149 orig_dst_cid,
150 issued,
151 } => {
152 if address != remote_address {
153 return Err(InvalidRetryTokenError);
154 }
155 if issued + server_config.retry_token_lifetime < server_config.time_source.now() {
156 return Err(InvalidRetryTokenError);
157 }
158
159 Ok(Self {
160 retry_src_cid: Some(header.dst_cid),
161 orig_dst_cid,
162 validated: true,
163 })
164 }
165 TokenPayload::Validation { ip, issued } => {
166 if ip != remote_address.ip() {
167 return Ok(unvalidated);
168 }
169 if issued + server_config.validation_token.lifetime
170 < server_config.time_source.now()
171 {
172 return Ok(unvalidated);
173 }
174 if server_config
175 .validation_token
176 .log
177 .check_and_insert(retry.nonce, issued, server_config.validation_token.lifetime)
178 .is_err()
179 {
180 return Ok(unvalidated);
181 }
182
183 Ok(Self {
184 retry_src_cid: None,
185 orig_dst_cid: header.dst_cid,
186 validated: true,
187 })
188 }
189 }
190 }
191}
192
193pub(crate) struct InvalidRetryTokenError;
197
198pub(crate) struct Token {
200 pub(crate) payload: TokenPayload,
202 nonce: u128,
204}
205
206impl Token {
207 pub(crate) fn new(payload: TokenPayload, rng: &mut impl Rng) -> Self {
209 Self {
210 nonce: rng.random(),
211 payload,
212 }
213 }
214
215 pub(crate) fn encode(&self, key: &dyn HandshakeTokenKey) -> Vec<u8> {
217 let mut buf = Vec::new();
218
219 match self.payload {
221 TokenPayload::Retry {
222 address,
223 orig_dst_cid,
224 issued,
225 } => {
226 buf.put_u8(TokenType::Retry as u8);
227 encode_addr(&mut buf, address);
228 orig_dst_cid.encode_long(&mut buf);
229 encode_unix_secs(&mut buf, issued);
230 }
231 TokenPayload::Validation { ip, issued } => {
232 buf.put_u8(TokenType::Validation as u8);
233 encode_ip(&mut buf, ip);
234 encode_unix_secs(&mut buf, issued);
235 }
236 }
237
238 let aead_key = key.aead_from_hkdf(&self.nonce.to_le_bytes());
240 aead_key.seal(&mut buf, &[]).unwrap();
241 buf.extend(&self.nonce.to_le_bytes());
242
243 buf
244 }
245
246 fn decode(key: &dyn HandshakeTokenKey, raw_token_bytes: &[u8]) -> Option<Self> {
248 let nonce_slice_start = raw_token_bytes.len().checked_sub(size_of::<u128>())?;
252 let (sealed_token, nonce_bytes) = raw_token_bytes.split_at(nonce_slice_start);
253
254 let nonce = u128::from_le_bytes(nonce_bytes.try_into().unwrap());
255
256 let aead_key = key.aead_from_hkdf(nonce_bytes);
257 let mut sealed_token = sealed_token.to_vec();
258 let data = aead_key.open(&mut sealed_token, &[]).ok()?;
259
260 let mut reader = &data[..];
262 let payload = match TokenType::from_byte((&mut reader).get::<u8>().ok()?)? {
263 TokenType::Retry => TokenPayload::Retry {
264 address: decode_addr(&mut reader)?,
265 orig_dst_cid: ConnectionId::decode_long(&mut reader)?,
266 issued: decode_unix_secs(&mut reader)?,
267 },
268 TokenType::Validation => TokenPayload::Validation {
269 ip: decode_ip(&mut reader)?,
270 issued: decode_unix_secs(&mut reader)?,
271 },
272 };
273
274 if !reader.is_empty() {
275 return None;
277 }
278
279 Some(Self { nonce, payload })
280 }
281}
282
283pub(crate) enum TokenPayload {
285 Retry {
287 address: SocketAddr,
289 orig_dst_cid: ConnectionId,
291 issued: SystemTime,
293 },
294 Validation {
296 ip: IpAddr,
298 issued: SystemTime,
300 },
301}
302
303#[derive(Copy, Clone)]
305#[repr(u8)]
306enum TokenType {
307 Retry = 0,
308 Validation = 1,
309}
310
311impl TokenType {
312 fn from_byte(n: u8) -> Option<Self> {
313 use TokenType::*;
314 [Retry, Validation].into_iter().find(|ty| *ty as u8 == n)
315 }
316}
317
318fn encode_addr(buf: &mut Vec<u8>, address: SocketAddr) {
319 encode_ip(buf, address.ip());
320 buf.put_u16(address.port());
321}
322
323fn decode_addr<B: Buf>(buf: &mut B) -> Option<SocketAddr> {
324 let ip = decode_ip(buf)?;
325 let port = buf.get().ok()?;
326 Some(SocketAddr::new(ip, port))
327}
328
329fn encode_ip(buf: &mut Vec<u8>, ip: IpAddr) {
330 match ip {
331 IpAddr::V4(x) => {
332 buf.put_u8(0);
333 buf.put_slice(&x.octets());
334 }
335 IpAddr::V6(x) => {
336 buf.put_u8(1);
337 buf.put_slice(&x.octets());
338 }
339 }
340}
341
342fn decode_ip<B: Buf>(buf: &mut B) -> Option<IpAddr> {
343 match buf.get::<u8>().ok()? {
344 0 => buf.get().ok().map(IpAddr::V4),
345 1 => buf.get().ok().map(IpAddr::V6),
346 _ => None,
347 }
348}
349
350fn encode_unix_secs(buf: &mut Vec<u8>, time: SystemTime) {
351 buf.write::<u64>(
352 time.duration_since(UNIX_EPOCH)
353 .unwrap_or_default()
354 .as_secs(),
355 );
356}
357
358fn decode_unix_secs<B: Buf>(buf: &mut B) -> Option<SystemTime> {
359 Some(UNIX_EPOCH + Duration::from_secs(buf.get::<u64>().ok()?))
360}
361
362#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Copy, Clone, Hash)]
367pub(crate) struct ResetToken([u8; RESET_TOKEN_SIZE]);
368
369impl ResetToken {
370 pub(crate) fn new(key: &dyn HmacKey, id: ConnectionId) -> Self {
371 let mut signature = vec![0; key.signature_len()];
372 key.sign(&id, &mut signature);
373 let mut result = [0; RESET_TOKEN_SIZE];
375 result.copy_from_slice(&signature[..RESET_TOKEN_SIZE]);
376 result.into()
377 }
378}
379
380impl PartialEq for ResetToken {
381 fn eq(&self, other: &Self) -> bool {
382 crate::constant_time::eq(&self.0, &other.0)
383 }
384}
385
386impl Eq for ResetToken {}
387
388impl From<[u8; RESET_TOKEN_SIZE]> for ResetToken {
389 fn from(x: [u8; RESET_TOKEN_SIZE]) -> Self {
390 Self(x)
391 }
392}
393
394impl std::ops::Deref for ResetToken {
395 type Target = [u8];
396 fn deref(&self) -> &[u8] {
397 &self.0
398 }
399}
400
401impl fmt::Display for ResetToken {
402 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403 for byte in self.iter() {
404 write!(f, "{byte:02x}")?;
405 }
406 Ok(())
407 }
408}
409
410#[cfg(all(test, any(feature = "aws-lc-rs", feature = "ring")))]
411mod test {
412 use super::*;
413 #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
414 use aws_lc_rs::hkdf;
415 use rand::prelude::*;
416 #[cfg(feature = "ring")]
417 use ring::hkdf;
418
419 fn token_round_trip(payload: TokenPayload) -> TokenPayload {
420 let rng = &mut rand::rng();
421 let token = Token::new(payload, rng);
422 let mut master_key = [0; 64];
423 rng.fill_bytes(&mut master_key);
424 let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
425 let encoded = token.encode(&prk);
426 let decoded = Token::decode(&prk, &encoded).expect("token didn't decrypt / decode");
427 assert_eq!(token.nonce, decoded.nonce);
428 decoded.payload
429 }
430
431 #[test]
432 fn retry_token_sanity() {
433 use crate::MAX_CID_SIZE;
434 use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator};
435 use crate::{Duration, UNIX_EPOCH};
436
437 use std::net::Ipv6Addr;
438
439 let address_1 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433);
440 let orig_dst_cid_1 = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
441 let issued_1 = UNIX_EPOCH + Duration::from_secs(42); let payload_1 = TokenPayload::Retry {
443 address: address_1,
444 orig_dst_cid: orig_dst_cid_1,
445 issued: issued_1,
446 };
447 let TokenPayload::Retry {
448 address: address_2,
449 orig_dst_cid: orig_dst_cid_2,
450 issued: issued_2,
451 } = token_round_trip(payload_1)
452 else {
453 panic!("token decoded as wrong variant");
454 };
455
456 assert_eq!(address_1, address_2);
457 assert_eq!(orig_dst_cid_1, orig_dst_cid_2);
458 assert_eq!(issued_1, issued_2);
459 }
460
461 #[test]
462 fn validation_token_sanity() {
463 use crate::{Duration, UNIX_EPOCH};
464
465 use std::net::Ipv6Addr;
466
467 let ip_1 = Ipv6Addr::LOCALHOST.into();
468 let issued_1 = UNIX_EPOCH + Duration::from_secs(42); let payload_1 = TokenPayload::Validation {
471 ip: ip_1,
472 issued: issued_1,
473 };
474 let TokenPayload::Validation {
475 ip: ip_2,
476 issued: issued_2,
477 } = token_round_trip(payload_1)
478 else {
479 panic!("token decoded as wrong variant");
480 };
481
482 assert_eq!(ip_1, ip_2);
483 assert_eq!(issued_1, issued_2);
484 }
485
486 #[test]
487 fn invalid_token_returns_err() {
488 use super::*;
489 use rand::RngCore;
490
491 let rng = &mut rand::rng();
492
493 let mut master_key = [0; 64];
494 rng.fill_bytes(&mut master_key);
495
496 let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
497
498 let mut invalid_token = Vec::new();
499
500 let mut random_data = [0; 32];
501 rand::rng().fill_bytes(&mut random_data);
502 invalid_token.put_slice(&random_data);
503
504 assert!(Token::decode(&prk, &invalid_token).is_none());
506 }
507}