ferron/modules/
non_standard_codes.rs

1use std::error::Error;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::ferron_util::ip_blocklist::IpBlockList;
6use crate::ferron_util::ip_match::ip_match;
7use crate::ferron_util::match_hostname::match_hostname;
8use crate::ferron_util::match_location::match_location;
9use crate::ferron_util::non_standard_code_structs::{
10  NonStandardCode, NonStandardCodesLocationWrap, NonStandardCodesWrap,
11};
12use crate::ferron_util::ttl_cache::TtlCache;
13
14use crate::ferron_common::{
15  ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule,
16  ServerModuleHandlers, SocketData,
17};
18use crate::ferron_common::{HyperUpgraded, WithRuntime};
19use async_trait::async_trait;
20use base64::{engine::general_purpose, Engine};
21use fancy_regex::RegexBuilder;
22use http_body_util::{BodyExt, Empty};
23use hyper::header::HeaderValue;
24use hyper::{header, HeaderMap, Response, StatusCode};
25use hyper_tungstenite::HyperWebsocket;
26use password_auth::verify_password;
27use tokio::runtime::Handle;
28use tokio::sync::RwLock;
29use yaml_rust2::Yaml;
30
31fn non_standard_codes_config_init(
32  non_standard_codes_list: &[Yaml],
33) -> Result<Vec<NonStandardCode>, anyhow::Error> {
34  let non_standard_codes_list_iter = non_standard_codes_list.iter();
35  let mut non_standard_codes_list_vec = Vec::new();
36  for non_standard_codes_list_entry in non_standard_codes_list_iter {
37    let status_code: u16 = match non_standard_codes_list_entry["scode"].as_i64() {
38      Some(scode) => scode.try_into()?,
39      None => {
40        return Err(anyhow::anyhow!(
41          "Non-standard codes must include a status code"
42        ));
43      }
44    };
45    let regex = match non_standard_codes_list_entry["regex"].as_str() {
46      Some(regex_str) => match RegexBuilder::new(regex_str)
47        .case_insensitive(cfg!(windows))
48        .build()
49      {
50        Ok(regex) => Some(regex),
51        Err(err) => {
52          return Err(anyhow::anyhow!(
53            "Invalid non-standard code regular expression: {}",
54            err.to_string()
55          ));
56        }
57      },
58      None => None,
59    };
60    let url = non_standard_codes_list_entry["url"]
61      .as_str()
62      .map(|s| s.to_string());
63
64    if regex.is_none() && url.is_none() {
65      return Err(anyhow::anyhow!(
66        "Non-standard codes must either include URL or a matching regular expression"
67      ));
68    }
69
70    let location = non_standard_codes_list_entry["location"]
71      .as_str()
72      .map(|s| s.to_string());
73    let realm = non_standard_codes_list_entry["realm"]
74      .as_str()
75      .map(|s| s.to_string());
76    let disable_brute_force_protection = non_standard_codes_list_entry["disableBruteProtection"]
77      .as_bool()
78      .unwrap_or(false);
79    let user_list = match non_standard_codes_list_entry["userList"].as_vec() {
80      Some(userlist) => {
81        let mut new_userlist = Vec::new();
82        for user_yaml in userlist.iter() {
83          if let Some(user) = user_yaml.as_str() {
84            new_userlist.push(user.to_string());
85          }
86        }
87        Some(new_userlist)
88      }
89      None => None,
90    };
91    let users = match non_standard_codes_list_entry["users"].as_vec() {
92      Some(users_vec) => {
93        let mut users_str_vec = Vec::new();
94        for user_yaml in users_vec.iter() {
95          if let Some(user) = user_yaml.as_str() {
96            users_str_vec.push(user);
97          }
98        }
99
100        let mut users_init = IpBlockList::new();
101        users_init.load_from_vec(users_str_vec);
102        Some(users_init)
103      }
104      None => None,
105    };
106    non_standard_codes_list_vec.push(NonStandardCode::new(
107      status_code,
108      url,
109      regex,
110      location,
111      realm,
112      disable_brute_force_protection,
113      user_list,
114      users,
115    ));
116  }
117
118  Ok(non_standard_codes_list_vec)
119}
120
121pub fn server_module_init(
122  config: &ServerConfig,
123) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
124  let mut global_non_standard_codes_list = Vec::new();
125  let mut host_non_standard_codes_lists = Vec::new();
126  if let Some(non_standard_codes_list_yaml) = config["global"]["nonStandardCodes"].as_vec() {
127    global_non_standard_codes_list = non_standard_codes_config_init(non_standard_codes_list_yaml)?;
128  }
129
130  if let Some(hosts) = config["hosts"].as_vec() {
131    for host_yaml in hosts.iter() {
132      let domain = host_yaml["domain"].as_str().map(String::from);
133      let ip = host_yaml["ip"].as_str().map(String::from);
134      let mut locations = Vec::new();
135      if let Some(locations_yaml) = host_yaml["locations"].as_vec() {
136        for location_yaml in locations_yaml.iter() {
137          if let Some(path_str) = location_yaml["path"].as_str() {
138            let path = String::from(path_str);
139            if let Some(non_standard_codes_list_yaml) = location_yaml["nonStandardCodes"].as_vec() {
140              locations.push(NonStandardCodesLocationWrap::new(
141                path,
142                non_standard_codes_config_init(non_standard_codes_list_yaml)?,
143              ));
144            }
145          }
146        }
147      }
148      if let Some(non_standard_codes_list_yaml) = host_yaml["nonStandardCodes"].as_vec() {
149        host_non_standard_codes_lists.push(NonStandardCodesWrap::new(
150          domain,
151          ip,
152          non_standard_codes_config_init(non_standard_codes_list_yaml)?,
153          locations,
154        ));
155      } else if !locations.is_empty() {
156        host_non_standard_codes_lists.push(NonStandardCodesWrap::new(
157          domain,
158          ip,
159          Vec::new(),
160          locations,
161        ));
162      }
163    }
164  }
165
166  Ok(Box::new(NonStandardCodesModule::new(
167    Arc::new(global_non_standard_codes_list),
168    Arc::new(host_non_standard_codes_lists),
169    Arc::new(RwLock::new(TtlCache::new(Duration::new(300, 0)))),
170  )))
171}
172
173struct NonStandardCodesModule {
174  global_non_standard_codes_list: Arc<Vec<NonStandardCode>>,
175  host_non_standard_codes_lists: Arc<Vec<NonStandardCodesWrap>>,
176  brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
177}
178
179impl NonStandardCodesModule {
180  fn new(
181    global_non_standard_codes_list: Arc<Vec<NonStandardCode>>,
182    host_non_standard_codes_lists: Arc<Vec<NonStandardCodesWrap>>,
183    brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
184  ) -> Self {
185    Self {
186      global_non_standard_codes_list,
187      host_non_standard_codes_lists,
188      brute_force_db,
189    }
190  }
191}
192
193impl ServerModule for NonStandardCodesModule {
194  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
195    Box::new(NonStandardCodesModuleHandlers {
196      global_non_standard_codes_list: self.global_non_standard_codes_list.clone(),
197      host_non_standard_codes_lists: self.host_non_standard_codes_lists.clone(),
198      brute_force_db: self.brute_force_db.clone(),
199      handle,
200    })
201  }
202}
203
204fn parse_basic_auth(auth_str: &str) -> Option<(String, String)> {
205  if let Some(base64_credentials) = auth_str.strip_prefix("Basic ") {
206    if let Ok(decoded) = general_purpose::STANDARD.decode(base64_credentials) {
207      if let Ok(decoded_str) = std::str::from_utf8(&decoded) {
208        let parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
209        if parts.len() == 2 {
210          return Some((parts[0].to_string(), parts[1].to_string()));
211        }
212      }
213    }
214  }
215  None
216}
217
218struct NonStandardCodesModuleHandlers {
219  global_non_standard_codes_list: Arc<Vec<NonStandardCode>>,
220  host_non_standard_codes_lists: Arc<Vec<NonStandardCodesWrap>>,
221  brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
222  handle: Handle,
223}
224
225#[async_trait]
226impl ServerModuleHandlers for NonStandardCodesModuleHandlers {
227  async fn request_handler(
228    &mut self,
229    request: RequestData,
230    config: &ServerConfig,
231    socket_data: &SocketData,
232    error_logger: &ErrorLogger,
233  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
234    WithRuntime::new(self.handle.clone(), async move {
235      let hyper_request = request.get_hyper_request();
236      let global_non_standard_codes_list = self.global_non_standard_codes_list.iter();
237      let empty_vector = Vec::new();
238      let another_empty_vector = Vec::new();
239      let mut host_non_standard_codes_list = empty_vector.iter();
240      let mut location_non_standard_codes_list = another_empty_vector.iter();
241
242      // Should have used a HashMap instead of iterating over an array for better performance...
243      for host_non_standard_codes_list_wrap in self.host_non_standard_codes_lists.iter() {
244        if match_hostname(
245          match &host_non_standard_codes_list_wrap.domain {
246            Some(value) => Some(value as &str),
247            None => None,
248          },
249          match hyper_request.headers().get(header::HOST) {
250            Some(value) => value.to_str().ok(),
251            None => None,
252          },
253        ) && match &host_non_standard_codes_list_wrap.ip {
254          Some(value) => ip_match(value as &str, socket_data.remote_addr.ip()),
255          None => true,
256        } {
257          host_non_standard_codes_list =
258            host_non_standard_codes_list_wrap.non_standard_codes.iter();
259          if let Ok(path_decoded) = urlencoding::decode(
260            request
261              .get_original_url()
262              .unwrap_or(request.get_hyper_request().uri())
263              .path(),
264          ) {
265            for location_wrap in host_non_standard_codes_list_wrap.locations.iter() {
266              if match_location(&location_wrap.path, &path_decoded) {
267                location_non_standard_codes_list = location_wrap.non_standard_codes.iter();
268                break;
269              }
270            }
271          }
272          break;
273        }
274      }
275
276      let combined_non_standard_codes_list = global_non_standard_codes_list
277        .chain(host_non_standard_codes_list)
278        .chain(location_non_standard_codes_list);
279
280      let request_url = format!(
281        "{}{}",
282        hyper_request.uri().path(),
283        match hyper_request.uri().query() {
284          Some(query) => format!("?{}", query),
285          None => String::from(""),
286        }
287      );
288
289      let mut auth_user = None;
290
291      for non_standard_code in combined_non_standard_codes_list {
292        let mut redirect_url = None;
293        let mut url_matched = false;
294
295        if let Some(users) = &non_standard_code.users {
296          if !users.is_blocked(socket_data.remote_addr.ip()) {
297            // Don't process this non-standard code
298            continue;
299          }
300        }
301
302        if let Some(regex) = &non_standard_code.regex {
303          let regex_match_option = regex.find(&request_url)?;
304          if let Some(regex_match) = regex_match_option {
305            url_matched = true;
306            if non_standard_code.status_code == 301
307              || non_standard_code.status_code == 302
308              || non_standard_code.status_code == 307
309              || non_standard_code.status_code == 308
310            {
311              let matched_text = regex_match.as_str();
312              if let Some(location) = &non_standard_code.location {
313                redirect_url = Some(regex.replace(matched_text, location).to_string());
314              }
315            }
316          }
317        }
318
319        if !url_matched {
320          if let Some(url) = &non_standard_code.url {
321            if url == hyper_request.uri().path() {
322              url_matched = true;
323              if non_standard_code.status_code == 301
324                || non_standard_code.status_code == 302
325                || non_standard_code.status_code == 307
326                || non_standard_code.status_code == 308
327              {
328                if let Some(location) = &non_standard_code.location {
329                  redirect_url = Some(format!(
330                    "{}{}",
331                    location,
332                    match hyper_request.uri().query() {
333                      Some(query) => format!("?{}", query),
334                      None => String::from(""),
335                    }
336                  ));
337                }
338              }
339            }
340          }
341        }
342
343        if url_matched {
344          match non_standard_code.status_code {
345            301 | 302 | 307 | 308 => {
346              return Ok(
347                ResponseData::builder(request)
348                  .response(
349                    Response::builder()
350                      .status(StatusCode::from_u16(non_standard_code.status_code)?)
351                      .header(header::LOCATION, redirect_url.unwrap_or(request_url))
352                      .body(Empty::new().map_err(|e| match e {}).boxed())?,
353                  )
354                  .build(),
355              );
356            }
357            401 => {
358              let brute_force_db_key = socket_data.remote_addr.ip().to_string();
359              if !non_standard_code.disable_brute_force_protection {
360                let rwlock_read = self.brute_force_db.read().await;
361                let current_attempts = rwlock_read.get(&brute_force_db_key).unwrap_or(0);
362                if current_attempts >= 10 {
363                  error_logger
364                    .log(&format!(
365                      "Too many failed authorization attempts for client \"{}\"",
366                      socket_data.remote_addr.ip()
367                    ))
368                    .await;
369
370                  return Ok(
371                    ResponseData::builder(request)
372                      .status(StatusCode::TOO_MANY_REQUESTS)
373                      .build(),
374                  );
375                }
376              }
377              let mut header_map = HeaderMap::new();
378              header_map.insert(
379                header::WWW_AUTHENTICATE,
380                HeaderValue::from_str(&format!(
381                  "Basic realm=\"{}\", charset=\"UTF-8\"",
382                  non_standard_code
383                    .realm
384                    .clone()
385                    .unwrap_or("Ferron HTTP Basic Authorization".to_string())
386                    .replace("\\", "\\\\")
387                    .replace("\"", "\\\"")
388                ))?,
389              );
390
391              if let Some(authorization_header_value) =
392                hyper_request.headers().get(header::AUTHORIZATION)
393              {
394                let authorization_str = match authorization_header_value.to_str() {
395                  Ok(str) => str,
396                  Err(_) => {
397                    return Ok(
398                      ResponseData::builder(request)
399                        .status(StatusCode::BAD_REQUEST)
400                        .build(),
401                    );
402                  }
403                };
404
405                if let Some((username, password)) = parse_basic_auth(authorization_str) {
406                  if let Some(users_vec_yaml) = config["users"].as_vec() {
407                    let mut authorized_user = None;
408                    for user_yaml in users_vec_yaml {
409                      if let Some(username_db) = user_yaml["name"].as_str() {
410                        if username_db != username {
411                          continue;
412                        }
413                        if let Some(user_list) = &non_standard_code.user_list {
414                          if !user_list.contains(&username) {
415                            continue;
416                          }
417                        }
418                        if let Some(password_hash_db) = user_yaml["pass"].as_str() {
419                          let password_cloned = password.clone();
420                          let password_hash_db_cloned = password_hash_db.to_string();
421                          // Offload verifying the hash into a separate blocking thread.
422                          let password_valid = tokio::task::spawn_blocking(move || {
423                            verify_password(password_cloned, &password_hash_db_cloned).is_ok()
424                          })
425                          .await?;
426                          if password_valid {
427                            authorized_user = Some(&username);
428                            break;
429                          }
430                        }
431                      }
432                    }
433                    if let Some(authorized_user) = authorized_user {
434                      auth_user = Some(authorized_user.to_owned());
435                      continue;
436                    }
437                  }
438
439                  if !non_standard_code.disable_brute_force_protection {
440                    let mut rwlock_write = self.brute_force_db.write().await;
441                    rwlock_write.cleanup();
442                    let current_attempts = rwlock_write.get(&brute_force_db_key).unwrap_or(0);
443                    rwlock_write.insert(brute_force_db_key, current_attempts + 1);
444                  }
445
446                  error_logger
447                    .log(&format!(
448                      "Authorization failed for user \"{}\" and client \"{}\"",
449                      username,
450                      socket_data.remote_addr.ip()
451                    ))
452                    .await;
453                }
454              }
455
456              return Ok(
457                ResponseData::builder(request)
458                  .status(StatusCode::UNAUTHORIZED)
459                  .headers(header_map)
460                  .build(),
461              );
462            }
463            _ => {
464              return Ok(
465                ResponseData::builder(request)
466                  .status(StatusCode::from_u16(non_standard_code.status_code)?)
467                  .build(),
468              )
469            }
470          }
471        }
472      }
473
474      if auth_user.is_some() {
475        let (hyper_request, _, original_url) = request.into_parts();
476        Ok(ResponseData::builder(RequestData::new(hyper_request, auth_user, original_url)).build())
477      } else {
478        Ok(ResponseData::builder(request).build())
479      }
480    })
481    .await
482  }
483
484  async fn proxy_request_handler(
485    &mut self,
486    request: RequestData,
487    _config: &ServerConfig,
488    _socket_data: &SocketData,
489    _error_logger: &ErrorLogger,
490  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
491    Ok(ResponseData::builder(request).build())
492  }
493
494  async fn response_modifying_handler(
495    &mut self,
496    response: HyperResponse,
497  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
498    Ok(response)
499  }
500
501  async fn proxy_response_modifying_handler(
502    &mut self,
503    response: HyperResponse,
504  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
505    Ok(response)
506  }
507
508  async fn connect_proxy_request_handler(
509    &mut self,
510    _upgraded_request: HyperUpgraded,
511    _connect_address: &str,
512    _config: &ServerConfig,
513    _socket_data: &SocketData,
514    _error_logger: &ErrorLogger,
515  ) -> Result<(), Box<dyn Error + Send + Sync>> {
516    Ok(())
517  }
518
519  fn does_connect_proxy_requests(&mut self) -> bool {
520    false
521  }
522
523  async fn websocket_request_handler(
524    &mut self,
525    _websocket: HyperWebsocket,
526    _uri: &hyper::Uri,
527    _config: &ServerConfig,
528    _socket_data: &SocketData,
529    _error_logger: &ErrorLogger,
530  ) -> Result<(), Box<dyn Error + Send + Sync>> {
531    Ok(())
532  }
533
534  fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
535    false
536  }
537}