Skip to content

Commit e472612

Browse files
committed
FEAT: Combine common parts of apply_collect and par_apply_collect
Factor out the common part of the parallel and and regular apply_collect implementation; the non-parallel part came first and ended up more complicated originally. With the parallel version in place, both can use the same main operation which is implemented in the methods Zip::collect_with_partial.
1 parent d02b757 commit e472612

File tree

2 files changed

+58
-51
lines changed

2 files changed

+58
-51
lines changed

src/parallel/impl_par_methods.rs

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,25 +99,11 @@ macro_rules! zip_impl {
9999
};
100100

101101
let collect_result = splits.map(move |zip| {
102+
// Apply the mapping function on this chunk of the zip
102103
// Create a partial result for the contiguous slice of data being written to
103-
let output = zip.last_producer();
104-
debug_assert!(output.is_contiguous());
105-
let mut partial;
106104
unsafe {
107-
partial = Partial::new(output.as_ptr());
105+
zip.collect_with_partial(&f)
108106
}
109-
110-
// Apply the mapping function on this chunk of the zip
111-
let partial_len = &mut partial.len;
112-
let f = &f;
113-
zip.apply(move |$($p,)* output_elem: *mut R| unsafe {
114-
output_elem.write(f($($p),*));
115-
if std::mem::needs_drop::<R>() {
116-
*partial_len += 1;
117-
}
118-
});
119-
120-
partial
121107
})
122108
.reduce(Partial::stub, Partial::try_merge);
123109

src/zip/mod.rs

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -914,12 +914,6 @@ zipt_impl! {
914914
[A B C D E F][ a b c d e f],
915915
}
916916

917-
#[cfg(feature = "rayon")]
918-
macro_rules! last_of {
919-
($q:ty) => { $q };
920-
($p:ty, $($q:ty),+) => { last_of!($($q),+) };
921-
}
922-
923917
macro_rules! map_impl {
924918
($([$notlast:ident $($p:ident)*],)+) => {
925919
$(
@@ -1016,14 +1010,6 @@ macro_rules! map_impl {
10161010
}).is_done()
10171011
}
10181012

1019-
#[cfg(feature = "rayon")]
1020-
#[allow(dead_code)] // unused for the first of the Zip arities
1021-
/// Return a reference to the last producer
1022-
pub(crate) fn last_producer(&self) -> &last_of!($($p),*) {
1023-
let (.., ref last) = &self.parts;
1024-
last
1025-
}
1026-
10271013
expand_if!(@bool [$notlast]
10281014

10291015
/// Include the producer `p` in the Zip.
@@ -1068,32 +1054,19 @@ macro_rules! map_impl {
10681054
/// inputs.
10691055
///
10701056
/// If all inputs are c- or f-order respectively, that is preserved in the output.
1071-
pub fn apply_collect<R>(self, mut f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D>
1057+
pub fn apply_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D>
10721058
{
10731059
// Make uninit result
10741060
let mut output = self.uninitalized_for_current_layout::<R>();
1075-
if !std::mem::needs_drop::<R>() {
1076-
// For elements with no drop glue, just overwrite into the array
1077-
self.apply_assign_into(&mut output, f);
1078-
} else {
1079-
// For generic elements, use a Partial to counts the number of filled elements,
1080-
// and can drop the right number of elements on unwinding
1081-
unsafe {
1082-
let mut output = output.raw_view_mut().cast::<R>();
1083-
let mut partial = Partial::new(output.as_mut_ptr());
1084-
let partial_ref = &mut partial;
1085-
debug_assert!(output.is_contiguous());
1086-
debug_assert_eq!(output.layout().tendency() >= 0, self.layout_tendency >= 0);
1087-
self.and(output)
1088-
.apply(move |$($p, )* output_: *mut R| {
1089-
output_.write(f($($p ),*));
1090-
partial_ref.len += 1;
1091-
});
1092-
partial.release_ownership();
1093-
}
1094-
}
10951061

1062+
// Use partial to counts the number of filled elements, and can drop the right
1063+
// number of elements on unwinding (if it happens during apply/collect).
10961064
unsafe {
1065+
let output_view = output.raw_view_mut().cast::<R>();
1066+
self.and(output_view)
1067+
.collect_with_partial(f)
1068+
.release_ownership();
1069+
10971070
output.assume_init()
10981071
}
10991072
}
@@ -1126,6 +1099,54 @@ macro_rules! map_impl {
11261099
}
11271100
}
11281101

1102+
expand_if!(@bool [$notlast]
1103+
// For collect; Last producer is a RawViewMut
1104+
#[allow(non_snake_case)]
1105+
impl<D, PLast, R, $($p),*> Zip<($($p,)* PLast), D>
1106+
where D: Dimension,
1107+
$($p: NdProducer<Dim=D> ,)*
1108+
PLast: NdProducer<Dim = D, Item = *mut R, Ptr = *mut R, Stride = isize>,
1109+
{
1110+
/// The inner workings of apply_collect and par_apply_collect
1111+
///
1112+
/// Apply the function and collect the results into the output (last producer)
1113+
/// which should be a raw array view; a Partial that owns the written
1114+
/// elements is returned.
1115+
///
1116+
/// Elements will be overwritten in place (in the sense of std::ptr::write).
1117+
///
1118+
/// ## Safety
1119+
///
1120+
/// The last producer is a RawArrayViewMut and must be safe to write into.
1121+
/// The producer must be c- or f-contig and have the same layout tendency
1122+
/// as the whole Zip.
1123+
///
1124+
/// The returned Partial's proxy ownership of the elements must be handled,
1125+
/// before the array the raw view points to realizes its ownership.
1126+
pub(crate) unsafe fn collect_with_partial<F>(self, mut f: F) -> Partial<R>
1127+
where F: FnMut($($p::Item,)* ) -> R
1128+
{
1129+
// Get the last producer; and make a Partial that aliases its data pointer
1130+
let (.., ref output) = &self.parts;
1131+
debug_assert!(output.layout().is(CORDER | FORDER));
1132+
debug_assert_eq!(output.layout().tendency() >= 0, self.layout_tendency >= 0);
1133+
let mut partial = Partial::new(output.as_ptr());
1134+
1135+
// Apply the mapping function on this zip
1136+
// if we panic with unwinding; Partial will drop the written elements.
1137+
let partial_len = &mut partial.len;
1138+
self.apply(move |$($p,)* output_elem: *mut R| {
1139+
output_elem.write(f($($p),*));
1140+
if std::mem::needs_drop::<R>() {
1141+
*partial_len += 1;
1142+
}
1143+
});
1144+
1145+
partial
1146+
}
1147+
}
1148+
);
1149+
11291150
impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
11301151
where D: Dimension,
11311152
$($p: NdProducer<Dim=D> ,)*

0 commit comments

Comments
 (0)