1use std::error::Error;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::ferron_util::ip_blocklist::IpBlockList;
6use crate::ferron_util::obtain_config_struct_vec::ObtainConfigStructVec;
7use crate::ferron_util::ttl_cache::TtlCache;
8
9use crate::ferron_common::{
10 ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule,
11 ServerModuleHandlers, SocketData,
12};
13use crate::ferron_common::{HyperUpgraded, WithRuntime};
14use async_trait::async_trait;
15use base64::{engine::general_purpose, Engine};
16use fancy_regex::{Regex, RegexBuilder};
17use http_body_util::{BodyExt, Empty};
18use hyper::header::HeaderValue;
19use hyper::{header, HeaderMap, Response, StatusCode};
20use hyper_tungstenite::HyperWebsocket;
21use password_auth::verify_password;
22use tokio::runtime::Handle;
23use tokio::sync::RwLock;
24use yaml_rust2::Yaml;
25
26#[allow(dead_code)]
27struct NonStandardCode {
28 status_code: u16,
29 url: Option<String>,
30 regex: Option<Regex>,
31 location: Option<String>,
32 realm: Option<String>,
33 disable_brute_force_protection: bool,
34 user_list: Option<Vec<String>>,
35 users: Option<IpBlockList>,
36}
37
38impl NonStandardCode {
39 #[allow(clippy::too_many_arguments)]
40 fn new(
41 status_code: u16,
42 url: Option<String>,
43 regex: Option<Regex>,
44 location: Option<String>,
45 realm: Option<String>,
46 disable_brute_force_protection: bool,
47 user_list: Option<Vec<String>>,
48 users: Option<IpBlockList>,
49 ) -> Self {
50 Self {
51 status_code,
52 url,
53 regex,
54 location,
55 realm,
56 disable_brute_force_protection,
57 user_list,
58 users,
59 }
60 }
61}
62
63fn non_standard_codes_config_init(
64 non_standard_codes_list: &[Yaml],
65) -> Result<Vec<NonStandardCode>, anyhow::Error> {
66 let non_standard_codes_list_iter = non_standard_codes_list.iter();
67 let mut non_standard_codes_list_vec = Vec::new();
68 for non_standard_codes_list_entry in non_standard_codes_list_iter {
69 let status_code: u16 = match non_standard_codes_list_entry["scode"].as_i64() {
70 Some(scode) => scode.try_into()?,
71 None => {
72 return Err(anyhow::anyhow!(
73 "Non-standard codes must include a status code"
74 ));
75 }
76 };
77 let regex = match non_standard_codes_list_entry["regex"].as_str() {
78 Some(regex_str) => match RegexBuilder::new(regex_str)
79 .case_insensitive(cfg!(windows))
80 .build()
81 {
82 Ok(regex) => Some(regex),
83 Err(err) => {
84 return Err(anyhow::anyhow!(
85 "Invalid non-standard code regular expression: {}",
86 err.to_string()
87 ));
88 }
89 },
90 None => None,
91 };
92 let url = non_standard_codes_list_entry["url"]
93 .as_str()
94 .map(|s| s.to_string());
95
96 if regex.is_none() && url.is_none() {
97 return Err(anyhow::anyhow!(
98 "Non-standard codes must either include URL or a matching regular expression"
99 ));
100 }
101
102 let location = non_standard_codes_list_entry["location"]
103 .as_str()
104 .map(|s| s.to_string());
105 let realm = non_standard_codes_list_entry["realm"]
106 .as_str()
107 .map(|s| s.to_string());
108 let disable_brute_force_protection = non_standard_codes_list_entry["disableBruteProtection"]
109 .as_bool()
110 .unwrap_or(false);
111 let user_list = match non_standard_codes_list_entry["userList"].as_vec() {
112 Some(userlist) => {
113 let mut new_userlist = Vec::new();
114 for user_yaml in userlist.iter() {
115 if let Some(user) = user_yaml.as_str() {
116 new_userlist.push(user.to_string());
117 }
118 }
119 Some(new_userlist)
120 }
121 None => None,
122 };
123 let users = match non_standard_codes_list_entry["users"].as_vec() {
124 Some(users_vec) => {
125 let mut users_str_vec = Vec::new();
126 for user_yaml in users_vec.iter() {
127 if let Some(user) = user_yaml.as_str() {
128 users_str_vec.push(user);
129 }
130 }
131
132 let mut users_init = IpBlockList::new();
133 users_init.load_from_vec(users_str_vec);
134 Some(users_init)
135 }
136 None => None,
137 };
138 non_standard_codes_list_vec.push(NonStandardCode::new(
139 status_code,
140 url,
141 regex,
142 location,
143 realm,
144 disable_brute_force_protection,
145 user_list,
146 users,
147 ));
148 }
149
150 Ok(non_standard_codes_list_vec)
151}
152
153pub fn server_module_init(
154 config: &ServerConfig,
155) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
156 Ok(Box::new(NonStandardCodesModule::new(
157 ObtainConfigStructVec::new(config, |config| {
158 if let Some(non_standard_codes_yaml) = config["nonStandardCodes"].as_vec() {
159 Ok(non_standard_codes_config_init(non_standard_codes_yaml)?)
160 } else {
161 Ok(vec![])
162 }
163 })?,
164 Arc::new(RwLock::new(TtlCache::new(Duration::new(300, 0)))),
165 )))
166}
167
168struct NonStandardCodesModule {
169 non_standard_codes_list: ObtainConfigStructVec<NonStandardCode>,
170 brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
171}
172
173impl NonStandardCodesModule {
174 fn new(
175 non_standard_codes_list: ObtainConfigStructVec<NonStandardCode>,
176 brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
177 ) -> Self {
178 Self {
179 non_standard_codes_list,
180 brute_force_db,
181 }
182 }
183}
184
185impl ServerModule for NonStandardCodesModule {
186 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
187 Box::new(NonStandardCodesModuleHandlers {
188 non_standard_codes_list: self.non_standard_codes_list.clone(),
189 brute_force_db: self.brute_force_db.clone(),
190 handle,
191 })
192 }
193}
194
195fn parse_basic_auth(auth_str: &str) -> Option<(String, String)> {
196 if let Some(base64_credentials) = auth_str.strip_prefix("Basic ") {
197 if let Ok(decoded) = general_purpose::STANDARD.decode(base64_credentials) {
198 if let Ok(decoded_str) = std::str::from_utf8(&decoded) {
199 let parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
200 if parts.len() == 2 {
201 return Some((parts[0].to_string(), parts[1].to_string()));
202 }
203 }
204 }
205 }
206 None
207}
208
209struct NonStandardCodesModuleHandlers {
210 non_standard_codes_list: ObtainConfigStructVec<NonStandardCode>,
211 brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
212 handle: Handle,
213}
214
215#[async_trait]
216impl ServerModuleHandlers for NonStandardCodesModuleHandlers {
217 async fn request_handler(
218 &mut self,
219 request: RequestData,
220 config: &ServerConfig,
221 socket_data: &SocketData,
222 error_logger: &ErrorLogger,
223 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
224 WithRuntime::new(self.handle.clone(), async move {
225 let hyper_request = request.get_hyper_request();
226 let combined_non_standard_codes_list = self.non_standard_codes_list.obtain(
227 match hyper_request.headers().get(header::HOST) {
228 Some(value) => value.to_str().ok(),
229 None => None,
230 },
231 socket_data.remote_addr.ip(),
232 request
233 .get_original_url()
234 .unwrap_or(request.get_hyper_request().uri())
235 .path(),
236 request.get_error_status_code().map(|x| x.as_u16()),
237 );
238
239 let request_url = format!(
240 "{}{}",
241 hyper_request.uri().path(),
242 match hyper_request.uri().query() {
243 Some(query) => format!("?{}", query),
244 None => String::from(""),
245 }
246 );
247
248 let mut auth_user = None;
249
250 for non_standard_code in combined_non_standard_codes_list {
251 let mut redirect_url = None;
252 let mut url_matched = false;
253
254 if let Some(users) = &non_standard_code.users {
255 if !users.is_blocked(socket_data.remote_addr.ip()) {
256 continue;
258 }
259 }
260
261 if let Some(regex) = &non_standard_code.regex {
262 let regex_match_option = regex.find(&request_url)?;
263 if let Some(regex_match) = regex_match_option {
264 url_matched = true;
265 if non_standard_code.status_code == 301
266 || non_standard_code.status_code == 302
267 || non_standard_code.status_code == 307
268 || non_standard_code.status_code == 308
269 {
270 let matched_text = regex_match.as_str();
271 if let Some(location) = &non_standard_code.location {
272 redirect_url = Some(regex.replace(matched_text, location).to_string());
273 }
274 }
275 }
276 }
277
278 if !url_matched {
279 if let Some(url) = &non_standard_code.url {
280 if url == hyper_request.uri().path() {
281 url_matched = true;
282 if non_standard_code.status_code == 301
283 || non_standard_code.status_code == 302
284 || non_standard_code.status_code == 307
285 || non_standard_code.status_code == 308
286 {
287 if let Some(location) = &non_standard_code.location {
288 redirect_url = Some(format!(
289 "{}{}",
290 location,
291 match hyper_request.uri().query() {
292 Some(query) => format!("?{}", query),
293 None => String::from(""),
294 }
295 ));
296 }
297 }
298 }
299 }
300 }
301
302 if url_matched {
303 match non_standard_code.status_code {
304 301 | 302 | 307 | 308 => {
305 return Ok(
306 ResponseData::builder(request)
307 .response(
308 Response::builder()
309 .status(StatusCode::from_u16(non_standard_code.status_code)?)
310 .header(header::LOCATION, redirect_url.unwrap_or(request_url))
311 .body(Empty::new().map_err(|e| match e {}).boxed())?,
312 )
313 .build(),
314 );
315 }
316 401 => {
317 let brute_force_db_key = socket_data.remote_addr.ip().to_string();
318 if !non_standard_code.disable_brute_force_protection {
319 let rwlock_read = self.brute_force_db.read().await;
320 let current_attempts = rwlock_read.get(&brute_force_db_key).unwrap_or(0);
321 if current_attempts >= 10 {
322 error_logger
323 .log(&format!(
324 "Too many failed authorization attempts for client \"{}\"",
325 socket_data.remote_addr.ip()
326 ))
327 .await;
328
329 return Ok(
330 ResponseData::builder(request)
331 .status(StatusCode::TOO_MANY_REQUESTS)
332 .build(),
333 );
334 }
335 }
336 let mut header_map = HeaderMap::new();
337 header_map.insert(
338 header::WWW_AUTHENTICATE,
339 HeaderValue::from_str(&format!(
340 "Basic realm=\"{}\", charset=\"UTF-8\"",
341 non_standard_code
342 .realm
343 .clone()
344 .unwrap_or("Ferron HTTP Basic Authorization".to_string())
345 .replace("\\", "\\\\")
346 .replace("\"", "\\\"")
347 ))?,
348 );
349
350 if let Some(authorization_header_value) =
351 hyper_request.headers().get(header::AUTHORIZATION)
352 {
353 let authorization_str = match authorization_header_value.to_str() {
354 Ok(str) => str,
355 Err(_) => {
356 return Ok(
357 ResponseData::builder(request)
358 .status(StatusCode::BAD_REQUEST)
359 .build(),
360 );
361 }
362 };
363
364 if let Some((username, password)) = parse_basic_auth(authorization_str) {
365 if let Some(users_vec_yaml) = config["users"].as_vec() {
366 let mut authorized_user = None;
367 for user_yaml in users_vec_yaml {
368 if let Some(username_db) = user_yaml["name"].as_str() {
369 if username_db != username {
370 continue;
371 }
372 if let Some(user_list) = &non_standard_code.user_list {
373 if !user_list.contains(&username) {
374 continue;
375 }
376 }
377 if let Some(password_hash_db) = user_yaml["pass"].as_str() {
378 let password_cloned = password.clone();
379 let password_hash_db_cloned = password_hash_db.to_string();
380 let password_valid = tokio::task::spawn_blocking(move || {
382 verify_password(password_cloned, &password_hash_db_cloned).is_ok()
383 })
384 .await?;
385 if password_valid {
386 authorized_user = Some(&username);
387 break;
388 }
389 }
390 }
391 }
392 if let Some(authorized_user) = authorized_user {
393 auth_user = Some(authorized_user.to_owned());
394 continue;
395 }
396 }
397
398 if !non_standard_code.disable_brute_force_protection {
399 let mut rwlock_write = self.brute_force_db.write().await;
400 rwlock_write.cleanup();
401 let current_attempts = rwlock_write.get(&brute_force_db_key).unwrap_or(0);
402 rwlock_write.insert(brute_force_db_key, current_attempts + 1);
403 }
404
405 error_logger
406 .log(&format!(
407 "Authorization failed for user \"{}\" and client \"{}\"",
408 username,
409 socket_data.remote_addr.ip()
410 ))
411 .await;
412 }
413 }
414
415 return Ok(
416 ResponseData::builder(request)
417 .status(StatusCode::UNAUTHORIZED)
418 .headers(header_map)
419 .build(),
420 );
421 }
422 _ => {
423 return Ok(
424 ResponseData::builder(request)
425 .status(StatusCode::from_u16(non_standard_code.status_code)?)
426 .build(),
427 )
428 }
429 }
430 }
431 }
432
433 if auth_user.is_some() {
434 let (hyper_request, _, original_url, error_status_code) = request.into_parts();
435 Ok(
436 ResponseData::builder(RequestData::new(
437 hyper_request,
438 auth_user,
439 original_url,
440 error_status_code,
441 ))
442 .build(),
443 )
444 } else {
445 Ok(ResponseData::builder(request).build())
446 }
447 })
448 .await
449 }
450
451 async fn proxy_request_handler(
452 &mut self,
453 request: RequestData,
454 _config: &ServerConfig,
455 _socket_data: &SocketData,
456 _error_logger: &ErrorLogger,
457 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
458 Ok(ResponseData::builder(request).build())
459 }
460
461 async fn response_modifying_handler(
462 &mut self,
463 response: HyperResponse,
464 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
465 Ok(response)
466 }
467
468 async fn proxy_response_modifying_handler(
469 &mut self,
470 response: HyperResponse,
471 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
472 Ok(response)
473 }
474
475 async fn connect_proxy_request_handler(
476 &mut self,
477 _upgraded_request: HyperUpgraded,
478 _connect_address: &str,
479 _config: &ServerConfig,
480 _socket_data: &SocketData,
481 _error_logger: &ErrorLogger,
482 ) -> Result<(), Box<dyn Error + Send + Sync>> {
483 Ok(())
484 }
485
486 fn does_connect_proxy_requests(&mut self) -> bool {
487 false
488 }
489
490 async fn websocket_request_handler(
491 &mut self,
492 _websocket: HyperWebsocket,
493 _uri: &hyper::Uri,
494 _headers: &hyper::HeaderMap,
495 _config: &ServerConfig,
496 _socket_data: &SocketData,
497 _error_logger: &ErrorLogger,
498 ) -> Result<(), Box<dyn Error + Send + Sync>> {
499 Ok(())
500 }
501
502 fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
503 false
504 }
505}