1use std::collections::HashMap;
2use std::error::Error;
3use std::hash::RandomState;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use crate::ferron_common::{
8 ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
9 ServerModuleHandlers, SocketData,
10};
11use crate::ferron_common::{HyperResponse, WithRuntime};
12use async_trait::async_trait;
13use cache_control::{Cachability, CacheControl};
14use futures_util::{StreamExt, TryStreamExt};
15use hashlink::LinkedHashMap;
16use http_body_util::{BodyExt, Full, StreamBody};
17use hyper::body::{Bytes, Frame};
18use hyper::header::HeaderValue;
19use hyper::{header, HeaderMap, Method, Response, StatusCode};
20use hyper_tungstenite::HyperWebsocket;
21use itertools::Itertools;
22use tokio::runtime::Handle;
23use tokio::sync::RwLock;
24
25const CACHE_HEADER_NAME: &str = "X-Ferron-Cache";
26const DEFAULT_MAX_AGE: u64 = 300;
27
28pub fn server_module_init(
29 config: &ServerConfig,
30) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
31 let maximum_cache_entries = config["global"]["maximumCacheEntries"]
32 .as_i64()
33 .map(|v| v as usize);
34
35 Ok(Box::new(CacheModule::new(
36 Arc::new(RwLock::new(LinkedHashMap::with_hasher(RandomState::new()))),
37 Arc::new(RwLock::new(HashMap::new())),
38 maximum_cache_entries,
39 )))
40}
41
42#[allow(clippy::type_complexity)]
43struct CacheModule {
44 cache: Arc<
45 RwLock<
46 LinkedHashMap<
47 String,
48 (
49 StatusCode,
50 HeaderMap,
51 Vec<u8>,
52 Instant,
53 Option<CacheControl>,
54 ),
55 RandomState,
56 >,
57 >,
58 >,
59 vary_cache: Arc<RwLock<HashMap<String, Vec<String>>>>,
60 maximum_cache_entries: Option<usize>,
61}
62
63impl CacheModule {
64 #[allow(clippy::type_complexity)]
65 fn new(
66 cache: Arc<
67 RwLock<
68 LinkedHashMap<
69 String,
70 (
71 StatusCode,
72 HeaderMap,
73 Vec<u8>,
74 Instant,
75 Option<CacheControl>,
76 ),
77 RandomState,
78 >,
79 >,
80 >,
81 vary_cache: Arc<RwLock<HashMap<String, Vec<String>>>>,
82 maximum_cache_entries: Option<usize>,
83 ) -> Self {
84 Self {
85 cache,
86 vary_cache,
87 maximum_cache_entries,
88 }
89 }
90}
91
92impl ServerModule for CacheModule {
93 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
94 Box::new(CacheModuleHandlers {
95 cache: self.cache.clone(),
96 vary_cache: self.vary_cache.clone(),
97 maximum_cache_entries: self.maximum_cache_entries,
98 cache_vary_headers_configured: Vec::new(),
99 cache_ignore_headers_configured: Vec::new(),
100 maximum_cached_response_size: None,
101 cache_key: None,
102 request_headers: HeaderMap::new(),
103 has_authorization: false,
104 cached: false,
105 no_store: false,
106 handle,
107 })
108 }
109}
110
111#[allow(clippy::type_complexity)]
112struct CacheModuleHandlers {
113 handle: Handle,
114 cache: Arc<
115 RwLock<
116 LinkedHashMap<
117 String,
118 (
119 StatusCode,
120 HeaderMap,
121 Vec<u8>,
122 Instant,
123 Option<CacheControl>,
124 ),
125 RandomState,
126 >,
127 >,
128 >,
129 vary_cache: Arc<RwLock<HashMap<String, Vec<String>>>>,
130 maximum_cache_entries: Option<usize>,
131 cache_vary_headers_configured: Vec<String>,
132 cache_ignore_headers_configured: Vec<String>,
133 maximum_cached_response_size: Option<u64>,
134 cache_key: Option<String>,
135 request_headers: HeaderMap<HeaderValue>,
136 has_authorization: bool,
137 cached: bool,
138 no_store: bool,
139}
140
141#[async_trait]
142impl ServerModuleHandlers for CacheModuleHandlers {
143 async fn request_handler(
144 &mut self,
145 request: RequestData,
146 config: &ServerConfig,
147 socket_data: &SocketData,
148 _error_logger: &ErrorLogger,
149 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
150 WithRuntime::new(self.handle.clone(), async move {
151 self.cache_vary_headers_configured = match config["cacheVaryHeaders"].as_vec() {
152 Some(vector) => {
153 let mut new_vector = Vec::new();
154 for yaml_value in vector.iter() {
155 if let Some(str_value) = yaml_value.as_str() {
156 new_vector.push(str_value.to_string());
157 }
158 }
159 new_vector
160 }
161 None => Vec::new(),
162 };
163 self.cache_ignore_headers_configured = match config["cacheIgnoreHeaders"].as_vec() {
164 Some(vector) => {
165 let mut new_vector = Vec::new();
166 for yaml_value in vector.iter() {
167 if let Some(str_value) = yaml_value.as_str() {
168 new_vector.push(str_value.to_string());
169 }
170 }
171 new_vector
172 }
173 None => Vec::new(),
174 };
175 self.maximum_cached_response_size = config["maximumCachedResponseSize"]
176 .as_i64()
177 .map(|f| f as u64);
178
179 let hyper_request = request.get_hyper_request();
180 let cache_key = format!(
181 "{} {}{}{}{}",
182 hyper_request.method().as_str(),
183 match socket_data.encrypted {
184 false => "http://",
185 true => "https://",
186 },
187 match hyper_request.headers().get(header::HOST) {
188 Some(host) => String::from_utf8_lossy(host.as_bytes()).into_owned(),
189 None => "".to_string(),
190 },
191 hyper_request.uri().path(),
192 match hyper_request.uri().query() {
193 Some(query) => format!("?{}", query),
194 None => "".to_string(),
195 }
196 );
197
198 let request_cache_control = match hyper_request.headers().get(header::CACHE_CONTROL) {
199 Some(value) => CacheControl::from_value(&String::from_utf8_lossy(value.as_bytes())),
200 None => None,
201 };
202
203 let mut no_store = false;
204 let mut no_cache = false;
205
206 if let Some(request_cache_control) = request_cache_control {
207 no_store = request_cache_control.no_store;
208 if let Some(cachability) = request_cache_control.cachability {
209 if cachability == Cachability::NoCache {
210 no_cache = true;
211 }
212 }
213 }
214
215 match hyper_request.method() {
216 &Method::GET | &Method::HEAD => (),
217 _ => {
218 no_store = true;
219 }
220 };
221
222 if no_store {
223 self.no_store = true;
224 return Ok(ResponseData::builder(request).build());
225 }
226
227 if !no_cache {
228 let rwlock_read = self.vary_cache.read().await;
229 let processed_vary = rwlock_read.get(&cache_key);
230 if let Some(processed_vary) = processed_vary {
231 let cache_key_with_vary = format!(
232 "{}\n{}",
233 &cache_key,
234 processed_vary
235 .iter()
236 .map(|header_name| {
237 match hyper_request.headers().get(header_name) {
238 Some(header_value) => format!(
239 "{}: {}",
240 header_name,
241 String::from_utf8_lossy(header_value.as_bytes()).into_owned()
242 ),
243 None => "".to_string(),
244 }
245 })
246 .collect::<Vec<String>>()
247 .join("\n")
248 );
249
250 drop(rwlock_read);
251
252 let rwlock_read = self.cache.read().await;
253 let cached_entry_option = rwlock_read.get(&cache_key_with_vary);
254
255 if let Some((status_code, headers, body, timestamp, response_cache_control)) =
256 cached_entry_option
257 {
258 let max_age = match response_cache_control {
259 Some(response_cache_control) => match response_cache_control.s_max_age {
260 Some(s_max_age) => Some(s_max_age),
261 None => response_cache_control.max_age,
262 },
263 None => None,
264 };
265
266 let mut cached = true;
267
268 if timestamp.elapsed() > max_age.unwrap_or(Duration::from_secs(DEFAULT_MAX_AGE)) {
269 cached = false;
270 }
271
272 if cached {
273 self.cached = true;
274 let mut hyper_response_builder = Response::builder().status(status_code);
275 for (header_name, header_value) in headers.iter() {
276 hyper_response_builder = hyper_response_builder.header(header_name, header_value);
277 }
278 let hyper_response = hyper_response_builder.body(
279 Full::new(Bytes::from(body.clone()))
280 .map_err(|e| match e {})
281 .boxed(),
282 )?;
283 return Ok(
284 ResponseData::builder(request)
285 .response(hyper_response)
286 .build(),
287 );
288 } else {
289 drop(rwlock_read);
290 }
291 }
292 } else {
293 drop(rwlock_read);
294 }
295 }
296
297 self.request_headers = hyper_request.headers().clone();
298 self.cache_key = Some(cache_key);
299 self.has_authorization = hyper_request.headers().contains_key(header::AUTHORIZATION);
300
301 Ok(ResponseData::builder(request).build())
302 })
303 .await
304 }
305
306 async fn proxy_request_handler(
307 &mut self,
308 request: RequestData,
309 _config: &ServerConfig,
310 _socket_data: &SocketData,
311 _error_logger: &ErrorLogger,
312 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
313 Ok(ResponseData::builder(request).build())
314 }
315
316 async fn response_modifying_handler(
317 &mut self,
318 mut response: HyperResponse,
319 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
320 WithRuntime::new(self.handle.clone(), async move {
321 if self.no_store {
322 response
323 .headers_mut()
324 .insert(CACHE_HEADER_NAME, HeaderValue::from_str("BYPASS")?);
325 Ok(response)
326 } else if self.cached {
327 response
328 .headers_mut()
329 .insert(CACHE_HEADER_NAME, HeaderValue::from_str("HIT")?);
330 Ok(response)
331 } else if let Some(cache_key) = &self.cache_key {
332 let (mut response_parts, mut response_body) = response.into_parts();
333 let response_cache_control = match response_parts.headers.get(header::CACHE_CONTROL) {
334 Some(value) => CacheControl::from_value(&String::from_utf8_lossy(value.as_bytes())),
335 None => None,
336 };
337
338 let should_cache_response = match &response_cache_control {
339 Some(response_cache_control) => {
340 let is_private = response_cache_control.cachability == Some(Cachability::Private);
341 let is_public = response_cache_control.cachability == Some(Cachability::Public);
342
343 !response_cache_control.no_store
344 && !is_private
345 && (is_public
346 || (!self.has_authorization
347 && (response_cache_control.max_age.is_some()
348 || response_cache_control.s_max_age.is_some())))
349 }
350 None => false,
351 };
352
353 if should_cache_response {
354 let mut response_body_buffer = Vec::new();
355 let mut maximum_cached_response_size_exceeded = false;
356
357 while let Some(frame) = response_body.frame().await {
358 let frame_unwrapped = frame?;
359 if frame_unwrapped.is_data() {
360 if let Some(bytes) = frame_unwrapped.data_ref() {
361 response_body_buffer.extend_from_slice(bytes);
362 if let Some(maximum_cached_response_size) = self.maximum_cached_response_size {
363 if response_body_buffer.len() as u64 > maximum_cached_response_size {
364 maximum_cached_response_size_exceeded = true;
365 break;
366 }
367 }
368 }
369 }
370 }
371
372 if maximum_cached_response_size_exceeded {
373 let cached_stream =
374 futures_util::stream::once(async move { Ok(Bytes::from(response_body_buffer)) });
375 let response_stream = response_body.into_data_stream();
376 let chained_stream = cached_stream.chain(response_stream);
377 let stream_body = StreamBody::new(chained_stream.map_ok(Frame::data));
378 let response_body = BodyExt::boxed(stream_body);
379 response_parts
380 .headers
381 .insert(CACHE_HEADER_NAME, HeaderValue::from_str("MISS")?);
382 let response = Response::from_parts(response_parts, response_body);
383 Ok(response)
384 } else {
385 let mut response_vary = match response_parts.headers.get(header::VARY) {
386 Some(value) => String::from_utf8_lossy(value.as_bytes())
387 .split(",")
388 .map(|s| s.trim().to_owned())
389 .collect(),
390 None => Vec::new(),
391 };
392
393 let mut processed_vary_orig = self.cache_vary_headers_configured.clone();
394 processed_vary_orig.append(&mut response_vary);
395
396 let processed_vary = processed_vary_orig
397 .iter()
398 .unique()
399 .map(|s| s.to_owned())
400 .collect::<Vec<String>>();
401
402 if !processed_vary.contains(&"*".to_string()) {
403 let cache_key_with_vary = format!(
404 "{}\n{}",
405 &cache_key,
406 processed_vary
407 .iter()
408 .map(|header_name| {
409 match self.request_headers.get(header_name) {
410 Some(header_value) => format!(
411 "{}: {}",
412 header_name,
413 String::from_utf8_lossy(header_value.as_bytes()).into_owned()
414 ),
415 None => "".to_string(),
416 }
417 })
418 .collect::<Vec<String>>()
419 .join("\n")
420 );
421
422 let mut rwlock_write = self.vary_cache.write().await;
423 rwlock_write.insert(cache_key.clone(), processed_vary);
424 drop(rwlock_write);
425
426 let mut written_headers = response_parts.headers.clone();
427 for header in self.cache_ignore_headers_configured.iter() {
428 while written_headers.remove(header).is_some() {}
429 }
430
431 let mut rwlock_write = self.cache.write().await;
432 rwlock_write.retain(|_, (_, _, _, timestamp, response_cache_control)| {
433 let max_age = match response_cache_control {
434 Some(response_cache_control) => match response_cache_control.s_max_age {
435 Some(s_max_age) => Some(s_max_age),
436 None => response_cache_control.max_age,
437 },
438 None => None,
439 };
440
441 timestamp.elapsed() <= max_age.unwrap_or(Duration::from_secs(DEFAULT_MAX_AGE))
442 });
443
444 if let Some(maximum_cache_entries) = self.maximum_cache_entries {
445 while !rwlock_write.is_empty() && rwlock_write.len() >= maximum_cache_entries {
447 rwlock_write.pop_front();
448 }
449 }
450
451 rwlock_write.insert(
453 cache_key_with_vary,
454 (
455 response_parts.status,
456 written_headers,
457 response_body_buffer.clone(),
458 Instant::now(),
459 response_cache_control,
460 ),
461 );
462 drop(rwlock_write);
463 }
464
465 let cached_stream =
466 futures_util::stream::once(async move { Ok(Bytes::from(response_body_buffer)) });
467 let stream_body = StreamBody::new(cached_stream.map_ok(Frame::data));
468 let response_body = BodyExt::boxed(stream_body);
469 response_parts
470 .headers
471 .insert(CACHE_HEADER_NAME, HeaderValue::from_str("MISS")?);
472 let response = Response::from_parts(response_parts, response_body);
473 Ok(response)
474 }
475 } else {
476 response_parts
477 .headers
478 .insert(CACHE_HEADER_NAME, HeaderValue::from_str("MISS")?);
479 let response = Response::from_parts(response_parts, response_body);
480 Ok(response)
481 }
482 } else {
483 Ok(response)
484 }
485 })
486 .await
487 }
488
489 async fn proxy_response_modifying_handler(
490 &mut self,
491 response: HyperResponse,
492 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
493 Ok(response)
494 }
495
496 async fn connect_proxy_request_handler(
497 &mut self,
498 _upgraded_request: HyperUpgraded,
499 _connect_address: &str,
500 _config: &ServerConfig,
501 _socket_data: &SocketData,
502 _error_logger: &ErrorLogger,
503 ) -> Result<(), Box<dyn Error + Send + Sync>> {
504 Ok(())
505 }
506
507 fn does_connect_proxy_requests(&mut self) -> bool {
508 false
509 }
510
511 async fn websocket_request_handler(
512 &mut self,
513 _websocket: HyperWebsocket,
514 _uri: &hyper::Uri,
515 _config: &ServerConfig,
516 _socket_data: &SocketData,
517 _error_logger: &ErrorLogger,
518 ) -> Result<(), Box<dyn Error + Send + Sync>> {
519 Ok(())
520 }
521
522 fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
523 false
524 }
525}