1use std::error::Error;
2use std::str::FromStr;
3
4use crate::ferron_common::{
5 ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
6 ServerModuleHandlers, SocketData,
7};
8use crate::ferron_common::{HyperResponse, WithRuntime};
9use async_trait::async_trait;
10use http_body_util::combinators::BoxBody;
11use http_body_util::BodyExt;
12use hyper::body::Bytes;
13use hyper::{header, Request, StatusCode, Uri};
14use hyper_tungstenite::HyperWebsocket;
15use hyper_util::rt::TokioIo;
16use tokio::io::{AsyncRead, AsyncWrite};
17use tokio::net::TcpStream;
18use tokio::runtime::Handle;
19
20pub fn server_module_init(
21 _config: &ServerConfig,
22) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
23 Ok(Box::new(ForwardProxyModule::new()))
24}
25
26struct ForwardProxyModule;
27
28impl ForwardProxyModule {
29 fn new() -> Self {
30 Self
31 }
32}
33
34impl ServerModule for ForwardProxyModule {
35 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
36 Box::new(ForwardProxyModuleHandlers { handle })
37 }
38}
39
40struct ForwardProxyModuleHandlers {
41 handle: Handle,
42}
43
44#[async_trait]
45impl ServerModuleHandlers for ForwardProxyModuleHandlers {
46 async fn request_handler(
47 &mut self,
48 request: RequestData,
49 _config: &ServerConfig,
50 _socket_data: &SocketData,
51 _error_logger: &ErrorLogger,
52 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
53 Ok(ResponseData::builder(request).build())
54 }
55
56 async fn proxy_request_handler(
57 &mut self,
58 request: RequestData,
59 _config: &ServerConfig,
60 _socket_data: &SocketData,
61 error_logger: &ErrorLogger,
62 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
63 WithRuntime::new(self.handle.clone(), async move {
64 let (hyper_request, _auth_user, _original_url) = request.into_parts();
66 let (mut hyper_request_parts, request_body) = hyper_request.into_parts();
67
68 match hyper_request_parts.uri.scheme_str() {
69 Some("http") | None => (),
70 _ => {
71 return Ok(
72 ResponseData::builder_without_request()
73 .status(StatusCode::BAD_REQUEST)
74 .build(),
75 );
76 }
77 };
78
79 let host = match hyper_request_parts.uri.host() {
80 Some(host) => host,
81 None => {
82 return Ok(
83 ResponseData::builder_without_request()
84 .status(StatusCode::BAD_REQUEST)
85 .build(),
86 );
87 }
88 };
89
90 let port = hyper_request_parts.uri.port_u16().unwrap_or(80);
91
92 let addr = format!("{}:{}", host, port);
93 let stream = match TcpStream::connect(addr).await {
94 Ok(stream) => stream,
95 Err(err) => {
96 match err.kind() {
97 tokio::io::ErrorKind::ConnectionRefused
98 | tokio::io::ErrorKind::NotFound
99 | tokio::io::ErrorKind::HostUnreachable => {
100 error_logger
101 .log(&format!("Service unavailable: {}", err))
102 .await;
103 return Ok(
104 ResponseData::builder_without_request()
105 .status(StatusCode::SERVICE_UNAVAILABLE)
106 .build(),
107 );
108 }
109 tokio::io::ErrorKind::TimedOut => {
110 error_logger.log(&format!("Gateway timeout: {}", err)).await;
111 return Ok(
112 ResponseData::builder_without_request()
113 .status(StatusCode::GATEWAY_TIMEOUT)
114 .build(),
115 );
116 }
117 _ => {
118 error_logger.log(&format!("Bad gateway: {}", err)).await;
119 return Ok(
120 ResponseData::builder_without_request()
121 .status(StatusCode::BAD_GATEWAY)
122 .build(),
123 );
124 }
125 };
126 }
127 };
128
129 match stream.set_nodelay(true) {
130 Ok(_) => (),
131 Err(err) => {
132 error_logger.log(&format!("Bad gateway: {}", err)).await;
133 return Ok(
134 ResponseData::builder_without_request()
135 .status(StatusCode::BAD_GATEWAY)
136 .build(),
137 );
138 }
139 };
140
141 let hyper_request_path = hyper_request_parts.uri.path();
142
143 hyper_request_parts.uri = Uri::from_str(&format!(
144 "{}{}",
145 hyper_request_path,
146 match hyper_request_parts.uri.query() {
147 Some(query) => format!("?{}", query),
148 None => "".to_string(),
149 }
150 ))?;
151
152 hyper_request_parts
154 .headers
155 .insert(header::CONNECTION, "close".parse()?);
156
157 let proxy_request = Request::from_parts(hyper_request_parts, request_body);
158
159 http_proxy(stream, proxy_request, error_logger).await
160 })
161 .await
162 }
163
164 async fn response_modifying_handler(
165 &mut self,
166 response: HyperResponse,
167 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
168 Ok(response)
169 }
170
171 async fn proxy_response_modifying_handler(
172 &mut self,
173 response: HyperResponse,
174 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
175 Ok(response)
176 }
177
178 async fn connect_proxy_request_handler(
179 &mut self,
180 upgraded_request: HyperUpgraded,
181 connect_address: &str,
182 _config: &ServerConfig,
183 _socket_data: &SocketData,
184 error_logger: &ErrorLogger,
185 ) -> Result<(), Box<dyn Error + Send + Sync>> {
186 WithRuntime::new(self.handle.clone(), async move {
187 let mut stream = match TcpStream::connect(connect_address).await {
188 Ok(stream) => stream,
189 Err(err) => {
190 error_logger
191 .log(&format!("Cannot connect to the remote server: {}", err))
192 .await;
193 return Ok(());
194 }
195 };
196 match stream.set_nodelay(true) {
197 Ok(_) => (),
198 Err(err) => {
199 error_logger
200 .log(&format!(
201 "Cannot disable Nagle algorithm when connecting to the remote server: {}",
202 err
203 ))
204 .await;
205 return Ok(());
206 }
207 };
208
209 let mut upgraded = TokioIo::new(upgraded_request);
210
211 tokio::io::copy_bidirectional(&mut upgraded, &mut stream)
212 .await
213 .unwrap_or_default();
214
215 Ok(())
216 })
217 .await
218 }
219
220 fn does_connect_proxy_requests(&mut self) -> bool {
221 true
222 }
223
224 async fn websocket_request_handler(
225 &mut self,
226 _websocket: HyperWebsocket,
227 _uri: &hyper::Uri,
228 _config: &ServerConfig,
229 _socket_data: &SocketData,
230 _error_logger: &ErrorLogger,
231 ) -> Result<(), Box<dyn Error + Send + Sync>> {
232 Ok(())
233 }
234
235 fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
236 false
237 }
238}
239
240async fn http_proxy(
241 stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
242 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
243 error_logger: &ErrorLogger,
244) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
245 let io = TokioIo::new(stream);
246
247 let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
248 Ok(data) => data,
249 Err(err) => {
250 error_logger.log(&format!("Bad gateway: {}", err)).await;
251 return Ok(
252 ResponseData::builder_without_request()
253 .status(StatusCode::BAD_GATEWAY)
254 .build(),
255 );
256 }
257 };
258
259 let send_request = sender.send_request(proxy_request);
260
261 let mut pinned_conn = Box::pin(conn);
262 tokio::pin!(send_request);
263
264 let response;
265
266 loop {
267 tokio::select! {
268 biased;
269
270 proxy_response = &mut send_request => {
271 let proxy_response = match proxy_response {
272 Ok(response) => response,
273 Err(err) => {
274 error_logger.log(&format!("Bad gateway: {}", err)).await;
275 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
276 }
277 };
278
279 response = ResponseData::builder_without_request()
280 .response(proxy_response.map(|b| {
281 b.map_err(|e| std::io::Error::other(e.to_string()))
282 .boxed()
283 }))
284 .parallel_fn(async move {
285 pinned_conn.await.unwrap_or_default();
286 })
287 .build();
288
289 break;
290 },
291 state = &mut pinned_conn => {
292 if state.is_err() {
293 error_logger.log("Bad gateway: incomplete response").await;
294 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
295 }
296 },
297 };
298 }
299
300 Ok(response)
301}