1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use futures_util::ready;
8use tokio::io::{AsyncRead, AsyncWrite};
9
10struct ZeroWriter<I> {
11 inner: I,
12}
13
14impl<I> Future for ZeroWriter<I>
15where
16 I: AsyncWrite + Unpin,
17{
18 type Output = Result<(), tokio::io::Error>;
19
20 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
21 let empty_slice = [0u8; 0];
22 ready!(Pin::new(&mut self.inner).poll_write(cx, &empty_slice))?;
23 ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
24 Poll::Ready(Ok(()))
25 }
26}
27
28pub struct Copier<R, W> {
29 reader: R,
30 writer: W,
31 zero_packet: bool,
32}
33
34impl<R, W> Copier<R, W>
35where
36 R: AsyncRead + Unpin,
37 W: AsyncWrite + Unpin,
38{
39 pub fn new(reader: R, writer: W) -> Self {
40 Self {
41 reader,
42 writer,
43 zero_packet: false,
44 }
45 }
46
47 pub fn with_zero_packet_writing(reader: R, writer: W) -> Self {
48 Self {
49 reader,
50 writer,
51 zero_packet: true,
52 }
53 }
54
55 pub async fn copy(mut self) -> Result<u64, tokio::io::Error> {
56 let copied_size = tokio::io::copy(&mut self.reader, &mut self.writer).await?;
57 if self.zero_packet {
58 let zero_writer = ZeroWriter { inner: self.writer };
59 zero_writer.await?;
60 }
61 Ok(copied_size)
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use std::pin::Pin;
69 use tokio::io::{ReadBuf, Result};
70
71 struct MockReader {
72 data: Vec<u8>,
73 position: usize,
74 }
75
76 impl MockReader {
77 fn new(data: Vec<u8>) -> Self {
78 Self { data, position: 0 }
79 }
80 }
81
82 impl AsyncRead for MockReader {
83 fn poll_read(
84 self: Pin<&mut Self>,
85 _cx: &mut Context<'_>,
86 buf: &mut ReadBuf<'_>,
87 ) -> Poll<Result<()>> {
88 let this = self.get_mut();
89 let remaining = this.data.len() - this.position;
90
91 if remaining == 0 {
92 return Poll::Ready(Ok(())); }
94
95 let to_read = remaining.min(buf.remaining());
96 buf.put_slice(&this.data[this.position..this.position + to_read]);
97 this.position += to_read;
98
99 Poll::Ready(Ok(()))
100 }
101 }
102
103 struct MockWriter {
104 data: Vec<u8>,
105 }
106
107 impl MockWriter {
108 fn new() -> Self {
109 Self { data: Vec::new() }
110 }
111 }
112
113 impl AsyncWrite for MockWriter {
114 fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
115 let this = self.get_mut();
116 this.data.extend_from_slice(buf);
117 Poll::Ready(Ok(buf.len()))
118 }
119
120 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
121 Poll::Ready(Ok(()))
122 }
123
124 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
125 Poll::Ready(Ok(()))
126 }
127 }
128
129 #[tokio::test]
130 async fn test_copy() {
131 let data = b"Hello, world!".to_vec();
132 let reader = MockReader::new(data.clone());
133 let writer = MockWriter::new();
134
135 let copy = Copier::new(reader, writer).copy();
136 let result = copy.await;
137
138 assert!(result.is_ok());
139 assert_eq!(result.unwrap(), data.len() as u64);
140 }
141
142 #[tokio::test]
143 async fn test_copy_empty() {
144 let data = b"".to_vec();
145 let reader = MockReader::new(data.clone());
146 let writer = MockWriter::new();
147
148 let copy = Copier::new(reader, writer).copy();
149 let result = copy.await;
150
151 assert!(result.is_ok());
152 assert_eq!(result.unwrap(), 0);
153 }
154}