ferron/optional_modules/
rproxy.rs

1use std::collections::HashMap;
2use std::error::Error;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::ferron_common::{
8  ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
9  ServerModuleHandlers, SocketData,
10};
11use crate::ferron_common::{HyperResponse, WithRuntime};
12use async_trait::async_trait;
13use futures_util::{SinkExt, StreamExt};
14use http::header::SEC_WEBSOCKET_PROTOCOL;
15use http::uri::{PathAndQuery, Scheme};
16use http_body_util::combinators::BoxBody;
17use http_body_util::BodyExt;
18use hyper::body::Bytes;
19use hyper::client::conn::http1::SendRequest;
20use hyper::{header, Request, StatusCode, Uri, Version};
21use hyper_tungstenite::HyperWebsocket;
22use hyper_util::rt::TokioIo;
23use rustls::pki_types::ServerName;
24use rustls::RootCertStore;
25use rustls_native_certs::load_native_certs;
26use tokio::io::{AsyncRead, AsyncWrite};
27use tokio::net::TcpStream;
28use tokio::runtime::Handle;
29use tokio::sync::RwLock;
30use tokio_rustls::TlsConnector;
31use tokio_tungstenite::tungstenite::client::IntoClientRequest;
32use tokio_tungstenite::tungstenite::ClientRequestBuilder;
33use tokio_tungstenite::Connector;
34
35use crate::ferron_util::no_server_verifier::NoServerVerifier;
36use crate::ferron_util::ttl_cache::TtlCache;
37
38const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32;
39
40pub fn server_module_init(
41  config: &ServerConfig,
42) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
43  let mut roots: RootCertStore = RootCertStore::empty();
44  let certs_result = load_native_certs();
45  if !certs_result.errors.is_empty() {
46    Err(anyhow::anyhow!(format!(
47      "Couldn't load the native certificate store: {}",
48      certs_result.errors[0]
49    )))?
50  }
51  let certs = certs_result.certs;
52
53  for cert in certs {
54    match roots.add(cert) {
55      Ok(_) => (),
56      Err(err) => Err(anyhow::anyhow!(format!(
57        "Couldn't add a certificate to the certificate store: {}",
58        err
59      )))?,
60    }
61  }
62
63  let mut connections_vec = Vec::new();
64  for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST {
65    connections_vec.push(RwLock::new(HashMap::new()));
66  }
67  Ok(Box::new(ReverseProxyModule::new(
68    Arc::new(roots),
69    Arc::new(connections_vec),
70    Arc::new(RwLock::new(TtlCache::new(Duration::from_millis(
71      config["global"]["loadBalancerHealthCheckWindow"]
72        .as_i64()
73        .unwrap_or(5000) as u64,
74    )))),
75  )))
76}
77
78#[allow(clippy::type_complexity)]
79struct ReverseProxyModule {
80  roots: Arc<RootCertStore>,
81  connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
82  failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
83}
84
85impl ReverseProxyModule {
86  #[allow(clippy::type_complexity)]
87  fn new(
88    roots: Arc<RootCertStore>,
89    connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
90    failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
91  ) -> Self {
92    Self {
93      roots,
94      connections,
95      failed_backends,
96    }
97  }
98}
99
100impl ServerModule for ReverseProxyModule {
101  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
102    Box::new(ReverseProxyModuleHandlers {
103      roots: self.roots.clone(),
104      connections: self.connections.clone(),
105      failed_backends: self.failed_backends.clone(),
106      handle,
107    })
108  }
109}
110
111#[allow(clippy::type_complexity)]
112struct ReverseProxyModuleHandlers {
113  handle: Handle,
114  roots: Arc<RootCertStore>,
115  connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
116  failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
117}
118
119#[async_trait]
120impl ServerModuleHandlers for ReverseProxyModuleHandlers {
121  async fn request_handler(
122    &mut self,
123    request: RequestData,
124    config: &ServerConfig,
125    socket_data: &SocketData,
126    error_logger: &ErrorLogger,
127  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
128    WithRuntime::new(self.handle.clone(), async move {
129      let enable_health_check = config["enableLoadBalancerHealthCheck"]
130        .as_bool()
131        .unwrap_or(false);
132      let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
133        .as_i64()
134        .unwrap_or(3) as u64;
135      let disable_certificate_verification = config["disableProxyCertificateVerification"]
136        .as_bool()
137        .unwrap_or(false);
138      let proxy_intercept_errors = config["proxyInterceptErrors"].as_bool().unwrap_or(false);
139      if let Some(proxy_to) = determine_proxy_to(
140        config,
141        socket_data.encrypted,
142        &self.failed_backends,
143        enable_health_check,
144        health_check_max_fails,
145      )
146      .await
147      {
148        let (hyper_request, _, _, _) = request.into_parts();
149        let (mut hyper_request_parts, request_body) = hyper_request.into_parts();
150
151        let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
152        let scheme_str = proxy_request_url.scheme_str();
153        let mut encrypted = false;
154
155        match scheme_str {
156          Some("http") => {
157            encrypted = false;
158          }
159          Some("https") => {
160            encrypted = true;
161          }
162          _ => Err(anyhow::anyhow!(
163            "Only HTTP and HTTPS reverse proxy URLs are supported."
164          ))?,
165        };
166
167        let host = match proxy_request_url.host() {
168          Some(host) => host,
169          None => Err(anyhow::anyhow!(
170            "The reverse proxy URL doesn't include the host"
171          ))?,
172        };
173
174        let port = proxy_request_url.port_u16().unwrap_or(match scheme_str {
175          Some("http") => 80,
176          Some("https") => 443,
177          _ => 80,
178        });
179
180        let addr = format!("{host}:{port}");
181        let authority = proxy_request_url.authority().cloned();
182
183        let hyper_request_path = hyper_request_parts.uri.path();
184
185        let path = match hyper_request_path.as_bytes().first() {
186          Some(b'/') => {
187            let mut proxy_request_path = proxy_request_url.path();
188            while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
189              proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
190            }
191            format!("{proxy_request_path}{hyper_request_path}")
192          }
193          _ => hyper_request_path.to_string(),
194        };
195
196        hyper_request_parts.uri = Uri::from_str(&format!(
197          "{}{}",
198          path,
199          match hyper_request_parts.uri.query() {
200            Some(query) => format!("?{query}"),
201            None => "".to_string(),
202          }
203        ))?;
204
205        let original_host = hyper_request_parts.headers.get(header::HOST).cloned();
206
207        // Host header for host identification
208        match authority {
209          Some(authority) => {
210            hyper_request_parts
211              .headers
212              .insert(header::HOST, authority.to_string().parse()?);
213          }
214          None => {
215            hyper_request_parts.headers.remove(header::HOST);
216          }
217        }
218
219        // Connection header to enable HTTP/1.1 keep-alive
220        hyper_request_parts
221          .headers
222          .insert(header::CONNECTION, "keep-alive".parse()?);
223
224        // X-Forwarded-* headers to send the client's data to a server that's behind the reverse proxy
225        hyper_request_parts.headers.insert(
226          "x-forwarded-for",
227          socket_data
228            .remote_addr
229            .ip()
230            .to_canonical()
231            .to_string()
232            .parse()?,
233        );
234
235        if socket_data.encrypted {
236          hyper_request_parts
237            .headers
238            .insert("x-forwarded-proto", "https".parse()?);
239        } else {
240          hyper_request_parts
241            .headers
242            .insert("x-forwarded-proto", "http".parse()?);
243        }
244
245        if let Some(original_host) = original_host {
246          hyper_request_parts
247            .headers
248            .insert("x-forwarded-host", original_host);
249        }
250
251        hyper_request_parts.version = Version::HTTP_11;
252
253        let proxy_request = Request::from_parts(hyper_request_parts, request_body);
254
255        let connections = &self.connections[rand::random_range(..self.connections.len())];
256
257        let rwlock_read = connections.read().await;
258        let sender_read_option = rwlock_read.get(&addr);
259
260        if let Some(sender_read) = sender_read_option {
261          if !sender_read.is_closed() {
262            drop(rwlock_read);
263            let mut rwlock_write = connections.write().await;
264            let sender_option = rwlock_write.get_mut(&addr);
265
266            if let Some(sender) = sender_option {
267              if !sender.is_closed() && sender.ready().await.is_ok() {
268                let result = http_proxy_kept_alive(
269                  sender,
270                  proxy_request,
271                  error_logger,
272                  proxy_intercept_errors,
273                )
274                .await;
275                drop(rwlock_write);
276                return result;
277              } else {
278                drop(rwlock_write);
279              }
280            } else {
281              drop(rwlock_write);
282            }
283          } else {
284            drop(rwlock_read);
285          }
286        } else {
287          drop(rwlock_read);
288        }
289
290        let stream = match TcpStream::connect(&addr).await {
291          Ok(stream) => stream,
292          Err(err) => {
293            if enable_health_check {
294              let mut failed_backends_write = self.failed_backends.write().await;
295              let proxy_to = proxy_to.clone();
296              let failed_attempts = failed_backends_write.get(&proxy_to);
297              failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
298            }
299            match err.kind() {
300              tokio::io::ErrorKind::ConnectionRefused
301              | tokio::io::ErrorKind::NotFound
302              | tokio::io::ErrorKind::HostUnreachable => {
303                error_logger
304                  .log(&format!("Service unavailable: {err}"))
305                  .await;
306                return Ok(
307                  ResponseData::builder_without_request()
308                    .status(StatusCode::SERVICE_UNAVAILABLE)
309                    .build(),
310                );
311              }
312              tokio::io::ErrorKind::TimedOut => {
313                error_logger.log(&format!("Gateway timeout: {err}")).await;
314                return Ok(
315                  ResponseData::builder_without_request()
316                    .status(StatusCode::GATEWAY_TIMEOUT)
317                    .build(),
318                );
319              }
320              _ => {
321                error_logger.log(&format!("Bad gateway: {err}")).await;
322                return Ok(
323                  ResponseData::builder_without_request()
324                    .status(StatusCode::BAD_GATEWAY)
325                    .build(),
326                );
327              }
328            };
329          }
330        };
331
332        match stream.set_nodelay(true) {
333          Ok(_) => (),
334          Err(err) => {
335            if enable_health_check {
336              let mut failed_backends_write = self.failed_backends.write().await;
337              let proxy_to = proxy_to.clone();
338              let failed_attempts = failed_backends_write.get(&proxy_to);
339              failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
340            }
341            error_logger.log(&format!("Bad gateway: {err}")).await;
342            return Ok(
343              ResponseData::builder_without_request()
344                .status(StatusCode::BAD_GATEWAY)
345                .build(),
346            );
347          }
348        };
349
350        let failed_backends_option_borrowed = if enable_health_check {
351          Some(&*self.failed_backends)
352        } else {
353          None
354        };
355
356        if !encrypted {
357          http_proxy(
358            connections,
359            addr,
360            stream,
361            proxy_request,
362            error_logger,
363            proxy_to,
364            failed_backends_option_borrowed,
365            proxy_intercept_errors,
366          )
367          .await
368        } else {
369          let tls_client_config = (if disable_certificate_verification {
370            rustls::ClientConfig::builder()
371              .dangerous()
372              .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
373          } else {
374            rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
375          })
376          .with_no_client_auth();
377          let connector = TlsConnector::from(Arc::new(tls_client_config));
378          let domain = ServerName::try_from(host)?.to_owned();
379
380          let tls_stream = match connector.connect(domain, stream).await {
381            Ok(stream) => stream,
382            Err(err) => {
383              if enable_health_check {
384                let mut failed_backends_write = self.failed_backends.write().await;
385                let proxy_to = proxy_to.clone();
386                let failed_attempts = failed_backends_write.get(&proxy_to);
387                failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
388              }
389              error_logger.log(&format!("Bad gateway: {err}")).await;
390              return Ok(
391                ResponseData::builder_without_request()
392                  .status(StatusCode::BAD_GATEWAY)
393                  .build(),
394              );
395            }
396          };
397
398          http_proxy(
399            connections,
400            addr,
401            tls_stream,
402            proxy_request,
403            error_logger,
404            proxy_to,
405            failed_backends_option_borrowed,
406            proxy_intercept_errors,
407          )
408          .await
409        }
410      } else {
411        Ok(ResponseData::builder(request).build())
412      }
413    })
414    .await
415  }
416
417  async fn proxy_request_handler(
418    &mut self,
419    request: RequestData,
420    _config: &ServerConfig,
421    _socket_data: &SocketData,
422    _error_logger: &ErrorLogger,
423  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
424    Ok(ResponseData::builder(request).build())
425  }
426
427  async fn response_modifying_handler(
428    &mut self,
429    response: HyperResponse,
430  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
431    Ok(response)
432  }
433
434  async fn proxy_response_modifying_handler(
435    &mut self,
436    response: HyperResponse,
437  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
438    Ok(response)
439  }
440
441  async fn connect_proxy_request_handler(
442    &mut self,
443    _upgraded_request: HyperUpgraded,
444    _connect_address: &str,
445    _config: &ServerConfig,
446    _socket_data: &SocketData,
447    _error_logger: &ErrorLogger,
448  ) -> Result<(), Box<dyn Error + Send + Sync>> {
449    Ok(())
450  }
451
452  fn does_connect_proxy_requests(&mut self) -> bool {
453    false
454  }
455
456  async fn websocket_request_handler(
457    &mut self,
458    websocket: HyperWebsocket,
459    uri: &hyper::Uri,
460    headers: &hyper::HeaderMap,
461    config: &ServerConfig,
462    socket_data: &SocketData,
463    error_logger: &ErrorLogger,
464  ) -> Result<(), Box<dyn Error + Send + Sync>> {
465    WithRuntime::new(self.handle.clone(), async move {
466      let enable_health_check = config["enableLoadBalancerHealthCheck"]
467        .as_bool()
468        .unwrap_or(false);
469      let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
470        .as_i64()
471        .unwrap_or(3) as u64;
472
473      let disable_certificate_verification = config["disableProxyCertificateVerification"]
474        .as_bool()
475        .unwrap_or(false);
476      if let Some(proxy_to) = determine_proxy_to(
477        config,
478        socket_data.encrypted,
479        &self.failed_backends,
480        enable_health_check,
481        health_check_max_fails,
482      )
483      .await
484      {
485        let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
486        let scheme_str = proxy_request_url.scheme_str();
487        let mut encrypted = false;
488
489        match scheme_str {
490          Some("http") => {
491            encrypted = false;
492          }
493          Some("https") => {
494            encrypted = true;
495          }
496          _ => Err(anyhow::anyhow!(
497            "Only HTTP and HTTPS reverse proxy URLs are supported."
498          ))?,
499        };
500
501        let request_path = uri.path();
502
503        let path = match request_path.as_bytes().first() {
504          Some(b'/') => {
505            let mut proxy_request_path = proxy_request_url.path();
506            while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
507              proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
508            }
509            format!("{proxy_request_path}{request_path}")
510          }
511          _ => request_path.to_string(),
512        };
513
514        let mut proxy_request_url_parts = proxy_request_url.into_parts();
515        proxy_request_url_parts.scheme = if encrypted {
516          Some(Scheme::from_str("wss")?)
517        } else {
518          Some(Scheme::from_str("ws")?)
519        };
520        match uri.path_and_query() {
521          Some(path_and_query) => {
522            let path_and_query_string = match path_and_query.query() {
523              Some(query) => {
524                format!("{path}?{query}")
525              }
526              None => path,
527            };
528            proxy_request_url_parts.path_and_query =
529              Some(PathAndQuery::from_str(&path_and_query_string)?);
530          }
531          None => {
532            proxy_request_url_parts.path_and_query = Some(PathAndQuery::from_str(&path)?);
533          }
534        };
535
536        let proxy_request_url = hyper::Uri::from_parts(proxy_request_url_parts)?;
537
538        let connector = if !encrypted {
539          Connector::Plain
540        } else {
541          Connector::Rustls(Arc::new(
542            (if disable_certificate_verification {
543              rustls::ClientConfig::builder()
544                .dangerous()
545                .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
546            } else {
547              rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
548            })
549            .with_no_client_auth(),
550          ))
551        };
552
553        let mut proxy_request_builder = ClientRequestBuilder::new(proxy_request_url);
554        for (header_name, header_value) in headers {
555          let header_name_str = header_name.as_str();
556          if header_name == SEC_WEBSOCKET_PROTOCOL {
557            for subprotocol in String::from_utf8_lossy(header_value.as_bytes()).split(",") {
558              proxy_request_builder = proxy_request_builder.with_sub_protocol(subprotocol.trim());
559            }
560          } else if !header_name_str.starts_with("sec-websocket-")
561            && header_name_str != "x-forwarded-for"
562          {
563            proxy_request_builder = proxy_request_builder.with_header(
564              header_name_str,
565              String::from_utf8_lossy(header_value.as_bytes()),
566            );
567          }
568        }
569
570        // Add X-Forwarded-For header
571        proxy_request_builder = proxy_request_builder.with_header(
572          "x-forwarded-for",
573          socket_data.remote_addr.ip().to_canonical().to_string(),
574        );
575
576        let proxy_request_constructed = proxy_request_builder.into_client_request()?;
577
578        let client_bi_stream = websocket.await?;
579
580        let (proxy_bi_stream, _) = match tokio_tungstenite::connect_async_tls_with_config(
581          proxy_request_constructed,
582          None,
583          true,
584          Some(connector),
585        )
586        .await
587        {
588          Ok(data) => data,
589          Err(err) => {
590            error_logger
591              .log(&format!("Cannot connect to WebSocket server: {err}"))
592              .await;
593            return Ok(());
594          }
595        };
596
597        let (mut client_sink, mut client_stream) = client_bi_stream.split();
598        let (mut proxy_sink, mut proxy_stream) = proxy_bi_stream.split();
599
600        let client_to_proxy = async {
601          while let Some(Ok(value)) = client_stream.next().await {
602            if proxy_sink.send(value).await.is_err() {
603              break;
604            }
605          }
606        };
607
608        let proxy_to_client = async {
609          while let Some(Ok(value)) = proxy_stream.next().await {
610            if client_sink.send(value).await.is_err() {
611              break;
612            }
613          }
614        };
615
616        tokio::pin!(client_to_proxy);
617        tokio::pin!(proxy_to_client);
618
619        let client_to_proxy_first;
620        tokio::select! {
621          _ = &mut client_to_proxy => {
622            client_to_proxy_first = true;
623          }
624          _ = &mut proxy_to_client => {
625            client_to_proxy_first = false;
626          }
627        }
628
629        if client_to_proxy_first {
630          proxy_to_client.await;
631        } else {
632          client_to_proxy.await;
633        }
634      }
635
636      Ok(())
637    })
638    .await
639  }
640
641  fn does_websocket_requests(&mut self, config: &ServerConfig, socket_data: &SocketData) -> bool {
642    if socket_data.encrypted {
643      let secure_proxy_to = &config["secureProxyTo"];
644      if secure_proxy_to.as_vec().is_some() || secure_proxy_to.as_str().is_some() {
645        return true;
646      }
647    }
648
649    let proxy_to = &config["proxyTo"];
650    proxy_to.as_vec().is_some() || proxy_to.as_str().is_some()
651  }
652}
653
654async fn determine_proxy_to(
655  config: &ServerConfig,
656  encrypted: bool,
657  failed_backends: &RwLock<TtlCache<String, u64>>,
658  enable_health_check: bool,
659  health_check_max_fails: u64,
660) -> Option<String> {
661  let mut proxy_to = None;
662  // When the array is supplied with non-string values, the reverse proxy may have undesirable behavior
663  // The "proxyTo" and "secureProxyTo" are validated though.
664
665  if encrypted {
666    let secure_proxy_to_yaml = &config["secureProxyTo"];
667    if let Some(secure_proxy_to_vector) = secure_proxy_to_yaml.as_vec() {
668      if enable_health_check {
669        let mut secure_proxy_to_vector = secure_proxy_to_vector.clone();
670        loop {
671          if !secure_proxy_to_vector.is_empty() {
672            let index = rand::random_range(..secure_proxy_to_vector.len());
673            if let Some(secure_proxy_to) = secure_proxy_to_vector[index].as_str() {
674              proxy_to = Some(secure_proxy_to.to_string());
675              let failed_backends_read = failed_backends.read().await;
676              let failed_backend_fails =
677                match failed_backends_read.get(&secure_proxy_to.to_string()) {
678                  Some(fails) => fails,
679                  None => break,
680                };
681              if failed_backend_fails > health_check_max_fails {
682                secure_proxy_to_vector.remove(index);
683              } else {
684                break;
685              }
686            }
687          } else {
688            break;
689          }
690        }
691      } else if !secure_proxy_to_vector.is_empty() {
692        if let Some(secure_proxy_to) =
693          secure_proxy_to_vector[rand::random_range(..secure_proxy_to_vector.len())].as_str()
694        {
695          proxy_to = Some(secure_proxy_to.to_string());
696        }
697      }
698    } else if let Some(secure_proxy_to) = secure_proxy_to_yaml.as_str() {
699      proxy_to = Some(secure_proxy_to.to_string());
700    }
701  }
702
703  if proxy_to.is_none() {
704    let proxy_to_yaml = &config["proxyTo"];
705    if let Some(proxy_to_vector) = proxy_to_yaml.as_vec() {
706      if enable_health_check {
707        let mut proxy_to_vector = proxy_to_vector.clone();
708        loop {
709          if !proxy_to_vector.is_empty() {
710            let index = rand::random_range(..proxy_to_vector.len());
711            if let Some(proxy_to_str) = proxy_to_vector[index].as_str() {
712              proxy_to = Some(proxy_to_str.to_string());
713              let failed_backends_read = failed_backends.read().await;
714              let failed_backend_fails = match failed_backends_read.get(&proxy_to_str.to_string()) {
715                Some(fails) => fails,
716                None => break,
717              };
718              if failed_backend_fails > health_check_max_fails {
719                proxy_to_vector.remove(index);
720              } else {
721                break;
722              }
723            }
724          } else {
725            break;
726          }
727        }
728      } else if !proxy_to_vector.is_empty() {
729        if let Some(proxy_to_str) =
730          proxy_to_vector[rand::random_range(..proxy_to_vector.len())].as_str()
731        {
732          proxy_to = Some(proxy_to_str.to_string());
733        }
734      }
735    } else if let Some(proxy_to_str) = proxy_to_yaml.as_str() {
736      proxy_to = Some(proxy_to_str.to_string());
737    }
738  }
739
740  proxy_to
741}
742
743#[allow(clippy::too_many_arguments)]
744async fn http_proxy(
745  connections: &RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>,
746  connect_addr: String,
747  stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
748  proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
749  error_logger: &ErrorLogger,
750  proxy_to: String,
751  failed_backends: Option<&tokio::sync::RwLock<TtlCache<std::string::String, u64>>>,
752  proxy_intercept_errors: bool,
753) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
754  let io = TokioIo::new(stream);
755
756  let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
757    Ok(data) => data,
758    Err(err) => {
759      if let Some(failed_backends) = failed_backends {
760        let mut failed_backends_write = failed_backends.write().await;
761        let failed_attempts = failed_backends_write.get(&proxy_to);
762        failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
763      }
764      error_logger.log(&format!("Bad gateway: {err}")).await;
765      return Ok(
766        ResponseData::builder_without_request()
767          .status(StatusCode::BAD_GATEWAY)
768          .build(),
769      );
770    }
771  };
772
773  let send_request = sender.send_request(proxy_request);
774
775  let mut pinned_conn = Box::pin(conn);
776  tokio::pin!(send_request);
777
778  let response;
779
780  loop {
781    tokio::select! {
782      biased;
783
784      proxy_response = &mut send_request => {
785        let proxy_response = match proxy_response {
786          Ok(response) => response,
787          Err(err) => {
788            error_logger.log(&format!("Bad gateway: {err}")).await;
789            return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
790          }
791        };
792
793        let status_code = proxy_response.status();
794        response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
795          ResponseData::builder_without_request()
796          .status(status_code)
797          .parallel_fn(async move {
798            pinned_conn.await.unwrap_or_default();
799          })
800          .build()
801        } else {
802          ResponseData::builder_without_request()
803          .response(proxy_response.map(|b| {
804            b.map_err(|e| std::io::Error::other(e.to_string()))
805              .boxed()
806          }))
807          .parallel_fn(async move {
808            pinned_conn.await.unwrap_or_default();
809          })
810          .build()
811        };
812
813        break;
814      },
815      state = &mut pinned_conn => {
816        if state.is_err() {
817          error_logger.log("Bad gateway: incomplete response").await;
818          return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
819        }
820      },
821    };
822  }
823
824  if !sender.is_closed() {
825    let mut rwlock_write = connections.write().await;
826    rwlock_write.insert(connect_addr, sender);
827    drop(rwlock_write);
828  }
829
830  Ok(response)
831}
832
833async fn http_proxy_kept_alive(
834  sender: &mut SendRequest<BoxBody<Bytes, std::io::Error>>,
835  proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
836  error_logger: &ErrorLogger,
837  proxy_intercept_errors: bool,
838) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
839  let proxy_response = match sender.send_request(proxy_request).await {
840    Ok(response) => response,
841    Err(err) => {
842      error_logger.log(&format!("Bad gateway: {err}")).await;
843      return Ok(
844        ResponseData::builder_without_request()
845          .status(StatusCode::BAD_GATEWAY)
846          .build(),
847      );
848    }
849  };
850
851  let status_code = proxy_response.status();
852  let response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
853    ResponseData::builder_without_request()
854      .status(status_code)
855      .build()
856  } else {
857    ResponseData::builder_without_request()
858      .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed()))
859      .build()
860  };
861
862  Ok(response)
863}