Skip to content

Commit a80d8bc

Browse files
committed
Update documentation and function names
1 parent 2eca6fb commit a80d8bc

File tree

4 files changed

+64
-72
lines changed

4 files changed

+64
-72
lines changed

src/dimension/broadcast.rs

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
11
use crate::error::*;
22
use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
33

4-
/// Calculate the co_broadcast shape of two dimensions. Return error if shapes are
5-
/// not compatible.
6-
fn broadcast_shape<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
7-
where
8-
D1: Dimension,
9-
D2: Dimension,
10-
Output: Dimension,
4+
/// Calculate the common shape for a pair of array shapes, which can be broadcasted
5+
/// to each other. Return an error if shapes are not compatible.
6+
///
7+
/// Uses the [NumPy broadcasting rules]
8+
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
9+
fn co_broadcasting<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
10+
where
11+
D1: Dimension,
12+
D2: Dimension,
13+
Output: Dimension,
1114
{
1215
let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim());
1316
// Swap the order if d2 is longer.
1417
if overflow {
15-
return broadcast_shape::<D2, D1, Output>(shape2, shape1);
18+
return co_broadcasting::<D2, D1, Output>(shape2, shape1);
1619
}
1720
// The output should be the same length as shape1.
1821
let mut out = Output::zeros(shape1.ndim());
19-
// Uses the [NumPy broadcasting rules]
20-
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
21-
//
22-
// Zero dimension element is not in the original rules of broadcasting.
23-
// We currently treat it like any other number greater than 1. As numpy does.
2422
for (out, s) in izip!(out.slice_mut(), shape1.slice()) {
2523
*out = *s;
2624
}
@@ -42,10 +40,7 @@ pub trait BroadcastShape<Other: Dimension> {
4240

4341
/// Determines the shape after broadcasting the dimensions together.
4442
///
45-
/// If the dimensions are not compatible, returns `Err`.
46-
///
47-
/// Uses the [NumPy broadcasting rules]
48-
/// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
43+
/// If the shapes are not compatible, returns `Err`.
4944
fn broadcast_shape(&self, other: &Other) -> Result<Self::Output, ShapeError>;
5045
}
5146

@@ -56,7 +51,7 @@ impl<D: Dimension> BroadcastShape<D> for D {
5651
type Output = D;
5752

5853
fn broadcast_shape(&self, other: &D) -> Result<Self::Output, ShapeError> {
59-
broadcast_shape::<D, D, Self::Output>(self, other)
54+
co_broadcasting::<D, D, Self::Output>(self, other)
6055
}
6156
}
6257

@@ -66,15 +61,15 @@ macro_rules! impl_broadcast_distinct_fixed {
6661
type Output = $larger;
6762

6863
fn broadcast_shape(&self, other: &$larger) -> Result<Self::Output, ShapeError> {
69-
broadcast_shape::<Self, $larger, Self::Output>(self, other)
64+
co_broadcasting::<Self, $larger, Self::Output>(self, other)
7065
}
7166
}
7267

7368
impl BroadcastShape<$smaller> for $larger {
7469
type Output = $larger;
7570

7671
fn broadcast_shape(&self, other: &$smaller) -> Result<Self::Output, ShapeError> {
77-
broadcast_shape::<Self, $smaller, Self::Output>(self, other)
72+
co_broadcasting::<Self, $smaller, Self::Output>(self, other)
7873
}
7974
}
8075
};

src/impl_ops.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ where
6868
A: Clone + $trt<B, Output=A>,
6969
B: Clone,
7070
S: DataOwned<Elem=A> + DataMut,
71-
S::MaybeUninit: DataMut,
7271
S2: Data<Elem=B>,
7372
D: Dimension + BroadcastShape<E>,
7473
E: Dimension,
@@ -96,7 +95,6 @@ where
9695
A: Clone + $trt<B, Output=A>,
9796
B: Clone,
9897
S: DataOwned<Elem=A> + DataMut,
99-
S::MaybeUninit: DataMut,
10098
S2: Data<Elem=B>,
10199
D: Dimension + BroadcastShape<E>,
102100
E: Dimension,
@@ -134,7 +132,6 @@ where
134132
B: Clone,
135133
S: Data<Elem=A>,
136134
S2: DataOwned<Elem=B> + DataMut,
137-
S2::MaybeUninit: DataMut,
138135
D: Dimension,
139136
E: Dimension + BroadcastShape<D>,
140137
{

tests/array.rs

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use defmac::defmac;
1111
use itertools::{enumerate, zip, Itertools};
1212
use ndarray::prelude::*;
1313
use ndarray::{arr3, rcarr2};
14-
use ndarray::{indices, BroadcastShape, ErrorKind, IxDynImpl, ShapeError};
14+
use ndarray::indices;
1515
use ndarray::{Slice, SliceInfo, SliceOrIndex};
1616
use std::iter::FromIterator;
1717

@@ -1558,53 +1558,6 @@ fn insert_axis_view() {
15581558
);
15591559
}
15601560

1561-
#[test]
1562-
fn test_broadcast_shape() {
1563-
fn test_co<D1, D2>(
1564-
d1: &D1,
1565-
d2: &D2,
1566-
r: Result<<D1 as BroadcastShape<D2>>::Output, ShapeError>,
1567-
) where
1568-
D1: Dimension + BroadcastShape<D2>,
1569-
D2: Dimension,
1570-
{
1571-
let d = d1.broadcast_shape(d2);
1572-
assert_eq!(d, r);
1573-
}
1574-
test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3])));
1575-
test_co(
1576-
&Dim([1, 2, 2]),
1577-
&Dim([1, 3, 4]),
1578-
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
1579-
);
1580-
test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5])));
1581-
let v = vec![1, 2, 3, 4, 5, 6, 7];
1582-
test_co(
1583-
&Dim(vec![1, 1, 3, 1, 5, 1, 7]),
1584-
&Dim([2, 1, 4, 1, 6, 1]),
1585-
Ok(Dim(IxDynImpl::from(v.as_slice()))),
1586-
);
1587-
let d = Dim([1, 2, 1, 3]);
1588-
test_co(&d, &d, Ok(d));
1589-
test_co(
1590-
&Dim([2, 1, 2]).into_dyn(),
1591-
&Dim(0),
1592-
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
1593-
);
1594-
test_co(
1595-
&Dim([2, 1, 1]),
1596-
&Dim([0, 0, 1, 3, 4]),
1597-
Ok(Dim([0, 0, 2, 3, 4])),
1598-
);
1599-
test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0])));
1600-
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0])));
1601-
test_co(
1602-
&Dim([1, 3, 0, 1, 1]),
1603-
&Dim([1, 2, 3, 1]),
1604-
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
1605-
);
1606-
}
1607-
16081561
#[test]
16091562
fn arithmetic_broadcast() {
16101563
let mut a = arr2(&[[1., 2.], [3., 4.]]);

tests/dimension.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
use defmac::defmac;
44

5-
use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IntoDimension, IxDyn, RemoveAxis};
5+
use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IntoDimension,Ix0, IxDyn, IxDynImpl, RemoveAxis, ErrorKind, ShapeError, BroadcastShape};
66

77
use std::hash::{Hash, Hasher};
88

@@ -339,3 +339,50 @@ fn test_all_ndindex() {
339339
ndindex!(10, 4, 3, 2, 2);
340340
ndindex!(10, 4, 3, 2, 2, 2);
341341
}
342+
343+
#[test]
344+
fn test_broadcast_shape() {
345+
fn test_co<D1, D2>(
346+
d1: &D1,
347+
d2: &D2,
348+
r: Result<<D1 as BroadcastShape<D2>>::Output, ShapeError>,
349+
) where
350+
D1: Dimension + BroadcastShape<D2>,
351+
D2: Dimension,
352+
{
353+
let d = d1.broadcast_shape(d2);
354+
assert_eq!(d, r);
355+
}
356+
test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3])));
357+
test_co(
358+
&Dim([1, 2, 2]),
359+
&Dim([1, 3, 4]),
360+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
361+
);
362+
test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5])));
363+
let v = vec![1, 2, 3, 4, 5, 6, 7];
364+
test_co(
365+
&Dim(vec![1, 1, 3, 1, 5, 1, 7]),
366+
&Dim([2, 1, 4, 1, 6, 1]),
367+
Ok(Dim(IxDynImpl::from(v.as_slice()))),
368+
);
369+
let d = Dim([1, 2, 1, 3]);
370+
test_co(&d, &d, Ok(d));
371+
test_co(
372+
&Dim([2, 1, 2]).into_dyn(),
373+
&Dim(0),
374+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
375+
);
376+
test_co(
377+
&Dim([2, 1, 1]),
378+
&Dim([0, 0, 1, 3, 4]),
379+
Ok(Dim([0, 0, 2, 3, 4])),
380+
);
381+
test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0])));
382+
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0])));
383+
test_co(
384+
&Dim([1, 3, 0, 1, 1]),
385+
&Dim([1, 2, 3, 1]),
386+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
387+
);
388+
}

0 commit comments

Comments
 (0)