Skip to content

Commit dee84ae

Browse files
andrei.papouandrei-papou
authored andcommitted
Removed allocation in stack_new_axis, renamed concatenate -> stack, stack -> stack_new_axis
1 parent 9f1a3d0 commit dee84ae

File tree

5 files changed

+64
-41
lines changed

5 files changed

+64
-41
lines changed

src/doc/ndarray_for_numpy_users/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,8 @@
532532
//! ------|-----------|------
533533
//! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value
534534
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
535-
//! `np.concatenate((a,b), axis=1)` | [`concatenate![Axis(1), a, b]`][concatenate!] or [`concatenate(Axis(1), &[a.view(), b.view()])`][concatenate()] | concatenate arrays `a` and `b` along axis 1
536-
//! `np.stack((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), vec![a.view(), b.view()])`][stack()] | stack arrays `a` and `b` along axis 1
535+
//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1
536+
//! `np.stack((a,b), axis=1)` | [`stack_new_axis![Axis(1), a, b]`][stack_new_axis!] or [`stack_new_axis(Axis(1), vec![a.view(), b.view()])`][stack_new_axis()] | stack arrays `a` and `b` along axis 1
537537
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
538538
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
539539
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
@@ -641,10 +641,10 @@
641641
//! [.slice_move()]: ../../struct.ArrayBase.html#method.slice_move
642642
//! [.slice_mut()]: ../../struct.ArrayBase.html#method.slice_mut
643643
//! [.shape()]: ../../struct.ArrayBase.html#method.shape
644-
//! [concatenate!]: ../../macro.concatenate.html
645-
//! [concatenate()]: ../../fn.concatenate.html
646644
//! [stack!]: ../../macro.stack.html
647645
//! [stack()]: ../../fn.stack.html
646+
//! [stack_new_axis!]: ../../macro.stack_new_axis.html
647+
//! [stack_new_axis()]: ../../fn.stack_new_axis.html
648648
//! [.strides()]: ../../struct.ArrayBase.html#method.strides
649649
//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis
650650
//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis

src/impl_methods.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::iter::{
2828
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
2929
};
3030
use crate::slice::MultiSlice;
31-
use crate::stacking::concatenate;
31+
use crate::stacking::stack;
3232
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};
3333

3434
/// # Methods For All Array Types
@@ -840,7 +840,7 @@ where
840840
dim.set_axis(axis, 0);
841841
unsafe { Array::from_shape_vec_unchecked(dim, vec![]) }
842842
} else {
843-
concatenate(axis, &subs).unwrap()
843+
stack(axis, &subs).unwrap()
844844
}
845845
}
846846

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane
131131

132132
pub use crate::arraytraits::AsArray;
133133
pub use crate::linalg_traits::{LinalgScalar, NdFloat};
134-
pub use crate::stacking::{concatenate, stack};
134+
pub use crate::stacking::{stack, stack_new_axis};
135135

136136
pub use crate::impl_views::IndexLonger;
137137
pub use crate::shape_builder::ShapeBuilder;

src/stacking.rs

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@ use crate::imp_prelude::*;
1717
/// if the result is larger than is possible to represent.
1818
///
1919
/// ```
20-
/// use ndarray::{arr2, Axis, concatenate};
20+
/// use ndarray::{arr2, Axis, stack};
2121
///
2222
/// let a = arr2(&[[2., 2.],
2323
/// [3., 3.]]);
2424
/// assert!(
25-
/// concatenate(Axis(0), &[a.view(), a.view()])
25+
/// stack(Axis(0), &[a.view(), a.view()])
2626
/// == Ok(arr2(&[[2., 2.],
2727
/// [3., 3.],
2828
/// [2., 2.],
2929
/// [3., 3.]]))
3030
/// );
3131
/// ```
32-
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
32+
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
3333
where
3434
A: Copy,
3535
D: RemoveAxis,
@@ -73,7 +73,7 @@ where
7373
Ok(res)
7474
}
7575

76-
pub fn stack<A, D>(
76+
pub fn stack_new_axis<A, D>(
7777
axis: Axis,
7878
arrays: Vec<ArrayView<A, D>>,
7979
) -> Result<Array<A, D::Larger>, ShapeError>
@@ -82,36 +82,59 @@ where
8282
D: Dimension,
8383
D::Larger: RemoveAxis,
8484
{
85-
if let Some(ndim) = D::NDIM {
86-
if axis.index() > ndim {
87-
return Err(from_kind(ErrorKind::OutOfBounds));
88-
}
85+
if arrays.is_empty() {
86+
return Err(from_kind(ErrorKind::Unsupported));
87+
}
88+
let common_dim = arrays[0].raw_dim();
89+
// Avoid panic on `insert_axis` call, return an Err instead of it.
90+
if axis.index() > common_dim.ndim() {
91+
return Err(from_kind(ErrorKind::OutOfBounds));
92+
}
93+
let mut res_dim = common_dim.insert_axis(axis);
94+
95+
if arrays.iter().any(|a| a.raw_dim() != common_dim) {
96+
return Err(from_kind(ErrorKind::IncompatibleShape));
8997
}
90-
let arrays: Vec<ArrayView<A, D::Larger>> =
91-
arrays.into_iter().map(|a| a.insert_axis(axis)).collect();
92-
concatenate(axis, &arrays)
98+
99+
res_dim.set_axis(axis, arrays.len());
100+
101+
// we can safely use uninitialized values here because they are Copy
102+
// and we will only ever write to them
103+
let size = res_dim.size();
104+
let mut v = Vec::with_capacity(size);
105+
unsafe {
106+
v.set_len(size);
107+
}
108+
let mut res = Array::from_shape_vec(res_dim, v)?;
109+
110+
res.axis_iter_mut(axis).zip(arrays.into_iter())
111+
.for_each(|(mut assign_view, array)| {
112+
assign_view.assign(&array);
113+
});
114+
115+
Ok(res)
93116
}
94117

95118
/// Concatenate arrays along the given axis.
96119
///
97-
/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each
120+
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
98121
/// argument `a`.
99122
///
100-
/// [1]: fn.concatenate.html
123+
/// [1]: fn.stack.html
101124
///
102125
/// ***Panics*** if the `concatenate` function would return an error.
103126
///
104127
/// ```
105128
/// extern crate ndarray;
106129
///
107-
/// use ndarray::{arr2, concatenate, Axis};
130+
/// use ndarray::{arr2, stack, Axis};
108131
///
109132
/// # fn main() {
110133
///
111134
/// let a = arr2(&[[2., 2.],
112135
/// [3., 3.]]);
113136
/// assert!(
114-
/// concatenate![Axis(0), a, a]
137+
/// stack![Axis(0), a, a]
115138
/// == arr2(&[[2., 2.],
116139
/// [3., 3.],
117140
/// [2., 2.],
@@ -120,32 +143,32 @@ where
120143
/// # }
121144
/// ```
122145
#[macro_export]
123-
macro_rules! concatenate {
146+
macro_rules! stack {
124147
($axis:expr, $( $array:expr ),+ ) => {
125-
$crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
148+
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
126149
}
127150
}
128151

129152
/// Stack arrays along the new axis.
130153
///
131-
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
154+
/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each
132155
/// argument `a`.
133156
///
134-
/// [1]: fn.concatenate.html
157+
/// [1]: fn.stack_new_axis.html
135158
///
136159
/// ***Panics*** if the `stack` function would return an error.
137160
///
138161
/// ```
139162
/// extern crate ndarray;
140163
///
141-
/// use ndarray::{arr2, arr3, stack, Axis};
164+
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
142165
///
143166
/// # fn main() {
144167
///
145168
/// let a = arr2(&[[2., 2.],
146169
/// [3., 3.]]);
147170
/// assert!(
148-
/// stack![Axis(0), a, a]
171+
/// stack_new_axis![Axis(0), a, a]
149172
/// == arr3(&[[[2., 2.],
150173
/// [3., 3.]],
151174
/// [[2., 2.],
@@ -154,8 +177,8 @@ macro_rules! concatenate {
154177
/// # }
155178
/// ```
156179
#[macro_export]
157-
macro_rules! stack {
180+
macro_rules! stack_new_axis {
158181
($axis:expr, $( $array:expr ),+ ) => {
159-
$crate::stack($axis, vec![ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
182+
$crate::stack_new_axis($axis, vec![ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
160183
}
161184
}

tests/stacking.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,43 @@
1-
use ndarray::{arr2, arr3, aview1, concatenate, Array2, Axis, ErrorKind, Ix1};
1+
use ndarray::{arr2, arr3, aview1, stack, Array2, Axis, ErrorKind, Ix1};
22

33
#[test]
44
fn concatenating() {
55
let a = arr2(&[[2., 2.], [3., 3.]]);
6-
let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
6+
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
77
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
88

9-
let c = concatenate![Axis(0), a, b];
9+
let c = stack![Axis(0), a, b];
1010
assert_eq!(
1111
c,
1212
arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
1313
);
1414

15-
let d = concatenate![Axis(0), a.row(0), &[9., 9.]];
15+
let d = stack![Axis(0), a.row(0), &[9., 9.]];
1616
assert_eq!(d, aview1(&[2., 2., 9., 9.]));
1717

18-
let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]);
18+
let res = ndarray::stack(Axis(1), &[a.view(), c.view()]);
1919
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);
2020

21-
let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]);
21+
let res = ndarray::stack(Axis(2), &[a.view(), c.view()]);
2222
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);
2323

24-
let res: Result<Array2<f64>, _> = ndarray::concatenate(Axis(0), &[]);
24+
let res: Result<Array2<f64>, _> = ndarray::stack(Axis(0), &[]);
2525
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
2626
}
2727

2828
#[test]
2929
fn stacking() {
3030
let a = arr2(&[[2., 2.], [3., 3.]]);
31-
let b = ndarray::stack(Axis(0), vec![a.view(), a.view()]).unwrap();
31+
let b = ndarray::stack_new_axis(Axis(0), vec![a.view(), a.view()]).unwrap();
3232
assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));
3333

3434
let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]);
35-
let res = ndarray::stack(Axis(1), vec![a.view(), c.view()]);
35+
let res = ndarray::stack_new_axis(Axis(1), vec![a.view(), c.view()]);
3636
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);
3737

38-
let res = ndarray::stack(Axis(3), vec![a.view(), a.view()]);
38+
let res = ndarray::stack_new_axis(Axis(3), vec![a.view(), a.view()]);
3939
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);
4040

41-
let res: Result<Array2<f64>, _> = ndarray::stack::<_, Ix1>(Axis(0), vec![]);
41+
let res: Result<Array2<f64>, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), vec![]);
4242
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
4343
}

0 commit comments

Comments
 (0)