1macro_rules! ready {
4 ( $e:expr ) => {
5 match $e {
6 std::task::Poll::Ready(t) => t,
7 std::task::Poll::Pending => return std::task::Poll::Pending,
8 }
9 };
10}
11
12pub mod client;
13mod common;
14pub mod server;
15
16use common::{MidHandshake, Stream, TlsState};
17use futures_io::{AsyncRead, AsyncWrite};
18use rustls::server::AcceptedAlert;
19use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
20use std::future::Future;
21use std::io;
22#[cfg(unix)]
23use std::os::unix::io::{AsRawFd, RawFd};
24#[cfg(windows)]
25use std::os::windows::io::{AsRawSocket, RawSocket};
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30pub use pki_types;
31pub use rustls;
32
33#[derive(Clone)]
35pub struct TlsConnector {
36 inner: Arc<ClientConfig>,
37 #[cfg(feature = "early-data")]
38 early_data: bool,
39}
40
41#[derive(Clone)]
43pub struct TlsAcceptor {
44 inner: Arc<ServerConfig>,
45}
46
47impl From<Arc<ClientConfig>> for TlsConnector {
48 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
49 TlsConnector {
50 inner,
51 #[cfg(feature = "early-data")]
52 early_data: false,
53 }
54 }
55}
56
57impl From<Arc<ServerConfig>> for TlsAcceptor {
58 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
59 TlsAcceptor { inner }
60 }
61}
62
63impl TlsConnector {
64 #[cfg(feature = "early-data")]
69 pub fn early_data(mut self, flag: bool) -> TlsConnector {
70 self.early_data = flag;
71 self
72 }
73
74 #[inline]
75 pub fn connect<IO>(&self, domain: pki_types::ServerName<'static>, stream: IO) -> Connect<IO>
76 where
77 IO: AsyncRead + AsyncWrite + Unpin,
78 {
79 self.connect_with(domain, stream, |_| ())
80 }
81
82 pub fn connect_with<IO, F>(
83 &self,
84 domain: pki_types::ServerName<'static>,
85 stream: IO,
86 f: F,
87 ) -> Connect<IO>
88 where
89 IO: AsyncRead + AsyncWrite + Unpin,
90 F: FnOnce(&mut ClientConnection),
91 {
92 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
93 Ok(session) => session,
94 Err(error) => {
95 return Connect(MidHandshake::Error {
96 io: stream,
97 error: io::Error::new(io::ErrorKind::Other, error),
100 });
101 }
102 };
103 f(&mut session);
104
105 Connect(MidHandshake::Handshaking(client::TlsStream {
106 io: stream,
107
108 #[cfg(not(feature = "early-data"))]
109 state: TlsState::Stream,
110
111 #[cfg(feature = "early-data")]
112 state: if self.early_data && session.early_data().is_some() {
113 TlsState::EarlyData(0, Vec::new())
114 } else {
115 TlsState::Stream
116 },
117
118 #[cfg(feature = "early-data")]
119 early_waker: None,
120
121 session,
122 }))
123 }
124}
125
126impl TlsAcceptor {
127 #[inline]
128 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
129 where
130 IO: AsyncRead + AsyncWrite + Unpin,
131 {
132 self.accept_with(stream, |_| ())
133 }
134
135 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
136 where
137 IO: AsyncRead + AsyncWrite + Unpin,
138 F: FnOnce(&mut ServerConnection),
139 {
140 let mut session = match ServerConnection::new(self.inner.clone()) {
141 Ok(session) => session,
142 Err(error) => {
143 return Accept(MidHandshake::Error {
144 io: stream,
145 error: io::Error::new(io::ErrorKind::Other, error),
148 });
149 }
150 };
151 f(&mut session);
152
153 Accept(MidHandshake::Handshaking(server::TlsStream {
154 session,
155 io: stream,
156 state: TlsState::Stream,
157 }))
158 }
159}
160
161pub struct LazyConfigAcceptor<IO> {
162 acceptor: rustls::server::Acceptor,
163 io: Option<IO>,
164 alert: Option<(rustls::Error, AcceptedAlert)>,
165}
166
167impl<IO> LazyConfigAcceptor<IO>
168where
169 IO: AsyncRead + AsyncWrite + Unpin,
170{
171 #[inline]
172 pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
173 Self {
174 acceptor,
175 io: Some(io),
176 alert: None,
177 }
178 }
179}
180
181impl<IO> Future for LazyConfigAcceptor<IO>
182where
183 IO: AsyncRead + AsyncWrite + Unpin,
184{
185 type Output = Result<StartHandshake<IO>, io::Error>;
186
187 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188 let this = self.get_mut();
189 loop {
190 let io = match this.io.as_mut() {
191 Some(io) => io,
192 None => {
193 return Poll::Ready(Err(io::Error::new(
194 io::ErrorKind::Other,
195 "acceptor cannot be polled after acceptance",
196 )))
197 }
198 };
199
200 if let Some((err, mut alert)) = this.alert.take() {
201 match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
202 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
203 this.alert = Some((err, alert));
204 return Poll::Pending;
205 }
206 Ok(0) | Err(_) => {
207 return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
208 }
209 Ok(_) => {
210 this.alert = Some((err, alert));
211 continue;
212 }
213 };
214 }
215
216 let mut reader = common::SyncReadAdapter { io, cx };
217 match this.acceptor.read_tls(&mut reader) {
218 Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
219 Ok(_) => {}
220 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
221 Err(e) => return Err(e).into(),
222 }
223
224 match this.acceptor.accept() {
225 Ok(Some(accepted)) => {
226 let io = this.io.take().unwrap();
227 return Poll::Ready(Ok(StartHandshake { accepted, io }));
228 }
229 Ok(None) => {}
230 Err((err, alert)) => {
231 this.alert = Some((err, alert));
232 }
233 }
234 }
235 }
236}
237
238pub struct StartHandshake<IO> {
239 accepted: rustls::server::Accepted,
240 io: IO,
241}
242
243impl<IO> StartHandshake<IO>
244where
245 IO: AsyncRead + AsyncWrite + Unpin,
246{
247 pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
248 self.accepted.client_hello()
249 }
250
251 pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
252 self.into_stream_with(config, |_| ())
253 }
254
255 pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
256 where
257 F: FnOnce(&mut ServerConnection),
258 {
259 let mut conn = match self.accepted.into_connection(config) {
260 Ok(conn) => conn,
261 Err((error, alert)) => {
262 return Accept(MidHandshake::SendAlert {
263 io: self.io,
264 error: io::Error::new(io::ErrorKind::Other, error),
267 alert,
268 });
269 }
270 };
271 f(&mut conn);
272
273 Accept(MidHandshake::Handshaking(server::TlsStream {
274 session: conn,
275 io: self.io,
276 state: TlsState::Stream,
277 }))
278 }
279}
280
281pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
284
285pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
288
289pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
291
292pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
294
295impl<IO> Connect<IO> {
296 #[inline]
297 pub fn into_fallible(self) -> FallibleConnect<IO> {
298 FallibleConnect(self.0)
299 }
300}
301
302impl<IO> Accept<IO> {
303 #[inline]
304 pub fn into_fallible(self) -> FallibleAccept<IO> {
305 FallibleAccept(self.0)
306 }
307}
308
309impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
310 type Output = io::Result<client::TlsStream<IO>>;
311
312 #[inline]
313 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
314 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
315 }
316}
317
318impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
319 type Output = io::Result<server::TlsStream<IO>>;
320
321 #[inline]
322 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
323 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
324 }
325}
326
327impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
328 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
329
330 #[inline]
331 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
332 Pin::new(&mut self.0).poll(cx)
333 }
334}
335
336impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
337 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
338
339 #[inline]
340 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
341 Pin::new(&mut self.0).poll(cx)
342 }
343}
344
345#[derive(Debug)]
350pub enum TlsStream<T> {
351 Client(client::TlsStream<T>),
352 Server(server::TlsStream<T>),
353}
354
355impl<T> TlsStream<T> {
356 pub fn get_ref(&self) -> (&T, &CommonState) {
357 use TlsStream::*;
358 match self {
359 Client(io) => {
360 let (io, session) = io.get_ref();
361 (io, session)
362 }
363 Server(io) => {
364 let (io, session) = io.get_ref();
365 (io, session)
366 }
367 }
368 }
369
370 pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
371 use TlsStream::*;
372 match self {
373 Client(io) => {
374 let (io, session) = io.get_mut();
375 (io, &mut *session)
376 }
377 Server(io) => {
378 let (io, session) = io.get_mut();
379 (io, &mut *session)
380 }
381 }
382 }
383}
384
385impl<T> From<client::TlsStream<T>> for TlsStream<T> {
386 fn from(s: client::TlsStream<T>) -> Self {
387 Self::Client(s)
388 }
389}
390
391impl<T> From<server::TlsStream<T>> for TlsStream<T> {
392 fn from(s: server::TlsStream<T>) -> Self {
393 Self::Server(s)
394 }
395}
396
397#[cfg(unix)]
398impl<S> AsRawFd for TlsStream<S>
399where
400 S: AsRawFd,
401{
402 fn as_raw_fd(&self) -> RawFd {
403 self.get_ref().0.as_raw_fd()
404 }
405}
406
407#[cfg(windows)]
408impl<S> AsRawSocket for TlsStream<S>
409where
410 S: AsRawSocket,
411{
412 fn as_raw_socket(&self) -> RawSocket {
413 self.get_ref().0.as_raw_socket()
414 }
415}
416
417impl<T> AsyncRead for TlsStream<T>
418where
419 T: AsyncRead + AsyncWrite + Unpin,
420{
421 #[inline]
422 fn poll_read(
423 self: Pin<&mut Self>,
424 cx: &mut Context<'_>,
425 buf: &mut [u8],
426 ) -> Poll<io::Result<usize>> {
427 match self.get_mut() {
428 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
429 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
430 }
431 }
432}
433
434impl<T> AsyncWrite for TlsStream<T>
435where
436 T: AsyncRead + AsyncWrite + Unpin,
437{
438 #[inline]
439 fn poll_write(
440 self: Pin<&mut Self>,
441 cx: &mut Context<'_>,
442 buf: &[u8],
443 ) -> Poll<io::Result<usize>> {
444 match self.get_mut() {
445 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
446 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
447 }
448 }
449
450 #[inline]
451 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
452 match self.get_mut() {
453 TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
454 TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
455 }
456 }
457
458 #[inline]
459 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
460 match self.get_mut() {
461 TlsStream::Client(x) => Pin::new(x).poll_close(cx),
462 TlsStream::Server(x) => Pin::new(x).poll_close(cx),
463 }
464 }
465
466 #[inline]
467 fn poll_write_vectored(
468 self: Pin<&mut Self>,
469 cx: &mut Context<'_>,
470 bufs: &[io::IoSlice<'_>],
471 ) -> Poll<io::Result<usize>> {
472 match self.get_mut() {
473 TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
474 TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
475 }
476 }
477}