ferron/util/
split_stream_by_map.rs

1// Copyright (c) Andrew Burkett
2// Portions of this file are derived from `split-stream-by` (https://github.com/drewkett/split-stream-by).
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in all
12// copies or substantial portions of the Software.
13
14use std::marker::PhantomData;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Poll, Waker};
18
19pub use futures_util::future::Either;
20use futures_util::stream::Stream;
21use pin_project_lite::pin_project;
22use tokio::sync::Mutex;
23
24pin_project! {
25struct SplitByMap<I, L, R, S, P> {
26    buf_left: Option<L>,
27    buf_right: Option<R>,
28    waker_left: Option<Waker>,
29    waker_right: Option<Waker>,
30    #[pin]
31    stream: S,
32    predicate: P,
33    item: PhantomData<I>,
34}
35}
36
37impl<I, L, R, S, P> SplitByMap<I, L, R, S, P>
38where
39  S: Stream<Item = I>,
40  P: Fn(I) -> Either<L, R>,
41{
42  fn new(stream: S, predicate: P) -> Arc<Mutex<Self>> {
43    Arc::new(Mutex::new(Self {
44      buf_right: None,
45      buf_left: None,
46      waker_right: None,
47      waker_left: None,
48      stream,
49      predicate,
50      item: PhantomData,
51    }))
52  }
53
54  fn poll_next_left(
55    self: std::pin::Pin<&mut Self>,
56    cx: &mut futures_util::task::Context<'_>,
57  ) -> std::task::Poll<Option<L>> {
58    let this = self.project();
59    // Assign the waker multiple times, because if it was only once, the waking might fail
60    *this.waker_left = Some(cx.waker().clone());
61    if let Some(item) = this.buf_left.take() {
62      // There was already a value in the buffer. Return that value
63      return Poll::Ready(Some(item));
64    }
65    if this.buf_right.is_some() {
66      // There is a value available for the other stream. Wake that stream if possible
67      // and return pending since we can't store multiple values for a stream
68      if let Some(waker) = this.waker_right {
69        waker.wake_by_ref();
70      }
71      return Poll::Pending;
72    }
73    match this.stream.poll_next(cx) {
74      Poll::Ready(Some(item)) => {
75        match (this.predicate)(item) {
76          Either::Left(left_item) => Poll::Ready(Some(left_item)),
77          Either::Right(right_item) => {
78            // This value is not what we wanted. Store it and notify other partition
79            // task if it exists
80            let _ = this.buf_right.replace(right_item);
81            if let Some(waker) = this.waker_right {
82              waker.wake_by_ref();
83            }
84            Poll::Pending
85          }
86        }
87      }
88      Poll::Ready(None) => {
89        // If the underlying stream is finished, the `right` stream also must be
90        // finished, so wake it in case nothing else polls it
91        if let Some(waker) = this.waker_right {
92          waker.wake_by_ref();
93        }
94        Poll::Ready(None)
95      }
96      Poll::Pending => Poll::Pending,
97    }
98  }
99
100  fn poll_next_right(
101    self: std::pin::Pin<&mut Self>,
102    cx: &mut futures_util::task::Context<'_>,
103  ) -> std::task::Poll<Option<R>> {
104    let this = self.project();
105    // Assign the waker multiple times, because if it was only once, the waking might fail
106    *this.waker_right = Some(cx.waker().clone());
107    if let Some(item) = this.buf_right.take() {
108      // There was already a value in the buffer. Return that value
109      return Poll::Ready(Some(item));
110    }
111    if this.buf_left.is_some() {
112      // There is a value available for the other stream. Wake that stream if possible
113      // and return pending since we can't store multiple values for a stream
114      if let Some(waker) = this.waker_left {
115        waker.wake_by_ref();
116      }
117      return Poll::Pending;
118    }
119    match this.stream.poll_next(cx) {
120      Poll::Ready(Some(item)) => {
121        match (this.predicate)(item) {
122          Either::Left(left_item) => {
123            // This value is not what we wanted. Store it and notify other partition
124            // task if it exists
125            let _ = this.buf_left.replace(left_item);
126            if let Some(waker) = this.waker_left {
127              waker.wake_by_ref();
128            }
129            Poll::Pending
130          }
131          Either::Right(right_item) => Poll::Ready(Some(right_item)),
132        }
133      }
134      Poll::Ready(None) => {
135        // If the underlying stream is finished, the `left` stream also must be
136        // finished, so wake it in case nothing else polls it
137        if let Some(waker) = this.waker_left {
138          waker.wake_by_ref();
139        }
140        Poll::Ready(None)
141      }
142      Poll::Pending => Poll::Pending,
143    }
144  }
145}
146
147/// A struct that implements `Stream` which returns the inner values where
148/// the predicate returns `Either::Left(..)` when using `split_by_map`
149#[allow(clippy::type_complexity)]
150pub struct LeftSplitByMap<I, L, R, S, P> {
151  stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>,
152}
153
154impl<I, L, R, S, P> LeftSplitByMap<I, L, R, S, P> {
155  #[allow(clippy::type_complexity)]
156  fn new(stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>) -> Self {
157    Self { stream }
158  }
159}
160
161impl<I, L, R, S, P> Stream for LeftSplitByMap<I, L, R, S, P>
162where
163  S: Stream<Item = I> + Unpin,
164  P: Fn(I) -> Either<L, R>,
165{
166  type Item = L;
167  fn poll_next(
168    self: std::pin::Pin<&mut Self>,
169    cx: &mut futures_util::task::Context<'_>,
170  ) -> std::task::Poll<Option<Self::Item>> {
171    let response = if let Ok(mut guard) = self.stream.try_lock() {
172      SplitByMap::poll_next_left(Pin::new(&mut guard), cx)
173    } else {
174      cx.waker().wake_by_ref();
175      Poll::Pending
176    };
177    response
178  }
179}
180
181/// A struct that implements `Stream` which returns the inner values where
182/// the predicate returns `Either::Right(..)` when using `split_by_map`
183#[allow(clippy::type_complexity)]
184pub struct RightSplitByMap<I, L, R, S, P> {
185  stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>,
186}
187
188impl<I, L, R, S, P> RightSplitByMap<I, L, R, S, P> {
189  #[allow(clippy::type_complexity)]
190  fn new(stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>) -> Self {
191    Self { stream }
192  }
193}
194
195impl<I, L, R, S, P> Stream for RightSplitByMap<I, L, R, S, P>
196where
197  S: Stream<Item = I> + Unpin,
198  P: Fn(I) -> Either<L, R>,
199{
200  type Item = R;
201  fn poll_next(
202    self: std::pin::Pin<&mut Self>,
203    cx: &mut futures_util::task::Context<'_>,
204  ) -> std::task::Poll<Option<Self::Item>> {
205    let response = if let Ok(mut guard) = self.stream.try_lock() {
206      SplitByMap::poll_next_right(Pin::new(&mut guard), cx)
207    } else {
208      cx.waker().wake_by_ref();
209      Poll::Pending
210    };
211    response
212  }
213}
214
215/// This extension trait provides the functionality for splitting a
216/// stream by a predicate of type `Fn(Self::Item) -> Either<L,R>`. The resulting
217/// streams will yield types `L` and `R` respectively
218pub trait SplitStreamByMapExt<P, L, R>: Stream {
219  /// This takes ownership of a stream and returns two streams based on a
220  /// predicate. The predicate takes an item by value and returns
221  /// `Either::Left(..)` or `Either::Right(..)` where the inner
222  /// values of `Left` and `Right` become the items of the two respective
223  /// streams
224  ///
225  /// ```
226  /// use split_stream_by::{Either, SplitStreamByMapExt};
227  /// struct Request {
228  ///   //...
229  /// }
230  /// struct Response {
231  ///   //...
232  /// }
233  /// enum Message {
234  ///   Request(Request),
235  ///   Response(Response)
236  /// }
237  /// let incoming_stream = futures::stream::iter([
238  ///   Message::Request(Request {}),
239  ///   Message::Response(Response {}),
240  ///   Message::Response(Response {}),
241  /// ]);
242  /// let (mut request_stream, mut response_stream) = incoming_stream.split_by_map(|item| match item {
243  ///   Message::Request(req) => Either::Left(req),
244  ///   Message::Response(res) => Either::Right(res),
245  /// });
246  /// ```
247  #[allow(clippy::type_complexity)]
248  fn split_by_map(
249    self,
250    predicate: P,
251  ) -> (
252    LeftSplitByMap<Self::Item, L, R, Self, P>,
253    RightSplitByMap<Self::Item, L, R, Self, P>,
254  )
255  where
256    P: Fn(Self::Item) -> Either<L, R>,
257    Self: Sized,
258  {
259    let stream = SplitByMap::new(self, predicate);
260    let true_stream = LeftSplitByMap::new(stream.clone());
261    let false_stream = RightSplitByMap::new(stream);
262    (true_stream, false_stream)
263  }
264}
265
266impl<T, P, L, R> SplitStreamByMapExt<P, L, R> for T where T: Stream + ?Sized {}
267
268#[cfg(test)]
269mod tests {
270  use super::*;
271  use futures_util::{stream, StreamExt};
272
273  #[tokio::test]
274  async fn test_split_by_map_basic() {
275    let input_stream = stream::iter(vec![1, 2, 3, 4, 5, 6]);
276    let (evens, odds) = input_stream.split_by_map(|x| {
277      if x % 2 == 0 {
278        Either::Left(x)
279      } else {
280        Either::Right(x)
281      }
282    });
283
284    tokio::spawn(async move {
285      let evens_collected: Vec<i32> = evens.collect().await;
286      assert_eq!(evens_collected, vec![2, 4, 6]);
287    });
288
289    tokio::spawn(async move {
290      let odds_collected: Vec<i32> = odds.collect().await;
291      assert_eq!(odds_collected, vec![1, 3, 5]);
292    });
293  }
294
295  #[tokio::test]
296  async fn test_split_by_map_empty_stream() {
297    let input_stream = stream::iter(Vec::<i32>::new());
298    let (left, right) = input_stream.split_by_map(|x| {
299      if x % 2 == 0 {
300        Either::Left(x)
301      } else {
302        Either::Right(x)
303      }
304    });
305
306    tokio::spawn(async move {
307      let left_collected: Vec<i32> = left.collect().await;
308      assert!(left_collected.is_empty());
309    });
310
311    tokio::spawn(async move {
312      let right_collected: Vec<i32> = right.collect().await;
313      assert!(right_collected.is_empty());
314    });
315  }
316
317  #[tokio::test]
318  async fn test_split_by_map_all_left() {
319    let input_stream = stream::iter(vec![2, 4, 6, 8]);
320    let (left, right) = input_stream.split_by_map(Either::<i32, i32>::Left);
321
322    tokio::spawn(async move {
323      let left_collected: Vec<i32> = left.collect().await;
324      assert_eq!(left_collected, vec![2, 4, 6, 8]);
325    });
326
327    tokio::spawn(async move {
328      let right_collected: Vec<i32> = right.collect().await;
329      assert!(right_collected.is_empty());
330    });
331  }
332
333  #[tokio::test]
334  async fn test_split_by_map_all_right() {
335    let input_stream = stream::iter(vec![1, 3, 5, 7]);
336    let (left, right) = input_stream.split_by_map(Either::<i32, i32>::Right);
337
338    tokio::spawn(async move {
339      let left_collected: Vec<i32> = left.collect().await;
340      assert!(left_collected.is_empty());
341    });
342
343    tokio::spawn(async move {
344      let right_collected: Vec<i32> = right.collect().await;
345      assert_eq!(right_collected, vec![1, 3, 5, 7]);
346    });
347  }
348}