1use std::collections::HashMap;
2use std::error::Error;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::ferron_common::{
8 ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
9 ServerModuleHandlers, SocketData,
10};
11use crate::ferron_common::{HyperResponse, WithRuntime};
12use async_trait::async_trait;
13use futures_util::{SinkExt, StreamExt};
14use http::header::SEC_WEBSOCKET_PROTOCOL;
15use http::uri::{PathAndQuery, Scheme};
16use http_body_util::combinators::BoxBody;
17use http_body_util::BodyExt;
18use hyper::body::Bytes;
19use hyper::client::conn::http1::SendRequest;
20use hyper::{header, Request, StatusCode, Uri, Version};
21use hyper_tungstenite::HyperWebsocket;
22use hyper_util::rt::TokioIo;
23use rustls::pki_types::ServerName;
24use rustls::RootCertStore;
25use rustls_native_certs::load_native_certs;
26use tokio::io::{AsyncRead, AsyncWrite};
27use tokio::net::TcpStream;
28use tokio::runtime::Handle;
29use tokio::sync::RwLock;
30use tokio_rustls::TlsConnector;
31use tokio_tungstenite::tungstenite::client::IntoClientRequest;
32use tokio_tungstenite::tungstenite::ClientRequestBuilder;
33use tokio_tungstenite::Connector;
34
35use crate::ferron_util::no_server_verifier::NoServerVerifier;
36use crate::ferron_util::ttl_cache::TtlCache;
37
38const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32;
39
40pub fn server_module_init(
41 config: &ServerConfig,
42) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
43 let mut roots: RootCertStore = RootCertStore::empty();
44 let certs_result = load_native_certs();
45 if !certs_result.errors.is_empty() {
46 Err(anyhow::anyhow!(format!(
47 "Couldn't load the native certificate store: {}",
48 certs_result.errors[0]
49 )))?
50 }
51 let certs = certs_result.certs;
52
53 for cert in certs {
54 match roots.add(cert) {
55 Ok(_) => (),
56 Err(err) => Err(anyhow::anyhow!(format!(
57 "Couldn't add a certificate to the certificate store: {}",
58 err
59 )))?,
60 }
61 }
62
63 let mut connections_vec = Vec::new();
64 for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST {
65 connections_vec.push(RwLock::new(HashMap::new()));
66 }
67 Ok(Box::new(ReverseProxyModule::new(
68 Arc::new(roots),
69 Arc::new(connections_vec),
70 Arc::new(RwLock::new(TtlCache::new(Duration::from_millis(
71 config["global"]["loadBalancerHealthCheckWindow"]
72 .as_i64()
73 .unwrap_or(5000) as u64,
74 )))),
75 )))
76}
77
78#[allow(clippy::type_complexity)]
79struct ReverseProxyModule {
80 roots: Arc<RootCertStore>,
81 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
82 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
83}
84
85impl ReverseProxyModule {
86 #[allow(clippy::type_complexity)]
87 fn new(
88 roots: Arc<RootCertStore>,
89 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
90 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
91 ) -> Self {
92 Self {
93 roots,
94 connections,
95 failed_backends,
96 }
97 }
98}
99
100impl ServerModule for ReverseProxyModule {
101 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
102 Box::new(ReverseProxyModuleHandlers {
103 roots: self.roots.clone(),
104 connections: self.connections.clone(),
105 failed_backends: self.failed_backends.clone(),
106 handle,
107 })
108 }
109}
110
111#[allow(clippy::type_complexity)]
112struct ReverseProxyModuleHandlers {
113 handle: Handle,
114 roots: Arc<RootCertStore>,
115 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
116 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
117}
118
119#[async_trait]
120impl ServerModuleHandlers for ReverseProxyModuleHandlers {
121 async fn request_handler(
122 &mut self,
123 request: RequestData,
124 config: &ServerConfig,
125 socket_data: &SocketData,
126 error_logger: &ErrorLogger,
127 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
128 WithRuntime::new(self.handle.clone(), async move {
129 let enable_health_check = config["enableLoadBalancerHealthCheck"]
130 .as_bool()
131 .unwrap_or(false);
132 let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
133 .as_i64()
134 .unwrap_or(3) as u64;
135 let disable_certificate_verification = config["disableProxyCertificateVerification"]
136 .as_bool()
137 .unwrap_or(false);
138 let proxy_intercept_errors = config["proxyInterceptErrors"].as_bool().unwrap_or(false);
139 if let Some(proxy_to) = determine_proxy_to(
140 config,
141 socket_data.encrypted,
142 &self.failed_backends,
143 enable_health_check,
144 health_check_max_fails,
145 )
146 .await
147 {
148 let (hyper_request, _, _, _) = request.into_parts();
149 let (mut hyper_request_parts, request_body) = hyper_request.into_parts();
150
151 let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
152 let scheme_str = proxy_request_url.scheme_str();
153 let mut encrypted = false;
154
155 match scheme_str {
156 Some("http") => {
157 encrypted = false;
158 }
159 Some("https") => {
160 encrypted = true;
161 }
162 _ => Err(anyhow::anyhow!(
163 "Only HTTP and HTTPS reverse proxy URLs are supported."
164 ))?,
165 };
166
167 let host = match proxy_request_url.host() {
168 Some(host) => host,
169 None => Err(anyhow::anyhow!(
170 "The reverse proxy URL doesn't include the host"
171 ))?,
172 };
173
174 let port = proxy_request_url.port_u16().unwrap_or(match scheme_str {
175 Some("http") => 80,
176 Some("https") => 443,
177 _ => 80,
178 });
179
180 let addr = format!("{host}:{port}");
181 let authority = proxy_request_url.authority().cloned();
182
183 let hyper_request_path = hyper_request_parts.uri.path();
184
185 let path = match hyper_request_path.as_bytes().first() {
186 Some(b'/') => {
187 let mut proxy_request_path = proxy_request_url.path();
188 while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
189 proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
190 }
191 format!("{proxy_request_path}{hyper_request_path}")
192 }
193 _ => hyper_request_path.to_string(),
194 };
195
196 hyper_request_parts.uri = Uri::from_str(&format!(
197 "{}{}",
198 path,
199 match hyper_request_parts.uri.query() {
200 Some(query) => format!("?{query}"),
201 None => "".to_string(),
202 }
203 ))?;
204
205 let original_host = hyper_request_parts.headers.get(header::HOST).cloned();
206
207 match authority {
209 Some(authority) => {
210 hyper_request_parts
211 .headers
212 .insert(header::HOST, authority.to_string().parse()?);
213 }
214 None => {
215 hyper_request_parts.headers.remove(header::HOST);
216 }
217 }
218
219 hyper_request_parts
221 .headers
222 .insert(header::CONNECTION, "keep-alive".parse()?);
223
224 hyper_request_parts.headers.insert(
226 "x-forwarded-for",
227 socket_data
228 .remote_addr
229 .ip()
230 .to_canonical()
231 .to_string()
232 .parse()?,
233 );
234
235 if socket_data.encrypted {
236 hyper_request_parts
237 .headers
238 .insert("x-forwarded-proto", "https".parse()?);
239 } else {
240 hyper_request_parts
241 .headers
242 .insert("x-forwarded-proto", "http".parse()?);
243 }
244
245 if let Some(original_host) = original_host {
246 hyper_request_parts
247 .headers
248 .insert("x-forwarded-host", original_host);
249 }
250
251 hyper_request_parts.version = Version::HTTP_11;
252
253 let proxy_request = Request::from_parts(hyper_request_parts, request_body);
254
255 let connections = &self.connections[rand::random_range(..self.connections.len())];
256
257 let rwlock_read = connections.read().await;
258 let sender_read_option = rwlock_read.get(&addr);
259
260 if let Some(sender_read) = sender_read_option {
261 if !sender_read.is_closed() {
262 drop(rwlock_read);
263 let mut rwlock_write = connections.write().await;
264 let sender_option = rwlock_write.get_mut(&addr);
265
266 if let Some(sender) = sender_option {
267 if !sender.is_closed() && sender.ready().await.is_ok() {
268 let result = http_proxy_kept_alive(
269 sender,
270 proxy_request,
271 error_logger,
272 proxy_intercept_errors,
273 )
274 .await;
275 drop(rwlock_write);
276 return result;
277 } else {
278 drop(rwlock_write);
279 }
280 } else {
281 drop(rwlock_write);
282 }
283 } else {
284 drop(rwlock_read);
285 }
286 } else {
287 drop(rwlock_read);
288 }
289
290 let stream = match TcpStream::connect(&addr).await {
291 Ok(stream) => stream,
292 Err(err) => {
293 if enable_health_check {
294 let mut failed_backends_write = self.failed_backends.write().await;
295 let proxy_to = proxy_to.clone();
296 let failed_attempts = failed_backends_write.get(&proxy_to);
297 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
298 }
299 match err.kind() {
300 tokio::io::ErrorKind::ConnectionRefused
301 | tokio::io::ErrorKind::NotFound
302 | tokio::io::ErrorKind::HostUnreachable => {
303 error_logger
304 .log(&format!("Service unavailable: {err}"))
305 .await;
306 return Ok(
307 ResponseData::builder_without_request()
308 .status(StatusCode::SERVICE_UNAVAILABLE)
309 .build(),
310 );
311 }
312 tokio::io::ErrorKind::TimedOut => {
313 error_logger.log(&format!("Gateway timeout: {err}")).await;
314 return Ok(
315 ResponseData::builder_without_request()
316 .status(StatusCode::GATEWAY_TIMEOUT)
317 .build(),
318 );
319 }
320 _ => {
321 error_logger.log(&format!("Bad gateway: {err}")).await;
322 return Ok(
323 ResponseData::builder_without_request()
324 .status(StatusCode::BAD_GATEWAY)
325 .build(),
326 );
327 }
328 };
329 }
330 };
331
332 match stream.set_nodelay(true) {
333 Ok(_) => (),
334 Err(err) => {
335 if enable_health_check {
336 let mut failed_backends_write = self.failed_backends.write().await;
337 let proxy_to = proxy_to.clone();
338 let failed_attempts = failed_backends_write.get(&proxy_to);
339 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
340 }
341 error_logger.log(&format!("Bad gateway: {err}")).await;
342 return Ok(
343 ResponseData::builder_without_request()
344 .status(StatusCode::BAD_GATEWAY)
345 .build(),
346 );
347 }
348 };
349
350 let failed_backends_option_borrowed = if enable_health_check {
351 Some(&*self.failed_backends)
352 } else {
353 None
354 };
355
356 if !encrypted {
357 http_proxy(
358 connections,
359 addr,
360 stream,
361 proxy_request,
362 error_logger,
363 proxy_to,
364 failed_backends_option_borrowed,
365 proxy_intercept_errors,
366 )
367 .await
368 } else {
369 let tls_client_config = (if disable_certificate_verification {
370 rustls::ClientConfig::builder()
371 .dangerous()
372 .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
373 } else {
374 rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
375 })
376 .with_no_client_auth();
377 let connector = TlsConnector::from(Arc::new(tls_client_config));
378 let domain = ServerName::try_from(host)?.to_owned();
379
380 let tls_stream = match connector.connect(domain, stream).await {
381 Ok(stream) => stream,
382 Err(err) => {
383 if enable_health_check {
384 let mut failed_backends_write = self.failed_backends.write().await;
385 let proxy_to = proxy_to.clone();
386 let failed_attempts = failed_backends_write.get(&proxy_to);
387 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
388 }
389 error_logger.log(&format!("Bad gateway: {err}")).await;
390 return Ok(
391 ResponseData::builder_without_request()
392 .status(StatusCode::BAD_GATEWAY)
393 .build(),
394 );
395 }
396 };
397
398 http_proxy(
399 connections,
400 addr,
401 tls_stream,
402 proxy_request,
403 error_logger,
404 proxy_to,
405 failed_backends_option_borrowed,
406 proxy_intercept_errors,
407 )
408 .await
409 }
410 } else {
411 Ok(ResponseData::builder(request).build())
412 }
413 })
414 .await
415 }
416
417 async fn proxy_request_handler(
418 &mut self,
419 request: RequestData,
420 _config: &ServerConfig,
421 _socket_data: &SocketData,
422 _error_logger: &ErrorLogger,
423 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
424 Ok(ResponseData::builder(request).build())
425 }
426
427 async fn response_modifying_handler(
428 &mut self,
429 response: HyperResponse,
430 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
431 Ok(response)
432 }
433
434 async fn proxy_response_modifying_handler(
435 &mut self,
436 response: HyperResponse,
437 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
438 Ok(response)
439 }
440
441 async fn connect_proxy_request_handler(
442 &mut self,
443 _upgraded_request: HyperUpgraded,
444 _connect_address: &str,
445 _config: &ServerConfig,
446 _socket_data: &SocketData,
447 _error_logger: &ErrorLogger,
448 ) -> Result<(), Box<dyn Error + Send + Sync>> {
449 Ok(())
450 }
451
452 fn does_connect_proxy_requests(&mut self) -> bool {
453 false
454 }
455
456 async fn websocket_request_handler(
457 &mut self,
458 websocket: HyperWebsocket,
459 uri: &hyper::Uri,
460 headers: &hyper::HeaderMap,
461 config: &ServerConfig,
462 socket_data: &SocketData,
463 error_logger: &ErrorLogger,
464 ) -> Result<(), Box<dyn Error + Send + Sync>> {
465 WithRuntime::new(self.handle.clone(), async move {
466 let enable_health_check = config["enableLoadBalancerHealthCheck"]
467 .as_bool()
468 .unwrap_or(false);
469 let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
470 .as_i64()
471 .unwrap_or(3) as u64;
472
473 let disable_certificate_verification = config["disableProxyCertificateVerification"]
474 .as_bool()
475 .unwrap_or(false);
476 if let Some(proxy_to) = determine_proxy_to(
477 config,
478 socket_data.encrypted,
479 &self.failed_backends,
480 enable_health_check,
481 health_check_max_fails,
482 )
483 .await
484 {
485 let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
486 let scheme_str = proxy_request_url.scheme_str();
487 let mut encrypted = false;
488
489 match scheme_str {
490 Some("http") => {
491 encrypted = false;
492 }
493 Some("https") => {
494 encrypted = true;
495 }
496 _ => Err(anyhow::anyhow!(
497 "Only HTTP and HTTPS reverse proxy URLs are supported."
498 ))?,
499 };
500
501 let request_path = uri.path();
502
503 let path = match request_path.as_bytes().first() {
504 Some(b'/') => {
505 let mut proxy_request_path = proxy_request_url.path();
506 while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
507 proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
508 }
509 format!("{proxy_request_path}{request_path}")
510 }
511 _ => request_path.to_string(),
512 };
513
514 let mut proxy_request_url_parts = proxy_request_url.into_parts();
515 proxy_request_url_parts.scheme = if encrypted {
516 Some(Scheme::from_str("wss")?)
517 } else {
518 Some(Scheme::from_str("ws")?)
519 };
520 match uri.path_and_query() {
521 Some(path_and_query) => {
522 let path_and_query_string = match path_and_query.query() {
523 Some(query) => {
524 format!("{path}?{query}")
525 }
526 None => path,
527 };
528 proxy_request_url_parts.path_and_query =
529 Some(PathAndQuery::from_str(&path_and_query_string)?);
530 }
531 None => {
532 proxy_request_url_parts.path_and_query = Some(PathAndQuery::from_str(&path)?);
533 }
534 };
535
536 let proxy_request_url = hyper::Uri::from_parts(proxy_request_url_parts)?;
537
538 let connector = if !encrypted {
539 Connector::Plain
540 } else {
541 Connector::Rustls(Arc::new(
542 (if disable_certificate_verification {
543 rustls::ClientConfig::builder()
544 .dangerous()
545 .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
546 } else {
547 rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
548 })
549 .with_no_client_auth(),
550 ))
551 };
552
553 let mut proxy_request_builder = ClientRequestBuilder::new(proxy_request_url);
554 for (header_name, header_value) in headers {
555 let header_name_str = header_name.as_str();
556 if header_name == SEC_WEBSOCKET_PROTOCOL {
557 for subprotocol in String::from_utf8_lossy(header_value.as_bytes()).split(",") {
558 proxy_request_builder = proxy_request_builder.with_sub_protocol(subprotocol.trim());
559 }
560 } else if !header_name_str.starts_with("sec-websocket-")
561 && header_name_str != "x-forwarded-for"
562 {
563 proxy_request_builder = proxy_request_builder.with_header(
564 header_name_str,
565 String::from_utf8_lossy(header_value.as_bytes()),
566 );
567 }
568 }
569
570 proxy_request_builder = proxy_request_builder.with_header(
572 "x-forwarded-for",
573 socket_data.remote_addr.ip().to_canonical().to_string(),
574 );
575
576 let proxy_request_constructed = proxy_request_builder.into_client_request()?;
577
578 let client_bi_stream = websocket.await?;
579
580 let (proxy_bi_stream, _) = match tokio_tungstenite::connect_async_tls_with_config(
581 proxy_request_constructed,
582 None,
583 true,
584 Some(connector),
585 )
586 .await
587 {
588 Ok(data) => data,
589 Err(err) => {
590 error_logger
591 .log(&format!("Cannot connect to WebSocket server: {err}"))
592 .await;
593 return Ok(());
594 }
595 };
596
597 let (mut client_sink, mut client_stream) = client_bi_stream.split();
598 let (mut proxy_sink, mut proxy_stream) = proxy_bi_stream.split();
599
600 let client_to_proxy = async {
601 while let Some(Ok(value)) = client_stream.next().await {
602 if proxy_sink.send(value).await.is_err() {
603 break;
604 }
605 }
606 };
607
608 let proxy_to_client = async {
609 while let Some(Ok(value)) = proxy_stream.next().await {
610 if client_sink.send(value).await.is_err() {
611 break;
612 }
613 }
614 };
615
616 tokio::pin!(client_to_proxy);
617 tokio::pin!(proxy_to_client);
618
619 let client_to_proxy_first;
620 tokio::select! {
621 _ = &mut client_to_proxy => {
622 client_to_proxy_first = true;
623 }
624 _ = &mut proxy_to_client => {
625 client_to_proxy_first = false;
626 }
627 }
628
629 if client_to_proxy_first {
630 proxy_to_client.await;
631 } else {
632 client_to_proxy.await;
633 }
634 }
635
636 Ok(())
637 })
638 .await
639 }
640
641 fn does_websocket_requests(&mut self, config: &ServerConfig, socket_data: &SocketData) -> bool {
642 if socket_data.encrypted {
643 let secure_proxy_to = &config["secureProxyTo"];
644 if secure_proxy_to.as_vec().is_some() || secure_proxy_to.as_str().is_some() {
645 return true;
646 }
647 }
648
649 let proxy_to = &config["proxyTo"];
650 proxy_to.as_vec().is_some() || proxy_to.as_str().is_some()
651 }
652}
653
654async fn determine_proxy_to(
655 config: &ServerConfig,
656 encrypted: bool,
657 failed_backends: &RwLock<TtlCache<String, u64>>,
658 enable_health_check: bool,
659 health_check_max_fails: u64,
660) -> Option<String> {
661 let mut proxy_to = None;
662 if encrypted {
666 let secure_proxy_to_yaml = &config["secureProxyTo"];
667 if let Some(secure_proxy_to_vector) = secure_proxy_to_yaml.as_vec() {
668 if enable_health_check {
669 let mut secure_proxy_to_vector = secure_proxy_to_vector.clone();
670 loop {
671 if !secure_proxy_to_vector.is_empty() {
672 let index = rand::random_range(..secure_proxy_to_vector.len());
673 if let Some(secure_proxy_to) = secure_proxy_to_vector[index].as_str() {
674 proxy_to = Some(secure_proxy_to.to_string());
675 let failed_backends_read = failed_backends.read().await;
676 let failed_backend_fails =
677 match failed_backends_read.get(&secure_proxy_to.to_string()) {
678 Some(fails) => fails,
679 None => break,
680 };
681 if failed_backend_fails > health_check_max_fails {
682 secure_proxy_to_vector.remove(index);
683 } else {
684 break;
685 }
686 }
687 } else {
688 break;
689 }
690 }
691 } else if !secure_proxy_to_vector.is_empty() {
692 if let Some(secure_proxy_to) =
693 secure_proxy_to_vector[rand::random_range(..secure_proxy_to_vector.len())].as_str()
694 {
695 proxy_to = Some(secure_proxy_to.to_string());
696 }
697 }
698 } else if let Some(secure_proxy_to) = secure_proxy_to_yaml.as_str() {
699 proxy_to = Some(secure_proxy_to.to_string());
700 }
701 }
702
703 if proxy_to.is_none() {
704 let proxy_to_yaml = &config["proxyTo"];
705 if let Some(proxy_to_vector) = proxy_to_yaml.as_vec() {
706 if enable_health_check {
707 let mut proxy_to_vector = proxy_to_vector.clone();
708 loop {
709 if !proxy_to_vector.is_empty() {
710 let index = rand::random_range(..proxy_to_vector.len());
711 if let Some(proxy_to_str) = proxy_to_vector[index].as_str() {
712 proxy_to = Some(proxy_to_str.to_string());
713 let failed_backends_read = failed_backends.read().await;
714 let failed_backend_fails = match failed_backends_read.get(&proxy_to_str.to_string()) {
715 Some(fails) => fails,
716 None => break,
717 };
718 if failed_backend_fails > health_check_max_fails {
719 proxy_to_vector.remove(index);
720 } else {
721 break;
722 }
723 }
724 } else {
725 break;
726 }
727 }
728 } else if !proxy_to_vector.is_empty() {
729 if let Some(proxy_to_str) =
730 proxy_to_vector[rand::random_range(..proxy_to_vector.len())].as_str()
731 {
732 proxy_to = Some(proxy_to_str.to_string());
733 }
734 }
735 } else if let Some(proxy_to_str) = proxy_to_yaml.as_str() {
736 proxy_to = Some(proxy_to_str.to_string());
737 }
738 }
739
740 proxy_to
741}
742
743#[allow(clippy::too_many_arguments)]
744async fn http_proxy(
745 connections: &RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>,
746 connect_addr: String,
747 stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
748 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
749 error_logger: &ErrorLogger,
750 proxy_to: String,
751 failed_backends: Option<&tokio::sync::RwLock<TtlCache<std::string::String, u64>>>,
752 proxy_intercept_errors: bool,
753) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
754 let io = TokioIo::new(stream);
755
756 let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
757 Ok(data) => data,
758 Err(err) => {
759 if let Some(failed_backends) = failed_backends {
760 let mut failed_backends_write = failed_backends.write().await;
761 let failed_attempts = failed_backends_write.get(&proxy_to);
762 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
763 }
764 error_logger.log(&format!("Bad gateway: {err}")).await;
765 return Ok(
766 ResponseData::builder_without_request()
767 .status(StatusCode::BAD_GATEWAY)
768 .build(),
769 );
770 }
771 };
772
773 let send_request = sender.send_request(proxy_request);
774
775 let mut pinned_conn = Box::pin(conn);
776 tokio::pin!(send_request);
777
778 let response;
779
780 loop {
781 tokio::select! {
782 biased;
783
784 proxy_response = &mut send_request => {
785 let proxy_response = match proxy_response {
786 Ok(response) => response,
787 Err(err) => {
788 error_logger.log(&format!("Bad gateway: {err}")).await;
789 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
790 }
791 };
792
793 let status_code = proxy_response.status();
794 response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
795 ResponseData::builder_without_request()
796 .status(status_code)
797 .parallel_fn(async move {
798 pinned_conn.await.unwrap_or_default();
799 })
800 .build()
801 } else {
802 ResponseData::builder_without_request()
803 .response(proxy_response.map(|b| {
804 b.map_err(|e| std::io::Error::other(e.to_string()))
805 .boxed()
806 }))
807 .parallel_fn(async move {
808 pinned_conn.await.unwrap_or_default();
809 })
810 .build()
811 };
812
813 break;
814 },
815 state = &mut pinned_conn => {
816 if state.is_err() {
817 error_logger.log("Bad gateway: incomplete response").await;
818 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
819 }
820 },
821 };
822 }
823
824 if !sender.is_closed() {
825 let mut rwlock_write = connections.write().await;
826 rwlock_write.insert(connect_addr, sender);
827 drop(rwlock_write);
828 }
829
830 Ok(response)
831}
832
833async fn http_proxy_kept_alive(
834 sender: &mut SendRequest<BoxBody<Bytes, std::io::Error>>,
835 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
836 error_logger: &ErrorLogger,
837 proxy_intercept_errors: bool,
838) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
839 let proxy_response = match sender.send_request(proxy_request).await {
840 Ok(response) => response,
841 Err(err) => {
842 error_logger.log(&format!("Bad gateway: {err}")).await;
843 return Ok(
844 ResponseData::builder_without_request()
845 .status(StatusCode::BAD_GATEWAY)
846 .build(),
847 );
848 }
849 };
850
851 let status_code = proxy_response.status();
852 let response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
853 ResponseData::builder_without_request()
854 .status(status_code)
855 .build()
856 } else {
857 ResponseData::builder_without_request()
858 .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed()))
859 .build()
860 };
861
862 Ok(response)
863}