tower_http/follow_redirect/policy/
same_origin.rs

1use super::{eq_origin, Action, Attempt, Policy};
2use std::fmt;
3
4/// A redirection [`Policy`] that stops cross-origin redirections.
5#[derive(Clone, Copy, Default)]
6pub struct SameOrigin {
7    _priv: (),
8}
9
10impl SameOrigin {
11    /// Create a new [`SameOrigin`].
12    pub fn new() -> Self {
13        Self::default()
14    }
15}
16
17impl fmt::Debug for SameOrigin {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        f.debug_struct("SameOrigin").finish()
20    }
21}
22
23impl<B, E> Policy<B, E> for SameOrigin {
24    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
25        if eq_origin(attempt.previous(), attempt.location()) {
26            Ok(Action::Follow)
27        } else {
28            Ok(Action::Stop)
29        }
30    }
31}
32
33#[cfg(test)]
34mod tests {
35    use super::*;
36    use http::{Request, Uri};
37
38    #[test]
39    fn works() {
40        let mut policy = SameOrigin::default();
41
42        let initial = Uri::from_static("http://example.com/old");
43        let same_origin = Uri::from_static("http://example.com/new");
44        let cross_origin = Uri::from_static("https://example.com/new");
45
46        let mut request = Request::builder().uri(initial).body(()).unwrap();
47        Policy::<(), ()>::on_request(&mut policy, &mut request);
48
49        let attempt = Attempt {
50            status: Default::default(),
51            location: &same_origin,
52            previous: request.uri(),
53        };
54        assert!(Policy::<(), ()>::redirect(&mut policy, &attempt)
55            .unwrap()
56            .is_follow());
57
58        let mut request = Request::builder().uri(same_origin).body(()).unwrap();
59        Policy::<(), ()>::on_request(&mut policy, &mut request);
60
61        let attempt = Attempt {
62            status: Default::default(),
63            location: &cross_origin,
64            previous: request.uri(),
65        };
66        assert!(Policy::<(), ()>::redirect(&mut policy, &attempt)
67            .unwrap()
68            .is_stop());
69    }
70}