1use std::{
2 collections::VecDeque,
3 fmt,
4 future::Future,
5 io,
6 io::IoSliceMut,
7 mem,
8 net::{SocketAddr, SocketAddrV6},
9 pin::Pin,
10 str,
11 sync::{Arc, Mutex},
12 task::{Context, Poll, Waker},
13};
14
15#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
16use crate::runtime::default_runtime;
17use crate::{
18 Instant,
19 runtime::{AsyncUdpSocket, Runtime},
20 udp_transmit,
21};
22use bytes::{Bytes, BytesMut};
23use pin_project_lite::pin_project;
24use proto::{
25 self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
26 EndpointEvent, ServerConfig,
27};
28use rustc_hash::FxHashMap;
29#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"),))]
30use socket2::{Domain, Protocol, Socket, Type};
31use tokio::sync::{Notify, futures::Notified, mpsc};
32use tracing::{Instrument, Span};
33use udp::{BATCH_SIZE, RecvMeta};
34
35use crate::{
36 ConnectionEvent, EndpointConfig, IO_LOOP_BOUND, RECV_TIME_BOUND, VarInt,
37 connection::Connecting, incoming::Incoming, work_limiter::WorkLimiter,
38};
39
40#[derive(Debug, Clone)]
47pub struct Endpoint {
48 pub(crate) inner: EndpointRef,
49 pub(crate) default_client_config: Option<ClientConfig>,
50 runtime: Arc<dyn Runtime>,
51}
52
53impl Endpoint {
54 #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] pub fn client(addr: SocketAddr) -> io::Result<Self> {
73 let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
74 if addr.is_ipv6() {
75 if let Err(e) = socket.set_only_v6(false) {
76 tracing::debug!(%e, "unable to make socket dual-stack");
77 }
78 }
79 socket.bind(&addr.into())?;
80 let runtime = default_runtime()
81 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no async runtime found"))?;
82 Self::new_with_abstract_socket(
83 EndpointConfig::default(),
84 None,
85 runtime.wrap_udp_socket(socket.into())?,
86 runtime,
87 )
88 }
89
90 pub fn stats(&self) -> EndpointStats {
92 self.inner.state.lock().unwrap().stats
93 }
94
95 #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
103 let socket = std::net::UdpSocket::bind(addr)?;
104 let runtime = default_runtime()
105 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no async runtime found"))?;
106 Self::new_with_abstract_socket(
107 EndpointConfig::default(),
108 Some(config),
109 runtime.wrap_udp_socket(socket)?,
110 runtime,
111 )
112 }
113
114 #[cfg(not(wasm_browser))]
116 pub fn new(
117 config: EndpointConfig,
118 server_config: Option<ServerConfig>,
119 socket: std::net::UdpSocket,
120 runtime: Arc<dyn Runtime>,
121 ) -> io::Result<Self> {
122 let socket = runtime.wrap_udp_socket(socket)?;
123 Self::new_with_abstract_socket(config, server_config, socket, runtime)
124 }
125
126 pub fn new_with_abstract_socket(
131 config: EndpointConfig,
132 server_config: Option<ServerConfig>,
133 socket: Arc<dyn AsyncUdpSocket>,
134 runtime: Arc<dyn Runtime>,
135 ) -> io::Result<Self> {
136 let addr = socket.local_addr()?;
137 let allow_mtud = !socket.may_fragment();
138 let rc = EndpointRef::new(
139 socket,
140 proto::Endpoint::new(
141 Arc::new(config),
142 server_config.map(Arc::new),
143 allow_mtud,
144 None,
145 ),
146 addr.is_ipv6(),
147 runtime.clone(),
148 );
149 let driver = EndpointDriver(rc.clone());
150 runtime.spawn(Box::pin(
151 async {
152 if let Err(e) = driver.await {
153 tracing::error!("I/O error: {}", e);
154 }
155 }
156 .instrument(Span::current()),
157 ));
158 Ok(Self {
159 inner: rc,
160 default_client_config: None,
161 runtime,
162 })
163 }
164
165 pub fn accept(&self) -> Accept<'_> {
172 Accept {
173 endpoint: self,
174 notify: self.inner.shared.incoming.notified(),
175 }
176 }
177
178 pub fn set_default_client_config(&mut self, config: ClientConfig) {
180 self.default_client_config = Some(config);
181 }
182
183 pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
192 let config = match &self.default_client_config {
193 Some(config) => config.clone(),
194 None => return Err(ConnectError::NoDefaultClientConfig),
195 };
196
197 self.connect_with(config, addr, server_name)
198 }
199
200 pub fn connect_with(
206 &self,
207 config: ClientConfig,
208 addr: SocketAddr,
209 server_name: &str,
210 ) -> Result<Connecting, ConnectError> {
211 let mut endpoint = self.inner.state.lock().unwrap();
212 if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
213 return Err(ConnectError::EndpointStopping);
214 }
215 if addr.is_ipv6() && !endpoint.ipv6 {
216 return Err(ConnectError::InvalidRemoteAddress(addr));
217 }
218 let addr = if endpoint.ipv6 {
219 SocketAddr::V6(ensure_ipv6(addr))
220 } else {
221 addr
222 };
223
224 let (ch, conn) = endpoint
225 .inner
226 .connect(self.runtime.now(), config, addr, server_name)?;
227
228 let socket = endpoint.socket.clone();
229 endpoint.stats.outgoing_handshakes += 1;
230 Ok(endpoint
231 .recv_state
232 .connections
233 .insert(ch, conn, socket, self.runtime.clone()))
234 }
235
236 #[cfg(not(wasm_browser))]
240 pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
241 self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
242 }
243
244 pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
251 let addr = socket.local_addr()?;
252 let mut inner = self.inner.state.lock().unwrap();
253 inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
254 inner.ipv6 = addr.is_ipv6();
255
256 for sender in inner.recv_state.connections.senders.values() {
258 let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
260 }
261
262 Ok(())
263 }
264
265 pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
269 self.inner
270 .state
271 .lock()
272 .unwrap()
273 .inner
274 .set_server_config(server_config.map(Arc::new))
275 }
276
277 pub fn local_addr(&self) -> io::Result<SocketAddr> {
279 self.inner.state.lock().unwrap().socket.local_addr()
280 }
281
282 pub fn open_connections(&self) -> usize {
284 self.inner.state.lock().unwrap().inner.open_connections()
285 }
286
287 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
293 let reason = Bytes::copy_from_slice(reason);
294 let mut endpoint = self.inner.state.lock().unwrap();
295 endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
296 for sender in endpoint.recv_state.connections.senders.values() {
297 let _ = sender.send(ConnectionEvent::Close {
299 error_code,
300 reason: reason.clone(),
301 });
302 }
303 self.inner.shared.incoming.notify_waiters();
304 }
305
306 pub async fn wait_idle(&self) {
317 loop {
318 {
319 let endpoint = &mut *self.inner.state.lock().unwrap();
320 if endpoint.recv_state.connections.is_empty() {
321 break;
322 }
323 self.inner.shared.idle.notified()
325 }
326 .await;
327 }
328 }
329}
330
331#[non_exhaustive]
333#[derive(Debug, Default, Copy, Clone)]
334pub struct EndpointStats {
335 pub accepted_handshakes: u64,
337 pub outgoing_handshakes: u64,
339 pub refused_handshakes: u64,
341 pub ignored_handshakes: u64,
343}
344
345#[must_use = "endpoint drivers must be spawned for I/O to occur"]
356#[derive(Debug)]
357pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
358
359impl Future for EndpointDriver {
360 type Output = Result<(), io::Error>;
361
362 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
363 let mut endpoint = self.0.state.lock().unwrap();
364 if endpoint.driver.is_none() {
365 endpoint.driver = Some(cx.waker().clone());
366 }
367
368 let now = endpoint.runtime.now();
369 let mut keep_going = false;
370 keep_going |= endpoint.drive_recv(cx, now)?;
371 keep_going |= endpoint.handle_events(cx, &self.0.shared);
372
373 if !endpoint.recv_state.incoming.is_empty() {
374 self.0.shared.incoming.notify_waiters();
375 }
376
377 if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
378 Poll::Ready(Ok(()))
379 } else {
380 drop(endpoint);
381 if keep_going {
385 cx.waker().wake_by_ref();
386 }
387 Poll::Pending
388 }
389 }
390}
391
392impl Drop for EndpointDriver {
393 fn drop(&mut self) {
394 let mut endpoint = self.0.state.lock().unwrap();
395 endpoint.driver_lost = true;
396 self.0.shared.incoming.notify_waiters();
397 endpoint.recv_state.connections.senders.clear();
400 }
401}
402
403#[derive(Debug)]
404pub(crate) struct EndpointInner {
405 pub(crate) state: Mutex<State>,
406 pub(crate) shared: Shared,
407}
408
409impl EndpointInner {
410 pub(crate) fn accept(
411 &self,
412 incoming: proto::Incoming,
413 server_config: Option<Arc<ServerConfig>>,
414 ) -> Result<Connecting, ConnectionError> {
415 let mut state = self.state.lock().unwrap();
416 let mut response_buffer = Vec::new();
417 let now = state.runtime.now();
418 match state
419 .inner
420 .accept(incoming, now, &mut response_buffer, server_config)
421 {
422 Ok((handle, conn)) => {
423 state.stats.accepted_handshakes += 1;
424 let socket = state.socket.clone();
425 let runtime = state.runtime.clone();
426 Ok(state
427 .recv_state
428 .connections
429 .insert(handle, conn, socket, runtime))
430 }
431 Err(error) => {
432 if let Some(transmit) = error.response {
433 respond(transmit, &response_buffer, &*state.socket);
434 }
435 Err(error.cause)
436 }
437 }
438 }
439
440 pub(crate) fn refuse(&self, incoming: proto::Incoming) {
441 let mut state = self.state.lock().unwrap();
442 state.stats.refused_handshakes += 1;
443 let mut response_buffer = Vec::new();
444 let transmit = state.inner.refuse(incoming, &mut response_buffer);
445 respond(transmit, &response_buffer, &*state.socket);
446 }
447
448 pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
449 let mut state = self.state.lock().unwrap();
450 let mut response_buffer = Vec::new();
451 let transmit = state.inner.retry(incoming, &mut response_buffer)?;
452 respond(transmit, &response_buffer, &*state.socket);
453 Ok(())
454 }
455
456 pub(crate) fn ignore(&self, incoming: proto::Incoming) {
457 let mut state = self.state.lock().unwrap();
458 state.stats.ignored_handshakes += 1;
459 state.inner.ignore(incoming);
460 }
461}
462
463#[derive(Debug)]
464pub(crate) struct State {
465 socket: Arc<dyn AsyncUdpSocket>,
466 prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
469 inner: proto::Endpoint,
470 recv_state: RecvState,
471 driver: Option<Waker>,
472 ipv6: bool,
473 events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
474 ref_count: usize,
476 driver_lost: bool,
477 runtime: Arc<dyn Runtime>,
478 stats: EndpointStats,
479}
480
481#[derive(Debug)]
482pub(crate) struct Shared {
483 incoming: Notify,
484 idle: Notify,
485}
486
487impl State {
488 fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
489 let get_time = || self.runtime.now();
490 self.recv_state.recv_limiter.start_cycle(get_time);
491 if let Some(socket) = &self.prev_socket {
492 let poll_res =
494 self.recv_state
495 .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
496 if poll_res.is_err() {
497 self.prev_socket = None;
498 }
499 };
500 let poll_res =
501 self.recv_state
502 .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
503 self.recv_state.recv_limiter.finish_cycle(get_time);
504 let poll_res = poll_res?;
505 if poll_res.received_connection_packet {
506 self.prev_socket = None;
509 }
510 Ok(poll_res.keep_going)
511 }
512
513 fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
514 for _ in 0..IO_LOOP_BOUND {
515 let (ch, event) = match self.events.poll_recv(cx) {
516 Poll::Ready(Some(x)) => x,
517 Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
518 Poll::Pending => {
519 return false;
520 }
521 };
522
523 if event.is_drained() {
524 self.recv_state.connections.senders.remove(&ch);
525 if self.recv_state.connections.is_empty() {
526 shared.idle.notify_waiters();
527 }
528 }
529 let Some(event) = self.inner.handle_event(ch, event) else {
530 continue;
531 };
532 let _ = self
534 .recv_state
535 .connections
536 .senders
537 .get_mut(&ch)
538 .unwrap()
539 .send(ConnectionEvent::Proto(event));
540 }
541
542 true
543 }
544}
545
546impl Drop for State {
547 fn drop(&mut self) {
548 for incoming in self.recv_state.incoming.drain(..) {
549 self.inner.ignore(incoming);
550 }
551 }
552}
553
554fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
555 _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
576}
577
578#[inline]
579fn proto_ecn(ecn: udp::EcnCodepoint) -> proto::EcnCodepoint {
580 match ecn {
581 udp::EcnCodepoint::Ect0 => proto::EcnCodepoint::Ect0,
582 udp::EcnCodepoint::Ect1 => proto::EcnCodepoint::Ect1,
583 udp::EcnCodepoint::Ce => proto::EcnCodepoint::Ce,
584 }
585}
586
587#[derive(Debug)]
588struct ConnectionSet {
589 senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
591 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
593 close: Option<(VarInt, Bytes)>,
595}
596
597impl ConnectionSet {
598 fn insert(
599 &mut self,
600 handle: ConnectionHandle,
601 conn: proto::Connection,
602 socket: Arc<dyn AsyncUdpSocket>,
603 runtime: Arc<dyn Runtime>,
604 ) -> Connecting {
605 let (send, recv) = mpsc::unbounded_channel();
606 if let Some((error_code, ref reason)) = self.close {
607 send.send(ConnectionEvent::Close {
608 error_code,
609 reason: reason.clone(),
610 })
611 .unwrap();
612 }
613 self.senders.insert(handle, send);
614 Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
615 }
616
617 fn is_empty(&self) -> bool {
618 self.senders.is_empty()
619 }
620}
621
622fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
623 match x {
624 SocketAddr::V6(x) => x,
625 SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
626 }
627}
628
629pin_project! {
630 pub struct Accept<'a> {
632 endpoint: &'a Endpoint,
633 #[pin]
634 notify: Notified<'a>,
635 }
636}
637
638impl Future for Accept<'_> {
639 type Output = Option<Incoming>;
640 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
641 let mut this = self.project();
642 let mut endpoint = this.endpoint.inner.state.lock().unwrap();
643 if endpoint.driver_lost {
644 return Poll::Ready(None);
645 }
646 if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
647 drop(endpoint);
649 let incoming = Incoming::new(incoming, this.endpoint.inner.clone());
650 return Poll::Ready(Some(incoming));
651 }
652 if endpoint.recv_state.connections.close.is_some() {
653 return Poll::Ready(None);
654 }
655 loop {
656 match this.notify.as_mut().poll(ctx) {
657 Poll::Pending => return Poll::Pending,
659 Poll::Ready(()) => this
661 .notify
662 .set(this.endpoint.inner.shared.incoming.notified()),
663 }
664 }
665 }
666}
667
668#[derive(Debug)]
669pub(crate) struct EndpointRef(Arc<EndpointInner>);
670
671impl EndpointRef {
672 pub(crate) fn new(
673 socket: Arc<dyn AsyncUdpSocket>,
674 inner: proto::Endpoint,
675 ipv6: bool,
676 runtime: Arc<dyn Runtime>,
677 ) -> Self {
678 let (sender, events) = mpsc::unbounded_channel();
679 let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
680 Self(Arc::new(EndpointInner {
681 shared: Shared {
682 incoming: Notify::new(),
683 idle: Notify::new(),
684 },
685 state: Mutex::new(State {
686 socket,
687 prev_socket: None,
688 inner,
689 ipv6,
690 events,
691 driver: None,
692 ref_count: 0,
693 driver_lost: false,
694 recv_state,
695 runtime,
696 stats: EndpointStats::default(),
697 }),
698 }))
699 }
700}
701
702impl Clone for EndpointRef {
703 fn clone(&self) -> Self {
704 self.0.state.lock().unwrap().ref_count += 1;
705 Self(self.0.clone())
706 }
707}
708
709impl Drop for EndpointRef {
710 fn drop(&mut self) {
711 let endpoint = &mut *self.0.state.lock().unwrap();
712 if let Some(x) = endpoint.ref_count.checked_sub(1) {
713 endpoint.ref_count = x;
714 if x == 0 {
715 if let Some(task) = endpoint.driver.take() {
718 task.wake();
719 }
720 }
721 }
722 }
723}
724
725impl std::ops::Deref for EndpointRef {
726 type Target = EndpointInner;
727 fn deref(&self) -> &Self::Target {
728 &self.0
729 }
730}
731
732struct RecvState {
734 incoming: VecDeque<proto::Incoming>,
735 connections: ConnectionSet,
736 recv_buf: Box<[u8]>,
737 recv_limiter: WorkLimiter,
738}
739
740impl RecvState {
741 fn new(
742 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
743 max_receive_segments: usize,
744 endpoint: &proto::Endpoint,
745 ) -> Self {
746 let recv_buf = vec![
747 0;
748 endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
749 * max_receive_segments
750 * BATCH_SIZE
751 ];
752 Self {
753 connections: ConnectionSet {
754 senders: FxHashMap::default(),
755 sender,
756 close: None,
757 },
758 incoming: VecDeque::new(),
759 recv_buf: recv_buf.into(),
760 recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
761 }
762 }
763
764 fn poll_socket(
765 &mut self,
766 cx: &mut Context,
767 endpoint: &mut proto::Endpoint,
768 socket: &dyn AsyncUdpSocket,
769 runtime: &dyn Runtime,
770 now: Instant,
771 ) -> Result<PollProgress, io::Error> {
772 let mut received_connection_packet = false;
773 let mut metas = [RecvMeta::default(); BATCH_SIZE];
774 let mut iovs: [IoSliceMut; BATCH_SIZE] = {
775 let mut bufs = self
776 .recv_buf
777 .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
778 .map(IoSliceMut::new);
779
780 std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
784 };
785 loop {
786 match socket.poll_recv(cx, &mut iovs, &mut metas) {
787 Poll::Ready(Ok(msgs)) => {
788 self.recv_limiter.record_work(msgs);
789 for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
790 let mut data: BytesMut = buf[0..meta.len].into();
791 while !data.is_empty() {
792 let buf = data.split_to(meta.stride.min(data.len()));
793 let mut response_buffer = Vec::new();
794 match endpoint.handle(
795 now,
796 meta.addr,
797 meta.dst_ip,
798 meta.ecn.map(proto_ecn),
799 buf,
800 &mut response_buffer,
801 ) {
802 Some(DatagramEvent::NewConnection(incoming)) => {
803 if self.connections.close.is_none() {
804 self.incoming.push_back(incoming);
805 } else {
806 let transmit =
807 endpoint.refuse(incoming, &mut response_buffer);
808 respond(transmit, &response_buffer, socket);
809 }
810 }
811 Some(DatagramEvent::ConnectionEvent(handle, event)) => {
812 received_connection_packet = true;
814 let _ = self
815 .connections
816 .senders
817 .get_mut(&handle)
818 .unwrap()
819 .send(ConnectionEvent::Proto(event));
820 }
821 Some(DatagramEvent::Response(transmit)) => {
822 respond(transmit, &response_buffer, socket);
823 }
824 None => {}
825 }
826 }
827 }
828 }
829 Poll::Pending => {
830 return Ok(PollProgress {
831 received_connection_packet,
832 keep_going: false,
833 });
834 }
835 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
838 continue;
839 }
840 Poll::Ready(Err(e)) => {
841 return Err(e);
842 }
843 }
844 if !self.recv_limiter.allow_work(|| runtime.now()) {
845 return Ok(PollProgress {
846 received_connection_packet,
847 keep_going: true,
848 });
849 }
850 }
851 }
852}
853
854impl fmt::Debug for RecvState {
855 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
856 f.debug_struct("RecvState")
857 .field("incoming", &self.incoming)
858 .field("connections", &self.connections)
859 .field("recv_limiter", &self.recv_limiter)
861 .finish_non_exhaustive()
862 }
863}
864
865#[derive(Default)]
866struct PollProgress {
867 received_connection_packet: bool,
869 keep_going: bool,
871}