1use std::{any::Any, io, str, sync::Arc};
2
3#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
4use aws_lc_rs::aead;
5use bytes::BytesMut;
6#[cfg(feature = "ring")]
7use ring::aead;
8pub use rustls::Error;
9use rustls::{
10 self, CipherSuite,
11 client::danger::ServerCertVerifier,
12 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
13 quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
14};
15
16use crate::{
17 ConnectError, ConnectionId, Side, TransportError, TransportErrorCode,
18 crypto::{
19 self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
20 },
21 transport_parameters::TransportParameters,
22};
23
24impl From<Side> for rustls::Side {
25 fn from(s: Side) -> Self {
26 match s {
27 Side::Client => Self::Client,
28 Side::Server => Self::Server,
29 }
30 }
31}
32
33pub struct TlsSession {
35 version: Version,
36 got_handshake_data: bool,
37 next_secrets: Option<Secrets>,
38 inner: Connection,
39 suite: Suite,
40}
41
42impl TlsSession {
43 fn side(&self) -> Side {
44 match self.inner {
45 Connection::Client(_) => Side::Client,
46 Connection::Server(_) => Side::Server,
47 }
48 }
49}
50
51impl crypto::Session for TlsSession {
52 fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
53 initial_keys(self.version, *dst_cid, side, &self.suite)
54 }
55
56 fn handshake_data(&self) -> Option<Box<dyn Any>> {
57 if !self.got_handshake_data {
58 return None;
59 }
60 Some(Box::new(HandshakeData {
61 protocol: self.inner.alpn_protocol().map(|x| x.into()),
62 server_name: match self.inner {
63 Connection::Client(_) => None,
64 Connection::Server(ref session) => session.server_name().map(|x| x.into()),
65 },
66 }))
67 }
68
69 fn peer_identity(&self) -> Option<Box<dyn Any>> {
71 self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
72 Box::new(
73 v.iter()
74 .map(|v| v.clone().into_owned())
75 .collect::<Vec<CertificateDer<'static>>>(),
76 )
77 })
78 }
79
80 fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
81 let keys = self.inner.zero_rtt_keys()?;
82 Some((Box::new(keys.header), Box::new(keys.packet)))
83 }
84
85 fn early_data_accepted(&self) -> Option<bool> {
86 match self.inner {
87 Connection::Client(ref session) => Some(session.is_early_data_accepted()),
88 _ => None,
89 }
90 }
91
92 fn is_handshaking(&self) -> bool {
93 self.inner.is_handshaking()
94 }
95
96 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
97 self.inner.read_hs(buf).map_err(|e| {
98 if let Some(alert) = self.inner.alert() {
99 TransportError {
100 code: TransportErrorCode::crypto(alert.into()),
101 frame: None,
102 reason: e.to_string(),
103 }
104 } else {
105 TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
106 }
107 })?;
108 if !self.got_handshake_data {
109 let have_server_name = match self.inner {
113 Connection::Client(_) => false,
114 Connection::Server(ref session) => session.server_name().is_some(),
115 };
116 if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
117 self.got_handshake_data = true;
118 return Ok(true);
119 }
120 }
121 Ok(false)
122 }
123
124 fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
125 match self.inner.quic_transport_parameters() {
126 None => Ok(None),
127 Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
128 Ok(params) => Ok(Some(params)),
129 Err(e) => Err(e.into()),
130 },
131 }
132 }
133
134 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
135 let keys = match self.inner.write_hs(buf)? {
136 KeyChange::Handshake { keys } => keys,
137 KeyChange::OneRtt { keys, next } => {
138 self.next_secrets = Some(next);
139 keys
140 }
141 };
142
143 Some(Keys {
144 header: KeyPair {
145 local: Box::new(keys.local.header),
146 remote: Box::new(keys.remote.header),
147 },
148 packet: KeyPair {
149 local: Box::new(keys.local.packet),
150 remote: Box::new(keys.remote.packet),
151 },
152 })
153 }
154
155 fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
156 let secrets = self.next_secrets.as_mut()?;
157 let keys = secrets.next_packet_keys();
158 Some(KeyPair {
159 local: Box::new(keys.local),
160 remote: Box::new(keys.remote),
161 })
162 }
163
164 fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
165 let tag_start = match payload.len().checked_sub(16) {
166 Some(x) => x,
167 None => return false,
168 };
169
170 let mut pseudo_packet =
171 Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
172 pseudo_packet.push(orig_dst_cid.len() as u8);
173 pseudo_packet.extend_from_slice(orig_dst_cid);
174 pseudo_packet.extend_from_slice(header);
175 let tag_start = tag_start + pseudo_packet.len();
176 pseudo_packet.extend_from_slice(payload);
177
178 let (nonce, key) = match self.version {
179 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
180 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
181 _ => unreachable!(),
182 };
183
184 let nonce = aead::Nonce::assume_unique_for_key(nonce);
185 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
186
187 let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
188 key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
189 }
190
191 fn export_keying_material(
192 &self,
193 output: &mut [u8],
194 label: &[u8],
195 context: &[u8],
196 ) -> Result<(), ExportKeyingMaterialError> {
197 self.inner
198 .export_keying_material(output, label, Some(context))
199 .map_err(|_| ExportKeyingMaterialError)?;
200 Ok(())
201 }
202}
203
204const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
205 0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
206];
207const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
208 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
209];
210
211const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
212 0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
213];
214const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
215 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
216];
217
218impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
219 fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
220 let (header, sample) = packet.split_at_mut(pn_offset + 4);
221 let (first, rest) = header.split_at_mut(1);
222 let pn_end = Ord::min(pn_offset + 3, rest.len());
223 self.decrypt_in_place(
224 &sample[..self.sample_size()],
225 &mut first[0],
226 &mut rest[pn_offset - 1..pn_end],
227 )
228 .unwrap();
229 }
230
231 fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
232 let (header, sample) = packet.split_at_mut(pn_offset + 4);
233 let (first, rest) = header.split_at_mut(1);
234 let pn_end = Ord::min(pn_offset + 3, rest.len());
235 self.encrypt_in_place(
236 &sample[..self.sample_size()],
237 &mut first[0],
238 &mut rest[pn_offset - 1..pn_end],
239 )
240 .unwrap();
241 }
242
243 fn sample_size(&self) -> usize {
244 self.sample_len()
245 }
246}
247
248pub struct HandshakeData {
250 pub protocol: Option<Vec<u8>>,
254 pub server_name: Option<String>,
258}
259
260pub struct QuicClientConfig {
279 pub(crate) inner: Arc<rustls::ClientConfig>,
280 initial: Suite,
281}
282
283impl QuicClientConfig {
284 pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
289 let inner = Self::inner(verifier);
290 Self {
291 initial: initial_suite_from_provider(inner.crypto_provider())
293 .expect("no initial cipher suite found"),
294 inner: Arc::new(inner),
295 }
296 }
297
298 pub fn with_initial(
302 inner: Arc<rustls::ClientConfig>,
303 initial: Suite,
304 ) -> Result<Self, NoInitialCipherSuite> {
305 match initial.suite.common.suite {
306 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
307 _ => Err(NoInitialCipherSuite { specific: true }),
308 }
309 }
310
311 pub(crate) fn inner(verifier: Arc<dyn ServerCertVerifier>) -> rustls::ClientConfig {
312 let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
313 .with_protocol_versions(&[&rustls::version::TLS13])
314 .unwrap() .dangerous()
316 .with_custom_certificate_verifier(verifier)
317 .with_no_client_auth();
318
319 config.enable_early_data = true;
320 config
321 }
322}
323
324impl crypto::ClientConfig for QuicClientConfig {
325 fn start_session(
326 self: Arc<Self>,
327 version: u32,
328 server_name: &str,
329 params: &TransportParameters,
330 ) -> Result<Box<dyn crypto::Session>, ConnectError> {
331 let version = interpret_version(version)?;
332 Ok(Box::new(TlsSession {
333 version,
334 got_handshake_data: false,
335 next_secrets: None,
336 inner: rustls::quic::Connection::Client(
337 rustls::quic::ClientConnection::new(
338 self.inner.clone(),
339 version,
340 ServerName::try_from(server_name)
341 .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
342 .to_owned(),
343 to_vec(params),
344 )
345 .unwrap(),
346 ),
347 suite: self.initial,
348 }))
349 }
350}
351
352impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
353 type Error = NoInitialCipherSuite;
354
355 fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
356 Arc::new(inner).try_into()
357 }
358}
359
360impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
361 type Error = NoInitialCipherSuite;
362
363 fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
364 Ok(Self {
365 initial: initial_suite_from_provider(inner.crypto_provider())
366 .ok_or(NoInitialCipherSuite { specific: false })?,
367 inner,
368 })
369 }
370}
371
372#[derive(Clone, Debug)]
380pub struct NoInitialCipherSuite {
381 specific: bool,
383}
384
385impl std::fmt::Display for NoInitialCipherSuite {
386 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
387 f.write_str(match self.specific {
388 true => "invalid cipher suite specified",
389 false => "no initial cipher suite found",
390 })
391 }
392}
393
394impl std::error::Error for NoInitialCipherSuite {}
395
396pub struct QuicServerConfig {
409 inner: Arc<rustls::ServerConfig>,
410 initial: Suite,
411}
412
413impl QuicServerConfig {
414 pub(crate) fn new(
415 cert_chain: Vec<CertificateDer<'static>>,
416 key: PrivateKeyDer<'static>,
417 ) -> Result<Self, rustls::Error> {
418 let inner = Self::inner(cert_chain, key)?;
419 Ok(Self {
420 initial: initial_suite_from_provider(inner.crypto_provider())
422 .expect("no initial cipher suite found"),
423 inner: Arc::new(inner),
424 })
425 }
426
427 pub fn with_initial(
431 inner: Arc<rustls::ServerConfig>,
432 initial: Suite,
433 ) -> Result<Self, NoInitialCipherSuite> {
434 match initial.suite.common.suite {
435 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
436 _ => Err(NoInitialCipherSuite { specific: true }),
437 }
438 }
439
440 pub(crate) fn inner(
446 cert_chain: Vec<CertificateDer<'static>>,
447 key: PrivateKeyDer<'static>,
448 ) -> Result<rustls::ServerConfig, rustls::Error> {
449 let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
450 .with_protocol_versions(&[&rustls::version::TLS13])
451 .unwrap() .with_no_client_auth()
453 .with_single_cert(cert_chain, key)?;
454
455 inner.max_early_data_size = u32::MAX;
456 Ok(inner)
457 }
458}
459
460impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
461 type Error = NoInitialCipherSuite;
462
463 fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
464 Arc::new(inner).try_into()
465 }
466}
467
468impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
469 type Error = NoInitialCipherSuite;
470
471 fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
472 Ok(Self {
473 initial: initial_suite_from_provider(inner.crypto_provider())
474 .ok_or(NoInitialCipherSuite { specific: false })?,
475 inner,
476 })
477 }
478}
479
480impl crypto::ServerConfig for QuicServerConfig {
481 fn start_session(
482 self: Arc<Self>,
483 version: u32,
484 params: &TransportParameters,
485 ) -> Box<dyn crypto::Session> {
486 let version = interpret_version(version).unwrap();
488 Box::new(TlsSession {
489 version,
490 got_handshake_data: false,
491 next_secrets: None,
492 inner: rustls::quic::Connection::Server(
493 rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
494 .unwrap(),
495 ),
496 suite: self.initial,
497 })
498 }
499
500 fn initial_keys(
501 &self,
502 version: u32,
503 dst_cid: &ConnectionId,
504 ) -> Result<Keys, UnsupportedVersion> {
505 let version = interpret_version(version)?;
506 Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
507 }
508
509 fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
510 let version = interpret_version(version).unwrap();
512 let (nonce, key) = match version {
513 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
514 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
515 _ => unreachable!(),
516 };
517
518 let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
519 pseudo_packet.push(orig_dst_cid.len() as u8);
520 pseudo_packet.extend_from_slice(orig_dst_cid);
521 pseudo_packet.extend_from_slice(packet);
522
523 let nonce = aead::Nonce::assume_unique_for_key(nonce);
524 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
525
526 let tag = key
527 .seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut [])
528 .unwrap();
529 let mut result = [0; 16];
530 result.copy_from_slice(tag.as_ref());
531 result
532 }
533}
534
535pub(crate) fn initial_suite_from_provider(
536 provider: &Arc<rustls::crypto::CryptoProvider>,
537) -> Option<Suite> {
538 provider
539 .cipher_suites
540 .iter()
541 .find_map(|cs| match (cs.suite(), cs.tls13()) {
542 (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
543 Some(suite.quic_suite())
544 }
545 _ => None,
546 })
547 .flatten()
548}
549
550pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
551 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
552 let provider = rustls::crypto::aws_lc_rs::default_provider();
553 #[cfg(feature = "rustls-ring")]
554 let provider = rustls::crypto::ring::default_provider();
555 Arc::new(provider)
556}
557
558fn to_vec(params: &TransportParameters) -> Vec<u8> {
559 let mut bytes = Vec::new();
560 params.write(&mut bytes);
561 bytes
562}
563
564pub(crate) fn initial_keys(
565 version: Version,
566 dst_cid: ConnectionId,
567 side: Side,
568 suite: &Suite,
569) -> Keys {
570 let keys = suite.keys(&dst_cid, side.into(), version);
571 Keys {
572 header: KeyPair {
573 local: Box::new(keys.local.header),
574 remote: Box::new(keys.remote.header),
575 },
576 packet: KeyPair {
577 local: Box::new(keys.local.packet),
578 remote: Box::new(keys.remote.packet),
579 },
580 }
581}
582
583impl crypto::PacketKey for Box<dyn PacketKey> {
584 fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) {
585 let (header, payload_tag) = buf.split_at_mut(header_len);
586 let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
587 let tag = self.encrypt_in_place(packet, &*header, payload).unwrap();
588 tag_storage.copy_from_slice(tag.as_ref());
589 }
590
591 fn decrypt(
592 &self,
593 packet: u64,
594 header: &[u8],
595 payload: &mut BytesMut,
596 ) -> Result<(), CryptoError> {
597 let plain = self
598 .decrypt_in_place(packet, header, payload.as_mut())
599 .map_err(|_| CryptoError)?;
600 let plain_len = plain.len();
601 payload.truncate(plain_len);
602 Ok(())
603 }
604
605 fn tag_len(&self) -> usize {
606 (**self).tag_len()
607 }
608
609 fn confidentiality_limit(&self) -> u64 {
610 (**self).confidentiality_limit()
611 }
612
613 fn integrity_limit(&self) -> u64 {
614 (**self).integrity_limit()
615 }
616}
617
618fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
619 match version {
620 0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
621 0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
622 _ => Err(UnsupportedVersion),
623 }
624}