hyper_tungstenite/lib.rs
1//! This crate allows [`hyper`](https://docs.rs/hyper) servers to accept websocket connections, backed by [`tungstenite`](https://docs.rs/tungstenite).
2//!
3//! The [`upgrade`] function allows you to upgrade a HTTP connection to a websocket connection.
4//! It returns a HTTP response to send to the client, and a future that resolves to a [`WebSocketStream`].
5//! The response must be sent to the client for the future to be resolved.
6//! In practise this means that you must spawn the future in a different task.
7//!
8//! Note that the [`upgrade`] function itself does not check if the request is actually an upgrade request.
9//! For simple cases, you can check this using the [`is_upgrade_request`] function before calling [`upgrade`].
10//! For more complicated cases where the server should support multiple upgrade protocols,
11//! you can manually inspect the `Connection` and `Upgrade` headers.
12//!
13//! # Example
14//! ```no_run
15//! use futures::sink::SinkExt;
16//! use futures::stream::StreamExt;
17//! use http_body_util::Full;
18//! use hyper::body::{Bytes, Incoming};
19//! use hyper::{Request, Response};
20//! use hyper_tungstenite::{tungstenite, HyperWebsocket};
21//! use hyper_util::rt::TokioIo;
22//! use tungstenite::Message;
23//!
24//! type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
25//!
26//! /// Handle a HTTP or WebSocket request.
27//! async fn handle_request(mut request: Request<Incoming>) -> Result<Response<Full<Bytes>>, Error> {
28//! // Check if the request is a websocket upgrade request.
29//! if hyper_tungstenite::is_upgrade_request(&request) {
30//! let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)?;
31//!
32//! // Spawn a task to handle the websocket connection.
33//! tokio::spawn(async move {
34//! if let Err(e) = serve_websocket(websocket).await {
35//! eprintln!("Error in websocket connection: {e}");
36//! }
37//! });
38//!
39//! // Return the response so the spawned future can continue.
40//! Ok(response)
41//! } else {
42//! // Handle regular HTTP requests here.
43//! Ok(Response::new(Full::<Bytes>::from("Hello HTTP!")))
44//! }
45//! }
46//!
47//! /// Handle a websocket connection.
48//! async fn serve_websocket(websocket: HyperWebsocket) -> Result<(), Error> {
49//! let mut websocket = websocket.await?;
50//! while let Some(message) = websocket.next().await {
51//! match message? {
52//! Message::Text(msg) => {
53//! println!("Received text message: {msg}");
54//! websocket.send(Message::text("Thank you, come again.")).await?;
55//! },
56//! Message::Binary(msg) => {
57//! println!("Received binary message: {msg:02X?}");
58//! websocket.send(Message::binary(b"Thank you, come again.".to_vec())).await?;
59//! },
60//! Message::Ping(msg) => {
61//! // No need to send a reply: tungstenite takes care of this for you.
62//! println!("Received ping message: {msg:02X?}");
63//! },
64//! Message::Pong(msg) => {
65//! println!("Received pong message: {msg:02X?}");
66//! }
67//! Message::Close(msg) => {
68//! // No need to send a reply: tungstenite takes care of this for you.
69//! if let Some(msg) = &msg {
70//! println!("Received close message with code {} and message: {}", msg.code, msg.reason);
71//! } else {
72//! println!("Received close message");
73//! }
74//! },
75//! Message::Frame(_msg) => {
76//! unreachable!();
77//! }
78//! }
79//! }
80//!
81//! Ok(())
82//! }
83//!
84//! #[tokio::main]
85//! async fn main() -> Result<(), Error> {
86//! let addr: std::net::SocketAddr = "[::1]:3000".parse()?;
87//! let listener = tokio::net::TcpListener::bind(&addr).await?;
88//! println!("Listening on http://{addr}");
89//!
90//! let mut http = hyper::server::conn::http1::Builder::new();
91//! http.keep_alive(true);
92//!
93//! loop {
94//! let (stream, _) = listener.accept().await?;
95//! let connection = http
96//! .serve_connection(TokioIo::new(stream), hyper::service::service_fn(handle_request))
97//! .with_upgrades();
98//! tokio::spawn(async move {
99//! if let Err(err) = connection.await {
100//! println!("Error serving HTTP connection: {err:?}");
101//! }
102//! });
103//! }
104//! }
105//! ```
106
107use http_body_util::Full;
108use hyper::body::Bytes;
109use hyper::{Request, Response};
110use hyper_util::rt::TokioIo;
111use std::task::{Context, Poll};
112use std::pin::Pin;
113use pin_project_lite::pin_project;
114
115use tungstenite::{Error, error::ProtocolError};
116use tungstenite::handshake::derive_accept_key;
117use tungstenite::protocol::{Role, WebSocketConfig};
118
119pub use hyper;
120pub use tungstenite;
121
122pub use tokio_tungstenite::WebSocketStream;
123
124/// A [`WebSocketStream`] that wraps an upgraded HTTP connection from hyper.
125pub type HyperWebsocketStream = WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>;
126
127pin_project! {
128 /// A future that resolves to a websocket stream when the associated HTTP upgrade completes.
129 #[derive(Debug)]
130 pub struct HyperWebsocket {
131 #[pin]
132 inner: hyper::upgrade::OnUpgrade,
133 config: Option<WebSocketConfig>,
134 }
135}
136
137/// Try to upgrade a received `hyper::Request` to a websocket connection.
138///
139/// The function returns a HTTP response and a future that resolves to the websocket stream.
140/// The response body *MUST* be sent to the client before the future can be resolved.
141///
142/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers.
143/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
144/// You can inspect the headers manually before calling this function,
145/// and modify the response headers appropriately.
146///
147/// This function also does not look at the `Connection` or `Upgrade` headers.
148/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`].
149/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
150///
151pub fn upgrade<B>(
152 mut request: impl std::borrow::BorrowMut<Request<B>>,
153 config: Option<WebSocketConfig>,
154) -> Result<(Response<Full<Bytes>>, HyperWebsocket), ProtocolError> {
155 let request = request.borrow_mut();
156
157 let key = request.headers().get("Sec-WebSocket-Key")
158 .ok_or(ProtocolError::MissingSecWebSocketKey)?;
159 if request.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") {
160 return Err(ProtocolError::MissingSecWebSocketVersionHeader);
161 }
162
163 let response = Response::builder()
164 .status(hyper::StatusCode::SWITCHING_PROTOCOLS)
165 .header(hyper::header::CONNECTION, "upgrade")
166 .header(hyper::header::UPGRADE, "websocket")
167 .header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes()))
168 .body(Full::<Bytes>::from("switching to websocket protocol"))
169 .expect("bug: failed to build response");
170
171 let stream = HyperWebsocket {
172 inner: hyper::upgrade::on(request),
173 config,
174 };
175
176 Ok((response, stream))
177}
178
179/// Check if a request is a websocket upgrade request.
180///
181/// If the `Upgrade` header lists multiple protocols,
182/// this function returns true if of them are `"websocket"`,
183/// If the server supports multiple upgrade protocols,
184/// it would be more appropriate to try each listed protocol in order.
185pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
186 header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade")
187 && header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket")
188}
189
190/// Check if there is a header of the given name containing the wanted value.
191fn header_contains_value(headers: &hyper::HeaderMap, header: impl hyper::header::AsHeaderName, value: impl AsRef<[u8]>) -> bool {
192 let value = value.as_ref();
193 for header in headers.get_all(header) {
194 if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) {
195 return true;
196 }
197 }
198 false
199}
200
201fn trim(data: &[u8]) -> &[u8] {
202 trim_end(trim_start(data))
203}
204
205fn trim_start(data: &[u8]) -> &[u8] {
206 if let Some(start) =data.iter().position(|x| !x.is_ascii_whitespace()) {
207 &data[start..]
208 } else {
209 b""
210 }
211}
212
213fn trim_end(data: &[u8]) -> &[u8] {
214 if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
215 &data[..last + 1]
216 } else {
217 b""
218 }
219}
220
221impl std::future::Future for HyperWebsocket {
222 type Output = Result<HyperWebsocketStream, Error>;
223
224 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
225 let this = self.project();
226 let upgraded = match this.inner.poll(cx) {
227 Poll::Pending => return Poll::Pending,
228 Poll::Ready(x) => x,
229 };
230
231 let upgraded = upgraded.map_err(|_| Error::Protocol(ProtocolError::HandshakeIncomplete))?;
232
233 let stream = WebSocketStream::from_raw_socket(
234 TokioIo::new(upgraded),
235 Role::Server,
236 this.config.take(),
237 );
238 tokio::pin!(stream);
239
240 // The future returned by `from_raw_socket` is always ready.
241 // Not sure why it is a future in the first place.
242 match stream.as_mut().poll(cx) {
243 Poll::Pending => unreachable!("from_raw_socket should always be created ready"),
244 Poll::Ready(x) => Poll::Ready(Ok(x)),
245 }
246 }
247}