ferron/optional_modules/
fproxy.rs

1use std::error::Error;
2use std::str::FromStr;
3
4use crate::ferron_common::{
5  ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
6  ServerModuleHandlers, SocketData,
7};
8use crate::ferron_common::{HyperResponse, WithRuntime};
9use async_trait::async_trait;
10use http_body_util::combinators::BoxBody;
11use http_body_util::BodyExt;
12use hyper::body::Bytes;
13use hyper::{header, Request, StatusCode, Uri};
14use hyper_tungstenite::HyperWebsocket;
15use hyper_util::rt::TokioIo;
16use tokio::io::{AsyncRead, AsyncWrite};
17use tokio::net::TcpStream;
18use tokio::runtime::Handle;
19
20pub fn server_module_init(
21  _config: &ServerConfig,
22) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
23  Ok(Box::new(ForwardProxyModule::new()))
24}
25
26struct ForwardProxyModule;
27
28impl ForwardProxyModule {
29  fn new() -> Self {
30    Self
31  }
32}
33
34impl ServerModule for ForwardProxyModule {
35  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
36    Box::new(ForwardProxyModuleHandlers { handle })
37  }
38}
39
40struct ForwardProxyModuleHandlers {
41  handle: Handle,
42}
43
44#[async_trait]
45impl ServerModuleHandlers for ForwardProxyModuleHandlers {
46  async fn request_handler(
47    &mut self,
48    request: RequestData,
49    _config: &ServerConfig,
50    _socket_data: &SocketData,
51    _error_logger: &ErrorLogger,
52  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
53    Ok(ResponseData::builder(request).build())
54  }
55
56  async fn proxy_request_handler(
57    &mut self,
58    request: RequestData,
59    _config: &ServerConfig,
60    _socket_data: &SocketData,
61    error_logger: &ErrorLogger,
62  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
63    WithRuntime::new(self.handle.clone(), async move {
64      // Code taken from reverse proxy module
65      let (hyper_request, _auth_user, _original_url) = request.into_parts();
66      let (mut hyper_request_parts, request_body) = hyper_request.into_parts();
67
68      match hyper_request_parts.uri.scheme_str() {
69        Some("http") | None => (),
70        _ => {
71          return Ok(
72            ResponseData::builder_without_request()
73              .status(StatusCode::BAD_REQUEST)
74              .build(),
75          );
76        }
77      };
78
79      let host = match hyper_request_parts.uri.host() {
80        Some(host) => host,
81        None => {
82          return Ok(
83            ResponseData::builder_without_request()
84              .status(StatusCode::BAD_REQUEST)
85              .build(),
86          );
87        }
88      };
89
90      let port = hyper_request_parts.uri.port_u16().unwrap_or(80);
91
92      let addr = format!("{}:{}", host, port);
93      let stream = match TcpStream::connect(addr).await {
94        Ok(stream) => stream,
95        Err(err) => {
96          match err.kind() {
97            tokio::io::ErrorKind::ConnectionRefused
98            | tokio::io::ErrorKind::NotFound
99            | tokio::io::ErrorKind::HostUnreachable => {
100              error_logger
101                .log(&format!("Service unavailable: {}", err))
102                .await;
103              return Ok(
104                ResponseData::builder_without_request()
105                  .status(StatusCode::SERVICE_UNAVAILABLE)
106                  .build(),
107              );
108            }
109            tokio::io::ErrorKind::TimedOut => {
110              error_logger.log(&format!("Gateway timeout: {}", err)).await;
111              return Ok(
112                ResponseData::builder_without_request()
113                  .status(StatusCode::GATEWAY_TIMEOUT)
114                  .build(),
115              );
116            }
117            _ => {
118              error_logger.log(&format!("Bad gateway: {}", err)).await;
119              return Ok(
120                ResponseData::builder_without_request()
121                  .status(StatusCode::BAD_GATEWAY)
122                  .build(),
123              );
124            }
125          };
126        }
127      };
128
129      match stream.set_nodelay(true) {
130        Ok(_) => (),
131        Err(err) => {
132          error_logger.log(&format!("Bad gateway: {}", err)).await;
133          return Ok(
134            ResponseData::builder_without_request()
135              .status(StatusCode::BAD_GATEWAY)
136              .build(),
137          );
138        }
139      };
140
141      let hyper_request_path = hyper_request_parts.uri.path();
142
143      hyper_request_parts.uri = Uri::from_str(&format!(
144        "{}{}",
145        hyper_request_path,
146        match hyper_request_parts.uri.query() {
147          Some(query) => format!("?{}", query),
148          None => "".to_string(),
149        }
150      ))?;
151
152      // Connection header to disable HTTP/1.1 keep-alive
153      hyper_request_parts
154        .headers
155        .insert(header::CONNECTION, "close".parse()?);
156
157      let proxy_request = Request::from_parts(hyper_request_parts, request_body);
158
159      http_proxy(stream, proxy_request, error_logger).await
160    })
161    .await
162  }
163
164  async fn response_modifying_handler(
165    &mut self,
166    response: HyperResponse,
167  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
168    Ok(response)
169  }
170
171  async fn proxy_response_modifying_handler(
172    &mut self,
173    response: HyperResponse,
174  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
175    Ok(response)
176  }
177
178  async fn connect_proxy_request_handler(
179    &mut self,
180    upgraded_request: HyperUpgraded,
181    connect_address: &str,
182    _config: &ServerConfig,
183    _socket_data: &SocketData,
184    error_logger: &ErrorLogger,
185  ) -> Result<(), Box<dyn Error + Send + Sync>> {
186    WithRuntime::new(self.handle.clone(), async move {
187      let mut stream = match TcpStream::connect(connect_address).await {
188        Ok(stream) => stream,
189        Err(err) => {
190          error_logger
191            .log(&format!("Cannot connect to the remote server: {}", err))
192            .await;
193          return Ok(());
194        }
195      };
196      match stream.set_nodelay(true) {
197        Ok(_) => (),
198        Err(err) => {
199          error_logger
200            .log(&format!(
201              "Cannot disable Nagle algorithm when connecting to the remote server: {}",
202              err
203            ))
204            .await;
205          return Ok(());
206        }
207      };
208
209      let mut upgraded = TokioIo::new(upgraded_request);
210
211      tokio::io::copy_bidirectional(&mut upgraded, &mut stream)
212        .await
213        .unwrap_or_default();
214
215      Ok(())
216    })
217    .await
218  }
219
220  fn does_connect_proxy_requests(&mut self) -> bool {
221    true
222  }
223
224  async fn websocket_request_handler(
225    &mut self,
226    _websocket: HyperWebsocket,
227    _uri: &hyper::Uri,
228    _config: &ServerConfig,
229    _socket_data: &SocketData,
230    _error_logger: &ErrorLogger,
231  ) -> Result<(), Box<dyn Error + Send + Sync>> {
232    Ok(())
233  }
234
235  fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
236    false
237  }
238}
239
240async fn http_proxy(
241  stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
242  proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
243  error_logger: &ErrorLogger,
244) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
245  let io = TokioIo::new(stream);
246
247  let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
248    Ok(data) => data,
249    Err(err) => {
250      error_logger.log(&format!("Bad gateway: {}", err)).await;
251      return Ok(
252        ResponseData::builder_without_request()
253          .status(StatusCode::BAD_GATEWAY)
254          .build(),
255      );
256    }
257  };
258
259  let send_request = sender.send_request(proxy_request);
260
261  let mut pinned_conn = Box::pin(conn);
262  tokio::pin!(send_request);
263
264  let response;
265
266  loop {
267    tokio::select! {
268      biased;
269
270       proxy_response = &mut send_request => {
271        let proxy_response = match proxy_response {
272          Ok(response) => response,
273          Err(err) => {
274            error_logger.log(&format!("Bad gateway: {}", err)).await;
275            return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
276          }
277        };
278
279        response = ResponseData::builder_without_request()
280                  .response(proxy_response.map(|b| {
281                    b.map_err(|e| std::io::Error::other(e.to_string()))
282                      .boxed()
283                  }))
284                  .parallel_fn(async move {
285                    pinned_conn.await.unwrap_or_default();
286                  })
287                  .build();
288
289        break;
290      },
291      state = &mut pinned_conn => {
292        if state.is_err() {
293          error_logger.log("Bad gateway: incomplete response").await;
294          return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
295        }
296      },
297    };
298  }
299
300  Ok(response)
301}