hyper_tungstenite/
lib.rs

1//! This crate allows [`hyper`](https://docs.rs/hyper) servers to accept websocket connections, backed by [`tungstenite`](https://docs.rs/tungstenite).
2//!
3//! The [`upgrade`] function allows you to upgrade a HTTP connection to a websocket connection.
4//! It returns a HTTP response to send to the client, and a future that resolves to a [`WebSocketStream`].
5//! The response must be sent to the client for the future to be resolved.
6//! In practise this means that you must spawn the future in a different task.
7//!
8//! Note that the [`upgrade`] function itself does not check if the request is actually an upgrade request.
9//! For simple cases, you can check this using the [`is_upgrade_request`] function before calling [`upgrade`].
10//! For more complicated cases where the server should support multiple upgrade protocols,
11//! you can manually inspect the `Connection` and `Upgrade` headers.
12//!
13//! # Example
14//! ```no_run
15//! use futures::sink::SinkExt;
16//! use futures::stream::StreamExt;
17//! use http_body_util::Full;
18//! use hyper::body::{Bytes, Incoming};
19//! use hyper::{Request, Response};
20//! use hyper_tungstenite::{tungstenite, HyperWebsocket};
21//! use hyper_util::rt::TokioIo;
22//! use tungstenite::Message;
23//!
24//! type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
25//!
26//! /// Handle a HTTP or WebSocket request.
27//! async fn handle_request(mut request: Request<Incoming>) -> Result<Response<Full<Bytes>>, Error> {
28//!     // Check if the request is a websocket upgrade request.
29//!     if hyper_tungstenite::is_upgrade_request(&request) {
30//!         let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)?;
31//!
32//!         // Spawn a task to handle the websocket connection.
33//!         tokio::spawn(async move {
34//!             if let Err(e) = serve_websocket(websocket).await {
35//!                 eprintln!("Error in websocket connection: {e}");
36//!             }
37//!         });
38//!
39//!         // Return the response so the spawned future can continue.
40//!         Ok(response)
41//!     } else {
42//!         // Handle regular HTTP requests here.
43//!         Ok(Response::new(Full::<Bytes>::from("Hello HTTP!")))
44//!     }
45//! }
46//!
47//! /// Handle a websocket connection.
48//! async fn serve_websocket(websocket: HyperWebsocket) -> Result<(), Error> {
49//!     let mut websocket = websocket.await?;
50//!     while let Some(message) = websocket.next().await {
51//!         match message? {
52//!             Message::Text(msg) => {
53//!                 println!("Received text message: {msg}");
54//!                 websocket.send(Message::text("Thank you, come again.")).await?;
55//!             },
56//!             Message::Binary(msg) => {
57//!                 println!("Received binary message: {msg:02X?}");
58//!                 websocket.send(Message::binary(b"Thank you, come again.".to_vec())).await?;
59//!             },
60//!             Message::Ping(msg) => {
61//!                 // No need to send a reply: tungstenite takes care of this for you.
62//!                 println!("Received ping message: {msg:02X?}");
63//!             },
64//!             Message::Pong(msg) => {
65//!                 println!("Received pong message: {msg:02X?}");
66//!             }
67//!             Message::Close(msg) => {
68//!                 // No need to send a reply: tungstenite takes care of this for you.
69//!                 if let Some(msg) = &msg {
70//!                     println!("Received close message with code {} and message: {}", msg.code, msg.reason);
71//!                 } else {
72//!                     println!("Received close message");
73//!                 }
74//!             },
75//!             Message::Frame(_msg) => {
76//!                 unreachable!();
77//!             }
78//!         }
79//!     }
80//!
81//!     Ok(())
82//! }
83//!
84//! #[tokio::main]
85//! async fn main() -> Result<(), Error> {
86//!     let addr: std::net::SocketAddr = "[::1]:3000".parse()?;
87//!     let listener = tokio::net::TcpListener::bind(&addr).await?;
88//!     println!("Listening on http://{addr}");
89//!
90//!     let mut http = hyper::server::conn::http1::Builder::new();
91//!     http.keep_alive(true);
92//!
93//!     loop {
94//!         let (stream, _) = listener.accept().await?;
95//!         let connection = http
96//!             .serve_connection(TokioIo::new(stream), hyper::service::service_fn(handle_request))
97//!             .with_upgrades();
98//!         tokio::spawn(async move {
99//!             if let Err(err) = connection.await {
100//!                 println!("Error serving HTTP connection: {err:?}");
101//!             }
102//!         });
103//!     }
104//! }
105//! ```
106
107use http_body_util::Full;
108use hyper::body::Bytes;
109use hyper::{Request, Response};
110use hyper_util::rt::TokioIo;
111use std::task::{Context, Poll};
112use std::pin::Pin;
113use pin_project_lite::pin_project;
114
115use tungstenite::{Error, error::ProtocolError};
116use tungstenite::handshake::derive_accept_key;
117use tungstenite::protocol::{Role, WebSocketConfig};
118
119pub use hyper;
120pub use tungstenite;
121
122pub use tokio_tungstenite::WebSocketStream;
123
124/// A [`WebSocketStream`] that wraps an upgraded HTTP connection from hyper.
125pub type HyperWebsocketStream = WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>;
126
127pin_project! {
128	/// A future that resolves to a websocket stream when the associated HTTP upgrade completes.
129	#[derive(Debug)]
130	pub struct HyperWebsocket {
131		#[pin]
132		inner: hyper::upgrade::OnUpgrade,
133		config: Option<WebSocketConfig>,
134	}
135}
136
137/// Try to upgrade a received `hyper::Request` to a websocket connection.
138///
139/// The function returns a HTTP response and a future that resolves to the websocket stream.
140/// The response body *MUST* be sent to the client before the future can be resolved.
141///
142/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers.
143/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
144/// You can inspect the headers manually before calling this function,
145/// and modify the response headers appropriately.
146///
147/// This function also does not look at the `Connection` or `Upgrade` headers.
148/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`].
149/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
150///
151pub fn upgrade<B>(
152	mut request: impl std::borrow::BorrowMut<Request<B>>,
153	config: Option<WebSocketConfig>,
154) -> Result<(Response<Full<Bytes>>, HyperWebsocket), ProtocolError> {
155	let request = request.borrow_mut();
156
157	let key = request.headers().get("Sec-WebSocket-Key")
158		.ok_or(ProtocolError::MissingSecWebSocketKey)?;
159	if request.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") {
160		return Err(ProtocolError::MissingSecWebSocketVersionHeader);
161	}
162
163	let response = Response::builder()
164		.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
165		.header(hyper::header::CONNECTION, "upgrade")
166		.header(hyper::header::UPGRADE, "websocket")
167		.header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes()))
168		.body(Full::<Bytes>::from("switching to websocket protocol"))
169		.expect("bug: failed to build response");
170
171	let stream = HyperWebsocket {
172		inner: hyper::upgrade::on(request),
173		config,
174	};
175
176	Ok((response, stream))
177}
178
179/// Check if a request is a websocket upgrade request.
180///
181/// If the `Upgrade` header lists multiple protocols,
182/// this function returns true if of them are `"websocket"`,
183/// If the server supports multiple upgrade protocols,
184/// it would be more appropriate to try each listed protocol in order.
185pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
186	header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade")
187		&& header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket")
188}
189
190/// Check if there is a header of the given name containing the wanted value.
191fn header_contains_value(headers: &hyper::HeaderMap, header: impl hyper::header::AsHeaderName, value: impl AsRef<[u8]>) -> bool {
192	let value = value.as_ref();
193	for header in headers.get_all(header) {
194		if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) {
195			return true;
196		}
197	}
198	false
199}
200
201fn trim(data: &[u8]) -> &[u8] {
202	trim_end(trim_start(data))
203}
204
205fn trim_start(data: &[u8]) -> &[u8] {
206	if let Some(start) =data.iter().position(|x| !x.is_ascii_whitespace()) {
207		&data[start..]
208	} else {
209		b""
210	}
211}
212
213fn trim_end(data: &[u8]) -> &[u8] {
214	if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
215		&data[..last + 1]
216	} else {
217		b""
218	}
219}
220
221impl std::future::Future for HyperWebsocket {
222	type Output = Result<HyperWebsocketStream, Error>;
223
224	fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
225		let this = self.project();
226		let upgraded = match this.inner.poll(cx) {
227			Poll::Pending => return Poll::Pending,
228			Poll::Ready(x) => x,
229		};
230
231		let upgraded = upgraded.map_err(|_| Error::Protocol(ProtocolError::HandshakeIncomplete))?;
232
233		let stream = WebSocketStream::from_raw_socket(
234			TokioIo::new(upgraded),
235			Role::Server,
236			this.config.take(),
237		);
238		tokio::pin!(stream);
239
240		// The future returned by `from_raw_socket` is always ready.
241		// Not sure why it is a future in the first place.
242		match stream.as_mut().poll(cx) {
243			Poll::Pending => unreachable!("from_raw_socket should always be created ready"),
244			Poll::Ready(x) => Poll::Ready(Ok(x)),
245		}
246	}
247}