ferron/util/
copy_move.rs

1use std::{
2  future::Future,
3  pin::Pin,
4  task::{Context, Poll},
5};
6
7use futures_util::ready;
8use tokio::io::{AsyncRead, AsyncWrite};
9
10struct ZeroWriter<I> {
11  inner: I,
12}
13
14impl<I> Future for ZeroWriter<I>
15where
16  I: AsyncWrite + Unpin,
17{
18  type Output = Result<(), tokio::io::Error>;
19
20  fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
21    let empty_slice = [0u8; 0];
22    ready!(Pin::new(&mut self.inner).poll_write(cx, &empty_slice))?;
23    ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
24    Poll::Ready(Ok(()))
25  }
26}
27
28pub struct Copier<R, W> {
29  reader: R,
30  writer: W,
31  zero_packet: bool,
32}
33
34impl<R, W> Copier<R, W>
35where
36  R: AsyncRead + Unpin,
37  W: AsyncWrite + Unpin,
38{
39  pub fn new(reader: R, writer: W) -> Self {
40    Self {
41      reader,
42      writer,
43      zero_packet: false,
44    }
45  }
46
47  pub fn with_zero_packet_writing(reader: R, writer: W) -> Self {
48    Self {
49      reader,
50      writer,
51      zero_packet: true,
52    }
53  }
54
55  pub async fn copy(mut self) -> Result<u64, tokio::io::Error> {
56    let copied_size = tokio::io::copy(&mut self.reader, &mut self.writer).await?;
57    if self.zero_packet {
58      let zero_writer = ZeroWriter { inner: self.writer };
59      zero_writer.await?;
60    }
61    Ok(copied_size)
62  }
63}
64
65#[cfg(test)]
66mod tests {
67  use super::*;
68  use std::pin::Pin;
69  use tokio::io::{ReadBuf, Result};
70
71  struct MockReader {
72    data: Vec<u8>,
73    position: usize,
74  }
75
76  impl MockReader {
77    fn new(data: Vec<u8>) -> Self {
78      Self { data, position: 0 }
79    }
80  }
81
82  impl AsyncRead for MockReader {
83    fn poll_read(
84      self: Pin<&mut Self>,
85      _cx: &mut Context<'_>,
86      buf: &mut ReadBuf<'_>,
87    ) -> Poll<Result<()>> {
88      let this = self.get_mut();
89      let remaining = this.data.len() - this.position;
90
91      if remaining == 0 {
92        return Poll::Ready(Ok(())); // EOF
93      }
94
95      let to_read = remaining.min(buf.remaining());
96      buf.put_slice(&this.data[this.position..this.position + to_read]);
97      this.position += to_read;
98
99      Poll::Ready(Ok(()))
100    }
101  }
102
103  struct MockWriter {
104    data: Vec<u8>,
105  }
106
107  impl MockWriter {
108    fn new() -> Self {
109      Self { data: Vec::new() }
110    }
111  }
112
113  impl AsyncWrite for MockWriter {
114    fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
115      let this = self.get_mut();
116      this.data.extend_from_slice(buf);
117      Poll::Ready(Ok(buf.len()))
118    }
119
120    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
121      Poll::Ready(Ok(()))
122    }
123
124    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
125      Poll::Ready(Ok(()))
126    }
127  }
128
129  #[tokio::test]
130  async fn test_copy() {
131    let data = b"Hello, world!".to_vec();
132    let reader = MockReader::new(data.clone());
133    let writer = MockWriter::new();
134
135    let copy = Copier::new(reader, writer).copy();
136    let result = copy.await;
137
138    assert!(result.is_ok());
139    assert_eq!(result.unwrap(), data.len() as u64);
140  }
141
142  #[tokio::test]
143  async fn test_copy_empty() {
144    let data = b"".to_vec();
145    let reader = MockReader::new(data.clone());
146    let writer = MockWriter::new();
147
148    let copy = Copier::new(reader, writer).copy();
149    let result = copy.await;
150
151    assert!(result.is_ok());
152    assert_eq!(result.unwrap(), 0);
153  }
154}