1use std::io::{Read, Write};
3
4use crate::{
5 client::{client_with_config, uri_mode, IntoClientRequest},
6 error::UrlError,
7 handshake::client::Response,
8 protocol::WebSocketConfig,
9 stream::MaybeTlsStream,
10 ClientHandshake, Error, HandshakeError, Result, WebSocket,
11};
12
13#[non_exhaustive]
17#[allow(missing_debug_implementations)]
18pub enum Connector {
19 Plain,
21 #[cfg(feature = "native-tls")]
23 NativeTls(native_tls_crate::TlsConnector),
24 #[cfg(feature = "__rustls-tls")]
26 Rustls(std::sync::Arc<rustls::ClientConfig>),
27}
28
29mod encryption {
30 #[cfg(feature = "native-tls")]
31 pub mod native_tls {
32 use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
33
34 use std::io::{Read, Write};
35
36 use crate::{
37 error::TlsError,
38 stream::{MaybeTlsStream, Mode},
39 Error, Result,
40 };
41
42 pub fn wrap_stream<S>(
43 socket: S,
44 domain: &str,
45 mode: Mode,
46 tls_connector: Option<TlsConnector>,
47 ) -> Result<MaybeTlsStream<S>>
48 where
49 S: Read + Write,
50 {
51 match mode {
52 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
53 Mode::Tls => {
54 let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
55 let connector = try_connector.map_err(TlsError::Native)?;
56 let connected = connector.connect(domain, socket);
57 match connected {
58 Err(e) => match e {
59 TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
60 TlsHandshakeError::WouldBlock(_) => {
61 panic!("Bug: TLS handshake not blocked")
62 }
63 },
64 Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
65 }
66 }
67 }
68 }
69 }
70
71 #[cfg(feature = "__rustls-tls")]
72 pub mod rustls {
73 use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
74 use rustls_pki_types::ServerName;
75
76 use std::{
77 io::{Read, Write},
78 sync::Arc,
79 };
80
81 use crate::{
82 error::TlsError,
83 stream::{MaybeTlsStream, Mode},
84 Result,
85 };
86
87 pub fn wrap_stream<S>(
88 socket: S,
89 domain: &str,
90 mode: Mode,
91 tls_connector: Option<Arc<ClientConfig>>,
92 ) -> Result<MaybeTlsStream<S>>
93 where
94 S: Read + Write,
95 {
96 match mode {
97 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
98 Mode::Tls => {
99 let config = match tls_connector {
100 Some(config) => config,
101 None => {
102 #[allow(unused_mut)]
103 let mut root_store = RootCertStore::empty();
104
105 #[cfg(feature = "rustls-tls-native-roots")]
106 {
107 let rustls_native_certs::CertificateResult {
108 certs, errors, ..
109 } = rustls_native_certs::load_native_certs();
110
111 if !errors.is_empty() {
112 log::warn!(
113 "native root CA certificate loading errors: {errors:?}"
114 );
115 }
116
117 #[cfg(not(feature = "rustls-tls-webpki-roots"))]
120 if certs.is_empty() {
121 return Err(std::io::Error::new(std::io::ErrorKind::NotFound, format!("no native root CA certificates found (errors: {errors:?})")).into());
122 }
123
124 let total_number = certs.len();
125 let (number_added, number_ignored) =
126 root_store.add_parsable_certificates(certs);
127 log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
128 }
129 #[cfg(feature = "rustls-tls-webpki-roots")]
130 {
131 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
132 }
133
134 Arc::new(
135 ClientConfig::builder()
136 .with_root_certificates(root_store)
137 .with_no_client_auth(),
138 )
139 }
140 };
141 let domain = ServerName::try_from(domain)
142 .map_err(|_| TlsError::InvalidDnsName)?
143 .to_owned();
144 let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
145 let stream = StreamOwned::new(client, socket);
146
147 Ok(MaybeTlsStream::Rustls(stream))
148 }
149 }
150 }
151 }
152
153 pub mod plain {
154 use std::io::{Read, Write};
155
156 use crate::{
157 error::UrlError,
158 stream::{MaybeTlsStream, Mode},
159 Error, Result,
160 };
161
162 pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>>
163 where
164 S: Read + Write,
165 {
166 match mode {
167 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
168 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
169 }
170 }
171 }
172}
173
174type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>;
175
176pub fn client_tls<R, S>(
179 request: R,
180 stream: S,
181) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
182where
183 R: IntoClientRequest,
184 S: Read + Write,
185{
186 client_tls_with_config(request, stream, None, None)
187}
188
189pub fn client_tls_with_config<R, S>(
195 request: R,
196 stream: S,
197 config: Option<WebSocketConfig>,
198 connector: Option<Connector>,
199) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
200where
201 R: IntoClientRequest,
202 S: Read + Write,
203{
204 let request = request.into_client_request()?;
205
206 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
207 let domain = match request.uri().host() {
208 Some(d) => Ok(d.to_string()),
209 None => Err(Error::Url(UrlError::NoHostName)),
210 }?;
211
212 let mode = uri_mode(request.uri())?;
213
214 let stream = match connector {
215 Some(conn) => match conn {
216 #[cfg(feature = "native-tls")]
217 Connector::NativeTls(conn) => {
218 self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
219 }
220 #[cfg(feature = "__rustls-tls")]
221 Connector::Rustls(conn) => {
222 self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
223 }
224 Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
225 },
226 None => {
227 #[cfg(feature = "native-tls")]
228 {
229 self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
230 }
231 #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
232 {
233 self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
234 }
235 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
236 {
237 self::encryption::plain::wrap_stream(stream, mode)
238 }
239 }
240 }?;
241
242 client_with_config(request, stream, config)
243}