tungstenite/
tls.rs

1//! Connection helper.
2use std::io::{Read, Write};
3
4use crate::{
5    client::{client_with_config, uri_mode, IntoClientRequest},
6    error::UrlError,
7    handshake::client::Response,
8    protocol::WebSocketConfig,
9    stream::MaybeTlsStream,
10    ClientHandshake, Error, HandshakeError, Result, WebSocket,
11};
12
13/// A connector that can be used when establishing connections, allowing to control whether
14/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
15/// `Plain` variant.
16#[non_exhaustive]
17#[allow(missing_debug_implementations)]
18pub enum Connector {
19    /// Plain (non-TLS) connector.
20    Plain,
21    /// `native-tls` TLS connector.
22    #[cfg(feature = "native-tls")]
23    NativeTls(native_tls_crate::TlsConnector),
24    /// `rustls` TLS connector.
25    #[cfg(feature = "__rustls-tls")]
26    Rustls(std::sync::Arc<rustls::ClientConfig>),
27}
28
29mod encryption {
30    #[cfg(feature = "native-tls")]
31    pub mod native_tls {
32        use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
33
34        use std::io::{Read, Write};
35
36        use crate::{
37            error::TlsError,
38            stream::{MaybeTlsStream, Mode},
39            Error, Result,
40        };
41
42        pub fn wrap_stream<S>(
43            socket: S,
44            domain: &str,
45            mode: Mode,
46            tls_connector: Option<TlsConnector>,
47        ) -> Result<MaybeTlsStream<S>>
48        where
49            S: Read + Write,
50        {
51            match mode {
52                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
53                Mode::Tls => {
54                    let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
55                    let connector = try_connector.map_err(TlsError::Native)?;
56                    let connected = connector.connect(domain, socket);
57                    match connected {
58                        Err(e) => match e {
59                            TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
60                            TlsHandshakeError::WouldBlock(_) => {
61                                panic!("Bug: TLS handshake not blocked")
62                            }
63                        },
64                        Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
65                    }
66                }
67            }
68        }
69    }
70
71    #[cfg(feature = "__rustls-tls")]
72    pub mod rustls {
73        use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
74        use rustls_pki_types::ServerName;
75
76        use std::{
77            io::{Read, Write},
78            sync::Arc,
79        };
80
81        use crate::{
82            error::TlsError,
83            stream::{MaybeTlsStream, Mode},
84            Result,
85        };
86
87        pub fn wrap_stream<S>(
88            socket: S,
89            domain: &str,
90            mode: Mode,
91            tls_connector: Option<Arc<ClientConfig>>,
92        ) -> Result<MaybeTlsStream<S>>
93        where
94            S: Read + Write,
95        {
96            match mode {
97                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
98                Mode::Tls => {
99                    let config = match tls_connector {
100                        Some(config) => config,
101                        None => {
102                            #[allow(unused_mut)]
103                            let mut root_store = RootCertStore::empty();
104
105                            #[cfg(feature = "rustls-tls-native-roots")]
106                            {
107                                let rustls_native_certs::CertificateResult {
108                                    certs, errors, ..
109                                } = rustls_native_certs::load_native_certs();
110
111                                if !errors.is_empty() {
112                                    log::warn!(
113                                        "native root CA certificate loading errors: {errors:?}"
114                                    );
115                                }
116
117                                // Not finding any native root CA certificates is not fatal if the
118                                // "rustls-tls-webpki-roots" feature is enabled.
119                                #[cfg(not(feature = "rustls-tls-webpki-roots"))]
120                                if certs.is_empty() {
121                                    return Err(std::io::Error::new(std::io::ErrorKind::NotFound, format!("no native root CA certificates found (errors: {errors:?})")).into());
122                                }
123
124                                let total_number = certs.len();
125                                let (number_added, number_ignored) =
126                                    root_store.add_parsable_certificates(certs);
127                                log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
128                            }
129                            #[cfg(feature = "rustls-tls-webpki-roots")]
130                            {
131                                root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
132                            }
133
134                            Arc::new(
135                                ClientConfig::builder()
136                                    .with_root_certificates(root_store)
137                                    .with_no_client_auth(),
138                            )
139                        }
140                    };
141                    let domain = ServerName::try_from(domain)
142                        .map_err(|_| TlsError::InvalidDnsName)?
143                        .to_owned();
144                    let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
145                    let stream = StreamOwned::new(client, socket);
146
147                    Ok(MaybeTlsStream::Rustls(stream))
148                }
149            }
150        }
151    }
152
153    pub mod plain {
154        use std::io::{Read, Write};
155
156        use crate::{
157            error::UrlError,
158            stream::{MaybeTlsStream, Mode},
159            Error, Result,
160        };
161
162        pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>>
163        where
164            S: Read + Write,
165        {
166            match mode {
167                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
168                Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
169            }
170        }
171    }
172}
173
174type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>;
175
176/// Creates a WebSocket handshake from a request and a stream,
177/// upgrading the stream to TLS if required.
178pub fn client_tls<R, S>(
179    request: R,
180    stream: S,
181) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
182where
183    R: IntoClientRequest,
184    S: Read + Write,
185{
186    client_tls_with_config(request, stream, None, None)
187}
188
189/// The same as [`client_tls()`] but one can specify a websocket configuration,
190/// and an optional connector. If no connector is specified, a default one will
191/// be created.
192///
193/// Please refer to [`client_tls()`] for more details.
194pub fn client_tls_with_config<R, S>(
195    request: R,
196    stream: S,
197    config: Option<WebSocketConfig>,
198    connector: Option<Connector>,
199) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
200where
201    R: IntoClientRequest,
202    S: Read + Write,
203{
204    let request = request.into_client_request()?;
205
206    #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
207    let domain = match request.uri().host() {
208        Some(d) => Ok(d.to_string()),
209        None => Err(Error::Url(UrlError::NoHostName)),
210    }?;
211
212    let mode = uri_mode(request.uri())?;
213
214    let stream = match connector {
215        Some(conn) => match conn {
216            #[cfg(feature = "native-tls")]
217            Connector::NativeTls(conn) => {
218                self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
219            }
220            #[cfg(feature = "__rustls-tls")]
221            Connector::Rustls(conn) => {
222                self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
223            }
224            Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
225        },
226        None => {
227            #[cfg(feature = "native-tls")]
228            {
229                self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
230            }
231            #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
232            {
233                self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
234            }
235            #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
236            {
237                self::encryption::plain::wrap_stream(stream, mode)
238            }
239        }
240    }?;
241
242    client_with_config(request, stream, config)
243}