Skip to content

Commit bbbddcc

Browse files
committed
Add stream selection early exit
1 parent 2e30ec3 commit bbbddcc

File tree

4 files changed

+113
-20
lines changed

4 files changed

+113
-20
lines changed

futures-util/src/stream/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ mod poll_immediate;
100100
pub use self::poll_immediate::{poll_immediate, PollImmediate};
101101

102102
mod select;
103-
pub use self::select::{select, Select};
103+
pub use self::select::{select, select_early_exit, Select};
104104

105105
mod select_with_strategy;
106-
pub use self::select_with_strategy::{select_with_strategy, PollNext, SelectWithStrategy};
106+
pub use self::select_with_strategy::{select_with_strategy, PollNext, SelectWithStrategy, ExitStrategy};
107107

108108
mod unfold;
109109
pub use self::unfold::{unfold, Unfold};

futures-util/src/stream/select.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::assert_stream;
2-
use crate::stream::{select_with_strategy, PollNext, SelectWithStrategy};
2+
use crate::stream::{select_with_strategy, PollNext, SelectWithStrategy, ExitStrategy};
33
use core::pin::Pin;
44
use futures_core::stream::{FusedStream, Stream};
55
use futures_core::task::{Context, Poll};
@@ -45,6 +45,23 @@ pin_project! {
4545
/// # });
4646
/// ```
4747
pub fn select<St1, St2>(stream1: St1, stream2: St2) -> Select<St1, St2>
48+
where
49+
St1: Stream,
50+
St2: Stream<Item = St1::Item>,
51+
{
52+
select_with_exit(stream1, stream2, ExitStrategy::WhenBothFinish)
53+
}
54+
55+
/// Same as `select`, but finishes when either stream finishes
56+
pub fn select_early_exit<St1, St2>(stream1: St1, stream2: St2) -> Select<St1, St2>
57+
where
58+
St1: Stream,
59+
St2: Stream<Item = St1::Item>,
60+
{
61+
select_with_exit(stream1, stream2, ExitStrategy::WhenEitherFinish)
62+
}
63+
64+
fn select_with_exit<St1, St2>(stream1: St1, stream2: St2, exit_strategy: ExitStrategy) -> Select<St1, St2>
4865
where
4966
St1: Stream,
5067
St2: Stream<Item = St1::Item>,
@@ -54,7 +71,7 @@ where
5471
}
5572

5673
assert_stream::<St1::Item, _>(Select {
57-
inner: select_with_strategy(stream1, stream2, round_robin),
74+
inner: select_with_strategy(stream1, stream2, round_robin, exit_strategy),
5875
})
5976
}
6077

futures-util/src/stream/select_with_strategy.rs

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ impl Default for PollNext {
3636
}
3737
}
3838

39+
#[derive(PartialEq, Eq, Clone, Copy)]
3940
enum InternalState {
4041
Start,
4142
LeftFinished,
@@ -61,6 +62,29 @@ impl InternalState {
6162
}
6263
}
6364

65+
/// Decides whether to exit when both streams are completed, or only one
66+
/// is completed. If you need to exit when a specific stream has finished,
67+
/// feel free to add a case here.
68+
#[derive(Clone, Copy, Debug)]
69+
pub enum ExitStrategy {
70+
/// Select stream finishes when both substreams finish
71+
WhenBothFinish,
72+
/// Select stream finishes when either substream finishes
73+
WhenEitherFinish,
74+
}
75+
76+
impl ExitStrategy {
77+
#[inline]
78+
fn is_finished(self, state: InternalState) -> bool {
79+
match (state, self) {
80+
(InternalState::BothFinished, _) => true,
81+
(InternalState::Start, ExitStrategy::WhenEitherFinish) => false,
82+
(_, ExitStrategy::WhenBothFinish) => false,
83+
_ => true,
84+
}
85+
}
86+
}
87+
6488
pin_project! {
6589
/// Stream for the [`select_with_strategy()`] function. See function docs for details.
6690
#[must_use = "streams do nothing unless polled"]
@@ -73,6 +97,7 @@ pin_project! {
7397
internal_state: InternalState,
7498
state: State,
7599
clos: Clos,
100+
exit_strategy: ExitStrategy,
76101
}
77102
}
78103

@@ -95,7 +120,7 @@ pin_project! {
95120
///
96121
/// ```rust
97122
/// # futures::executor::block_on(async {
98-
/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt };
123+
/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt, ExitStrategy };
99124
///
100125
/// let left = repeat(1);
101126
/// let right = repeat(2);
@@ -106,7 +131,7 @@ pin_project! {
106131
/// // use a function pointer instead of a closure.
107132
/// fn prio_left(_: &mut ()) -> PollNext { PollNext::Left }
108133
///
109-
/// let mut out = select_with_strategy(left, right, prio_left);
134+
/// let mut out = select_with_strategy(left, right, prio_left, ExitStrategy::WhenBothFinish);
110135
///
111136
/// for _ in 0..100 {
112137
/// // Whenever we poll out, we will always get `1`.
@@ -121,26 +146,54 @@ pin_project! {
121146
///
122147
/// ```rust
123148
/// # futures::executor::block_on(async {
124-
/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt };
149+
/// use futures::stream::{ repeat, select_with_strategy, FusedStream, PollNext, StreamExt, ExitStrategy };
125150
///
126-
/// let left = repeat(1);
127-
/// let right = repeat(2);
151+
/// // Finishes when both streams finish
152+
/// {
153+
/// let left = repeat(1).take(10);
154+
/// let right = repeat(2);
128155
///
129-
/// let rrobin = |last: &mut PollNext| last.toggle();
156+
/// let rrobin = |last: &mut PollNext| last.toggle();
130157
///
131-
/// let mut out = select_with_strategy(left, right, rrobin);
158+
/// let mut out = select_with_strategy(left, right, rrobin, ExitStrategy::WhenBothFinish);
132159
///
133-
/// for _ in 0..100 {
134-
/// // We should be alternating now.
135-
/// assert_eq!(1, out.select_next_some().await);
136-
/// assert_eq!(2, out.select_next_some().await);
160+
/// for _ in 0..10 {
161+
/// // We should be alternating now.
162+
/// assert_eq!(1, out.select_next_some().await);
163+
/// assert_eq!(2, out.select_next_some().await);
164+
/// }
165+
/// for _ in 0..100 {
166+
/// // First stream has finished
167+
/// assert_eq!(2, out.select_next_some().await);
168+
/// }
169+
/// assert!(!out.is_terminated());
170+
/// }
171+
///
172+
/// // Finishes when either stream finishes
173+
/// {
174+
/// let left = repeat(1).take(10);
175+
/// let right = repeat(2);
176+
///
177+
/// let rrobin = |last: &mut PollNext| last.toggle();
178+
///
179+
/// let mut out = select_with_strategy(left, right, rrobin, ExitStrategy::WhenEitherFinish);
180+
///
181+
/// for _ in 0..10 {
182+
/// // We should be alternating now.
183+
/// assert_eq!(1, out.select_next_some().await);
184+
/// assert_eq!(2, out.select_next_some().await);
185+
/// }
186+
/// assert_eq!(None, out.next().await);
187+
/// assert!(out.is_terminated());
137188
/// }
138189
/// # });
139190
/// ```
191+
///
140192
pub fn select_with_strategy<St1, St2, Clos, State>(
141193
stream1: St1,
142194
stream2: St2,
143195
which: Clos,
196+
exit_strategy: ExitStrategy,
144197
) -> SelectWithStrategy<St1, St2, Clos, State>
145198
where
146199
St1: Stream,
@@ -154,6 +207,7 @@ where
154207
state: Default::default(),
155208
internal_state: InternalState::Start,
156209
clos: which,
210+
exit_strategy,
157211
})
158212
}
159213

@@ -199,10 +253,7 @@ where
199253
Clos: FnMut(&mut State) -> PollNext,
200254
{
201255
fn is_terminated(&self) -> bool {
202-
match self.internal_state {
203-
InternalState::BothFinished => true,
204-
_ => false,
205-
}
256+
self.exit_strategy.is_finished(self.internal_state)
206257
}
207258
}
208259

@@ -227,6 +278,7 @@ fn poll_inner<St1, St2, Clos, State>(
227278
select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
228279
side: PollNext,
229280
cx: &mut Context<'_>,
281+
exit_strat: ExitStrategy,
230282
) -> Poll<Option<St1::Item>>
231283
where
232284
St1: Stream,
@@ -236,6 +288,9 @@ where
236288
Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
237289
Poll::Ready(None) => {
238290
select.internal_state.finish(side);
291+
if exit_strat.is_finished(*select.internal_state) {
292+
return Poll::Ready(None);
293+
}
239294
}
240295
Poll::Pending => (),
241296
};
@@ -259,11 +314,16 @@ where
259314

260315
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
261316
let mut this = self.project();
317+
let exit_strategy: ExitStrategy = *this.exit_strategy;
318+
319+
if exit_strategy.is_finished(*this.internal_state) {
320+
return Poll::Ready(None);
321+
}
262322

263323
match this.internal_state {
264324
InternalState::Start => {
265325
let next_side = (this.clos)(this.state);
266-
poll_inner(&mut this, next_side, cx)
326+
poll_inner(&mut this, next_side, cx, exit_strategy)
267327
}
268328
InternalState::LeftFinished => match this.stream2.poll_next(cx) {
269329
Poll::Ready(None) => {

futures/tests/stream.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ fn select() {
2525
select_and_compare(vec![1, 2], vec![4, 5, 6], vec![1, 4, 2, 5, 6]);
2626
}
2727

28+
#[test]
29+
fn select_early_exit() {
30+
fn select_and_compare(a: Vec<u32>, b: Vec<u32>, expected: Vec<u32>) {
31+
let a = stream::iter(a);
32+
let b = stream::iter(b);
33+
let vec = block_on(stream::select_early_exit(a, b).collect::<Vec<_>>());
34+
assert_eq!(vec, expected);
35+
}
36+
37+
select_and_compare(vec![1, 2, 3], vec![4, 5, 6], vec![1, 4, 2, 5, 3, 6]);
38+
select_and_compare(vec![], vec![4, 5], vec![]);
39+
select_and_compare(vec![4, 5], vec![], vec![4]);
40+
select_and_compare(vec![1, 2, 3], vec![4, 5], vec![1, 4, 2, 5, 3]);
41+
select_and_compare(vec![1, 2], vec![4, 5, 6], vec![1, 4, 2, 5]);
42+
}
43+
2844
#[test]
2945
fn flat_map() {
3046
block_on(async {

0 commit comments

Comments
 (0)