1use std::error::Error;
2use std::path::Path;
3use std::sync::Arc;
4
5use crate::ferron_util::ip_match::ip_match;
6use crate::ferron_util::match_hostname::match_hostname;
7use crate::ferron_util::match_location::match_location;
8use crate::ferron_util::url_rewrite_structs::{
9 UrlRewriteMapEntry, UrlRewriteMapLocationWrap, UrlRewriteMapWrap,
10};
11
12use crate::ferron_common::{
13 ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule,
14 ServerModuleHandlers, SocketData,
15};
16use crate::ferron_common::{HyperUpgraded, WithRuntime};
17use async_trait::async_trait;
18use fancy_regex::RegexBuilder;
19use hyper::{header, Request, StatusCode};
20use hyper_tungstenite::HyperWebsocket;
21use tokio::fs;
22use tokio::runtime::Handle;
23use yaml_rust2::Yaml;
24
25fn url_rewrite_config_init(rewrite_map: &[Yaml]) -> Result<Vec<UrlRewriteMapEntry>, anyhow::Error> {
26 let rewrite_map_iter = rewrite_map.iter();
27 let mut rewrite_map_vec = Vec::new();
28 for rewrite_map_entry in rewrite_map_iter {
29 let regex_str = match rewrite_map_entry["regex"].as_str() {
30 Some(regex_str) => regex_str,
31 None => return Err(anyhow::anyhow!("Invalid URL rewrite regular expression")),
32 };
33 let regex = match RegexBuilder::new(regex_str)
34 .case_insensitive(cfg!(windows))
35 .build()
36 {
37 Ok(regex) => regex,
38 Err(err) => {
39 return Err(anyhow::anyhow!(
40 "Invalid URL rewrite regular expression: {}",
41 err.to_string()
42 ))
43 }
44 };
45 let replacement = match rewrite_map_entry["replacement"].as_str() {
46 Some(replacement) => String::from(replacement),
47 None => return Err(anyhow::anyhow!("URL rewrite rules must have replacements")),
48 };
49 let is_not_file = rewrite_map_entry["isNotFile"].as_bool().unwrap_or(false);
50 let is_not_directory = rewrite_map_entry["isNotDirectory"]
51 .as_bool()
52 .unwrap_or(false);
53 let last = rewrite_map_entry["last"].as_bool().unwrap_or_default();
54 let allow_double_slashes = rewrite_map_entry["allowDoubleSlashes"]
55 .as_bool()
56 .unwrap_or(false);
57 rewrite_map_vec.push(UrlRewriteMapEntry::new(
58 regex,
59 replacement,
60 is_not_directory,
61 is_not_file,
62 last,
63 allow_double_slashes,
64 ));
65 }
66
67 Ok(rewrite_map_vec)
68}
69
70pub fn server_module_init(
71 config: &ServerConfig,
72) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
73 let mut global_url_rewrite_map = Vec::new();
74 let mut host_url_rewrite_maps = Vec::new();
75 if let Some(rewrite_map_yaml) = config["global"]["rewriteMap"].as_vec() {
76 global_url_rewrite_map = url_rewrite_config_init(rewrite_map_yaml)?;
77 }
78
79 if let Some(hosts) = config["hosts"].as_vec() {
80 for host_yaml in hosts.iter() {
81 let domain = host_yaml["domain"].as_str().map(String::from);
82 let ip = host_yaml["ip"].as_str().map(String::from);
83 let mut locations = Vec::new();
84 if let Some(locations_yaml) = host_yaml["locations"].as_vec() {
85 for location_yaml in locations_yaml.iter() {
86 if let Some(path_str) = location_yaml["path"].as_str() {
87 let path = String::from(path_str);
88 if let Some(rewrite_map_yaml) = location_yaml["rewriteMap"].as_vec() {
89 locations.push(UrlRewriteMapLocationWrap::new(
90 path,
91 url_rewrite_config_init(rewrite_map_yaml)?,
92 ));
93 }
94 }
95 }
96 }
97 if let Some(rewrite_map_yaml) = host_yaml["rewriteMap"].as_vec() {
98 host_url_rewrite_maps.push(UrlRewriteMapWrap::new(
99 domain,
100 ip,
101 url_rewrite_config_init(rewrite_map_yaml)?,
102 locations,
103 ));
104 } else if !locations.is_empty() {
105 host_url_rewrite_maps.push(UrlRewriteMapWrap::new(domain, ip, Vec::new(), locations));
106 }
107 }
108 }
109
110 Ok(Box::new(UrlRewriteModule::new(
111 Arc::new(global_url_rewrite_map),
112 Arc::new(host_url_rewrite_maps),
113 )))
114}
115
116struct UrlRewriteModule {
117 global_url_rewrite_map: Arc<Vec<UrlRewriteMapEntry>>,
118 host_url_rewrite_maps: Arc<Vec<UrlRewriteMapWrap>>,
119}
120
121impl UrlRewriteModule {
122 fn new(
123 global_url_rewrite_map: Arc<Vec<UrlRewriteMapEntry>>,
124 host_url_rewrite_maps: Arc<Vec<UrlRewriteMapWrap>>,
125 ) -> Self {
126 Self {
127 global_url_rewrite_map,
128 host_url_rewrite_maps,
129 }
130 }
131}
132
133impl ServerModule for UrlRewriteModule {
134 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
135 Box::new(UrlRewriteModuleHandlers {
136 global_url_rewrite_map: self.global_url_rewrite_map.clone(),
137 host_url_rewrite_maps: self.host_url_rewrite_maps.clone(),
138 handle,
139 })
140 }
141}
142struct UrlRewriteModuleHandlers {
143 global_url_rewrite_map: Arc<Vec<UrlRewriteMapEntry>>,
144 host_url_rewrite_maps: Arc<Vec<UrlRewriteMapWrap>>,
145 handle: Handle,
146}
147
148#[async_trait]
149impl ServerModuleHandlers for UrlRewriteModuleHandlers {
150 async fn request_handler(
151 &mut self,
152 request: RequestData,
153 config: &ServerConfig,
154 socket_data: &SocketData,
155 error_logger: &ErrorLogger,
156 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
157 WithRuntime::new(self.handle.clone(), async move {
158 let hyper_request = request.get_hyper_request();
159 let global_url_rewrite_map = self.global_url_rewrite_map.iter();
160 let empty_vector = Vec::new();
161 let another_empty_vector = Vec::new();
162 let mut host_url_rewrite_map = empty_vector.iter();
163 let mut location_url_rewrite_map = another_empty_vector.iter();
164
165 for host_url_rewrite_map_wrap in self.host_url_rewrite_maps.iter() {
167 if match_hostname(
168 match &host_url_rewrite_map_wrap.domain {
169 Some(value) => Some(value as &str),
170 None => None,
171 },
172 match hyper_request.headers().get(header::HOST) {
173 Some(value) => value.to_str().ok(),
174 None => None,
175 },
176 ) && match &host_url_rewrite_map_wrap.ip {
177 Some(value) => ip_match(value as &str, socket_data.remote_addr.ip()),
178 None => true,
179 } {
180 host_url_rewrite_map = host_url_rewrite_map_wrap.rewrite_map.iter();
181 if let Ok(path_decoded) = urlencoding::decode(
182 request
183 .get_original_url()
184 .unwrap_or(request.get_hyper_request().uri())
185 .path(),
186 ) {
187 for location_wrap in host_url_rewrite_map_wrap.locations.iter() {
188 if match_location(&location_wrap.path, &path_decoded) {
189 location_url_rewrite_map = location_wrap.rewrite_map.iter();
190 break;
191 }
192 }
193 }
194 break;
195 }
196 }
197
198 let combined_url_rewrite_map = global_url_rewrite_map
199 .chain(host_url_rewrite_map)
200 .chain(location_url_rewrite_map);
201
202 let original_url = format!(
203 "{}{}",
204 hyper_request.uri().path(),
205 match hyper_request.uri().query() {
206 Some(query) => format!("?{}", query),
207 None => String::from(""),
208 }
209 );
210 let mut rewritten_url = original_url.clone();
211
212 let mut rewritten_url_bytes = rewritten_url.bytes();
213 if rewritten_url_bytes.len() < 1 || rewritten_url_bytes.nth(0) != Some(b'/') {
214 return Ok(
215 ResponseData::builder(request)
216 .status(StatusCode::BAD_REQUEST)
217 .build(),
218 );
219 }
220
221 for url_rewrite_map_entry in combined_url_rewrite_map {
222 if url_rewrite_map_entry.is_not_directory || url_rewrite_map_entry.is_not_file {
224 if let Some(wwwroot) = config["wwwroot"].as_str() {
225 let path = Path::new(wwwroot);
226 let mut relative_path = &rewritten_url[1..];
227 while relative_path.as_bytes().first().copied() == Some(b'/') {
228 relative_path = &relative_path[1..];
229 }
230 let relative_path_split: Vec<&str> = relative_path.split("?").collect();
231 if !relative_path_split.is_empty() {
232 relative_path = relative_path_split[0];
233 }
234 let joined_pathbuf = path.join(relative_path);
235 if let Ok(metadata) = fs::metadata(joined_pathbuf).await {
236 if (url_rewrite_map_entry.is_not_file && metadata.is_file())
237 || (url_rewrite_map_entry.is_not_directory && metadata.is_dir())
238 {
239 continue;
240 }
241 }
242 }
243 }
244
245 if !url_rewrite_map_entry.allow_double_slashes {
246 while rewritten_url.contains("//") {
247 rewritten_url = rewritten_url.replace("//", "/");
248 }
249 }
250
251 let old_rewritten_url = rewritten_url;
253 rewritten_url = url_rewrite_map_entry
254 .regex
255 .replace(&old_rewritten_url, &url_rewrite_map_entry.replacement)
256 .to_string();
257
258 let mut rewritten_url_bytes = rewritten_url.bytes();
259 if rewritten_url_bytes.len() < 1 || rewritten_url_bytes.nth(0) != Some(b'/') {
260 return Ok(
261 ResponseData::builder(request)
262 .status(StatusCode::BAD_REQUEST)
263 .build(),
264 );
265 }
266
267 if url_rewrite_map_entry.last && old_rewritten_url != rewritten_url {
268 break;
269 }
270 }
271
272 if rewritten_url == original_url {
273 Ok(ResponseData::builder(request).build())
274 } else {
275 if config["enableRewriteLogging"].as_bool() == Some(true) {
276 error_logger
277 .log(&format!(
278 "URL rewritten from \"{}\" to \"{}\"",
279 original_url, rewritten_url
280 ))
281 .await;
282 }
283 let (hyper_request, auth_user, _) = request.into_parts();
284 let (mut parts, body) = hyper_request.into_parts();
285 let original_url = parts.uri.clone();
286 let mut url_parts = parts.uri.into_parts();
287 url_parts.path_and_query = Some(rewritten_url.parse()?);
288 parts.uri = hyper::Uri::from_parts(url_parts)?;
289 let hyper_request = Request::from_parts(parts, body);
290 let request = RequestData::new(hyper_request, auth_user, Some(original_url));
291 Ok(ResponseData::builder(request).build())
292 }
293 })
294 .await
295 }
296
297 async fn proxy_request_handler(
298 &mut self,
299 request: RequestData,
300 _config: &ServerConfig,
301 _socket_data: &SocketData,
302 _error_logger: &ErrorLogger,
303 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
304 Ok(ResponseData::builder(request).build())
305 }
306
307 async fn response_modifying_handler(
308 &mut self,
309 response: HyperResponse,
310 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
311 Ok(response)
312 }
313
314 async fn proxy_response_modifying_handler(
315 &mut self,
316 response: HyperResponse,
317 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
318 Ok(response)
319 }
320
321 async fn connect_proxy_request_handler(
322 &mut self,
323 _upgraded_request: HyperUpgraded,
324 _connect_address: &str,
325 _config: &ServerConfig,
326 _socket_data: &SocketData,
327 _error_logger: &ErrorLogger,
328 ) -> Result<(), Box<dyn Error + Send + Sync>> {
329 Ok(())
330 }
331
332 fn does_connect_proxy_requests(&mut self) -> bool {
333 false
334 }
335
336 async fn websocket_request_handler(
337 &mut self,
338 _websocket: HyperWebsocket,
339 _uri: &hyper::Uri,
340 _config: &ServerConfig,
341 _socket_data: &SocketData,
342 _error_logger: &ErrorLogger,
343 ) -> Result<(), Box<dyn Error + Send + Sync>> {
344 Ok(())
345 }
346
347 fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
348 false
349 }
350}