1use std::collections::HashMap;
2use std::error::Error;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::ferron_common::{
8 ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
9 ServerModuleHandlers, SocketData,
10};
11use crate::ferron_common::{HyperResponse, WithRuntime};
12use async_trait::async_trait;
13use futures_util::{SinkExt, StreamExt};
14use http::header::SEC_WEBSOCKET_PROTOCOL;
15use http::uri::{PathAndQuery, Scheme};
16use http_body_util::combinators::BoxBody;
17use http_body_util::BodyExt;
18use hyper::body::Bytes;
19use hyper::client::conn::http1::SendRequest;
20use hyper::{header, Request, StatusCode, Uri};
21use hyper_tungstenite::HyperWebsocket;
22use hyper_util::rt::TokioIo;
23use rustls::pki_types::ServerName;
24use rustls::RootCertStore;
25use rustls_native_certs::load_native_certs;
26use tokio::io::{AsyncRead, AsyncWrite};
27use tokio::net::TcpStream;
28use tokio::runtime::Handle;
29use tokio::sync::RwLock;
30use tokio_rustls::TlsConnector;
31use tokio_tungstenite::tungstenite::client::IntoClientRequest;
32use tokio_tungstenite::tungstenite::ClientRequestBuilder;
33use tokio_tungstenite::Connector;
34
35use crate::ferron_util::no_server_verifier::NoServerVerifier;
36use crate::ferron_util::ttl_cache::TtlCache;
37
38const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32;
39
40pub fn server_module_init(
41 config: &ServerConfig,
42) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
43 let mut roots: RootCertStore = RootCertStore::empty();
44 let certs_result = load_native_certs();
45 if !certs_result.errors.is_empty() {
46 Err(anyhow::anyhow!(format!(
47 "Couldn't load the native certificate store: {}",
48 certs_result.errors[0]
49 )))?
50 }
51 let certs = certs_result.certs;
52
53 for cert in certs {
54 match roots.add(cert) {
55 Ok(_) => (),
56 Err(err) => Err(anyhow::anyhow!(format!(
57 "Couldn't add a certificate to the certificate store: {}",
58 err
59 )))?,
60 }
61 }
62
63 let mut connections_vec = Vec::new();
64 for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST {
65 connections_vec.push(RwLock::new(HashMap::new()));
66 }
67 Ok(Box::new(ReverseProxyModule::new(
68 Arc::new(roots),
69 Arc::new(connections_vec),
70 Arc::new(RwLock::new(TtlCache::new(Duration::from_millis(
71 config["global"]["loadBalancerHealthCheckWindow"]
72 .as_i64()
73 .unwrap_or(5000) as u64,
74 )))),
75 )))
76}
77
78#[allow(clippy::type_complexity)]
79struct ReverseProxyModule {
80 roots: Arc<RootCertStore>,
81 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
82 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
83}
84
85impl ReverseProxyModule {
86 #[allow(clippy::type_complexity)]
87 fn new(
88 roots: Arc<RootCertStore>,
89 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
90 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
91 ) -> Self {
92 Self {
93 roots,
94 connections,
95 failed_backends,
96 }
97 }
98}
99
100impl ServerModule for ReverseProxyModule {
101 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
102 Box::new(ReverseProxyModuleHandlers {
103 roots: self.roots.clone(),
104 connections: self.connections.clone(),
105 failed_backends: self.failed_backends.clone(),
106 handle,
107 })
108 }
109}
110
111#[allow(clippy::type_complexity)]
112struct ReverseProxyModuleHandlers {
113 handle: Handle,
114 roots: Arc<RootCertStore>,
115 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
116 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
117}
118
119#[async_trait]
120impl ServerModuleHandlers for ReverseProxyModuleHandlers {
121 async fn request_handler(
122 &mut self,
123 request: RequestData,
124 config: &ServerConfig,
125 socket_data: &SocketData,
126 error_logger: &ErrorLogger,
127 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
128 WithRuntime::new(self.handle.clone(), async move {
129 let enable_health_check = config["enableLoadBalancerHealthCheck"]
130 .as_bool()
131 .unwrap_or(false);
132 let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
133 .as_i64()
134 .unwrap_or(3) as u64;
135 let disable_certificate_verification = config["disableProxyCertificateVerification"]
136 .as_bool()
137 .unwrap_or(false);
138 let proxy_intercept_errors = config["proxyInterceptErrors"].as_bool().unwrap_or(false);
139 if let Some(proxy_to) = determine_proxy_to(
140 config,
141 socket_data.encrypted,
142 &self.failed_backends,
143 enable_health_check,
144 health_check_max_fails,
145 )
146 .await
147 {
148 let (hyper_request, _, _, _) = request.into_parts();
149 let (mut hyper_request_parts, request_body) = hyper_request.into_parts();
150
151 let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
152 let scheme_str = proxy_request_url.scheme_str();
153 let mut encrypted = false;
154
155 match scheme_str {
156 Some("http") => {
157 encrypted = false;
158 }
159 Some("https") => {
160 encrypted = true;
161 }
162 _ => Err(anyhow::anyhow!(
163 "Only HTTP and HTTPS reverse proxy URLs are supported."
164 ))?,
165 };
166
167 let host = match proxy_request_url.host() {
168 Some(host) => host,
169 None => Err(anyhow::anyhow!(
170 "The reverse proxy URL doesn't include the host"
171 ))?,
172 };
173
174 let port = proxy_request_url.port_u16().unwrap_or(match scheme_str {
175 Some("http") => 80,
176 Some("https") => 443,
177 _ => 80,
178 });
179
180 let addr = format!("{}:{}", host, port);
181 let authority = proxy_request_url.authority().cloned();
182
183 let hyper_request_path = hyper_request_parts.uri.path();
184
185 let path = match hyper_request_path.as_bytes().first() {
186 Some(b'/') => {
187 let mut proxy_request_path = proxy_request_url.path();
188 while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
189 proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
190 }
191 format!("{}{}", proxy_request_path, hyper_request_path)
192 }
193 _ => hyper_request_path.to_string(),
194 };
195
196 hyper_request_parts.uri = Uri::from_str(&format!(
197 "{}{}",
198 path,
199 match hyper_request_parts.uri.query() {
200 Some(query) => format!("?{}", query),
201 None => "".to_string(),
202 }
203 ))?;
204
205 let original_host = hyper_request_parts.headers.get(header::HOST).cloned();
206
207 match authority {
209 Some(authority) => {
210 hyper_request_parts
211 .headers
212 .insert(header::HOST, authority.to_string().parse()?);
213 }
214 None => {
215 hyper_request_parts.headers.remove(header::HOST);
216 }
217 }
218
219 hyper_request_parts
221 .headers
222 .insert(header::CONNECTION, "keep-alive".parse()?);
223
224 hyper_request_parts.headers.insert(
226 "x-forwarded-for",
227 socket_data
228 .remote_addr
229 .ip()
230 .to_canonical()
231 .to_string()
232 .parse()?,
233 );
234
235 if socket_data.encrypted {
236 hyper_request_parts
237 .headers
238 .insert("x-forwarded-proto", "https".parse()?);
239 } else {
240 hyper_request_parts
241 .headers
242 .insert("x-forwarded-proto", "http".parse()?);
243 }
244
245 if let Some(original_host) = original_host {
246 hyper_request_parts
247 .headers
248 .insert("x-forwarded-host", original_host);
249 }
250
251 let proxy_request = Request::from_parts(hyper_request_parts, request_body);
252
253 let connections = &self.connections[rand::random_range(..self.connections.len())];
254
255 let rwlock_read = connections.read().await;
256 let sender_read_option = rwlock_read.get(&addr);
257
258 if let Some(sender_read) = sender_read_option {
259 if !sender_read.is_closed() {
260 drop(rwlock_read);
261 let mut rwlock_write = connections.write().await;
262 let sender_option = rwlock_write.get_mut(&addr);
263
264 if let Some(sender) = sender_option {
265 if !sender.is_closed() {
266 let result = http_proxy_kept_alive(
267 sender,
268 proxy_request,
269 error_logger,
270 proxy_intercept_errors,
271 )
272 .await;
273 drop(rwlock_write);
274 return result;
275 } else {
276 drop(rwlock_write);
277 }
278 } else {
279 drop(rwlock_write);
280 }
281 } else {
282 drop(rwlock_read);
283 }
284 } else {
285 drop(rwlock_read);
286 }
287
288 let stream = match TcpStream::connect(&addr).await {
289 Ok(stream) => stream,
290 Err(err) => {
291 if enable_health_check {
292 let mut failed_backends_write = self.failed_backends.write().await;
293 let proxy_to = proxy_to.clone();
294 let failed_attempts = failed_backends_write.get(&proxy_to);
295 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
296 }
297 match err.kind() {
298 tokio::io::ErrorKind::ConnectionRefused
299 | tokio::io::ErrorKind::NotFound
300 | tokio::io::ErrorKind::HostUnreachable => {
301 error_logger
302 .log(&format!("Service unavailable: {}", err))
303 .await;
304 return Ok(
305 ResponseData::builder_without_request()
306 .status(StatusCode::SERVICE_UNAVAILABLE)
307 .build(),
308 );
309 }
310 tokio::io::ErrorKind::TimedOut => {
311 error_logger.log(&format!("Gateway timeout: {}", err)).await;
312 return Ok(
313 ResponseData::builder_without_request()
314 .status(StatusCode::GATEWAY_TIMEOUT)
315 .build(),
316 );
317 }
318 _ => {
319 error_logger.log(&format!("Bad gateway: {}", err)).await;
320 return Ok(
321 ResponseData::builder_without_request()
322 .status(StatusCode::BAD_GATEWAY)
323 .build(),
324 );
325 }
326 };
327 }
328 };
329
330 match stream.set_nodelay(true) {
331 Ok(_) => (),
332 Err(err) => {
333 if enable_health_check {
334 let mut failed_backends_write = self.failed_backends.write().await;
335 let proxy_to = proxy_to.clone();
336 let failed_attempts = failed_backends_write.get(&proxy_to);
337 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
338 }
339 error_logger.log(&format!("Bad gateway: {}", err)).await;
340 return Ok(
341 ResponseData::builder_without_request()
342 .status(StatusCode::BAD_GATEWAY)
343 .build(),
344 );
345 }
346 };
347
348 let failed_backends_option_borrowed = if enable_health_check {
349 Some(&*self.failed_backends)
350 } else {
351 None
352 };
353
354 if !encrypted {
355 http_proxy(
356 connections,
357 addr,
358 stream,
359 proxy_request,
360 error_logger,
361 proxy_to,
362 failed_backends_option_borrowed,
363 proxy_intercept_errors,
364 )
365 .await
366 } else {
367 let tls_client_config = (if disable_certificate_verification {
368 rustls::ClientConfig::builder()
369 .dangerous()
370 .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
371 } else {
372 rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
373 })
374 .with_no_client_auth();
375 let connector = TlsConnector::from(Arc::new(tls_client_config));
376 let domain = ServerName::try_from(host)?.to_owned();
377
378 let tls_stream = match connector.connect(domain, stream).await {
379 Ok(stream) => stream,
380 Err(err) => {
381 if enable_health_check {
382 let mut failed_backends_write = self.failed_backends.write().await;
383 let proxy_to = proxy_to.clone();
384 let failed_attempts = failed_backends_write.get(&proxy_to);
385 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
386 }
387 error_logger.log(&format!("Bad gateway: {}", err)).await;
388 return Ok(
389 ResponseData::builder_without_request()
390 .status(StatusCode::BAD_GATEWAY)
391 .build(),
392 );
393 }
394 };
395
396 http_proxy(
397 connections,
398 addr,
399 tls_stream,
400 proxy_request,
401 error_logger,
402 proxy_to,
403 failed_backends_option_borrowed,
404 proxy_intercept_errors,
405 )
406 .await
407 }
408 } else {
409 Ok(ResponseData::builder(request).build())
410 }
411 })
412 .await
413 }
414
415 async fn proxy_request_handler(
416 &mut self,
417 request: RequestData,
418 _config: &ServerConfig,
419 _socket_data: &SocketData,
420 _error_logger: &ErrorLogger,
421 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
422 Ok(ResponseData::builder(request).build())
423 }
424
425 async fn response_modifying_handler(
426 &mut self,
427 response: HyperResponse,
428 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
429 Ok(response)
430 }
431
432 async fn proxy_response_modifying_handler(
433 &mut self,
434 response: HyperResponse,
435 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
436 Ok(response)
437 }
438
439 async fn connect_proxy_request_handler(
440 &mut self,
441 _upgraded_request: HyperUpgraded,
442 _connect_address: &str,
443 _config: &ServerConfig,
444 _socket_data: &SocketData,
445 _error_logger: &ErrorLogger,
446 ) -> Result<(), Box<dyn Error + Send + Sync>> {
447 Ok(())
448 }
449
450 fn does_connect_proxy_requests(&mut self) -> bool {
451 false
452 }
453
454 async fn websocket_request_handler(
455 &mut self,
456 websocket: HyperWebsocket,
457 uri: &hyper::Uri,
458 headers: &hyper::HeaderMap,
459 config: &ServerConfig,
460 socket_data: &SocketData,
461 error_logger: &ErrorLogger,
462 ) -> Result<(), Box<dyn Error + Send + Sync>> {
463 WithRuntime::new(self.handle.clone(), async move {
464 let enable_health_check = config["enableLoadBalancerHealthCheck"]
465 .as_bool()
466 .unwrap_or(false);
467 let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
468 .as_i64()
469 .unwrap_or(3) as u64;
470
471 let disable_certificate_verification = config["disableProxyCertificateVerification"]
472 .as_bool()
473 .unwrap_or(false);
474 if let Some(proxy_to) = determine_proxy_to(
475 config,
476 socket_data.encrypted,
477 &self.failed_backends,
478 enable_health_check,
479 health_check_max_fails,
480 )
481 .await
482 {
483 let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
484 let scheme_str = proxy_request_url.scheme_str();
485 let mut encrypted = false;
486
487 match scheme_str {
488 Some("http") => {
489 encrypted = false;
490 }
491 Some("https") => {
492 encrypted = true;
493 }
494 _ => Err(anyhow::anyhow!(
495 "Only HTTP and HTTPS reverse proxy URLs are supported."
496 ))?,
497 };
498
499 let request_path = uri.path();
500
501 let path = match request_path.as_bytes().first() {
502 Some(b'/') => {
503 let mut proxy_request_path = proxy_request_url.path();
504 while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
505 proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
506 }
507 format!("{}{}", proxy_request_path, request_path)
508 }
509 _ => request_path.to_string(),
510 };
511
512 let mut proxy_request_url_parts = proxy_request_url.into_parts();
513 proxy_request_url_parts.scheme = if encrypted {
514 Some(Scheme::from_str("wss")?)
515 } else {
516 Some(Scheme::from_str("ws")?)
517 };
518 match uri.path_and_query() {
519 Some(path_and_query) => {
520 let path_and_query_string = match path_and_query.query() {
521 Some(query) => {
522 format!("{}?{}", path, query)
523 }
524 None => path,
525 };
526 proxy_request_url_parts.path_and_query =
527 Some(PathAndQuery::from_str(&path_and_query_string)?);
528 }
529 None => {
530 proxy_request_url_parts.path_and_query = Some(PathAndQuery::from_str(&path)?);
531 }
532 };
533
534 let proxy_request_url = hyper::Uri::from_parts(proxy_request_url_parts)?;
535
536 let connector = if !encrypted {
537 Connector::Plain
538 } else {
539 Connector::Rustls(Arc::new(
540 (if disable_certificate_verification {
541 rustls::ClientConfig::builder()
542 .dangerous()
543 .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
544 } else {
545 rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
546 })
547 .with_no_client_auth(),
548 ))
549 };
550
551 let mut proxy_request_builder = ClientRequestBuilder::new(proxy_request_url);
552 for (header_name, header_value) in headers {
553 let header_name_str = header_name.as_str();
554 if header_name == SEC_WEBSOCKET_PROTOCOL {
555 for subprotocol in String::from_utf8_lossy(header_value.as_bytes()).split(",") {
556 proxy_request_builder = proxy_request_builder.with_sub_protocol(subprotocol.trim());
557 }
558 } else if !header_name_str.starts_with("sec-websocket-")
559 && header_name_str != "x-forwarded-for"
560 {
561 proxy_request_builder = proxy_request_builder.with_header(
562 header_name_str,
563 String::from_utf8_lossy(header_value.as_bytes()),
564 );
565 }
566 }
567
568 proxy_request_builder = proxy_request_builder.with_header(
570 "x-forwarded-for",
571 socket_data.remote_addr.ip().to_canonical().to_string(),
572 );
573
574 let proxy_request_constructed = proxy_request_builder.into_client_request()?;
575
576 let client_bi_stream = websocket.await?;
577
578 let (proxy_bi_stream, _) = match tokio_tungstenite::connect_async_tls_with_config(
579 proxy_request_constructed,
580 None,
581 true,
582 Some(connector),
583 )
584 .await
585 {
586 Ok(data) => data,
587 Err(err) => {
588 error_logger
589 .log(&format!("Cannot connect to WebSocket server: {}", err))
590 .await;
591 return Ok(());
592 }
593 };
594
595 let (mut client_sink, mut client_stream) = client_bi_stream.split();
596 let (mut proxy_sink, mut proxy_stream) = proxy_bi_stream.split();
597
598 let client_to_proxy = async {
599 while let Some(Ok(value)) = client_stream.next().await {
600 if proxy_sink.send(value).await.is_err() {
601 break;
602 }
603 }
604 };
605
606 let proxy_to_client = async {
607 while let Some(Ok(value)) = proxy_stream.next().await {
608 if client_sink.send(value).await.is_err() {
609 break;
610 }
611 }
612 };
613
614 tokio::pin!(client_to_proxy);
615 tokio::pin!(proxy_to_client);
616
617 let client_to_proxy_first;
618 tokio::select! {
619 _ = &mut client_to_proxy => {
620 client_to_proxy_first = true;
621 }
622 _ = &mut proxy_to_client => {
623 client_to_proxy_first = false;
624 }
625 }
626
627 if client_to_proxy_first {
628 proxy_to_client.await;
629 } else {
630 client_to_proxy.await;
631 }
632 }
633
634 Ok(())
635 })
636 .await
637 }
638
639 fn does_websocket_requests(&mut self, config: &ServerConfig, socket_data: &SocketData) -> bool {
640 if socket_data.encrypted {
641 let secure_proxy_to = &config["secureProxyTo"];
642 if secure_proxy_to.as_vec().is_some() || secure_proxy_to.as_str().is_some() {
643 return true;
644 }
645 }
646
647 let proxy_to = &config["proxyTo"];
648 proxy_to.as_vec().is_some() || proxy_to.as_str().is_some()
649 }
650}
651
652async fn determine_proxy_to(
653 config: &ServerConfig,
654 encrypted: bool,
655 failed_backends: &RwLock<TtlCache<String, u64>>,
656 enable_health_check: bool,
657 health_check_max_fails: u64,
658) -> Option<String> {
659 let mut proxy_to = None;
660 if encrypted {
664 let secure_proxy_to_yaml = &config["secureProxyTo"];
665 if let Some(secure_proxy_to_vector) = secure_proxy_to_yaml.as_vec() {
666 if enable_health_check {
667 let mut secure_proxy_to_vector = secure_proxy_to_vector.clone();
668 loop {
669 if !secure_proxy_to_vector.is_empty() {
670 let index = rand::random_range(..secure_proxy_to_vector.len());
671 if let Some(secure_proxy_to) = secure_proxy_to_vector[index].as_str() {
672 proxy_to = Some(secure_proxy_to.to_string());
673 let failed_backends_read = failed_backends.read().await;
674 let failed_backend_fails =
675 match failed_backends_read.get(&secure_proxy_to.to_string()) {
676 Some(fails) => fails,
677 None => break,
678 };
679 if failed_backend_fails > health_check_max_fails {
680 secure_proxy_to_vector.remove(index);
681 } else {
682 break;
683 }
684 }
685 } else {
686 break;
687 }
688 }
689 } else if !secure_proxy_to_vector.is_empty() {
690 if let Some(secure_proxy_to) =
691 secure_proxy_to_vector[rand::random_range(..secure_proxy_to_vector.len())].as_str()
692 {
693 proxy_to = Some(secure_proxy_to.to_string());
694 }
695 }
696 } else if let Some(secure_proxy_to) = secure_proxy_to_yaml.as_str() {
697 proxy_to = Some(secure_proxy_to.to_string());
698 }
699 }
700
701 if proxy_to.is_none() {
702 let proxy_to_yaml = &config["proxyTo"];
703 if let Some(proxy_to_vector) = proxy_to_yaml.as_vec() {
704 if enable_health_check {
705 let mut proxy_to_vector = proxy_to_vector.clone();
706 loop {
707 if !proxy_to_vector.is_empty() {
708 let index = rand::random_range(..proxy_to_vector.len());
709 if let Some(proxy_to_str) = proxy_to_vector[index].as_str() {
710 proxy_to = Some(proxy_to_str.to_string());
711 let failed_backends_read = failed_backends.read().await;
712 let failed_backend_fails = match failed_backends_read.get(&proxy_to_str.to_string()) {
713 Some(fails) => fails,
714 None => break,
715 };
716 if failed_backend_fails > health_check_max_fails {
717 proxy_to_vector.remove(index);
718 } else {
719 break;
720 }
721 }
722 } else {
723 break;
724 }
725 }
726 } else if !proxy_to_vector.is_empty() {
727 if let Some(proxy_to_str) =
728 proxy_to_vector[rand::random_range(..proxy_to_vector.len())].as_str()
729 {
730 proxy_to = Some(proxy_to_str.to_string());
731 }
732 }
733 } else if let Some(proxy_to_str) = proxy_to_yaml.as_str() {
734 proxy_to = Some(proxy_to_str.to_string());
735 }
736 }
737
738 proxy_to
739}
740
741#[allow(clippy::too_many_arguments)]
742async fn http_proxy(
743 connections: &RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>,
744 connect_addr: String,
745 stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
746 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
747 error_logger: &ErrorLogger,
748 proxy_to: String,
749 failed_backends: Option<&tokio::sync::RwLock<TtlCache<std::string::String, u64>>>,
750 proxy_intercept_errors: bool,
751) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
752 let io = TokioIo::new(stream);
753
754 let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
755 Ok(data) => data,
756 Err(err) => {
757 if let Some(failed_backends) = failed_backends {
758 let mut failed_backends_write = failed_backends.write().await;
759 let failed_attempts = failed_backends_write.get(&proxy_to);
760 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
761 }
762 error_logger.log(&format!("Bad gateway: {}", err)).await;
763 return Ok(
764 ResponseData::builder_without_request()
765 .status(StatusCode::BAD_GATEWAY)
766 .build(),
767 );
768 }
769 };
770
771 let send_request = sender.send_request(proxy_request);
772
773 let mut pinned_conn = Box::pin(conn);
774 tokio::pin!(send_request);
775
776 let response;
777
778 loop {
779 tokio::select! {
780 biased;
781
782 proxy_response = &mut send_request => {
783 let proxy_response = match proxy_response {
784 Ok(response) => response,
785 Err(err) => {
786 error_logger.log(&format!("Bad gateway: {}", err)).await;
787 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
788 }
789 };
790
791 let status_code = proxy_response.status();
792 response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
793 ResponseData::builder_without_request()
794 .status(status_code)
795 .parallel_fn(async move {
796 pinned_conn.await.unwrap_or_default();
797 })
798 .build()
799 } else {
800 ResponseData::builder_without_request()
801 .response(proxy_response.map(|b| {
802 b.map_err(|e| std::io::Error::other(e.to_string()))
803 .boxed()
804 }))
805 .parallel_fn(async move {
806 pinned_conn.await.unwrap_or_default();
807 })
808 .build()
809 };
810
811 break;
812 },
813 state = &mut pinned_conn => {
814 if state.is_err() {
815 error_logger.log("Bad gateway: incomplete response").await;
816 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
817 }
818 },
819 };
820 }
821
822 if !sender.is_closed() {
823 let mut rwlock_write = connections.write().await;
824 rwlock_write.insert(connect_addr, sender);
825 drop(rwlock_write);
826 }
827
828 Ok(response)
829}
830
831async fn http_proxy_kept_alive(
832 sender: &mut SendRequest<BoxBody<Bytes, std::io::Error>>,
833 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
834 error_logger: &ErrorLogger,
835 proxy_intercept_errors: bool,
836) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
837 let proxy_response = match sender.send_request(proxy_request).await {
838 Ok(response) => response,
839 Err(err) => {
840 error_logger.log(&format!("Bad gateway: {}", err)).await;
841 return Ok(
842 ResponseData::builder_without_request()
843 .status(StatusCode::BAD_GATEWAY)
844 .build(),
845 );
846 }
847 };
848
849 let status_code = proxy_response.status();
850 let response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
851 ResponseData::builder_without_request()
852 .status(status_code)
853 .build()
854 } else {
855 ResponseData::builder_without_request()
856 .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed()))
857 .build()
858 };
859
860 Ok(response)
861}