Skip to content

Commit fabea29

Browse files
committed
FEAT: Implement Zip::apply_collect for non-Copy elements
1 parent 3a767c0 commit fabea29

File tree

2 files changed

+163
-7
lines changed

2 files changed

+163
-7
lines changed

src/zip/mod.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#[macro_use]
1010
mod zipmacro;
11+
mod partial_array;
1112

1213
use std::mem::MaybeUninit;
1314

@@ -20,6 +21,8 @@ use crate::NdIndex;
2021
use crate::indexes::{indices, Indices};
2122
use crate::layout::{CORDER, FORDER};
2223

24+
use partial_array::PartialArray;
25+
2326
/// Return if the expression is a break value.
2427
macro_rules! fold_while {
2528
($e:expr) => {
@@ -195,6 +198,7 @@ pub trait NdProducer {
195198
fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
196199
where
197200
Self: Sized;
201+
198202
private_decl! {}
199203
}
200204

@@ -1070,16 +1074,24 @@ macro_rules! map_impl {
10701074
/// inputs.
10711075
///
10721076
/// If all inputs are c- or f-order respectively, that is preserved in the output.
1073-
///
1074-
/// Restricted to functions that produce copyable results for technical reasons; other
1075-
/// cases are not yet implemented.
10761077
pub fn apply_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D>
1077-
where R: Copy,
10781078
{
1079-
// To support non-Copy elements, implementation of dropping partial array (on
1080-
// panic) is needed
1079+
// Make uninit result
10811080
let mut output = self.uninitalized_for_current_layout::<R>();
1082-
self.apply_assign_into(&mut output, f);
1081+
if !std::mem::needs_drop::<R>() {
1082+
// For elements with no drop glue, just overwrite into the array
1083+
self.apply_assign_into(&mut output, f);
1084+
} else {
1085+
// For generic elements, use a proxy that counts the number of filled elements,
1086+
// and can drop the right number of elements on unwinding
1087+
unsafe {
1088+
PartialArray::scope(output.view_mut(), move |partial| {
1089+
debug_assert_eq!(partial.layout().tendency() >= 0, self.layout_tendency >= 0);
1090+
self.apply_assign_into(partial, f);
1091+
});
1092+
}
1093+
}
1094+
10831095
unsafe {
10841096
output.assume_init()
10851097
}

src/zip/partial_array.rs

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright 2020 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use crate::imp_prelude::*;
10+
use crate::{
11+
AssignElem,
12+
Layout,
13+
NdProducer,
14+
Zip,
15+
FoldWhile,
16+
};
17+
18+
use std::cell::Cell;
19+
use std::mem;
20+
use std::mem::MaybeUninit;
21+
use std::ptr;
22+
23+
/// An assignable element reference that increments a counter when assigned
24+
pub(crate) struct ProxyElem<'a, 'b, A> {
25+
item: &'a mut MaybeUninit<A>,
26+
filled: &'b Cell<usize>
27+
}
28+
29+
impl<'a, 'b, A> AssignElem<A> for ProxyElem<'a, 'b, A> {
30+
fn assign_elem(self, item: A) {
31+
self.filled.set(self.filled.get() + 1);
32+
*self.item = MaybeUninit::new(item);
33+
}
34+
}
35+
36+
/// Handles progress of assigning to a part of an array, for elements that need
37+
/// to be dropped on unwinding. See Self::scope.
38+
pub(crate) struct PartialArray<'a, 'b, A, D>
39+
where D: Dimension
40+
{
41+
data: ArrayViewMut<'a, MaybeUninit<A>, D>,
42+
filled: &'b Cell<usize>,
43+
}
44+
45+
impl<'a, 'b, A, D> PartialArray<'a, 'b, A, D>
46+
where D: Dimension
47+
{
48+
/// Create a temporary PartialArray that wraps the array view `data`;
49+
/// if the end of the scope is reached, the partial array is marked complete;
50+
/// if execution unwinds at any time before them, the elements written until then
51+
/// are dropped.
52+
///
53+
/// Safety: the caller *must* ensure that elements will be written in `data`'s preferred order.
54+
/// PartialArray can not handle arbitrary writes, only in the memory order.
55+
pub(crate) unsafe fn scope(data: ArrayViewMut<'a, MaybeUninit<A>, D>,
56+
scope_fn: impl FnOnce(&mut PartialArray<A, D>))
57+
{
58+
let filled = Cell::new(0);
59+
let mut partial = PartialArray::new(data, &filled);
60+
scope_fn(&mut partial);
61+
filled.set(0); // mark complete
62+
}
63+
64+
unsafe fn new(data: ArrayViewMut<'a, MaybeUninit<A>, D>,
65+
filled: &'b Cell<usize>) -> Self
66+
{
67+
debug_assert_eq!(filled.get(), 0);
68+
Self { data, filled }
69+
}
70+
}
71+
72+
impl<'a, 'b, A, D> Drop for PartialArray<'a, 'b, A, D>
73+
where D: Dimension
74+
{
75+
fn drop(&mut self) {
76+
if !mem::needs_drop::<A>() {
77+
return;
78+
}
79+
80+
let mut count = self.filled.get();
81+
if count == 0 {
82+
return;
83+
}
84+
85+
Zip::from(self).fold_while((), move |(), elt| {
86+
if count > 0 {
87+
count -= 1;
88+
unsafe {
89+
ptr::drop_in_place::<A>(elt.item.as_mut_ptr());
90+
}
91+
FoldWhile::Continue(())
92+
} else {
93+
FoldWhile::Done(())
94+
}
95+
});
96+
}
97+
}
98+
99+
impl<'a: 'c, 'b: 'c, 'c, A, D: Dimension> NdProducer for &'c mut PartialArray<'a, 'b, A, D> {
100+
// This just wraps ArrayViewMut as NdProducer and maps the item
101+
type Item = ProxyElem<'a, 'b, A>;
102+
type Dim = D;
103+
type Ptr = *mut MaybeUninit<A>;
104+
type Stride = isize;
105+
106+
private_impl! {}
107+
fn raw_dim(&self) -> Self::Dim {
108+
self.data.raw_dim()
109+
}
110+
111+
fn equal_dim(&self, dim: &Self::Dim) -> bool {
112+
self.data.equal_dim(dim)
113+
}
114+
115+
fn as_ptr(&self) -> Self::Ptr {
116+
NdProducer::as_ptr(&self.data)
117+
}
118+
119+
fn layout(&self) -> Layout {
120+
self.data.layout()
121+
}
122+
123+
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
124+
ProxyElem { filled: self.filled, item: &mut *ptr }
125+
}
126+
127+
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
128+
self.data.uget_ptr(i)
129+
}
130+
131+
fn stride_of(&self, axis: Axis) -> Self::Stride {
132+
self.data.stride_of(axis)
133+
}
134+
135+
#[inline(always)]
136+
fn contiguous_stride(&self) -> Self::Stride {
137+
self.data.contiguous_stride()
138+
}
139+
140+
fn split_at(self, _axis: Axis, _index: usize) -> (Self, Self) {
141+
unimplemented!();
142+
}
143+
}
144+

0 commit comments

Comments
 (0)