ferron/optional_modules/
cache.rs

1use std::collections::HashMap;
2use std::error::Error;
3use std::hash::RandomState;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use crate::ferron_common::{
8  ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
9  ServerModuleHandlers, SocketData,
10};
11use crate::ferron_common::{HyperResponse, WithRuntime};
12use async_trait::async_trait;
13use cache_control::{Cachability, CacheControl};
14use futures_util::{StreamExt, TryStreamExt};
15use hashlink::LinkedHashMap;
16use http_body_util::{BodyExt, Full, StreamBody};
17use hyper::body::{Bytes, Frame};
18use hyper::header::HeaderValue;
19use hyper::{header, HeaderMap, Method, Response, StatusCode};
20use hyper_tungstenite::HyperWebsocket;
21use itertools::Itertools;
22use tokio::runtime::Handle;
23use tokio::sync::RwLock;
24
25const CACHE_HEADER_NAME: &str = "X-Ferron-Cache";
26const DEFAULT_MAX_AGE: u64 = 300;
27
28pub fn server_module_init(
29  config: &ServerConfig,
30) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
31  let maximum_cache_entries = config["global"]["maximumCacheEntries"]
32    .as_i64()
33    .map(|v| v as usize);
34
35  Ok(Box::new(CacheModule::new(
36    Arc::new(RwLock::new(LinkedHashMap::with_hasher(RandomState::new()))),
37    Arc::new(RwLock::new(HashMap::new())),
38    maximum_cache_entries,
39  )))
40}
41
42#[allow(clippy::type_complexity)]
43struct CacheModule {
44  cache: Arc<
45    RwLock<
46      LinkedHashMap<
47        String,
48        (
49          StatusCode,
50          HeaderMap,
51          Vec<u8>,
52          Instant,
53          Option<CacheControl>,
54        ),
55        RandomState,
56      >,
57    >,
58  >,
59  vary_cache: Arc<RwLock<HashMap<String, Vec<String>>>>,
60  maximum_cache_entries: Option<usize>,
61}
62
63impl CacheModule {
64  #[allow(clippy::type_complexity)]
65  fn new(
66    cache: Arc<
67      RwLock<
68        LinkedHashMap<
69          String,
70          (
71            StatusCode,
72            HeaderMap,
73            Vec<u8>,
74            Instant,
75            Option<CacheControl>,
76          ),
77          RandomState,
78        >,
79      >,
80    >,
81    vary_cache: Arc<RwLock<HashMap<String, Vec<String>>>>,
82    maximum_cache_entries: Option<usize>,
83  ) -> Self {
84    Self {
85      cache,
86      vary_cache,
87      maximum_cache_entries,
88    }
89  }
90}
91
92impl ServerModule for CacheModule {
93  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
94    Box::new(CacheModuleHandlers {
95      cache: self.cache.clone(),
96      vary_cache: self.vary_cache.clone(),
97      maximum_cache_entries: self.maximum_cache_entries,
98      cache_vary_headers_configured: Vec::new(),
99      cache_ignore_headers_configured: Vec::new(),
100      maximum_cached_response_size: None,
101      cache_key: None,
102      request_headers: HeaderMap::new(),
103      has_authorization: false,
104      cached: false,
105      no_store: false,
106      handle,
107    })
108  }
109}
110
111#[allow(clippy::type_complexity)]
112struct CacheModuleHandlers {
113  handle: Handle,
114  cache: Arc<
115    RwLock<
116      LinkedHashMap<
117        String,
118        (
119          StatusCode,
120          HeaderMap,
121          Vec<u8>,
122          Instant,
123          Option<CacheControl>,
124        ),
125        RandomState,
126      >,
127    >,
128  >,
129  vary_cache: Arc<RwLock<HashMap<String, Vec<String>>>>,
130  maximum_cache_entries: Option<usize>,
131  cache_vary_headers_configured: Vec<String>,
132  cache_ignore_headers_configured: Vec<String>,
133  maximum_cached_response_size: Option<u64>,
134  cache_key: Option<String>,
135  request_headers: HeaderMap<HeaderValue>,
136  has_authorization: bool,
137  cached: bool,
138  no_store: bool,
139}
140
141#[async_trait]
142impl ServerModuleHandlers for CacheModuleHandlers {
143  async fn request_handler(
144    &mut self,
145    request: RequestData,
146    config: &ServerConfig,
147    socket_data: &SocketData,
148    _error_logger: &ErrorLogger,
149  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
150    WithRuntime::new(self.handle.clone(), async move {
151      self.cache_vary_headers_configured = match config["cacheVaryHeaders"].as_vec() {
152        Some(vector) => {
153          let mut new_vector = Vec::new();
154          for yaml_value in vector.iter() {
155            if let Some(str_value) = yaml_value.as_str() {
156              new_vector.push(str_value.to_string());
157            }
158          }
159          new_vector
160        }
161        None => Vec::new(),
162      };
163      self.cache_ignore_headers_configured = match config["cacheIgnoreHeaders"].as_vec() {
164        Some(vector) => {
165          let mut new_vector = Vec::new();
166          for yaml_value in vector.iter() {
167            if let Some(str_value) = yaml_value.as_str() {
168              new_vector.push(str_value.to_string());
169            }
170          }
171          new_vector
172        }
173        None => Vec::new(),
174      };
175      self.maximum_cached_response_size = config["maximumCachedResponseSize"]
176        .as_i64()
177        .map(|f| f as u64);
178
179      let hyper_request = request.get_hyper_request();
180      let cache_key = format!(
181        "{} {}{}{}{}",
182        hyper_request.method().as_str(),
183        match socket_data.encrypted {
184          false => "http://",
185          true => "https://",
186        },
187        match hyper_request.headers().get(header::HOST) {
188          Some(host) => String::from_utf8_lossy(host.as_bytes()).into_owned(),
189          None => "".to_string(),
190        },
191        hyper_request.uri().path(),
192        match hyper_request.uri().query() {
193          Some(query) => format!("?{}", query),
194          None => "".to_string(),
195        }
196      );
197
198      let request_cache_control = match hyper_request.headers().get(header::CACHE_CONTROL) {
199        Some(value) => CacheControl::from_value(&String::from_utf8_lossy(value.as_bytes())),
200        None => None,
201      };
202
203      let mut no_store = false;
204      let mut no_cache = false;
205
206      if let Some(request_cache_control) = request_cache_control {
207        no_store = request_cache_control.no_store;
208        if let Some(cachability) = request_cache_control.cachability {
209          if cachability == Cachability::NoCache {
210            no_cache = true;
211          }
212        }
213      }
214
215      match hyper_request.method() {
216        &Method::GET | &Method::HEAD => (),
217        _ => {
218          no_store = true;
219        }
220      };
221
222      if no_store {
223        self.no_store = true;
224        return Ok(ResponseData::builder(request).build());
225      }
226
227      if !no_cache {
228        let rwlock_read = self.vary_cache.read().await;
229        let processed_vary = rwlock_read.get(&cache_key);
230        if let Some(processed_vary) = processed_vary {
231          let cache_key_with_vary = format!(
232            "{}\n{}",
233            &cache_key,
234            processed_vary
235              .iter()
236              .map(|header_name| {
237                match hyper_request.headers().get(header_name) {
238                  Some(header_value) => format!(
239                    "{}: {}",
240                    header_name,
241                    String::from_utf8_lossy(header_value.as_bytes()).into_owned()
242                  ),
243                  None => "".to_string(),
244                }
245              })
246              .collect::<Vec<String>>()
247              .join("\n")
248          );
249
250          drop(rwlock_read);
251
252          let rwlock_read = self.cache.read().await;
253          let cached_entry_option = rwlock_read.get(&cache_key_with_vary);
254
255          if let Some((status_code, headers, body, timestamp, response_cache_control)) =
256            cached_entry_option
257          {
258            let max_age = match response_cache_control {
259              Some(response_cache_control) => match response_cache_control.s_max_age {
260                Some(s_max_age) => Some(s_max_age),
261                None => response_cache_control.max_age,
262              },
263              None => None,
264            };
265
266            let mut cached = true;
267
268            if timestamp.elapsed() > max_age.unwrap_or(Duration::from_secs(DEFAULT_MAX_AGE)) {
269              cached = false;
270            }
271
272            if cached {
273              self.cached = true;
274              let mut hyper_response_builder = Response::builder().status(status_code);
275              for (header_name, header_value) in headers.iter() {
276                hyper_response_builder = hyper_response_builder.header(header_name, header_value);
277              }
278              let hyper_response = hyper_response_builder.body(
279                Full::new(Bytes::from(body.clone()))
280                  .map_err(|e| match e {})
281                  .boxed(),
282              )?;
283              return Ok(
284                ResponseData::builder(request)
285                  .response(hyper_response)
286                  .build(),
287              );
288            } else {
289              drop(rwlock_read);
290            }
291          }
292        } else {
293          drop(rwlock_read);
294        }
295      }
296
297      self.request_headers = hyper_request.headers().clone();
298      self.cache_key = Some(cache_key);
299      self.has_authorization = hyper_request.headers().contains_key(header::AUTHORIZATION);
300
301      Ok(ResponseData::builder(request).build())
302    })
303    .await
304  }
305
306  async fn proxy_request_handler(
307    &mut self,
308    request: RequestData,
309    _config: &ServerConfig,
310    _socket_data: &SocketData,
311    _error_logger: &ErrorLogger,
312  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
313    Ok(ResponseData::builder(request).build())
314  }
315
316  async fn response_modifying_handler(
317    &mut self,
318    mut response: HyperResponse,
319  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
320    WithRuntime::new(self.handle.clone(), async move {
321      if self.no_store {
322        response
323          .headers_mut()
324          .insert(CACHE_HEADER_NAME, HeaderValue::from_str("BYPASS")?);
325        Ok(response)
326      } else if self.cached {
327        response
328          .headers_mut()
329          .insert(CACHE_HEADER_NAME, HeaderValue::from_str("HIT")?);
330        Ok(response)
331      } else if let Some(cache_key) = &self.cache_key {
332        let (mut response_parts, mut response_body) = response.into_parts();
333        let response_cache_control = match response_parts.headers.get(header::CACHE_CONTROL) {
334          Some(value) => CacheControl::from_value(&String::from_utf8_lossy(value.as_bytes())),
335          None => None,
336        };
337
338        let should_cache_response = match &response_cache_control {
339          Some(response_cache_control) => {
340            let is_private = response_cache_control.cachability == Some(Cachability::Private);
341            let is_public = response_cache_control.cachability == Some(Cachability::Public);
342
343            !response_cache_control.no_store
344              && !is_private
345              && (is_public
346                || (!self.has_authorization
347                  && (response_cache_control.max_age.is_some()
348                    || response_cache_control.s_max_age.is_some())))
349          }
350          None => false,
351        };
352
353        if should_cache_response {
354          let mut response_body_buffer = Vec::new();
355          let mut maximum_cached_response_size_exceeded = false;
356
357          while let Some(frame) = response_body.frame().await {
358            let frame_unwrapped = frame?;
359            if frame_unwrapped.is_data() {
360              if let Some(bytes) = frame_unwrapped.data_ref() {
361                response_body_buffer.extend_from_slice(bytes);
362                if let Some(maximum_cached_response_size) = self.maximum_cached_response_size {
363                  if response_body_buffer.len() as u64 > maximum_cached_response_size {
364                    maximum_cached_response_size_exceeded = true;
365                    break;
366                  }
367                }
368              }
369            }
370          }
371
372          if maximum_cached_response_size_exceeded {
373            let cached_stream =
374              futures_util::stream::once(async move { Ok(Bytes::from(response_body_buffer)) });
375            let response_stream = response_body.into_data_stream();
376            let chained_stream = cached_stream.chain(response_stream);
377            let stream_body = StreamBody::new(chained_stream.map_ok(Frame::data));
378            let response_body = BodyExt::boxed(stream_body);
379            response_parts
380              .headers
381              .insert(CACHE_HEADER_NAME, HeaderValue::from_str("MISS")?);
382            let response = Response::from_parts(response_parts, response_body);
383            Ok(response)
384          } else {
385            let mut response_vary = match response_parts.headers.get(header::VARY) {
386              Some(value) => String::from_utf8_lossy(value.as_bytes())
387                .split(",")
388                .map(|s| s.trim().to_owned())
389                .collect(),
390              None => Vec::new(),
391            };
392
393            let mut processed_vary_orig = self.cache_vary_headers_configured.clone();
394            processed_vary_orig.append(&mut response_vary);
395
396            let processed_vary = processed_vary_orig
397              .iter()
398              .unique()
399              .map(|s| s.to_owned())
400              .collect::<Vec<String>>();
401
402            if !processed_vary.contains(&"*".to_string()) {
403              let cache_key_with_vary = format!(
404                "{}\n{}",
405                &cache_key,
406                processed_vary
407                  .iter()
408                  .map(|header_name| {
409                    match self.request_headers.get(header_name) {
410                      Some(header_value) => format!(
411                        "{}: {}",
412                        header_name,
413                        String::from_utf8_lossy(header_value.as_bytes()).into_owned()
414                      ),
415                      None => "".to_string(),
416                    }
417                  })
418                  .collect::<Vec<String>>()
419                  .join("\n")
420              );
421
422              let mut rwlock_write = self.vary_cache.write().await;
423              rwlock_write.insert(cache_key.clone(), processed_vary);
424              drop(rwlock_write);
425
426              let mut written_headers = response_parts.headers.clone();
427              for header in self.cache_ignore_headers_configured.iter() {
428                while written_headers.remove(header).is_some() {}
429              }
430
431              let mut rwlock_write = self.cache.write().await;
432              rwlock_write.retain(|_, (_, _, _, timestamp, response_cache_control)| {
433                let max_age = match response_cache_control {
434                  Some(response_cache_control) => match response_cache_control.s_max_age {
435                    Some(s_max_age) => Some(s_max_age),
436                    None => response_cache_control.max_age,
437                  },
438                  None => None,
439                };
440
441                timestamp.elapsed() <= max_age.unwrap_or(Duration::from_secs(DEFAULT_MAX_AGE))
442              });
443
444              if let Some(maximum_cache_entries) = self.maximum_cache_entries {
445                // Remove a value at the front of the list
446                while !rwlock_write.is_empty() && rwlock_write.len() >= maximum_cache_entries {
447                  rwlock_write.pop_front();
448                }
449              }
450
451              // This inserts a value at the back of the list
452              rwlock_write.insert(
453                cache_key_with_vary,
454                (
455                  response_parts.status,
456                  written_headers,
457                  response_body_buffer.clone(),
458                  Instant::now(),
459                  response_cache_control,
460                ),
461              );
462              drop(rwlock_write);
463            }
464
465            let cached_stream =
466              futures_util::stream::once(async move { Ok(Bytes::from(response_body_buffer)) });
467            let stream_body = StreamBody::new(cached_stream.map_ok(Frame::data));
468            let response_body = BodyExt::boxed(stream_body);
469            response_parts
470              .headers
471              .insert(CACHE_HEADER_NAME, HeaderValue::from_str("MISS")?);
472            let response = Response::from_parts(response_parts, response_body);
473            Ok(response)
474          }
475        } else {
476          response_parts
477            .headers
478            .insert(CACHE_HEADER_NAME, HeaderValue::from_str("MISS")?);
479          let response = Response::from_parts(response_parts, response_body);
480          Ok(response)
481        }
482      } else {
483        Ok(response)
484      }
485    })
486    .await
487  }
488
489  async fn proxy_response_modifying_handler(
490    &mut self,
491    response: HyperResponse,
492  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
493    Ok(response)
494  }
495
496  async fn connect_proxy_request_handler(
497    &mut self,
498    _upgraded_request: HyperUpgraded,
499    _connect_address: &str,
500    _config: &ServerConfig,
501    _socket_data: &SocketData,
502    _error_logger: &ErrorLogger,
503  ) -> Result<(), Box<dyn Error + Send + Sync>> {
504    Ok(())
505  }
506
507  fn does_connect_proxy_requests(&mut self) -> bool {
508    false
509  }
510
511  async fn websocket_request_handler(
512    &mut self,
513    _websocket: HyperWebsocket,
514    _uri: &hyper::Uri,
515    _config: &ServerConfig,
516    _socket_data: &SocketData,
517    _error_logger: &ErrorLogger,
518  ) -> Result<(), Box<dyn Error + Send + Sync>> {
519    Ok(())
520  }
521
522  fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
523    false
524  }
525}