hyper_util/client/legacy/connect/proxy/
tunnel.rs1use std::error::Error as StdError;
2use std::future::Future;
3use std::marker::{PhantomData, Unpin};
4use std::pin::Pin;
5use std::task::{self, Poll};
6
7use http::{HeaderMap, HeaderValue, Uri};
8use hyper::rt::{Read, Write};
9use pin_project_lite::pin_project;
10use tower_service::Service;
11
12#[derive(Debug)]
18pub struct Tunnel<C> {
19 headers: Headers,
20 inner: C,
21 proxy_dst: Uri,
22}
23
24#[derive(Clone, Debug)]
25enum Headers {
26 Empty,
27 Auth(HeaderValue),
28 Extra(HeaderMap),
29}
30
31#[derive(Debug)]
32pub enum TunnelError {
33 ConnectFailed(Box<dyn StdError + Send + Sync>),
34 Io(std::io::Error),
35 MissingHost,
36 ProxyAuthRequired,
37 ProxyHeadersTooLong,
38 TunnelUnexpectedEof,
39 TunnelUnsuccessful,
40}
41
42pin_project! {
43 #[must_use = "futures do nothing unless polled"]
49 #[allow(missing_debug_implementations)]
50 pub struct Tunneling<F, T> {
51 #[pin]
52 fut: BoxTunneling<T>,
53 _marker: PhantomData<F>,
54 }
55}
56
57type BoxTunneling<T> = Pin<Box<dyn Future<Output = Result<T, TunnelError>> + Send>>;
58
59impl<C> Tunnel<C> {
60 pub fn new(proxy_dst: Uri, connector: C) -> Self {
69 Self {
70 headers: Headers::Empty,
71 inner: connector,
72 proxy_dst,
73 }
74 }
75
76 pub fn with_auth(mut self, mut auth: HeaderValue) -> Self {
78 auth.set_sensitive(true);
80 match self.headers {
81 Headers::Empty => {
82 self.headers = Headers::Auth(auth);
83 }
84 Headers::Auth(ref mut existing) => {
85 *existing = auth;
86 }
87 Headers::Extra(ref mut extra) => {
88 extra.insert(http::header::PROXY_AUTHORIZATION, auth);
89 }
90 }
91
92 self
93 }
94
95 pub fn with_headers(mut self, mut headers: HeaderMap) -> Self {
99 match self.headers {
100 Headers::Empty => {
101 self.headers = Headers::Extra(headers);
102 }
103 Headers::Auth(auth) => {
104 headers
105 .entry(http::header::PROXY_AUTHORIZATION)
106 .or_insert(auth);
107 self.headers = Headers::Extra(headers);
108 }
109 Headers::Extra(ref mut extra) => {
110 extra.extend(headers);
111 }
112 }
113
114 self
115 }
116}
117
118impl<C> Service<Uri> for Tunnel<C>
119where
120 C: Service<Uri>,
121 C::Future: Send + 'static,
122 C::Response: Read + Write + Unpin + Send + 'static,
123 C::Error: Into<Box<dyn StdError + Send + Sync>>,
124{
125 type Response = C::Response;
126 type Error = TunnelError;
127 type Future = Tunneling<C::Future, C::Response>;
128
129 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
130 futures_util::ready!(self.inner.poll_ready(cx))
131 .map_err(|e| TunnelError::ConnectFailed(e.into()))?;
132 Poll::Ready(Ok(()))
133 }
134
135 fn call(&mut self, dst: Uri) -> Self::Future {
136 let connecting = self.inner.call(self.proxy_dst.clone());
137 let headers = self.headers.clone();
138
139 Tunneling {
140 fut: Box::pin(async move {
141 let conn = connecting
142 .await
143 .map_err(|e| TunnelError::ConnectFailed(e.into()))?;
144 tunnel(
145 conn,
146 dst.host().ok_or(TunnelError::MissingHost)?,
147 dst.port().map(|p| p.as_u16()).unwrap_or(443),
148 &headers,
149 )
150 .await
151 }),
152 _marker: PhantomData,
153 }
154 }
155}
156
157impl<F, T, E> Future for Tunneling<F, T>
158where
159 F: Future<Output = Result<T, E>>,
160{
161 type Output = Result<T, TunnelError>;
162
163 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
164 self.project().fut.poll(cx)
165 }
166}
167
168async fn tunnel<T>(mut conn: T, host: &str, port: u16, headers: &Headers) -> Result<T, TunnelError>
169where
170 T: Read + Write + Unpin,
171{
172 let mut buf = format!(
173 "\
174 CONNECT {host}:{port} HTTP/1.1\r\n\
175 Host: {host}:{port}\r\n\
176 "
177 )
178 .into_bytes();
179
180 match headers {
181 Headers::Auth(auth) => {
182 buf.extend_from_slice(b"Proxy-Authorization: ");
183 buf.extend_from_slice(auth.as_bytes());
184 buf.extend_from_slice(b"\r\n");
185 }
186 Headers::Extra(extra) => {
187 for (name, value) in extra {
188 buf.extend_from_slice(name.as_str().as_bytes());
189 buf.extend_from_slice(b": ");
190 buf.extend_from_slice(value.as_bytes());
191 buf.extend_from_slice(b"\r\n");
192 }
193 }
194 Headers::Empty => (),
195 }
196
197 buf.extend_from_slice(b"\r\n");
199
200 crate::rt::write_all(&mut conn, &buf)
201 .await
202 .map_err(TunnelError::Io)?;
203
204 let mut buf = [0; 8192];
205 let mut pos = 0;
206
207 loop {
208 let n = crate::rt::read(&mut conn, &mut buf[pos..])
209 .await
210 .map_err(TunnelError::Io)?;
211
212 if n == 0 {
213 return Err(TunnelError::TunnelUnexpectedEof);
214 }
215 pos += n;
216
217 let recvd = &buf[..pos];
218 if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
219 if recvd.ends_with(b"\r\n\r\n") {
220 return Ok(conn);
221 }
222 if pos == buf.len() {
223 return Err(TunnelError::ProxyHeadersTooLong);
224 }
225 } else if recvd.starts_with(b"HTTP/1.1 407") {
227 return Err(TunnelError::ProxyAuthRequired);
228 } else {
229 return Err(TunnelError::TunnelUnsuccessful);
230 }
231 }
232}
233
234impl std::fmt::Display for TunnelError {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.write_str("tunnel error: ")?;
237
238 f.write_str(match self {
239 TunnelError::MissingHost => "missing destination host",
240 TunnelError::ProxyAuthRequired => "proxy authorization required",
241 TunnelError::ProxyHeadersTooLong => "proxy response headers too long",
242 TunnelError::TunnelUnexpectedEof => "unexpected end of file",
243 TunnelError::TunnelUnsuccessful => "unsuccessful",
244 TunnelError::ConnectFailed(_) => "failed to create underlying connection",
245 TunnelError::Io(_) => "io error establishing tunnel",
246 })
247 }
248}
249
250impl std::error::Error for TunnelError {
251 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
252 match self {
253 TunnelError::Io(ref e) => Some(e),
254 TunnelError::ConnectFailed(ref e) => Some(&**e),
255 _ => None,
256 }
257 }
258}