Skip to content

Commit 4277296

Browse files
committed
FEAT: Implement generic parallel collect
Allow non-copy elements by implementing dropping partial results from collect (needed if there is a panic with unwinding during the apply-collect process). It is implemented by: 1. allocate an uninit output array of the right size and layout 2. use parallelsplits to split the Zip into chunks processed in parallel 3. for each chunk keep track of the slice of written elements 4. each output chunk is contiguous due to the layout being picked to match the Zip's preferred layout 5. Use reduce to merge adjacent partial results; this ensures we drop all the rests correctly, if there is a panic in any thread
1 parent fe2ebf6 commit 4277296

File tree

3 files changed

+204
-6
lines changed

3 files changed

+204
-6
lines changed

src/parallel/impl_par_methods.rs

Lines changed: 120 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use crate::{Array, ArrayBase, DataMut, Dimension, IntoNdProducer, NdProducer, Zi
22
use crate::AssignElem;
33

44
use crate::parallel::prelude::*;
5+
use crate::parallel::par::ParallelSplits;
6+
use super::send_producer::SendProducer;
57

68
/// # Parallel methods
79
///
@@ -43,6 +45,8 @@ where
4345

4446
// Zip
4547

48+
const COLLECT_MAX_PARTS: usize = 256;
49+
4650
macro_rules! zip_impl {
4751
($([$notlast:ident $($p:ident)*],)+) => {
4852
$(
@@ -71,14 +75,56 @@ macro_rules! zip_impl {
7175
/// inputs.
7276
///
7377
/// If all inputs are c- or f-order respectively, that is preserved in the output.
74-
///
75-
/// Restricted to functions that produce copyable results for technical reasons; other
76-
/// cases are not yet implemented.
77-
pub fn par_apply_collect<R>(self, f: impl Fn($($p::Item,)* ) -> R + Sync + Send) -> Array<R, D>
78-
where R: Copy + Send
78+
pub fn par_apply_collect<R>(self, f: impl Fn($($p::Item,)* ) -> R + Sync + Send)
79+
-> Array<R, D>
80+
where R: Send
7981
{
8082
let mut output = self.uninitalized_for_current_layout::<R>();
81-
self.par_apply_assign_into(&mut output, f);
83+
let total_len = output.len();
84+
85+
// Create a parallel iterator that produces chunks of the zip with the output
86+
// array. It's crucial that both parts split in the same way, and in a way
87+
// so that the chunks of the output are still contig.
88+
//
89+
// Use a raw view so that we can alias the output data here and in the partial
90+
// result.
91+
let splits = unsafe {
92+
ParallelSplits {
93+
iter: self.and(SendProducer::new(output.raw_view_mut().cast::<R>())),
94+
// Keep it from splitting the Zip down too small
95+
min_size: total_len / COLLECT_MAX_PARTS,
96+
}
97+
};
98+
99+
let collect_result = splits.map(move |zip| {
100+
// Create a partial result for the contiguous slice of data being written to
101+
let output = zip.last_producer();
102+
debug_assert!(output.is_contiguous());
103+
104+
let mut partial = Partial::new(output.as_ptr());
105+
106+
// Apply the mapping function on this chunk of the zip
107+
let partial_len = &mut partial.len;
108+
let f = &f;
109+
zip.apply(move |$($p,)* output_elem: *mut R| unsafe {
110+
output_elem.write(f($($p),*));
111+
if std::mem::needs_drop::<R>() {
112+
*partial_len += 1;
113+
}
114+
});
115+
116+
partial
117+
})
118+
.reduce(Partial::stub, Partial::try_merge);
119+
120+
if std::mem::needs_drop::<R>() {
121+
debug_assert_eq!(total_len, collect_result.len, "collect len is not correct, expected {}", total_len);
122+
assert!(collect_result.len == total_len, "Collect: Expected number of writes not completed");
123+
}
124+
125+
// Here the collect result is complete, and we release its ownership and transfer
126+
// it to the output array.
127+
collect_result.release_ownership();
82128
unsafe {
83129
output.assume_init()
84130
}
@@ -113,3 +159,71 @@ zip_impl! {
113159
[true P1 P2 P3 P4 P5],
114160
[false P1 P2 P3 P4 P5 P6],
115161
}
162+
163+
/// Partial is a partially written contiguous slice of data;
164+
/// it is the owner of the elements, but not the allocation,
165+
/// and will drop the elements on drop.
166+
#[must_use]
167+
pub(crate) struct Partial<T> {
168+
/// Data pointer
169+
ptr: *mut T,
170+
/// Current length
171+
len: usize,
172+
}
173+
174+
impl<T> Partial<T> {
175+
/// Create an empty partial for this data pointer
176+
pub(crate) fn new(ptr: *mut T) -> Self {
177+
Self {
178+
ptr,
179+
len: 0,
180+
}
181+
}
182+
183+
pub(crate) fn stub() -> Self {
184+
Self { len: 0, ptr: 0 as *mut _ }
185+
}
186+
187+
pub(crate) fn is_stub(&self) -> bool {
188+
self.ptr.is_null()
189+
}
190+
191+
/// Release Partial's ownership of the written elements, and return the current length
192+
pub(crate) fn release_ownership(mut self) -> usize {
193+
let ret = self.len;
194+
self.len = 0;
195+
ret
196+
}
197+
198+
/// Merge if they are in order (left to right) and contiguous.
199+
/// Skips merge if T does not need drop.
200+
pub(crate) fn try_merge(mut left: Self, right: Self) -> Self {
201+
if !std::mem::needs_drop::<T>() {
202+
return left;
203+
}
204+
// Merge the partial collect results; the final result will be a slice that
205+
// covers the whole output.
206+
if left.is_stub() {
207+
right
208+
} else if left.ptr.wrapping_add(left.len) == right.ptr {
209+
left.len += right.release_ownership();
210+
left
211+
} else {
212+
// failure to merge; this is a bug in collect, so we will never reach this
213+
debug_assert!(false, "Partial: failure to merge left and right parts");
214+
left
215+
}
216+
}
217+
}
218+
219+
unsafe impl<T> Send for Partial<T> where T: Send { }
220+
221+
impl<T> Drop for Partial<T> {
222+
fn drop(&mut self) {
223+
if !self.ptr.is_null() {
224+
unsafe {
225+
std::ptr::drop_in_place(std::slice::from_raw_parts_mut(self.ptr, self.len));
226+
}
227+
}
228+
}
229+
}

src/parallel/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,5 @@ pub use crate::par_azip;
155155
mod impl_par_methods;
156156
mod into_impls;
157157
mod par;
158+
mod send_producer;
158159
mod zipmacro;

src/parallel/send_producer.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
2+
use crate::imp_prelude::*;
3+
use crate::{Layout, NdProducer};
4+
use std::ops::{Deref, DerefMut};
5+
6+
/// An NdProducer that is unconditionally `Send`.
7+
#[repr(transparent)]
8+
pub(crate) struct SendProducer<T> {
9+
inner: T
10+
}
11+
12+
impl<T> SendProducer<T> {
13+
/// Create an unconditionally `Send` ndproducer from the producer
14+
pub(crate) unsafe fn new(producer: T) -> Self { Self { inner: producer } }
15+
}
16+
17+
unsafe impl<P> Send for SendProducer<P> { }
18+
19+
impl<P> Deref for SendProducer<P> {
20+
type Target = P;
21+
fn deref(&self) -> &P { &self.inner }
22+
}
23+
24+
impl<P> DerefMut for SendProducer<P> {
25+
fn deref_mut(&mut self) -> &mut P { &mut self.inner }
26+
}
27+
28+
impl<P: NdProducer> NdProducer for SendProducer<P>
29+
where P: NdProducer,
30+
{
31+
type Item = P::Item;
32+
type Dim = P::Dim;
33+
type Ptr = P::Ptr;
34+
type Stride = P::Stride;
35+
36+
private_impl! {}
37+
38+
#[inline(always)]
39+
fn raw_dim(&self) -> Self::Dim {
40+
self.inner.raw_dim()
41+
}
42+
43+
#[inline(always)]
44+
fn equal_dim(&self, dim: &Self::Dim) -> bool {
45+
self.inner.equal_dim(dim)
46+
}
47+
48+
#[inline(always)]
49+
fn as_ptr(&self) -> Self::Ptr {
50+
self.inner.as_ptr()
51+
}
52+
53+
#[inline(always)]
54+
fn layout(&self) -> Layout {
55+
self.inner.layout()
56+
}
57+
58+
#[inline(always)]
59+
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
60+
self.inner.as_ref(ptr)
61+
}
62+
63+
#[inline(always)]
64+
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
65+
self.inner.uget_ptr(i)
66+
}
67+
68+
#[inline(always)]
69+
fn stride_of(&self, axis: Axis) -> Self::Stride {
70+
self.inner.stride_of(axis)
71+
}
72+
73+
#[inline(always)]
74+
fn contiguous_stride(&self) -> Self::Stride {
75+
self.inner.contiguous_stride()
76+
}
77+
78+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
79+
let (a, b) = self.inner.split_at(axis, index);
80+
(Self { inner: a }, Self { inner: b })
81+
}
82+
}
83+

0 commit comments

Comments
 (0)