1use std::error::Error;
2use std::path::Path;
3
4use crate::ferron_util::obtain_config_struct_vec::ObtainConfigStructVec;
5
6use crate::ferron_common::{
7 ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule,
8 ServerModuleHandlers, SocketData,
9};
10use crate::ferron_common::{HyperUpgraded, WithRuntime};
11use async_trait::async_trait;
12use fancy_regex::{Regex, RegexBuilder};
13use hyper::{header, Request, StatusCode};
14use hyper_tungstenite::HyperWebsocket;
15use tokio::fs;
16use tokio::runtime::Handle;
17use yaml_rust2::Yaml;
18
19struct UrlRewriteMapEntry {
20 regex: Regex,
21 replacement: String,
22 is_not_directory: bool,
23 is_not_file: bool,
24 last: bool,
25 allow_double_slashes: bool,
26}
27
28impl UrlRewriteMapEntry {
29 fn new(
30 regex: Regex,
31 replacement: String,
32 is_not_directory: bool,
33 is_not_file: bool,
34 last: bool,
35 allow_double_slashes: bool,
36 ) -> Self {
37 Self {
38 regex,
39 replacement,
40 is_not_directory,
41 is_not_file,
42 last,
43 allow_double_slashes,
44 }
45 }
46}
47
48fn url_rewrite_config_init(rewrite_map: &[Yaml]) -> Result<Vec<UrlRewriteMapEntry>, anyhow::Error> {
49 let rewrite_map_iter = rewrite_map.iter();
50 let mut rewrite_map_vec = Vec::new();
51 for rewrite_map_entry in rewrite_map_iter {
52 let regex_str = match rewrite_map_entry["regex"].as_str() {
53 Some(regex_str) => regex_str,
54 None => return Err(anyhow::anyhow!("Invalid URL rewrite regular expression")),
55 };
56 let regex = match RegexBuilder::new(regex_str)
57 .case_insensitive(cfg!(windows))
58 .build()
59 {
60 Ok(regex) => regex,
61 Err(err) => {
62 return Err(anyhow::anyhow!(
63 "Invalid URL rewrite regular expression: {}",
64 err.to_string()
65 ))
66 }
67 };
68 let replacement = match rewrite_map_entry["replacement"].as_str() {
69 Some(replacement) => String::from(replacement),
70 None => return Err(anyhow::anyhow!("URL rewrite rules must have replacements")),
71 };
72 let is_not_file = rewrite_map_entry["isNotFile"].as_bool().unwrap_or(false);
73 let is_not_directory = rewrite_map_entry["isNotDirectory"]
74 .as_bool()
75 .unwrap_or(false);
76 let last = rewrite_map_entry["last"].as_bool().unwrap_or_default();
77 let allow_double_slashes = rewrite_map_entry["allowDoubleSlashes"]
78 .as_bool()
79 .unwrap_or(false);
80 rewrite_map_vec.push(UrlRewriteMapEntry::new(
81 regex,
82 replacement,
83 is_not_directory,
84 is_not_file,
85 last,
86 allow_double_slashes,
87 ));
88 }
89
90 Ok(rewrite_map_vec)
91}
92
93pub fn server_module_init(
94 config: &ServerConfig,
95) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
96 Ok(Box::new(UrlRewriteModule::new(ObtainConfigStructVec::new(
97 config,
98 |config| {
99 if let Some(rewrite_map_yaml) = config["rewriteMap"].as_vec() {
100 Ok(url_rewrite_config_init(rewrite_map_yaml)?)
101 } else {
102 Ok(vec![])
103 }
104 },
105 )?)))
106}
107
108struct UrlRewriteModule {
109 url_rewrite_maps: ObtainConfigStructVec<UrlRewriteMapEntry>,
110}
111
112impl UrlRewriteModule {
113 fn new(url_rewrite_maps: ObtainConfigStructVec<UrlRewriteMapEntry>) -> Self {
114 Self { url_rewrite_maps }
115 }
116}
117
118impl ServerModule for UrlRewriteModule {
119 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
120 Box::new(UrlRewriteModuleHandlers {
121 url_rewrite_maps: self.url_rewrite_maps.clone(),
122 handle,
123 })
124 }
125}
126struct UrlRewriteModuleHandlers {
127 url_rewrite_maps: ObtainConfigStructVec<UrlRewriteMapEntry>,
128 handle: Handle,
129}
130
131#[async_trait]
132impl ServerModuleHandlers for UrlRewriteModuleHandlers {
133 async fn request_handler(
134 &mut self,
135 request: RequestData,
136 config: &ServerConfig,
137 socket_data: &SocketData,
138 error_logger: &ErrorLogger,
139 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
140 WithRuntime::new(self.handle.clone(), async move {
141 let hyper_request = request.get_hyper_request();
142 let combined_url_rewrite_map = self.url_rewrite_maps.obtain(
143 match hyper_request.headers().get(header::HOST) {
144 Some(value) => value.to_str().ok(),
145 None => None,
146 },
147 socket_data.remote_addr.ip(),
148 request
149 .get_original_url()
150 .unwrap_or(request.get_hyper_request().uri())
151 .path(),
152 request.get_error_status_code().map(|x| x.as_u16()),
153 );
154
155 let original_url = format!(
156 "{}{}",
157 hyper_request.uri().path(),
158 match hyper_request.uri().query() {
159 Some(query) => format!("?{}", query),
160 None => String::from(""),
161 }
162 );
163 let mut rewritten_url = original_url.clone();
164
165 let mut rewritten_url_bytes = rewritten_url.bytes();
166 if rewritten_url_bytes.len() < 1 || rewritten_url_bytes.nth(0) != Some(b'/') {
167 return Ok(
168 ResponseData::builder(request)
169 .status(StatusCode::BAD_REQUEST)
170 .build(),
171 );
172 }
173
174 for url_rewrite_map_entry in combined_url_rewrite_map {
175 if url_rewrite_map_entry.is_not_directory || url_rewrite_map_entry.is_not_file {
177 if let Some(wwwroot) = config["wwwroot"].as_str() {
178 let path = Path::new(wwwroot);
179 let mut relative_path = &rewritten_url[1..];
180 while relative_path.as_bytes().first().copied() == Some(b'/') {
181 relative_path = &relative_path[1..];
182 }
183 let relative_path_split: Vec<&str> = relative_path.split("?").collect();
184 if !relative_path_split.is_empty() {
185 relative_path = relative_path_split[0];
186 }
187 let joined_pathbuf = path.join(relative_path);
188 if let Ok(metadata) = fs::metadata(joined_pathbuf).await {
189 if (url_rewrite_map_entry.is_not_file && metadata.is_file())
190 || (url_rewrite_map_entry.is_not_directory && metadata.is_dir())
191 {
192 continue;
193 }
194 }
195 }
196 }
197
198 if !url_rewrite_map_entry.allow_double_slashes {
199 while rewritten_url.contains("//") {
200 rewritten_url = rewritten_url.replace("//", "/");
201 }
202 }
203
204 let old_rewritten_url = rewritten_url;
206 rewritten_url = url_rewrite_map_entry
207 .regex
208 .replace(&old_rewritten_url, &url_rewrite_map_entry.replacement)
209 .to_string();
210
211 let mut rewritten_url_bytes = rewritten_url.bytes();
212 if rewritten_url_bytes.len() < 1 || rewritten_url_bytes.nth(0) != Some(b'/') {
213 return Ok(
214 ResponseData::builder(request)
215 .status(StatusCode::BAD_REQUEST)
216 .build(),
217 );
218 }
219
220 if url_rewrite_map_entry.last && old_rewritten_url != rewritten_url {
221 break;
222 }
223 }
224
225 if rewritten_url == original_url {
226 Ok(ResponseData::builder(request).build())
227 } else {
228 if config["enableRewriteLogging"].as_bool() == Some(true) {
229 error_logger
230 .log(&format!(
231 "URL rewritten from \"{}\" to \"{}\"",
232 original_url, rewritten_url
233 ))
234 .await;
235 }
236 let (hyper_request, auth_user, _, error_status_code) = request.into_parts();
237 let (mut parts, body) = hyper_request.into_parts();
238 let original_url = parts.uri.clone();
239 let mut url_parts = parts.uri.into_parts();
240 url_parts.path_and_query = Some(rewritten_url.parse()?);
241 parts.uri = hyper::Uri::from_parts(url_parts)?;
242 let hyper_request = Request::from_parts(parts, body);
243 let request = RequestData::new(
244 hyper_request,
245 auth_user,
246 Some(original_url),
247 error_status_code,
248 );
249 Ok(ResponseData::builder(request).build())
250 }
251 })
252 .await
253 }
254
255 async fn proxy_request_handler(
256 &mut self,
257 request: RequestData,
258 _config: &ServerConfig,
259 _socket_data: &SocketData,
260 _error_logger: &ErrorLogger,
261 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
262 Ok(ResponseData::builder(request).build())
263 }
264
265 async fn response_modifying_handler(
266 &mut self,
267 response: HyperResponse,
268 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
269 Ok(response)
270 }
271
272 async fn proxy_response_modifying_handler(
273 &mut self,
274 response: HyperResponse,
275 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
276 Ok(response)
277 }
278
279 async fn connect_proxy_request_handler(
280 &mut self,
281 _upgraded_request: HyperUpgraded,
282 _connect_address: &str,
283 _config: &ServerConfig,
284 _socket_data: &SocketData,
285 _error_logger: &ErrorLogger,
286 ) -> Result<(), Box<dyn Error + Send + Sync>> {
287 Ok(())
288 }
289
290 fn does_connect_proxy_requests(&mut self) -> bool {
291 false
292 }
293
294 async fn websocket_request_handler(
295 &mut self,
296 _websocket: HyperWebsocket,
297 _uri: &hyper::Uri,
298 _headers: &hyper::HeaderMap,
299 _config: &ServerConfig,
300 _socket_data: &SocketData,
301 _error_logger: &ErrorLogger,
302 ) -> Result<(), Box<dyn Error + Send + Sync>> {
303 Ok(())
304 }
305
306 fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
307 false
308 }
309}