ferron/util/
cgi_response.rs

1use memmem::{Searcher, TwoWaySearcher};
2use std::io::Error;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
6
7// Constant defining the capacity of the response buffer
8const RESPONSE_BUFFER_CAPACITY: usize = 16384;
9
10// Struct representing a response, which wraps an async read stream
11pub struct CgiResponse<R>
12where
13  R: AsyncRead + Unpin,
14{
15  stream: R,
16  response_buf: Vec<u8>,
17  response_head_length: Option<usize>,
18}
19
20impl<R> CgiResponse<R>
21where
22  R: AsyncRead + Unpin,
23{
24  // Constructor to create a new CgiResponse instance
25  pub fn new(stream: R) -> Self {
26    Self {
27      stream,
28      response_buf: Vec::with_capacity(RESPONSE_BUFFER_CAPACITY),
29      response_head_length: None,
30    }
31  }
32
33  // Asynchronous method to get the response headers
34  pub async fn get_head(&mut self) -> Result<&[u8], Error> {
35    let mut temp_buf = [0u8; RESPONSE_BUFFER_CAPACITY];
36    let rnrn = TwoWaySearcher::new(b"\r\n\r\n");
37    let nrnr = TwoWaySearcher::new(b"\n\r\n\r");
38    let nn = TwoWaySearcher::new(b"\n\n");
39    let rr = TwoWaySearcher::new(b"\r\r");
40    let to_parse_length;
41
42    loop {
43      // Read data from the stream into the temporary buffer
44      let read_bytes = self.stream.read(&mut temp_buf).await?;
45
46      // If no bytes are read, return an empty response head
47      if read_bytes == 0 {
48        self.response_head_length = Some(0);
49        return Ok(&[0u8; 0]);
50      }
51
52      // If the response buffer exceeds the capacity, return an empty response head
53      if self.response_buf.len() + read_bytes > RESPONSE_BUFFER_CAPACITY {
54        self.response_head_length = Some(0);
55        return Ok(&[0u8; 0]);
56      }
57
58      // Determine the starting point for searching the "\r\n\r\n" sequence
59      let begin_rnrn_or_nrnr_search = self.response_buf.len().saturating_sub(3);
60      let begin_rr_or_nn_search = self.response_buf.len().saturating_sub(1);
61      self.response_buf.extend_from_slice(&temp_buf[..read_bytes]);
62
63      // Search for the "\r\n\r\n" sequence in the response buffer
64      if let Some(rnrn_index) = rnrn.search_in(&self.response_buf[begin_rnrn_or_nrnr_search..]) {
65        to_parse_length = begin_rnrn_or_nrnr_search + rnrn_index + 4;
66        break;
67      } else if let Some(nrnr_index) =
68        nrnr.search_in(&self.response_buf[begin_rnrn_or_nrnr_search..])
69      {
70        to_parse_length = begin_rnrn_or_nrnr_search + nrnr_index + 4;
71        break;
72      } else if let Some(nn_index) = nn.search_in(&self.response_buf[begin_rr_or_nn_search..]) {
73        to_parse_length = begin_rr_or_nn_search + nn_index + 2;
74        break;
75      } else if let Some(rr_index) = rr.search_in(&self.response_buf[begin_rr_or_nn_search..]) {
76        to_parse_length = begin_rr_or_nn_search + rr_index + 2;
77        break;
78      }
79    }
80
81    // Set the length of the response header
82    self.response_head_length = Some(to_parse_length);
83
84    // Return the response header as a byte slice
85    Ok(&self.response_buf[..to_parse_length])
86  }
87}
88
89// Implementation of AsyncRead for the CgiResponse struct
90impl<R> AsyncRead for CgiResponse<R>
91where
92  R: AsyncRead + Unpin,
93{
94  fn poll_read(
95    mut self: Pin<&mut Self>,
96    cx: &mut Context<'_>,
97    buf: &mut ReadBuf<'_>,
98  ) -> Poll<std::io::Result<()>> {
99    // If the response header length is known and the buffer contains more data than the header length
100    if let Some(response_head_length) = self.response_head_length {
101      if self.response_buf.len() > response_head_length {
102        let remaining_data = &self.response_buf[response_head_length..];
103        let to_read = remaining_data.len().min(buf.remaining());
104        buf.put_slice(&remaining_data[..to_read]);
105        self.response_head_length = Some(response_head_length + to_read);
106        return Poll::Ready(Ok(()));
107      }
108    }
109
110    // Create a temporary buffer to hold the data to be consumed
111    let stream = Pin::new(&mut self.stream);
112    match stream.poll_read(cx, buf) {
113      Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
114      other => other,
115    }
116  }
117}
118
119#[cfg(test)]
120mod tests {
121  use super::*;
122  use tokio::io::AsyncReadExt;
123  use tokio_test::io::Builder;
124
125  #[tokio::test]
126  async fn test_get_head() {
127    let data = b"Content-Type: text/plain\r\n\r\n";
128    let mut stream = Builder::new().read(data).build();
129    let mut response = CgiResponse::new(&mut stream);
130
131    let head = response.get_head().await.unwrap();
132    assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");
133  }
134
135  #[tokio::test]
136  async fn test_get_head_nn() {
137    let data = b"Content-Type: text/plain\n\n";
138    let mut stream = Builder::new().read(data).build();
139    let mut response = CgiResponse::new(&mut stream);
140
141    let head = response.get_head().await.unwrap();
142    assert_eq!(head, b"Content-Type: text/plain\n\n");
143  }
144
145  #[tokio::test]
146  async fn test_get_head_large_headers() {
147    let data = b"Content-Type: text/plain\r\n";
148    let large_header = vec![b'A'; RESPONSE_BUFFER_CAPACITY + 10]
149      .into_iter()
150      .collect::<Vec<u8>>();
151    let mut stream = Builder::new().read(data).read(&large_header).build();
152    let mut response = CgiResponse::new(&mut stream);
153
154    let result = response.get_head().await;
155    assert_eq!(result.unwrap().len(), 0);
156
157    // Consume the remaining data to avoid panicking
158    let mut remaining_data = vec![0u8; RESPONSE_BUFFER_CAPACITY + 10];
159    let _ = response.stream.read(&mut remaining_data).await;
160  }
161
162  #[tokio::test]
163  async fn test_get_head_premature_eof() {
164    let data = b"Content-Type: text/plain\r\n";
165    let mut stream = Builder::new().read(data).build();
166    let mut response = CgiResponse::new(&mut stream);
167
168    let result = response.get_head().await;
169    assert_eq!(result.unwrap().len(), 0);
170  }
171
172  #[tokio::test]
173  async fn test_poll_read() {
174    let data = b"Content-Type: text/plain\r\n\r\nHello, world!";
175    let mut stream = Builder::new().read(data).build();
176    let mut response = CgiResponse::new(&mut stream);
177
178    let head = response.get_head().await.unwrap();
179    assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");
180
181    let mut buf = vec![0u8; 13];
182    let n = response.read(&mut buf).await.unwrap();
183    assert_eq!(n, 13);
184    assert_eq!(&buf[..n], b"Hello, world!");
185  }
186}