Skip to content

Commit 79d99c7

Browse files
authored
Merge pull request #819 from rust-ndarray/reduce-generated-code
Combine common code / reduce codegen by factoring out common parts
2 parents 4bac2c1 + a76bb92 commit 79d99c7

File tree

3 files changed

+52
-21
lines changed

3 files changed

+52
-21
lines changed

src/dimension/mod.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@ pub fn can_index_slice_not_custom<A, D: Dimension>(data: &[A], dim: &D) -> Resul
139139
/// also implies that the length of any individual axis does not exceed
140140
/// `isize::MAX`.)
141141
pub fn max_abs_offset_check_overflow<A, D>(dim: &D, strides: &D) -> Result<usize, ShapeError>
142+
where
143+
D: Dimension,
144+
{
145+
max_abs_offset_check_overflow_impl(mem::size_of::<A>(), dim, strides)
146+
}
147+
148+
fn max_abs_offset_check_overflow_impl<D>(elem_size: usize, dim: &D, strides: &D)
149+
-> Result<usize, ShapeError>
142150
where
143151
D: Dimension,
144152
{
@@ -168,7 +176,7 @@ where
168176
// Determine absolute difference in units of bytes between least and
169177
// greatest address accessible by moving along all axes
170178
let max_offset_bytes = max_offset
171-
.checked_mul(mem::size_of::<A>())
179+
.checked_mul(elem_size)
172180
.ok_or_else(|| from_kind(ErrorKind::Overflow))?;
173181
// Condition 2b.
174182
if max_offset_bytes > isize::MAX as usize {
@@ -216,13 +224,21 @@ pub fn can_index_slice<A, D: Dimension>(
216224
) -> Result<(), ShapeError> {
217225
// Check conditions 1 and 2 and calculate `max_offset`.
218226
let max_offset = max_abs_offset_check_overflow::<A, _>(dim, strides)?;
227+
can_index_slice_impl(max_offset, data.len(), dim, strides)
228+
}
219229

230+
fn can_index_slice_impl<D: Dimension>(
231+
max_offset: usize,
232+
data_len: usize,
233+
dim: &D,
234+
strides: &D,
235+
) -> Result<(), ShapeError> {
220236
// Check condition 4.
221237
let is_empty = dim.slice().iter().any(|&d| d == 0);
222-
if is_empty && max_offset > data.len() {
238+
if is_empty && max_offset > data_len {
223239
return Err(from_kind(ErrorKind::OutOfBounds));
224240
}
225-
if !is_empty && max_offset >= data.len() {
241+
if !is_empty && max_offset >= data_len {
226242
return Err(from_kind(ErrorKind::OutOfBounds));
227243
}
228244

src/iterators/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,18 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D> {
7979
let ndim = self.dim.ndim();
8080
debug_assert_ne!(ndim, 0);
8181
let mut accum = init;
82-
while let Some(mut index) = self.index.clone() {
82+
while let Some(mut index) = self.index {
8383
let stride = self.strides.last_elem() as isize;
8484
let elem_index = index.last_elem();
8585
let len = self.dim.last_elem();
8686
let offset = D::stride_offset(&index, &self.strides);
8787
unsafe {
8888
let row_ptr = self.ptr.offset(offset);
89-
for i in 0..(len - elem_index) {
89+
let mut i = 0;
90+
let i_end = len - elem_index;
91+
while i < i_end {
9092
accum = g(accum, row_ptr.offset(i as isize * stride));
93+
i += 1;
9194
}
9295
}
9396
index.set_last_elem(len - 1);

src/zip/mod.rs

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ where
723723
}
724724
}
725725

726-
fn apply_core_contiguous<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
726+
fn apply_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
727727
where
728728
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
729729
P: ZippableTuple<Dim = D>,
@@ -732,15 +732,35 @@ where
732732
let size = self.dimension.size();
733733
let ptrs = self.parts.as_ptr();
734734
let inner_strides = self.parts.contiguous_stride();
735-
for i in 0..size {
736-
unsafe {
737-
let ptr_i = ptrs.stride_offset(inner_strides, i);
738-
acc = fold_while![function(acc, self.parts.as_ref(ptr_i))];
739-
}
735+
unsafe {
736+
self.inner(acc, ptrs, inner_strides, size, &mut function)
737+
}
738+
}
739+
740+
/// The innermost loop of the Zip apply methods
741+
///
742+
/// Run the fold while operation on a stretch of elements with constant strides
743+
///
744+
/// `ptr`: base pointer for the first element in this stretch
745+
/// `strides`: strides for the elements in this stretch
746+
/// `len`: number of elements
747+
/// `function`: closure
748+
unsafe fn inner<F, Acc>(&self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride,
749+
len: usize, function: &mut F) -> FoldWhile<Acc>
750+
where
751+
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
752+
P: ZippableTuple
753+
{
754+
let mut i = 0;
755+
while i < len {
756+
let p = ptr.stride_offset(strides, i);
757+
acc = fold_while!(function(acc, self.parts.as_ref(p)));
758+
i += 1;
740759
}
741760
FoldWhile::Continue(acc)
742761
}
743762

763+
744764
fn apply_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
745765
where
746766
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
@@ -773,15 +793,11 @@ where
773793
while let Some(index) = index_ {
774794
unsafe {
775795
let ptr = self.parts.uget_ptr(&index);
776-
for i in 0..inner_len {
777-
let p = ptr.stride_offset(inner_strides, i);
778-
acc = fold_while!(function(acc, self.parts.as_ref(p)));
779-
}
796+
acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
780797
}
781798

782799
index_ = self.dimension.next_for(index);
783800
}
784-
self.dimension[unroll_axis] = inner_len;
785801
FoldWhile::Continue(acc)
786802
}
787803

@@ -801,18 +817,14 @@ where
801817
loop {
802818
unsafe {
803819
let ptr = self.parts.uget_ptr(&index);
804-
for i in 0..inner_len {
805-
let p = ptr.stride_offset(inner_strides, i);
806-
acc = fold_while!(function(acc, self.parts.as_ref(p)));
807-
}
820+
acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
808821
}
809822

810823
if !self.dimension.next_for_f(&mut index) {
811824
break;
812825
}
813826
}
814827
}
815-
self.dimension[unroll_axis] = inner_len;
816828
FoldWhile::Continue(acc)
817829
}
818830

0 commit comments

Comments
 (0)