1use std::marker::PhantomData;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Poll, Waker};
18
19pub use futures_util::future::Either;
20use futures_util::stream::Stream;
21use pin_project_lite::pin_project;
22use tokio::sync::Mutex;
23
24pin_project! {
25struct SplitByMap<I, L, R, S, P> {
26 buf_left: Option<L>,
27 buf_right: Option<R>,
28 waker_left: Option<Waker>,
29 waker_right: Option<Waker>,
30 #[pin]
31 stream: S,
32 predicate: P,
33 item: PhantomData<I>,
34}
35}
36
37impl<I, L, R, S, P> SplitByMap<I, L, R, S, P>
38where
39 S: Stream<Item = I>,
40 P: Fn(I) -> Either<L, R>,
41{
42 fn new(stream: S, predicate: P) -> Arc<Mutex<Self>> {
43 Arc::new(Mutex::new(Self {
44 buf_right: None,
45 buf_left: None,
46 waker_right: None,
47 waker_left: None,
48 stream,
49 predicate,
50 item: PhantomData,
51 }))
52 }
53
54 fn poll_next_left(
55 self: std::pin::Pin<&mut Self>,
56 cx: &mut futures_util::task::Context<'_>,
57 ) -> std::task::Poll<Option<L>> {
58 let this = self.project();
59 *this.waker_left = Some(cx.waker().clone());
61 if let Some(item) = this.buf_left.take() {
62 return Poll::Ready(Some(item));
64 }
65 if this.buf_right.is_some() {
66 if let Some(waker) = this.waker_right {
69 waker.wake_by_ref();
70 }
71 return Poll::Pending;
72 }
73 match this.stream.poll_next(cx) {
74 Poll::Ready(Some(item)) => {
75 match (this.predicate)(item) {
76 Either::Left(left_item) => Poll::Ready(Some(left_item)),
77 Either::Right(right_item) => {
78 let _ = this.buf_right.replace(right_item);
81 if let Some(waker) = this.waker_right {
82 waker.wake_by_ref();
83 }
84 Poll::Pending
85 }
86 }
87 }
88 Poll::Ready(None) => {
89 if let Some(waker) = this.waker_right {
92 waker.wake_by_ref();
93 }
94 Poll::Ready(None)
95 }
96 Poll::Pending => Poll::Pending,
97 }
98 }
99
100 fn poll_next_right(
101 self: std::pin::Pin<&mut Self>,
102 cx: &mut futures_util::task::Context<'_>,
103 ) -> std::task::Poll<Option<R>> {
104 let this = self.project();
105 *this.waker_right = Some(cx.waker().clone());
107 if let Some(item) = this.buf_right.take() {
108 return Poll::Ready(Some(item));
110 }
111 if this.buf_left.is_some() {
112 if let Some(waker) = this.waker_left {
115 waker.wake_by_ref();
116 }
117 return Poll::Pending;
118 }
119 match this.stream.poll_next(cx) {
120 Poll::Ready(Some(item)) => {
121 match (this.predicate)(item) {
122 Either::Left(left_item) => {
123 let _ = this.buf_left.replace(left_item);
126 if let Some(waker) = this.waker_left {
127 waker.wake_by_ref();
128 }
129 Poll::Pending
130 }
131 Either::Right(right_item) => Poll::Ready(Some(right_item)),
132 }
133 }
134 Poll::Ready(None) => {
135 if let Some(waker) = this.waker_left {
138 waker.wake_by_ref();
139 }
140 Poll::Ready(None)
141 }
142 Poll::Pending => Poll::Pending,
143 }
144 }
145}
146
147#[allow(clippy::type_complexity)]
150pub struct LeftSplitByMap<I, L, R, S, P> {
151 stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>,
152}
153
154impl<I, L, R, S, P> LeftSplitByMap<I, L, R, S, P> {
155 #[allow(clippy::type_complexity)]
156 fn new(stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>) -> Self {
157 Self { stream }
158 }
159}
160
161impl<I, L, R, S, P> Stream for LeftSplitByMap<I, L, R, S, P>
162where
163 S: Stream<Item = I> + Unpin,
164 P: Fn(I) -> Either<L, R>,
165{
166 type Item = L;
167 fn poll_next(
168 self: std::pin::Pin<&mut Self>,
169 cx: &mut futures_util::task::Context<'_>,
170 ) -> std::task::Poll<Option<Self::Item>> {
171 let response = if let Ok(mut guard) = self.stream.try_lock() {
172 SplitByMap::poll_next_left(Pin::new(&mut guard), cx)
173 } else {
174 cx.waker().wake_by_ref();
175 Poll::Pending
176 };
177 response
178 }
179}
180
181#[allow(clippy::type_complexity)]
184pub struct RightSplitByMap<I, L, R, S, P> {
185 stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>,
186}
187
188impl<I, L, R, S, P> RightSplitByMap<I, L, R, S, P> {
189 #[allow(clippy::type_complexity)]
190 fn new(stream: Arc<Mutex<SplitByMap<I, L, R, S, P>>>) -> Self {
191 Self { stream }
192 }
193}
194
195impl<I, L, R, S, P> Stream for RightSplitByMap<I, L, R, S, P>
196where
197 S: Stream<Item = I> + Unpin,
198 P: Fn(I) -> Either<L, R>,
199{
200 type Item = R;
201 fn poll_next(
202 self: std::pin::Pin<&mut Self>,
203 cx: &mut futures_util::task::Context<'_>,
204 ) -> std::task::Poll<Option<Self::Item>> {
205 let response = if let Ok(mut guard) = self.stream.try_lock() {
206 SplitByMap::poll_next_right(Pin::new(&mut guard), cx)
207 } else {
208 cx.waker().wake_by_ref();
209 Poll::Pending
210 };
211 response
212 }
213}
214
215pub trait SplitStreamByMapExt<P, L, R>: Stream {
219 #[allow(clippy::type_complexity)]
248 fn split_by_map(
249 self,
250 predicate: P,
251 ) -> (
252 LeftSplitByMap<Self::Item, L, R, Self, P>,
253 RightSplitByMap<Self::Item, L, R, Self, P>,
254 )
255 where
256 P: Fn(Self::Item) -> Either<L, R>,
257 Self: Sized,
258 {
259 let stream = SplitByMap::new(self, predicate);
260 let true_stream = LeftSplitByMap::new(stream.clone());
261 let false_stream = RightSplitByMap::new(stream);
262 (true_stream, false_stream)
263 }
264}
265
266impl<T, P, L, R> SplitStreamByMapExt<P, L, R> for T where T: Stream + ?Sized {}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use futures_util::{stream, StreamExt};
272
273 #[tokio::test]
274 async fn test_split_by_map_basic() {
275 let input_stream = stream::iter(vec![1, 2, 3, 4, 5, 6]);
276 let (evens, odds) = input_stream.split_by_map(|x| {
277 if x % 2 == 0 {
278 Either::Left(x)
279 } else {
280 Either::Right(x)
281 }
282 });
283
284 tokio::spawn(async move {
285 let evens_collected: Vec<i32> = evens.collect().await;
286 assert_eq!(evens_collected, vec![2, 4, 6]);
287 });
288
289 tokio::spawn(async move {
290 let odds_collected: Vec<i32> = odds.collect().await;
291 assert_eq!(odds_collected, vec![1, 3, 5]);
292 });
293 }
294
295 #[tokio::test]
296 async fn test_split_by_map_empty_stream() {
297 let input_stream = stream::iter(Vec::<i32>::new());
298 let (left, right) = input_stream.split_by_map(|x| {
299 if x % 2 == 0 {
300 Either::Left(x)
301 } else {
302 Either::Right(x)
303 }
304 });
305
306 tokio::spawn(async move {
307 let left_collected: Vec<i32> = left.collect().await;
308 assert!(left_collected.is_empty());
309 });
310
311 tokio::spawn(async move {
312 let right_collected: Vec<i32> = right.collect().await;
313 assert!(right_collected.is_empty());
314 });
315 }
316
317 #[tokio::test]
318 async fn test_split_by_map_all_left() {
319 let input_stream = stream::iter(vec![2, 4, 6, 8]);
320 let (left, right) = input_stream.split_by_map(Either::<i32, i32>::Left);
321
322 tokio::spawn(async move {
323 let left_collected: Vec<i32> = left.collect().await;
324 assert_eq!(left_collected, vec![2, 4, 6, 8]);
325 });
326
327 tokio::spawn(async move {
328 let right_collected: Vec<i32> = right.collect().await;
329 assert!(right_collected.is_empty());
330 });
331 }
332
333 #[tokio::test]
334 async fn test_split_by_map_all_right() {
335 let input_stream = stream::iter(vec![1, 3, 5, 7]);
336 let (left, right) = input_stream.split_by_map(Either::<i32, i32>::Right);
337
338 tokio::spawn(async move {
339 let left_collected: Vec<i32> = left.collect().await;
340 assert!(left_collected.is_empty());
341 });
342
343 tokio::spawn(async move {
344 let right_collected: Vec<i32> = right.collect().await;
345 assert_eq!(right_collected, vec![1, 3, 5, 7]);
346 });
347 }
348}