1use std::error::Error;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::ferron_util::ip_blocklist::IpBlockList;
6use crate::ferron_util::ip_match::ip_match;
7use crate::ferron_util::match_hostname::match_hostname;
8use crate::ferron_util::match_location::match_location;
9use crate::ferron_util::non_standard_code_structs::{
10 NonStandardCode, NonStandardCodesLocationWrap, NonStandardCodesWrap,
11};
12use crate::ferron_util::ttl_cache::TtlCache;
13
14use crate::ferron_common::{
15 ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule,
16 ServerModuleHandlers, SocketData,
17};
18use crate::ferron_common::{HyperUpgraded, WithRuntime};
19use async_trait::async_trait;
20use base64::{engine::general_purpose, Engine};
21use fancy_regex::RegexBuilder;
22use http_body_util::{BodyExt, Empty};
23use hyper::header::HeaderValue;
24use hyper::{header, HeaderMap, Response, StatusCode};
25use hyper_tungstenite::HyperWebsocket;
26use password_auth::verify_password;
27use tokio::runtime::Handle;
28use tokio::sync::RwLock;
29use yaml_rust2::Yaml;
30
31fn non_standard_codes_config_init(
32 non_standard_codes_list: &[Yaml],
33) -> Result<Vec<NonStandardCode>, anyhow::Error> {
34 let non_standard_codes_list_iter = non_standard_codes_list.iter();
35 let mut non_standard_codes_list_vec = Vec::new();
36 for non_standard_codes_list_entry in non_standard_codes_list_iter {
37 let status_code: u16 = match non_standard_codes_list_entry["scode"].as_i64() {
38 Some(scode) => scode.try_into()?,
39 None => {
40 return Err(anyhow::anyhow!(
41 "Non-standard codes must include a status code"
42 ));
43 }
44 };
45 let regex = match non_standard_codes_list_entry["regex"].as_str() {
46 Some(regex_str) => match RegexBuilder::new(regex_str)
47 .case_insensitive(cfg!(windows))
48 .build()
49 {
50 Ok(regex) => Some(regex),
51 Err(err) => {
52 return Err(anyhow::anyhow!(
53 "Invalid non-standard code regular expression: {}",
54 err.to_string()
55 ));
56 }
57 },
58 None => None,
59 };
60 let url = non_standard_codes_list_entry["url"]
61 .as_str()
62 .map(|s| s.to_string());
63
64 if regex.is_none() && url.is_none() {
65 return Err(anyhow::anyhow!(
66 "Non-standard codes must either include URL or a matching regular expression"
67 ));
68 }
69
70 let location = non_standard_codes_list_entry["location"]
71 .as_str()
72 .map(|s| s.to_string());
73 let realm = non_standard_codes_list_entry["realm"]
74 .as_str()
75 .map(|s| s.to_string());
76 let disable_brute_force_protection = non_standard_codes_list_entry["disableBruteProtection"]
77 .as_bool()
78 .unwrap_or(false);
79 let user_list = match non_standard_codes_list_entry["userList"].as_vec() {
80 Some(userlist) => {
81 let mut new_userlist = Vec::new();
82 for user_yaml in userlist.iter() {
83 if let Some(user) = user_yaml.as_str() {
84 new_userlist.push(user.to_string());
85 }
86 }
87 Some(new_userlist)
88 }
89 None => None,
90 };
91 let users = match non_standard_codes_list_entry["users"].as_vec() {
92 Some(users_vec) => {
93 let mut users_str_vec = Vec::new();
94 for user_yaml in users_vec.iter() {
95 if let Some(user) = user_yaml.as_str() {
96 users_str_vec.push(user);
97 }
98 }
99
100 let mut users_init = IpBlockList::new();
101 users_init.load_from_vec(users_str_vec);
102 Some(users_init)
103 }
104 None => None,
105 };
106 non_standard_codes_list_vec.push(NonStandardCode::new(
107 status_code,
108 url,
109 regex,
110 location,
111 realm,
112 disable_brute_force_protection,
113 user_list,
114 users,
115 ));
116 }
117
118 Ok(non_standard_codes_list_vec)
119}
120
121pub fn server_module_init(
122 config: &ServerConfig,
123) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
124 let mut global_non_standard_codes_list = Vec::new();
125 let mut host_non_standard_codes_lists = Vec::new();
126 if let Some(non_standard_codes_list_yaml) = config["global"]["nonStandardCodes"].as_vec() {
127 global_non_standard_codes_list = non_standard_codes_config_init(non_standard_codes_list_yaml)?;
128 }
129
130 if let Some(hosts) = config["hosts"].as_vec() {
131 for host_yaml in hosts.iter() {
132 let domain = host_yaml["domain"].as_str().map(String::from);
133 let ip = host_yaml["ip"].as_str().map(String::from);
134 let mut locations = Vec::new();
135 if let Some(locations_yaml) = host_yaml["locations"].as_vec() {
136 for location_yaml in locations_yaml.iter() {
137 if let Some(path_str) = location_yaml["path"].as_str() {
138 let path = String::from(path_str);
139 if let Some(non_standard_codes_list_yaml) = location_yaml["nonStandardCodes"].as_vec() {
140 locations.push(NonStandardCodesLocationWrap::new(
141 path,
142 non_standard_codes_config_init(non_standard_codes_list_yaml)?,
143 ));
144 }
145 }
146 }
147 }
148 if let Some(non_standard_codes_list_yaml) = host_yaml["nonStandardCodes"].as_vec() {
149 host_non_standard_codes_lists.push(NonStandardCodesWrap::new(
150 domain,
151 ip,
152 non_standard_codes_config_init(non_standard_codes_list_yaml)?,
153 locations,
154 ));
155 } else if !locations.is_empty() {
156 host_non_standard_codes_lists.push(NonStandardCodesWrap::new(
157 domain,
158 ip,
159 Vec::new(),
160 locations,
161 ));
162 }
163 }
164 }
165
166 Ok(Box::new(NonStandardCodesModule::new(
167 Arc::new(global_non_standard_codes_list),
168 Arc::new(host_non_standard_codes_lists),
169 Arc::new(RwLock::new(TtlCache::new(Duration::new(300, 0)))),
170 )))
171}
172
173struct NonStandardCodesModule {
174 global_non_standard_codes_list: Arc<Vec<NonStandardCode>>,
175 host_non_standard_codes_lists: Arc<Vec<NonStandardCodesWrap>>,
176 brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
177}
178
179impl NonStandardCodesModule {
180 fn new(
181 global_non_standard_codes_list: Arc<Vec<NonStandardCode>>,
182 host_non_standard_codes_lists: Arc<Vec<NonStandardCodesWrap>>,
183 brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
184 ) -> Self {
185 Self {
186 global_non_standard_codes_list,
187 host_non_standard_codes_lists,
188 brute_force_db,
189 }
190 }
191}
192
193impl ServerModule for NonStandardCodesModule {
194 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
195 Box::new(NonStandardCodesModuleHandlers {
196 global_non_standard_codes_list: self.global_non_standard_codes_list.clone(),
197 host_non_standard_codes_lists: self.host_non_standard_codes_lists.clone(),
198 brute_force_db: self.brute_force_db.clone(),
199 handle,
200 })
201 }
202}
203
204fn parse_basic_auth(auth_str: &str) -> Option<(String, String)> {
205 if let Some(base64_credentials) = auth_str.strip_prefix("Basic ") {
206 if let Ok(decoded) = general_purpose::STANDARD.decode(base64_credentials) {
207 if let Ok(decoded_str) = std::str::from_utf8(&decoded) {
208 let parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
209 if parts.len() == 2 {
210 return Some((parts[0].to_string(), parts[1].to_string()));
211 }
212 }
213 }
214 }
215 None
216}
217
218struct NonStandardCodesModuleHandlers {
219 global_non_standard_codes_list: Arc<Vec<NonStandardCode>>,
220 host_non_standard_codes_lists: Arc<Vec<NonStandardCodesWrap>>,
221 brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
222 handle: Handle,
223}
224
225#[async_trait]
226impl ServerModuleHandlers for NonStandardCodesModuleHandlers {
227 async fn request_handler(
228 &mut self,
229 request: RequestData,
230 config: &ServerConfig,
231 socket_data: &SocketData,
232 error_logger: &ErrorLogger,
233 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
234 WithRuntime::new(self.handle.clone(), async move {
235 let hyper_request = request.get_hyper_request();
236 let global_non_standard_codes_list = self.global_non_standard_codes_list.iter();
237 let empty_vector = Vec::new();
238 let another_empty_vector = Vec::new();
239 let mut host_non_standard_codes_list = empty_vector.iter();
240 let mut location_non_standard_codes_list = another_empty_vector.iter();
241
242 for host_non_standard_codes_list_wrap in self.host_non_standard_codes_lists.iter() {
244 if match_hostname(
245 match &host_non_standard_codes_list_wrap.domain {
246 Some(value) => Some(value as &str),
247 None => None,
248 },
249 match hyper_request.headers().get(header::HOST) {
250 Some(value) => value.to_str().ok(),
251 None => None,
252 },
253 ) && match &host_non_standard_codes_list_wrap.ip {
254 Some(value) => ip_match(value as &str, socket_data.remote_addr.ip()),
255 None => true,
256 } {
257 host_non_standard_codes_list =
258 host_non_standard_codes_list_wrap.non_standard_codes.iter();
259 if let Ok(path_decoded) = urlencoding::decode(
260 request
261 .get_original_url()
262 .unwrap_or(request.get_hyper_request().uri())
263 .path(),
264 ) {
265 for location_wrap in host_non_standard_codes_list_wrap.locations.iter() {
266 if match_location(&location_wrap.path, &path_decoded) {
267 location_non_standard_codes_list = location_wrap.non_standard_codes.iter();
268 break;
269 }
270 }
271 }
272 break;
273 }
274 }
275
276 let combined_non_standard_codes_list = global_non_standard_codes_list
277 .chain(host_non_standard_codes_list)
278 .chain(location_non_standard_codes_list);
279
280 let request_url = format!(
281 "{}{}",
282 hyper_request.uri().path(),
283 match hyper_request.uri().query() {
284 Some(query) => format!("?{}", query),
285 None => String::from(""),
286 }
287 );
288
289 let mut auth_user = None;
290
291 for non_standard_code in combined_non_standard_codes_list {
292 let mut redirect_url = None;
293 let mut url_matched = false;
294
295 if let Some(users) = &non_standard_code.users {
296 if !users.is_blocked(socket_data.remote_addr.ip()) {
297 continue;
299 }
300 }
301
302 if let Some(regex) = &non_standard_code.regex {
303 let regex_match_option = regex.find(&request_url)?;
304 if let Some(regex_match) = regex_match_option {
305 url_matched = true;
306 if non_standard_code.status_code == 301
307 || non_standard_code.status_code == 302
308 || non_standard_code.status_code == 307
309 || non_standard_code.status_code == 308
310 {
311 let matched_text = regex_match.as_str();
312 if let Some(location) = &non_standard_code.location {
313 redirect_url = Some(regex.replace(matched_text, location).to_string());
314 }
315 }
316 }
317 }
318
319 if !url_matched {
320 if let Some(url) = &non_standard_code.url {
321 if url == hyper_request.uri().path() {
322 url_matched = true;
323 if non_standard_code.status_code == 301
324 || non_standard_code.status_code == 302
325 || non_standard_code.status_code == 307
326 || non_standard_code.status_code == 308
327 {
328 if let Some(location) = &non_standard_code.location {
329 redirect_url = Some(format!(
330 "{}{}",
331 location,
332 match hyper_request.uri().query() {
333 Some(query) => format!("?{}", query),
334 None => String::from(""),
335 }
336 ));
337 }
338 }
339 }
340 }
341 }
342
343 if url_matched {
344 match non_standard_code.status_code {
345 301 | 302 | 307 | 308 => {
346 return Ok(
347 ResponseData::builder(request)
348 .response(
349 Response::builder()
350 .status(StatusCode::from_u16(non_standard_code.status_code)?)
351 .header(header::LOCATION, redirect_url.unwrap_or(request_url))
352 .body(Empty::new().map_err(|e| match e {}).boxed())?,
353 )
354 .build(),
355 );
356 }
357 401 => {
358 let brute_force_db_key = socket_data.remote_addr.ip().to_string();
359 if !non_standard_code.disable_brute_force_protection {
360 let rwlock_read = self.brute_force_db.read().await;
361 let current_attempts = rwlock_read.get(&brute_force_db_key).unwrap_or(0);
362 if current_attempts >= 10 {
363 error_logger
364 .log(&format!(
365 "Too many failed authorization attempts for client \"{}\"",
366 socket_data.remote_addr.ip()
367 ))
368 .await;
369
370 return Ok(
371 ResponseData::builder(request)
372 .status(StatusCode::TOO_MANY_REQUESTS)
373 .build(),
374 );
375 }
376 }
377 let mut header_map = HeaderMap::new();
378 header_map.insert(
379 header::WWW_AUTHENTICATE,
380 HeaderValue::from_str(&format!(
381 "Basic realm=\"{}\", charset=\"UTF-8\"",
382 non_standard_code
383 .realm
384 .clone()
385 .unwrap_or("Ferron HTTP Basic Authorization".to_string())
386 .replace("\\", "\\\\")
387 .replace("\"", "\\\"")
388 ))?,
389 );
390
391 if let Some(authorization_header_value) =
392 hyper_request.headers().get(header::AUTHORIZATION)
393 {
394 let authorization_str = match authorization_header_value.to_str() {
395 Ok(str) => str,
396 Err(_) => {
397 return Ok(
398 ResponseData::builder(request)
399 .status(StatusCode::BAD_REQUEST)
400 .build(),
401 );
402 }
403 };
404
405 if let Some((username, password)) = parse_basic_auth(authorization_str) {
406 if let Some(users_vec_yaml) = config["users"].as_vec() {
407 let mut authorized_user = None;
408 for user_yaml in users_vec_yaml {
409 if let Some(username_db) = user_yaml["name"].as_str() {
410 if username_db != username {
411 continue;
412 }
413 if let Some(user_list) = &non_standard_code.user_list {
414 if !user_list.contains(&username) {
415 continue;
416 }
417 }
418 if let Some(password_hash_db) = user_yaml["pass"].as_str() {
419 let password_cloned = password.clone();
420 let password_hash_db_cloned = password_hash_db.to_string();
421 let password_valid = tokio::task::spawn_blocking(move || {
423 verify_password(password_cloned, &password_hash_db_cloned).is_ok()
424 })
425 .await?;
426 if password_valid {
427 authorized_user = Some(&username);
428 break;
429 }
430 }
431 }
432 }
433 if let Some(authorized_user) = authorized_user {
434 auth_user = Some(authorized_user.to_owned());
435 continue;
436 }
437 }
438
439 if !non_standard_code.disable_brute_force_protection {
440 let mut rwlock_write = self.brute_force_db.write().await;
441 rwlock_write.cleanup();
442 let current_attempts = rwlock_write.get(&brute_force_db_key).unwrap_or(0);
443 rwlock_write.insert(brute_force_db_key, current_attempts + 1);
444 }
445
446 error_logger
447 .log(&format!(
448 "Authorization failed for user \"{}\" and client \"{}\"",
449 username,
450 socket_data.remote_addr.ip()
451 ))
452 .await;
453 }
454 }
455
456 return Ok(
457 ResponseData::builder(request)
458 .status(StatusCode::UNAUTHORIZED)
459 .headers(header_map)
460 .build(),
461 );
462 }
463 _ => {
464 return Ok(
465 ResponseData::builder(request)
466 .status(StatusCode::from_u16(non_standard_code.status_code)?)
467 .build(),
468 )
469 }
470 }
471 }
472 }
473
474 if auth_user.is_some() {
475 let (hyper_request, _, original_url) = request.into_parts();
476 Ok(ResponseData::builder(RequestData::new(hyper_request, auth_user, original_url)).build())
477 } else {
478 Ok(ResponseData::builder(request).build())
479 }
480 })
481 .await
482 }
483
484 async fn proxy_request_handler(
485 &mut self,
486 request: RequestData,
487 _config: &ServerConfig,
488 _socket_data: &SocketData,
489 _error_logger: &ErrorLogger,
490 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
491 Ok(ResponseData::builder(request).build())
492 }
493
494 async fn response_modifying_handler(
495 &mut self,
496 response: HyperResponse,
497 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
498 Ok(response)
499 }
500
501 async fn proxy_response_modifying_handler(
502 &mut self,
503 response: HyperResponse,
504 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
505 Ok(response)
506 }
507
508 async fn connect_proxy_request_handler(
509 &mut self,
510 _upgraded_request: HyperUpgraded,
511 _connect_address: &str,
512 _config: &ServerConfig,
513 _socket_data: &SocketData,
514 _error_logger: &ErrorLogger,
515 ) -> Result<(), Box<dyn Error + Send + Sync>> {
516 Ok(())
517 }
518
519 fn does_connect_proxy_requests(&mut self) -> bool {
520 false
521 }
522
523 async fn websocket_request_handler(
524 &mut self,
525 _websocket: HyperWebsocket,
526 _uri: &hyper::Uri,
527 _config: &ServerConfig,
528 _socket_data: &SocketData,
529 _error_logger: &ErrorLogger,
530 ) -> Result<(), Box<dyn Error + Send + Sync>> {
531 Ok(())
532 }
533
534 fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
535 false
536 }
537}