1use std::error::Error;
2use std::path::Path;
3use std::sync::Arc;
4use std::time::Duration;
5
6use crate::ferron_common::{
7 ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule,
8 ServerModuleHandlers, SocketData,
9};
10use crate::ferron_common::{HyperUpgraded, WithRuntime};
11use async_trait::async_trait;
12use http_body_util::{BodyExt, Empty};
13use hyper::{header, Response, StatusCode};
14use hyper_tungstenite::HyperWebsocket;
15use tokio::fs;
16use tokio::runtime::Handle;
17use tokio::sync::RwLock;
18
19use crate::ferron_util::ttl_cache::TtlCache;
20
21pub fn server_module_init(
22) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
23 let cache = Arc::new(RwLock::new(TtlCache::new(Duration::from_millis(100))));
24 Ok(Box::new(RedirectTrailingSlashesModule::new(cache)))
25}
26
27struct RedirectTrailingSlashesModule {
28 cache: Arc<RwLock<TtlCache<String, bool>>>,
29}
30
31impl RedirectTrailingSlashesModule {
32 fn new(cache: Arc<RwLock<TtlCache<String, bool>>>) -> Self {
33 Self { cache }
34 }
35}
36
37impl ServerModule for RedirectTrailingSlashesModule {
38 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
39 Box::new(RedirectTrailingSlashesModuleHandlers {
40 cache: self.cache.clone(),
41 handle,
42 })
43 }
44}
45
46struct RedirectTrailingSlashesModuleHandlers {
47 cache: Arc<RwLock<TtlCache<String, bool>>>,
48 handle: Handle,
49}
50
51#[async_trait]
52impl ServerModuleHandlers for RedirectTrailingSlashesModuleHandlers {
53 async fn request_handler(
54 &mut self,
55 request: RequestData,
56 config: &ServerConfig,
57 _socket_data: &SocketData,
58 _error_logger: &ErrorLogger,
59 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
60 WithRuntime::new(self.handle.clone(), async move {
61 if config["disableTrailingSlashRedirects"].as_bool() != Some(true) {
62 if let Some(wwwroot) = config["wwwroot"].as_str() {
63 let hyper_request = request.get_hyper_request();
64
65 let request_path = hyper_request.uri().path();
66 let request_query = hyper_request.uri().query();
67 let mut request_path_bytes = request_path.bytes();
68 if request_path_bytes.len() < 1 || request_path_bytes.nth(0) != Some(b'/') {
69 return Ok(
70 ResponseData::builder(request)
71 .status(StatusCode::BAD_REQUEST)
72 .build(),
73 );
74 }
75
76 match request_path_bytes.last() {
77 Some(b'/') | None => {
78 return Ok(ResponseData::builder(request).build());
79 }
80 _ => {
81 let cache_key = format!(
82 "{}{}{}",
83 match config["ip"].as_str() {
84 Some(ip) => format!("{}-", ip),
85 None => String::from(""),
86 },
87 match config["domain"].as_str() {
88 Some(domain) => format!("{}-", domain),
89 None => String::from(""),
90 },
91 request_path
92 );
93
94 let read_rwlock = self.cache.read().await;
95 if let Some(is_directory) = read_rwlock.get(&cache_key) {
96 drop(read_rwlock);
97 if is_directory {
98 let new_request_uri = format!(
99 "{}/{}",
100 request_path,
101 match request_query {
102 Some(query) => format!("?{}", query),
103 None => String::from(""),
104 }
105 );
106 return Ok(
107 ResponseData::builder(request)
108 .response(
109 Response::builder()
110 .status(StatusCode::MOVED_PERMANENTLY)
111 .header(header::LOCATION, new_request_uri)
112 .body(Empty::new().map_err(|e| match e {}).boxed())?,
113 )
114 .build(),
115 );
116 }
117 } else {
118 drop(read_rwlock);
119
120 let path = Path::new(wwwroot);
121 let mut relative_path = &request_path[1..];
122 while relative_path.as_bytes().first().copied() == Some(b'/') {
123 relative_path = &relative_path[1..];
124 }
125
126 let decoded_relative_path = match urlencoding::decode(relative_path) {
127 Ok(path) => path.to_string(),
128 Err(_) => {
129 return Ok(
130 ResponseData::builder(request)
131 .status(StatusCode::BAD_REQUEST)
132 .build(),
133 );
134 }
135 };
136
137 let joined_pathbuf = path.join(decoded_relative_path);
138
139 match fs::metadata(joined_pathbuf).await {
140 Ok(metadata) => {
141 let is_directory = metadata.is_dir();
142 let mut write_rwlock = self.cache.write().await;
143 write_rwlock.cleanup();
144 write_rwlock.insert(cache_key, is_directory);
145 if is_directory {
146 let new_request_uri = format!(
147 "{}/{}",
148 request_path,
149 match request_query {
150 Some(query) => format!("?{}", query),
151 None => String::from(""),
152 }
153 );
154 return Ok(
155 ResponseData::builder(request)
156 .response(
157 Response::builder()
158 .status(StatusCode::MOVED_PERMANENTLY)
159 .header(header::LOCATION, new_request_uri)
160 .body(Empty::new().map_err(|e| match e {}).boxed())?,
161 )
162 .build(),
163 );
164 }
165 }
166 Err(_) => {
167 let mut write_rwlock = self.cache.write().await;
168 write_rwlock.cleanup();
169 write_rwlock.insert(cache_key, false);
170 }
171 }
172 }
173 }
174 };
175 }
176 }
177 Ok(ResponseData::builder(request).build())
178 })
179 .await
180 }
181
182 async fn proxy_request_handler(
183 &mut self,
184 request: RequestData,
185 _config: &ServerConfig,
186 _socket_data: &SocketData,
187 _error_logger: &ErrorLogger,
188 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
189 Ok(ResponseData::builder(request).build())
190 }
191
192 async fn response_modifying_handler(
193 &mut self,
194 response: HyperResponse,
195 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
196 Ok(response)
197 }
198
199 async fn proxy_response_modifying_handler(
200 &mut self,
201 response: HyperResponse,
202 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
203 Ok(response)
204 }
205
206 async fn connect_proxy_request_handler(
207 &mut self,
208 _upgraded_request: HyperUpgraded,
209 _connect_address: &str,
210 _config: &ServerConfig,
211 _socket_data: &SocketData,
212 _error_logger: &ErrorLogger,
213 ) -> Result<(), Box<dyn Error + Send + Sync>> {
214 Ok(())
215 }
216
217 fn does_connect_proxy_requests(&mut self) -> bool {
218 false
219 }
220
221 async fn websocket_request_handler(
222 &mut self,
223 _websocket: HyperWebsocket,
224 _uri: &hyper::Uri,
225 _config: &ServerConfig,
226 _socket_data: &SocketData,
227 _error_logger: &ErrorLogger,
228 ) -> Result<(), Box<dyn Error + Send + Sync>> {
229 Ok(())
230 }
231
232 fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
233 false
234 }
235}