ferron/modules/
url_rewrite.rs

1use std::error::Error;
2use std::path::Path;
3
4use crate::ferron_util::obtain_config_struct_vec::ObtainConfigStructVec;
5
6use crate::ferron_common::{
7  ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule,
8  ServerModuleHandlers, SocketData,
9};
10use crate::ferron_common::{HyperUpgraded, WithRuntime};
11use async_trait::async_trait;
12use fancy_regex::{Regex, RegexBuilder};
13use hyper::{header, Request, StatusCode};
14use hyper_tungstenite::HyperWebsocket;
15use tokio::fs;
16use tokio::runtime::Handle;
17use yaml_rust2::Yaml;
18
19struct UrlRewriteMapEntry {
20  regex: Regex,
21  replacement: String,
22  is_not_directory: bool,
23  is_not_file: bool,
24  last: bool,
25  allow_double_slashes: bool,
26}
27
28impl UrlRewriteMapEntry {
29  fn new(
30    regex: Regex,
31    replacement: String,
32    is_not_directory: bool,
33    is_not_file: bool,
34    last: bool,
35    allow_double_slashes: bool,
36  ) -> Self {
37    Self {
38      regex,
39      replacement,
40      is_not_directory,
41      is_not_file,
42      last,
43      allow_double_slashes,
44    }
45  }
46}
47
48fn url_rewrite_config_init(rewrite_map: &[Yaml]) -> Result<Vec<UrlRewriteMapEntry>, anyhow::Error> {
49  let rewrite_map_iter = rewrite_map.iter();
50  let mut rewrite_map_vec = Vec::new();
51  for rewrite_map_entry in rewrite_map_iter {
52    let regex_str = match rewrite_map_entry["regex"].as_str() {
53      Some(regex_str) => regex_str,
54      None => return Err(anyhow::anyhow!("Invalid URL rewrite regular expression")),
55    };
56    let regex = match RegexBuilder::new(regex_str)
57      .case_insensitive(cfg!(windows))
58      .build()
59    {
60      Ok(regex) => regex,
61      Err(err) => {
62        return Err(anyhow::anyhow!(
63          "Invalid URL rewrite regular expression: {}",
64          err.to_string()
65        ))
66      }
67    };
68    let replacement = match rewrite_map_entry["replacement"].as_str() {
69      Some(replacement) => String::from(replacement),
70      None => return Err(anyhow::anyhow!("URL rewrite rules must have replacements")),
71    };
72    let is_not_file = rewrite_map_entry["isNotFile"].as_bool().unwrap_or(false);
73    let is_not_directory = rewrite_map_entry["isNotDirectory"]
74      .as_bool()
75      .unwrap_or(false);
76    let last = rewrite_map_entry["last"].as_bool().unwrap_or_default();
77    let allow_double_slashes = rewrite_map_entry["allowDoubleSlashes"]
78      .as_bool()
79      .unwrap_or(false);
80    rewrite_map_vec.push(UrlRewriteMapEntry::new(
81      regex,
82      replacement,
83      is_not_directory,
84      is_not_file,
85      last,
86      allow_double_slashes,
87    ));
88  }
89
90  Ok(rewrite_map_vec)
91}
92
93pub fn server_module_init(
94  config: &ServerConfig,
95) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
96  Ok(Box::new(UrlRewriteModule::new(ObtainConfigStructVec::new(
97    config,
98    |config| {
99      if let Some(rewrite_map_yaml) = config["rewriteMap"].as_vec() {
100        Ok(url_rewrite_config_init(rewrite_map_yaml)?)
101      } else {
102        Ok(vec![])
103      }
104    },
105  )?)))
106}
107
108struct UrlRewriteModule {
109  url_rewrite_maps: ObtainConfigStructVec<UrlRewriteMapEntry>,
110}
111
112impl UrlRewriteModule {
113  fn new(url_rewrite_maps: ObtainConfigStructVec<UrlRewriteMapEntry>) -> Self {
114    Self { url_rewrite_maps }
115  }
116}
117
118impl ServerModule for UrlRewriteModule {
119  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
120    Box::new(UrlRewriteModuleHandlers {
121      url_rewrite_maps: self.url_rewrite_maps.clone(),
122      handle,
123    })
124  }
125}
126struct UrlRewriteModuleHandlers {
127  url_rewrite_maps: ObtainConfigStructVec<UrlRewriteMapEntry>,
128  handle: Handle,
129}
130
131#[async_trait]
132impl ServerModuleHandlers for UrlRewriteModuleHandlers {
133  async fn request_handler(
134    &mut self,
135    request: RequestData,
136    config: &ServerConfig,
137    socket_data: &SocketData,
138    error_logger: &ErrorLogger,
139  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
140    WithRuntime::new(self.handle.clone(), async move {
141      let hyper_request = request.get_hyper_request();
142      let combined_url_rewrite_map = self.url_rewrite_maps.obtain(
143        match hyper_request.headers().get(header::HOST) {
144          Some(value) => value.to_str().ok(),
145          None => None,
146        },
147        socket_data.remote_addr.ip(),
148        request
149          .get_original_url()
150          .unwrap_or(request.get_hyper_request().uri())
151          .path(),
152        request.get_error_status_code().map(|x| x.as_u16()),
153      );
154
155      let original_url = format!(
156        "{}{}",
157        hyper_request.uri().path(),
158        match hyper_request.uri().query() {
159          Some(query) => format!("?{}", query),
160          None => String::from(""),
161        }
162      );
163      let mut rewritten_url = original_url.clone();
164
165      let mut rewritten_url_bytes = rewritten_url.bytes();
166      if rewritten_url_bytes.len() < 1 || rewritten_url_bytes.nth(0) != Some(b'/') {
167        return Ok(
168          ResponseData::builder(request)
169            .status(StatusCode::BAD_REQUEST)
170            .build(),
171        );
172      }
173
174      for url_rewrite_map_entry in combined_url_rewrite_map {
175        // Check if it's a file or a directory according to the rewrite map configuration
176        if url_rewrite_map_entry.is_not_directory || url_rewrite_map_entry.is_not_file {
177          if let Some(wwwroot) = config["wwwroot"].as_str() {
178            let path = Path::new(wwwroot);
179            let mut relative_path = &rewritten_url[1..];
180            while relative_path.as_bytes().first().copied() == Some(b'/') {
181              relative_path = &relative_path[1..];
182            }
183            let relative_path_split: Vec<&str> = relative_path.split("?").collect();
184            if !relative_path_split.is_empty() {
185              relative_path = relative_path_split[0];
186            }
187            let joined_pathbuf = path.join(relative_path);
188            if let Ok(metadata) = fs::metadata(joined_pathbuf).await {
189              if (url_rewrite_map_entry.is_not_file && metadata.is_file())
190                || (url_rewrite_map_entry.is_not_directory && metadata.is_dir())
191              {
192                continue;
193              }
194            }
195          }
196        }
197
198        if !url_rewrite_map_entry.allow_double_slashes {
199          while rewritten_url.contains("//") {
200            rewritten_url = rewritten_url.replace("//", "/");
201          }
202        }
203
204        // Actual URL rewriting
205        let old_rewritten_url = rewritten_url;
206        rewritten_url = url_rewrite_map_entry
207          .regex
208          .replace(&old_rewritten_url, &url_rewrite_map_entry.replacement)
209          .to_string();
210
211        let mut rewritten_url_bytes = rewritten_url.bytes();
212        if rewritten_url_bytes.len() < 1 || rewritten_url_bytes.nth(0) != Some(b'/') {
213          return Ok(
214            ResponseData::builder(request)
215              .status(StatusCode::BAD_REQUEST)
216              .build(),
217          );
218        }
219
220        if url_rewrite_map_entry.last && old_rewritten_url != rewritten_url {
221          break;
222        }
223      }
224
225      if rewritten_url == original_url {
226        Ok(ResponseData::builder(request).build())
227      } else {
228        if config["enableRewriteLogging"].as_bool() == Some(true) {
229          error_logger
230            .log(&format!(
231              "URL rewritten from \"{}\" to \"{}\"",
232              original_url, rewritten_url
233            ))
234            .await;
235        }
236        let (hyper_request, auth_user, _, error_status_code) = request.into_parts();
237        let (mut parts, body) = hyper_request.into_parts();
238        let original_url = parts.uri.clone();
239        let mut url_parts = parts.uri.into_parts();
240        url_parts.path_and_query = Some(rewritten_url.parse()?);
241        parts.uri = hyper::Uri::from_parts(url_parts)?;
242        let hyper_request = Request::from_parts(parts, body);
243        let request = RequestData::new(
244          hyper_request,
245          auth_user,
246          Some(original_url),
247          error_status_code,
248        );
249        Ok(ResponseData::builder(request).build())
250      }
251    })
252    .await
253  }
254
255  async fn proxy_request_handler(
256    &mut self,
257    request: RequestData,
258    _config: &ServerConfig,
259    _socket_data: &SocketData,
260    _error_logger: &ErrorLogger,
261  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
262    Ok(ResponseData::builder(request).build())
263  }
264
265  async fn response_modifying_handler(
266    &mut self,
267    response: HyperResponse,
268  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
269    Ok(response)
270  }
271
272  async fn proxy_response_modifying_handler(
273    &mut self,
274    response: HyperResponse,
275  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
276    Ok(response)
277  }
278
279  async fn connect_proxy_request_handler(
280    &mut self,
281    _upgraded_request: HyperUpgraded,
282    _connect_address: &str,
283    _config: &ServerConfig,
284    _socket_data: &SocketData,
285    _error_logger: &ErrorLogger,
286  ) -> Result<(), Box<dyn Error + Send + Sync>> {
287    Ok(())
288  }
289
290  fn does_connect_proxy_requests(&mut self) -> bool {
291    false
292  }
293
294  async fn websocket_request_handler(
295    &mut self,
296    _websocket: HyperWebsocket,
297    _uri: &hyper::Uri,
298    _headers: &hyper::HeaderMap,
299    _config: &ServerConfig,
300    _socket_data: &SocketData,
301    _error_logger: &ErrorLogger,
302  ) -> Result<(), Box<dyn Error + Send + Sync>> {
303    Ok(())
304  }
305
306  fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
307    false
308  }
309}