Skip to content

Commit 14f2f57

Browse files
andrei.papouandrei-papou
authored andcommitted
Introduced concatenate function, deprecated stack function since 0.13.0
1 parent 1bb7104 commit 14f2f57

File tree

3 files changed

+92
-4
lines changed

3 files changed

+92
-4
lines changed

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane
131131

132132
pub use crate::arraytraits::AsArray;
133133
pub use crate::linalg_traits::{LinalgScalar, NdFloat};
134-
pub use crate::stacking::{stack, stack_new_axis};
134+
pub use crate::stacking::{concatenate, stack, stack_new_axis};
135135

136136
pub use crate::impl_views::IndexLonger;
137137
pub use crate::shape_builder::ShapeBuilder;

src/stacking.rs

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
use crate::error::{from_kind, ErrorKind, ShapeError};
1010
use crate::imp_prelude::*;
1111

12-
/// Stack arrays along the given axis.
12+
/// Concatenate arrays along the given axis.
1313
///
1414
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
1515
/// (may be made more flexible in the future).<br>
@@ -29,6 +29,10 @@ use crate::imp_prelude::*;
2929
/// [3., 3.]]))
3030
/// );
3131
/// ```
32+
#[deprecated(
33+
since = "0.13.0",
34+
note = "Please use the `concatenate` function instead"
35+
)]
3236
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
3337
where
3438
A: Copy,
@@ -73,6 +77,34 @@ where
7377
Ok(res)
7478
}
7579

80+
/// Concatenate arrays along the given axis.
81+
///
82+
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
83+
/// (may be made more flexible in the future).<br>
84+
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
85+
/// if the result is larger than is possible to represent.
86+
///
87+
/// ```
88+
/// use ndarray::{arr2, Axis, concatenate};
89+
///
90+
/// let a = arr2(&[[2., 2.],
91+
/// [3., 3.]]);
92+
/// assert!(
93+
/// concatenate(Axis(0), &[a.view(), a.view()])
94+
/// == Ok(arr2(&[[2., 2.],
95+
/// [3., 3.],
96+
/// [2., 2.],
97+
/// [3., 3.]]))
98+
/// );
99+
/// ```
100+
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
101+
where
102+
A: Copy,
103+
D: RemoveAxis,
104+
{
105+
stack(axis, arrays)
106+
}
107+
76108
pub fn stack_new_axis<A, D>(
77109
axis: Axis,
78110
arrays: Vec<ArrayView<A, D>>,
@@ -123,7 +155,7 @@ where
123155
///
124156
/// [1]: fn.stack.html
125157
///
126-
/// ***Panics*** if the `concatenate` function would return an error.
158+
/// ***Panics*** if the `stack` function would return an error.
127159
///
128160
/// ```
129161
/// extern crate ndarray;
@@ -150,6 +182,40 @@ macro_rules! stack {
150182
}
151183
}
152184

185+
/// Concatenate arrays along the given axis.
186+
///
187+
/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each
188+
/// argument `a`.
189+
///
190+
/// [1]: fn.concatenate.html
191+
///
192+
/// ***Panics*** if the `concatenate` function would return an error.
193+
///
194+
/// ```
195+
/// extern crate ndarray;
196+
///
197+
/// use ndarray::{arr2, concatenate, Axis};
198+
///
199+
/// # fn main() {
200+
///
201+
/// let a = arr2(&[[2., 2.],
202+
/// [3., 3.]]);
203+
/// assert!(
204+
/// concatenate![Axis(0), a, a]
205+
/// == arr2(&[[2., 2.],
206+
/// [3., 3.],
207+
/// [2., 2.],
208+
/// [3., 3.]])
209+
/// );
210+
/// # }
211+
/// ```
212+
#[macro_export]
213+
macro_rules! concatenate {
214+
($axis:expr, $( $array:expr ),+ ) => {
215+
$crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
216+
}
217+
}
218+
153219
/// Stack arrays along the new axis.
154220
///
155221
/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each

tests/stacking.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use ndarray::{arr2, arr3, aview1, stack, Array2, Axis, ErrorKind, Ix1};
1+
use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1};
22

33
#[test]
44
fn concatenating() {
@@ -23,6 +23,28 @@ fn concatenating() {
2323

2424
let res: Result<Array2<f64>, _> = ndarray::stack(Axis(0), &[]);
2525
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
26+
27+
let a = arr2(&[[2., 2.], [3., 3.]]);
28+
let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
29+
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
30+
31+
let c = concatenate![Axis(0), a, b];
32+
assert_eq!(
33+
c,
34+
arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
35+
);
36+
37+
let d = concatenate![Axis(0), a.row(0), &[9., 9.]];
38+
assert_eq!(d, aview1(&[2., 2., 9., 9.]));
39+
40+
let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]);
41+
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);
42+
43+
let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]);
44+
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);
45+
46+
let res: Result<Array2<f64>, _> = ndarray::concatenate(Axis(0), &[]);
47+
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
2648
}
2749

2850
#[test]

0 commit comments

Comments
 (0)