Skip to content

Commit 449df3b

Browse files
committed
Add TakeAny
1 parent ed98853 commit 449df3b

File tree

2 files changed

+169
-0
lines changed

2 files changed

+169
-0
lines changed

src/iter/mod.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ mod splitter;
145145
mod step_by;
146146
mod sum;
147147
mod take;
148+
mod take_any;
148149
mod try_fold;
149150
mod try_reduce;
150151
mod try_reduce_with;
@@ -188,6 +189,7 @@ pub use self::{
188189
splitter::{split, Split},
189190
step_by::StepBy,
190191
take::Take,
192+
take_any::TakeAny,
191193
try_fold::{TryFold, TryFoldWith},
192194
update::Update,
193195
while_some::WhileSome,
@@ -2194,6 +2196,25 @@ pub trait ParallelIterator: Sized + Send {
21942196
Intersperse::new(self, element)
21952197
}
21962198

2199+
/// Creates an iterator that yields the first `n` elements.
2200+
///
2201+
/// # Examples
2202+
///
2203+
/// ```
2204+
/// use rayon::prelude::*;
2205+
///
2206+
/// let result: Vec<_> = (0..100)
2207+
/// .into_par_iter()
2208+
/// .filter(|&x| x % 2 == 0)
2209+
/// .take_any(5)
2210+
/// .collect();
2211+
///
2212+
/// assert_eq!(result.len(), 5);
2213+
/// ```
2214+
fn take_any(self, n: usize) -> TakeAny<Self> {
2215+
TakeAny::new(self, n)
2216+
}
2217+
21972218
/// Internal method used to define the behavior of this parallel
21982219
/// iterator. You should not need to call this directly.
21992220
///

src/iter/take_any.rs

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
use super::plumbing::*;
2+
use super::*;
3+
use std::sync::atomic::{AtomicUsize, Ordering};
4+
5+
/// `TakeAny` is an iterator that iterates over the first `n` elements.
6+
/// This struct is created by the [`take_any()`] method on [`ParallelIterator`]
7+
///
8+
/// [`take_any()`]: trait.ParallelIterator.html#method.take_any
9+
/// [`ParallelIterator`]: trait.ParallelIterator.html
10+
#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
11+
#[derive(Debug)]
12+
pub struct TakeAny<I: ParallelIterator> {
13+
base: I,
14+
count: AtomicUsize,
15+
}
16+
17+
impl<I> TakeAny<I>
18+
where
19+
I: ParallelIterator,
20+
{
21+
/// Creates a new `TakeAny` iterator.
22+
pub(super) fn new(base: I, count: usize) -> Self {
23+
TakeAny {
24+
base,
25+
count: AtomicUsize::new(count),
26+
}
27+
}
28+
}
29+
30+
impl<I, T> ParallelIterator for TakeAny<I>
31+
where
32+
I: ParallelIterator<Item = T>,
33+
T: Send,
34+
{
35+
type Item = T;
36+
37+
fn drive_unindexed<C>(self, consumer: C) -> C::Result
38+
where
39+
C: UnindexedConsumer<Self::Item>,
40+
{
41+
let consumer1 = TakeAnyConsumer {
42+
base: consumer,
43+
count: &self.count,
44+
};
45+
self.base.drive_unindexed(consumer1)
46+
}
47+
}
48+
49+
/// ////////////////////////////////////////////////////////////////////////
50+
/// Consumer implementation
51+
52+
struct TakeAnyConsumer<'f, C> {
53+
base: C,
54+
count: &'f AtomicUsize,
55+
}
56+
57+
impl<'f, T, C> Consumer<T> for TakeAnyConsumer<'f, C>
58+
where
59+
C: Consumer<T>,
60+
T: Send,
61+
{
62+
type Folder = TakeAnyFolder<'f, C::Folder>;
63+
type Reducer = C::Reducer;
64+
type Result = C::Result;
65+
66+
fn split_at(self, index: usize) -> (Self, Self, Self::Reducer) {
67+
let (left, right, reducer) = self.base.split_at(index);
68+
(
69+
TakeAnyConsumer { base: left, ..self },
70+
TakeAnyConsumer {
71+
base: right,
72+
..self
73+
},
74+
reducer,
75+
)
76+
}
77+
78+
fn into_folder(self) -> Self::Folder {
79+
TakeAnyFolder {
80+
base: self.base.into_folder(),
81+
count: self.count,
82+
}
83+
}
84+
85+
fn full(&self) -> bool {
86+
self.count.load(Ordering::Relaxed) == 0 || self.base.full()
87+
}
88+
}
89+
90+
impl<'f, T, C> UnindexedConsumer<T> for TakeAnyConsumer<'f, C>
91+
where
92+
C: UnindexedConsumer<T>,
93+
T: Send,
94+
{
95+
fn split_off_left(&self) -> Self {
96+
TakeAnyConsumer {
97+
base: self.base.split_off_left(),
98+
..*self
99+
}
100+
}
101+
102+
fn to_reducer(&self) -> Self::Reducer {
103+
self.base.to_reducer()
104+
}
105+
}
106+
107+
struct TakeAnyFolder<'f, C> {
108+
base: C,
109+
count: &'f AtomicUsize,
110+
}
111+
112+
fn checked_decrement(u: &AtomicUsize) -> bool {
113+
u.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |u| u.checked_sub(1))
114+
.is_ok()
115+
}
116+
117+
impl<'f, T, C> Folder<T> for TakeAnyFolder<'f, C>
118+
where
119+
C: Folder<T>,
120+
{
121+
type Result = C::Result;
122+
123+
fn consume(mut self, item: T) -> Self {
124+
if checked_decrement(self.count) {
125+
self.base = self.base.consume(item);
126+
}
127+
self
128+
}
129+
130+
fn consume_iter<I>(mut self, iter: I) -> Self
131+
where
132+
I: IntoIterator<Item = T>,
133+
{
134+
self.base = self.base.consume_iter(
135+
iter.into_iter()
136+
.take_while(move |_| checked_decrement(self.count)),
137+
);
138+
self
139+
}
140+
141+
fn complete(self) -> C::Result {
142+
self.base.complete()
143+
}
144+
145+
fn full(&self) -> bool {
146+
self.count.load(Ordering::Relaxed) == 0 || self.base.full()
147+
}
148+
}

0 commit comments

Comments
 (0)