Skip to content

Commit f69248e

Browse files
authored
Merge pull request #817 from rust-ndarray/par-collect
Implement parallel collect to array for non-Copy elements
2 parents 79d99c7 + e472612 commit f69248e

File tree

14 files changed

+478
-189
lines changed

14 files changed

+478
-189
lines changed

benches/par_rayon.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,39 @@ fn rayon_add(bench: &mut Bencher) {
136136
});
137137
});
138138
}
139+
140+
const COLL_STRING_N: usize = 64;
141+
const COLL_F64_N: usize = 128;
142+
143+
#[bench]
144+
fn vec_string_collect(bench: &mut test::Bencher) {
145+
let v = vec![""; COLL_STRING_N * COLL_STRING_N];
146+
bench.iter(|| {
147+
v.iter().map(|s| s.to_owned()).collect::<Vec<_>>()
148+
});
149+
}
150+
151+
#[bench]
152+
fn array_string_collect(bench: &mut test::Bencher) {
153+
let v = Array::from_elem((COLL_STRING_N, COLL_STRING_N), "");
154+
bench.iter(|| {
155+
Zip::from(&v).par_apply_collect(|s| s.to_owned())
156+
});
157+
}
158+
159+
#[bench]
160+
fn vec_f64_collect(bench: &mut test::Bencher) {
161+
let v = vec![1.; COLL_F64_N * COLL_F64_N];
162+
bench.iter(|| {
163+
v.iter().map(|s| s + 1.).collect::<Vec<_>>()
164+
});
165+
}
166+
167+
#[bench]
168+
fn array_f64_collect(bench: &mut test::Bencher) {
169+
let v = Array::from_elem((COLL_F64_N, COLL_F64_N), 1.);
170+
bench.iter(|| {
171+
Zip::from(&v).par_apply_collect(|s| s + 1.)
172+
});
173+
}
174+

src/dimension/dimension_trait.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,14 @@ impl Dimension for Dim<[Ix; 1]> {
540540
fn try_remove_axis(&self, axis: Axis) -> Self::Smaller {
541541
self.remove_axis(axis)
542542
}
543+
544+
fn from_dimension<D2: Dimension>(d: &D2) -> Option<Self> {
545+
if 1 == d.ndim() {
546+
Some(Ix1(d[0]))
547+
} else {
548+
None
549+
}
550+
}
543551
private_impl! {}
544552
}
545553

src/impl_methods.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,11 @@ where
12931293
is_standard_layout(&self.dim, &self.strides)
12941294
}
12951295

1296-
fn is_contiguous(&self) -> bool {
1296+
/// Return true if the array is known to be contiguous.
1297+
///
1298+
/// Will detect c- and f-contig arrays correctly, but otherwise
1299+
/// There are some false negatives.
1300+
pub(crate) fn is_contiguous(&self) -> bool {
12971301
D::is_contiguous(&self.dim, &self.strides)
12981302
}
12991303

src/indexes.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
// except according to those terms.
88
use super::Dimension;
99
use crate::dimension::IntoDimension;
10-
use crate::zip::{Offset, Splittable};
10+
use crate::zip::Offset;
11+
use crate::split_at::SplitAt;
1112
use crate::Axis;
1213
use crate::Layout;
1314
use crate::NdProducer;

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,11 @@ mod linalg_traits;
176176
mod linspace;
177177
mod logspace;
178178
mod numeric_util;
179+
mod partial;
179180
mod shape_builder;
180181
#[macro_use]
181182
mod slice;
183+
mod split_at;
182184
mod stacking;
183185
#[macro_use]
184186
mod zip;

src/parallel/impl_par_methods.rs

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ use crate::{Array, ArrayBase, DataMut, Dimension, IntoNdProducer, NdProducer, Zi
22
use crate::AssignElem;
33

44
use crate::parallel::prelude::*;
5+
use crate::parallel::par::ParallelSplits;
6+
use super::send_producer::SendProducer;
7+
8+
use crate::partial::Partial;
59

610
/// # Parallel methods
711
///
@@ -43,6 +47,8 @@ where
4347

4448
// Zip
4549

50+
const COLLECT_MAX_SPLITS: usize = 10;
51+
4652
macro_rules! zip_impl {
4753
($([$notlast:ident $($p:ident)*],)+) => {
4854
$(
@@ -71,14 +77,46 @@ macro_rules! zip_impl {
7177
/// inputs.
7278
///
7379
/// If all inputs are c- or f-order respectively, that is preserved in the output.
74-
///
75-
/// Restricted to functions that produce copyable results for technical reasons; other
76-
/// cases are not yet implemented.
77-
pub fn par_apply_collect<R>(self, f: impl Fn($($p::Item,)* ) -> R + Sync + Send) -> Array<R, D>
78-
where R: Copy + Send
80+
pub fn par_apply_collect<R>(self, f: impl Fn($($p::Item,)* ) -> R + Sync + Send)
81+
-> Array<R, D>
82+
where R: Send
7983
{
8084
let mut output = self.uninitalized_for_current_layout::<R>();
81-
self.par_apply_assign_into(&mut output, f);
85+
let total_len = output.len();
86+
87+
// Create a parallel iterator that produces chunks of the zip with the output
88+
// array. It's crucial that both parts split in the same way, and in a way
89+
// so that the chunks of the output are still contig.
90+
//
91+
// Use a raw view so that we can alias the output data here and in the partial
92+
// result.
93+
let splits = unsafe {
94+
ParallelSplits {
95+
iter: self.and(SendProducer::new(output.raw_view_mut().cast::<R>())),
96+
// Keep it from splitting the Zip down too small
97+
max_splits: COLLECT_MAX_SPLITS,
98+
}
99+
};
100+
101+
let collect_result = splits.map(move |zip| {
102+
// Apply the mapping function on this chunk of the zip
103+
// Create a partial result for the contiguous slice of data being written to
104+
unsafe {
105+
zip.collect_with_partial(&f)
106+
}
107+
})
108+
.reduce(Partial::stub, Partial::try_merge);
109+
110+
if std::mem::needs_drop::<R>() {
111+
debug_assert_eq!(total_len, collect_result.len,
112+
"collect len is not correct, expected {}", total_len);
113+
assert!(collect_result.len == total_len,
114+
"Collect: Expected number of writes not completed");
115+
}
116+
117+
// Here the collect result is complete, and we release its ownership and transfer
118+
// it to the output array.
119+
collect_result.release_ownership();
82120
unsafe {
83121
output.assume_init()
84122
}

src/parallel/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,5 @@ pub use crate::par_azip;
155155
mod impl_par_methods;
156156
mod into_impls;
157157
mod par;
158+
mod send_producer;
158159
mod zipmacro;

src/parallel/par.rs

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use crate::iter::AxisIter;
1515
use crate::iter::AxisIterMut;
1616
use crate::Dimension;
1717
use crate::{ArrayView, ArrayViewMut};
18+
use crate::split_at::SplitPreference;
1819

1920
/// Parallel iterator wrapper.
2021
#[derive(Copy, Clone, Debug)]
@@ -170,7 +171,14 @@ macro_rules! par_iter_view_wrapper {
170171
fn fold_with<F>(self, folder: F) -> F
171172
where F: Folder<Self::Item>,
172173
{
173-
self.into_iter().fold(folder, move |f, elt| f.consume(elt))
174+
Zip::from(self.0).fold_while(folder, |mut folder, elt| {
175+
folder = folder.consume(elt);
176+
if folder.full() {
177+
FoldWhile::Done(folder)
178+
} else {
179+
FoldWhile::Continue(folder)
180+
}
181+
}).into_inner()
174182
}
175183
}
176184

@@ -243,7 +251,7 @@ macro_rules! zip_impl {
243251
type Item = ($($p::Item ,)*);
244252

245253
fn split(self) -> (Self, Option<Self>) {
246-
if self.0.size() <= 1 {
254+
if !self.0.can_split() {
247255
return (self, None)
248256
}
249257
let (a, b) = self.0.split();
@@ -275,3 +283,53 @@ zip_impl! {
275283
[P1 P2 P3 P4 P5],
276284
[P1 P2 P3 P4 P5 P6],
277285
}
286+
287+
/// A parallel iterator (unindexed) that produces the splits of the array
288+
/// or producer `P`.
289+
pub(crate) struct ParallelSplits<P> {
290+
pub(crate) iter: P,
291+
pub(crate) max_splits: usize,
292+
}
293+
294+
impl<P> ParallelIterator for ParallelSplits<P>
295+
where P: SplitPreference + Send,
296+
{
297+
type Item = P;
298+
299+
fn drive_unindexed<C>(self, consumer: C) -> C::Result
300+
where C: UnindexedConsumer<Self::Item>
301+
{
302+
bridge_unindexed(self, consumer)
303+
}
304+
305+
fn opt_len(&self) -> Option<usize> {
306+
None
307+
}
308+
}
309+
310+
impl<P> UnindexedProducer for ParallelSplits<P>
311+
where P: SplitPreference + Send,
312+
{
313+
type Item = P;
314+
315+
fn split(self) -> (Self, Option<Self>) {
316+
if self.max_splits == 0 || !self.iter.can_split() {
317+
return (self, None)
318+
}
319+
let (a, b) = self.iter.split();
320+
(ParallelSplits {
321+
iter: a,
322+
max_splits: self.max_splits - 1,
323+
},
324+
Some(ParallelSplits {
325+
iter: b,
326+
max_splits: self.max_splits - 1,
327+
}))
328+
}
329+
330+
fn fold_with<Fold>(self, folder: Fold) -> Fold
331+
where Fold: Folder<Self::Item>,
332+
{
333+
folder.consume(self.iter)
334+
}
335+
}

src/parallel/send_producer.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
2+
use crate::imp_prelude::*;
3+
use crate::{Layout, NdProducer};
4+
use std::ops::{Deref, DerefMut};
5+
6+
/// An NdProducer that is unconditionally `Send`.
7+
#[repr(transparent)]
8+
pub(crate) struct SendProducer<T> {
9+
inner: T
10+
}
11+
12+
impl<T> SendProducer<T> {
13+
/// Create an unconditionally `Send` ndproducer from the producer
14+
pub(crate) unsafe fn new(producer: T) -> Self { Self { inner: producer } }
15+
}
16+
17+
unsafe impl<P> Send for SendProducer<P> { }
18+
19+
impl<P> Deref for SendProducer<P> {
20+
type Target = P;
21+
fn deref(&self) -> &P { &self.inner }
22+
}
23+
24+
impl<P> DerefMut for SendProducer<P> {
25+
fn deref_mut(&mut self) -> &mut P { &mut self.inner }
26+
}
27+
28+
impl<P: NdProducer> NdProducer for SendProducer<P>
29+
where P: NdProducer,
30+
{
31+
type Item = P::Item;
32+
type Dim = P::Dim;
33+
type Ptr = P::Ptr;
34+
type Stride = P::Stride;
35+
36+
private_impl! {}
37+
38+
#[inline(always)]
39+
fn raw_dim(&self) -> Self::Dim {
40+
self.inner.raw_dim()
41+
}
42+
43+
#[inline(always)]
44+
fn equal_dim(&self, dim: &Self::Dim) -> bool {
45+
self.inner.equal_dim(dim)
46+
}
47+
48+
#[inline(always)]
49+
fn as_ptr(&self) -> Self::Ptr {
50+
self.inner.as_ptr()
51+
}
52+
53+
#[inline(always)]
54+
fn layout(&self) -> Layout {
55+
self.inner.layout()
56+
}
57+
58+
#[inline(always)]
59+
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
60+
self.inner.as_ref(ptr)
61+
}
62+
63+
#[inline(always)]
64+
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
65+
self.inner.uget_ptr(i)
66+
}
67+
68+
#[inline(always)]
69+
fn stride_of(&self, axis: Axis) -> Self::Stride {
70+
self.inner.stride_of(axis)
71+
}
72+
73+
#[inline(always)]
74+
fn contiguous_stride(&self) -> Self::Stride {
75+
self.inner.contiguous_stride()
76+
}
77+
78+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
79+
let (a, b) = self.inner.split_at(axis, index);
80+
(Self { inner: a }, Self { inner: b })
81+
}
82+
}
83+

0 commit comments

Comments
 (0)