ferron/modules/
non_standard_codes.rs

1use std::error::Error;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::ferron_util::ip_blocklist::IpBlockList;
6use crate::ferron_util::obtain_config_struct_vec::ObtainConfigStructVec;
7use crate::ferron_util::ttl_cache::TtlCache;
8
9use crate::ferron_common::{
10  ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule,
11  ServerModuleHandlers, SocketData,
12};
13use crate::ferron_common::{HyperUpgraded, WithRuntime};
14use async_trait::async_trait;
15use base64::{engine::general_purpose, Engine};
16use fancy_regex::{Regex, RegexBuilder};
17use http_body_util::{BodyExt, Empty};
18use hyper::header::HeaderValue;
19use hyper::{header, HeaderMap, Response, StatusCode};
20use hyper_tungstenite::HyperWebsocket;
21use password_auth::verify_password;
22use tokio::runtime::Handle;
23use tokio::sync::RwLock;
24use yaml_rust2::Yaml;
25
26#[allow(dead_code)]
27struct NonStandardCode {
28  status_code: u16,
29  url: Option<String>,
30  regex: Option<Regex>,
31  location: Option<String>,
32  realm: Option<String>,
33  disable_brute_force_protection: bool,
34  user_list: Option<Vec<String>>,
35  users: Option<IpBlockList>,
36}
37
38impl NonStandardCode {
39  #[allow(clippy::too_many_arguments)]
40  fn new(
41    status_code: u16,
42    url: Option<String>,
43    regex: Option<Regex>,
44    location: Option<String>,
45    realm: Option<String>,
46    disable_brute_force_protection: bool,
47    user_list: Option<Vec<String>>,
48    users: Option<IpBlockList>,
49  ) -> Self {
50    Self {
51      status_code,
52      url,
53      regex,
54      location,
55      realm,
56      disable_brute_force_protection,
57      user_list,
58      users,
59    }
60  }
61}
62
63fn non_standard_codes_config_init(
64  non_standard_codes_list: &[Yaml],
65) -> Result<Vec<NonStandardCode>, anyhow::Error> {
66  let non_standard_codes_list_iter = non_standard_codes_list.iter();
67  let mut non_standard_codes_list_vec = Vec::new();
68  for non_standard_codes_list_entry in non_standard_codes_list_iter {
69    let status_code: u16 = match non_standard_codes_list_entry["scode"].as_i64() {
70      Some(scode) => scode.try_into()?,
71      None => {
72        return Err(anyhow::anyhow!(
73          "Non-standard codes must include a status code"
74        ));
75      }
76    };
77    let regex = match non_standard_codes_list_entry["regex"].as_str() {
78      Some(regex_str) => match RegexBuilder::new(regex_str)
79        .case_insensitive(cfg!(windows))
80        .build()
81      {
82        Ok(regex) => Some(regex),
83        Err(err) => {
84          return Err(anyhow::anyhow!(
85            "Invalid non-standard code regular expression: {}",
86            err.to_string()
87          ));
88        }
89      },
90      None => None,
91    };
92    let url = non_standard_codes_list_entry["url"]
93      .as_str()
94      .map(|s| s.to_string());
95
96    if regex.is_none() && url.is_none() {
97      return Err(anyhow::anyhow!(
98        "Non-standard codes must either include URL or a matching regular expression"
99      ));
100    }
101
102    let location = non_standard_codes_list_entry["location"]
103      .as_str()
104      .map(|s| s.to_string());
105    let realm = non_standard_codes_list_entry["realm"]
106      .as_str()
107      .map(|s| s.to_string());
108    let disable_brute_force_protection = non_standard_codes_list_entry["disableBruteProtection"]
109      .as_bool()
110      .unwrap_or(false);
111    let user_list = match non_standard_codes_list_entry["userList"].as_vec() {
112      Some(userlist) => {
113        let mut new_userlist = Vec::new();
114        for user_yaml in userlist.iter() {
115          if let Some(user) = user_yaml.as_str() {
116            new_userlist.push(user.to_string());
117          }
118        }
119        Some(new_userlist)
120      }
121      None => None,
122    };
123    let users = match non_standard_codes_list_entry["users"].as_vec() {
124      Some(users_vec) => {
125        let mut users_str_vec = Vec::new();
126        for user_yaml in users_vec.iter() {
127          if let Some(user) = user_yaml.as_str() {
128            users_str_vec.push(user);
129          }
130        }
131
132        let mut users_init = IpBlockList::new();
133        users_init.load_from_vec(users_str_vec);
134        Some(users_init)
135      }
136      None => None,
137    };
138    non_standard_codes_list_vec.push(NonStandardCode::new(
139      status_code,
140      url,
141      regex,
142      location,
143      realm,
144      disable_brute_force_protection,
145      user_list,
146      users,
147    ));
148  }
149
150  Ok(non_standard_codes_list_vec)
151}
152
153pub fn server_module_init(
154  config: &ServerConfig,
155) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
156  Ok(Box::new(NonStandardCodesModule::new(
157    ObtainConfigStructVec::new(config, |config| {
158      if let Some(non_standard_codes_yaml) = config["nonStandardCodes"].as_vec() {
159        Ok(non_standard_codes_config_init(non_standard_codes_yaml)?)
160      } else {
161        Ok(vec![])
162      }
163    })?,
164    Arc::new(RwLock::new(TtlCache::new(Duration::new(300, 0)))),
165  )))
166}
167
168struct NonStandardCodesModule {
169  non_standard_codes_list: ObtainConfigStructVec<NonStandardCode>,
170  brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
171}
172
173impl NonStandardCodesModule {
174  fn new(
175    non_standard_codes_list: ObtainConfigStructVec<NonStandardCode>,
176    brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
177  ) -> Self {
178    Self {
179      non_standard_codes_list,
180      brute_force_db,
181    }
182  }
183}
184
185impl ServerModule for NonStandardCodesModule {
186  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
187    Box::new(NonStandardCodesModuleHandlers {
188      non_standard_codes_list: self.non_standard_codes_list.clone(),
189      brute_force_db: self.brute_force_db.clone(),
190      handle,
191    })
192  }
193}
194
195fn parse_basic_auth(auth_str: &str) -> Option<(String, String)> {
196  if let Some(base64_credentials) = auth_str.strip_prefix("Basic ") {
197    if let Ok(decoded) = general_purpose::STANDARD.decode(base64_credentials) {
198      if let Ok(decoded_str) = std::str::from_utf8(&decoded) {
199        let parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
200        if parts.len() == 2 {
201          return Some((parts[0].to_string(), parts[1].to_string()));
202        }
203      }
204    }
205  }
206  None
207}
208
209struct NonStandardCodesModuleHandlers {
210  non_standard_codes_list: ObtainConfigStructVec<NonStandardCode>,
211  brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
212  handle: Handle,
213}
214
215#[async_trait]
216impl ServerModuleHandlers for NonStandardCodesModuleHandlers {
217  async fn request_handler(
218    &mut self,
219    request: RequestData,
220    config: &ServerConfig,
221    socket_data: &SocketData,
222    error_logger: &ErrorLogger,
223  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
224    WithRuntime::new(self.handle.clone(), async move {
225      let hyper_request = request.get_hyper_request();
226      let combined_non_standard_codes_list = self.non_standard_codes_list.obtain(
227        match hyper_request.headers().get(header::HOST) {
228          Some(value) => value.to_str().ok(),
229          None => None,
230        },
231        socket_data.remote_addr.ip(),
232        request
233          .get_original_url()
234          .unwrap_or(request.get_hyper_request().uri())
235          .path(),
236        request.get_error_status_code().map(|x| x.as_u16()),
237      );
238
239      let request_url = format!(
240        "{}{}",
241        hyper_request.uri().path(),
242        match hyper_request.uri().query() {
243          Some(query) => format!("?{}", query),
244          None => String::from(""),
245        }
246      );
247
248      let mut auth_user = None;
249
250      for non_standard_code in combined_non_standard_codes_list {
251        let mut redirect_url = None;
252        let mut url_matched = false;
253
254        if let Some(users) = &non_standard_code.users {
255          if !users.is_blocked(socket_data.remote_addr.ip()) {
256            // Don't process this non-standard code
257            continue;
258          }
259        }
260
261        if let Some(regex) = &non_standard_code.regex {
262          let regex_match_option = regex.find(&request_url)?;
263          if let Some(regex_match) = regex_match_option {
264            url_matched = true;
265            if non_standard_code.status_code == 301
266              || non_standard_code.status_code == 302
267              || non_standard_code.status_code == 307
268              || non_standard_code.status_code == 308
269            {
270              let matched_text = regex_match.as_str();
271              if let Some(location) = &non_standard_code.location {
272                redirect_url = Some(regex.replace(matched_text, location).to_string());
273              }
274            }
275          }
276        }
277
278        if !url_matched {
279          if let Some(url) = &non_standard_code.url {
280            if url == hyper_request.uri().path() {
281              url_matched = true;
282              if non_standard_code.status_code == 301
283                || non_standard_code.status_code == 302
284                || non_standard_code.status_code == 307
285                || non_standard_code.status_code == 308
286              {
287                if let Some(location) = &non_standard_code.location {
288                  redirect_url = Some(format!(
289                    "{}{}",
290                    location,
291                    match hyper_request.uri().query() {
292                      Some(query) => format!("?{}", query),
293                      None => String::from(""),
294                    }
295                  ));
296                }
297              }
298            }
299          }
300        }
301
302        if url_matched {
303          match non_standard_code.status_code {
304            301 | 302 | 307 | 308 => {
305              return Ok(
306                ResponseData::builder(request)
307                  .response(
308                    Response::builder()
309                      .status(StatusCode::from_u16(non_standard_code.status_code)?)
310                      .header(header::LOCATION, redirect_url.unwrap_or(request_url))
311                      .body(Empty::new().map_err(|e| match e {}).boxed())?,
312                  )
313                  .build(),
314              );
315            }
316            401 => {
317              let brute_force_db_key = socket_data.remote_addr.ip().to_string();
318              if !non_standard_code.disable_brute_force_protection {
319                let rwlock_read = self.brute_force_db.read().await;
320                let current_attempts = rwlock_read.get(&brute_force_db_key).unwrap_or(0);
321                if current_attempts >= 10 {
322                  error_logger
323                    .log(&format!(
324                      "Too many failed authorization attempts for client \"{}\"",
325                      socket_data.remote_addr.ip()
326                    ))
327                    .await;
328
329                  return Ok(
330                    ResponseData::builder(request)
331                      .status(StatusCode::TOO_MANY_REQUESTS)
332                      .build(),
333                  );
334                }
335              }
336              let mut header_map = HeaderMap::new();
337              header_map.insert(
338                header::WWW_AUTHENTICATE,
339                HeaderValue::from_str(&format!(
340                  "Basic realm=\"{}\", charset=\"UTF-8\"",
341                  non_standard_code
342                    .realm
343                    .clone()
344                    .unwrap_or("Ferron HTTP Basic Authorization".to_string())
345                    .replace("\\", "\\\\")
346                    .replace("\"", "\\\"")
347                ))?,
348              );
349
350              if let Some(authorization_header_value) =
351                hyper_request.headers().get(header::AUTHORIZATION)
352              {
353                let authorization_str = match authorization_header_value.to_str() {
354                  Ok(str) => str,
355                  Err(_) => {
356                    return Ok(
357                      ResponseData::builder(request)
358                        .status(StatusCode::BAD_REQUEST)
359                        .build(),
360                    );
361                  }
362                };
363
364                if let Some((username, password)) = parse_basic_auth(authorization_str) {
365                  if let Some(users_vec_yaml) = config["users"].as_vec() {
366                    let mut authorized_user = None;
367                    for user_yaml in users_vec_yaml {
368                      if let Some(username_db) = user_yaml["name"].as_str() {
369                        if username_db != username {
370                          continue;
371                        }
372                        if let Some(user_list) = &non_standard_code.user_list {
373                          if !user_list.contains(&username) {
374                            continue;
375                          }
376                        }
377                        if let Some(password_hash_db) = user_yaml["pass"].as_str() {
378                          let password_cloned = password.clone();
379                          let password_hash_db_cloned = password_hash_db.to_string();
380                          // Offload verifying the hash into a separate blocking thread.
381                          let password_valid = tokio::task::spawn_blocking(move || {
382                            verify_password(password_cloned, &password_hash_db_cloned).is_ok()
383                          })
384                          .await?;
385                          if password_valid {
386                            authorized_user = Some(&username);
387                            break;
388                          }
389                        }
390                      }
391                    }
392                    if let Some(authorized_user) = authorized_user {
393                      auth_user = Some(authorized_user.to_owned());
394                      continue;
395                    }
396                  }
397
398                  if !non_standard_code.disable_brute_force_protection {
399                    let mut rwlock_write = self.brute_force_db.write().await;
400                    rwlock_write.cleanup();
401                    let current_attempts = rwlock_write.get(&brute_force_db_key).unwrap_or(0);
402                    rwlock_write.insert(brute_force_db_key, current_attempts + 1);
403                  }
404
405                  error_logger
406                    .log(&format!(
407                      "Authorization failed for user \"{}\" and client \"{}\"",
408                      username,
409                      socket_data.remote_addr.ip()
410                    ))
411                    .await;
412                }
413              }
414
415              return Ok(
416                ResponseData::builder(request)
417                  .status(StatusCode::UNAUTHORIZED)
418                  .headers(header_map)
419                  .build(),
420              );
421            }
422            _ => {
423              return Ok(
424                ResponseData::builder(request)
425                  .status(StatusCode::from_u16(non_standard_code.status_code)?)
426                  .build(),
427              )
428            }
429          }
430        }
431      }
432
433      if auth_user.is_some() {
434        let (hyper_request, _, original_url, error_status_code) = request.into_parts();
435        Ok(
436          ResponseData::builder(RequestData::new(
437            hyper_request,
438            auth_user,
439            original_url,
440            error_status_code,
441          ))
442          .build(),
443        )
444      } else {
445        Ok(ResponseData::builder(request).build())
446      }
447    })
448    .await
449  }
450
451  async fn proxy_request_handler(
452    &mut self,
453    request: RequestData,
454    _config: &ServerConfig,
455    _socket_data: &SocketData,
456    _error_logger: &ErrorLogger,
457  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
458    Ok(ResponseData::builder(request).build())
459  }
460
461  async fn response_modifying_handler(
462    &mut self,
463    response: HyperResponse,
464  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
465    Ok(response)
466  }
467
468  async fn proxy_response_modifying_handler(
469    &mut self,
470    response: HyperResponse,
471  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
472    Ok(response)
473  }
474
475  async fn connect_proxy_request_handler(
476    &mut self,
477    _upgraded_request: HyperUpgraded,
478    _connect_address: &str,
479    _config: &ServerConfig,
480    _socket_data: &SocketData,
481    _error_logger: &ErrorLogger,
482  ) -> Result<(), Box<dyn Error + Send + Sync>> {
483    Ok(())
484  }
485
486  fn does_connect_proxy_requests(&mut self) -> bool {
487    false
488  }
489
490  async fn websocket_request_handler(
491    &mut self,
492    _websocket: HyperWebsocket,
493    _uri: &hyper::Uri,
494    _headers: &hyper::HeaderMap,
495    _config: &ServerConfig,
496    _socket_data: &SocketData,
497    _error_logger: &ErrorLogger,
498  ) -> Result<(), Box<dyn Error + Send + Sync>> {
499    Ok(())
500  }
501
502  fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
503    false
504  }
505}