Skip to content

Commit 3a2040d

Browse files
authored
Merge pull request #855 from rust-ndarray/checked-shape-strides
Error-check array shape before computing strides
2 parents 143d5b4 + c195930 commit 3a2040d

File tree

8 files changed

+122
-77
lines changed

8 files changed

+122
-77
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
with:
5050
profile: minimal
5151
toolchain: ${{ matrix.rust }}
52-
taret: ${{ matrix.target }}
52+
target: ${{ matrix.target }}
5353
override: true
5454
- name: Cache cargo plugins
5555
uses: actions/cache@v1

src/dimension/mod.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pub use self::dynindeximpl::IxDynImpl;
1919
pub use self::ndindex::NdIndex;
2020
pub use self::remove_axis::RemoveAxis;
2121

22+
use crate::shape_builder::Strides;
23+
2224
use std::isize;
2325
use std::mem;
2426

@@ -114,11 +116,24 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
114116
/// conditions 1 and 2 are sufficient to guarantee that the offset in units of
115117
/// `A` and in units of bytes between the least address and greatest address
116118
/// accessible by moving along all axes does not exceed `isize::MAX`.
117-
pub fn can_index_slice_not_custom<A, D: Dimension>(data: &[A], dim: &D) -> Result<(), ShapeError> {
119+
pub(crate) fn can_index_slice_with_strides<A, D: Dimension>(data: &[A], dim: &D,
120+
strides: &Strides<D>)
121+
-> Result<(), ShapeError>
122+
{
123+
if let Strides::Custom(strides) = strides {
124+
can_index_slice(data, dim, strides)
125+
} else {
126+
can_index_slice_not_custom(data.len(), dim)
127+
}
128+
}
129+
130+
pub(crate) fn can_index_slice_not_custom<D: Dimension>(data_len: usize, dim: &D)
131+
-> Result<(), ShapeError>
132+
{
118133
// Condition 1.
119134
let len = size_of_shape_checked(dim)?;
120135
// Condition 2.
121-
if len > data.len() {
136+
if len > data_len {
122137
return Err(from_kind(ErrorKind::OutOfBounds));
123138
}
124139
Ok(())
@@ -217,7 +232,7 @@ where
217232
/// condition 4 is sufficient to guarantee that the absolute difference in
218233
/// units of `A` and in units of bytes between the least address and greatest
219234
/// address accessible by moving along all axes does not exceed `isize::MAX`.
220-
pub fn can_index_slice<A, D: Dimension>(
235+
pub(crate) fn can_index_slice<A, D: Dimension>(
221236
data: &[A],
222237
dim: &D,
223238
strides: &D,
@@ -771,7 +786,7 @@ mod test {
771786
quickcheck! {
772787
fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec<u8>, dim: Vec<usize>) -> bool {
773788
let dim = IxDyn(&dim);
774-
let result = can_index_slice_not_custom(&data, &dim);
789+
let result = can_index_slice_not_custom(data.len(), &dim);
775790
if dim.size_checked().is_none() {
776791
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
777792
result.is_err()

src/impl_constructors.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ where
358358
{
359359
let shape = shape.into_shape();
360360
let _ = size_of_shape_checked_unwrap!(&shape.dim);
361-
if shape.is_c {
361+
if shape.is_c() {
362362
let v = to_vec_mapped(indices(shape.dim.clone()).into_iter(), f);
363363
unsafe { Self::from_shape_vec_unchecked(shape, v) }
364364
} else {
@@ -411,15 +411,12 @@ where
411411

412412
fn from_shape_vec_impl(shape: StrideShape<D>, v: Vec<A>) -> Result<Self, ShapeError> {
413413
let dim = shape.dim;
414-
let strides = shape.strides;
415-
if shape.custom {
416-
dimension::can_index_slice(&v, &dim, &strides)?;
417-
} else {
418-
dimension::can_index_slice_not_custom::<A, _>(&v, &dim)?;
419-
if dim.size() != v.len() {
420-
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
421-
}
414+
let is_custom = shape.strides.is_custom();
415+
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides)?;
416+
if !is_custom && dim.size() != v.len() {
417+
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
422418
}
419+
let strides = shape.strides.strides_for_dim(&dim);
423420
unsafe { Ok(Self::from_vec_dim_stride_unchecked(dim, strides, v)) }
424421
}
425422

@@ -451,7 +448,9 @@ where
451448
Sh: Into<StrideShape<D>>,
452449
{
453450
let shape = shape.into();
454-
Self::from_vec_dim_stride_unchecked(shape.dim, shape.strides, v)
451+
let dim = shape.dim;
452+
let strides = shape.strides.strides_for_dim(&dim);
453+
Self::from_vec_dim_stride_unchecked(dim, strides, v)
455454
}
456455

457456
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self {

src/impl_raw_views.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use std::ptr::NonNull;
44
use crate::dimension::{self, stride_offset};
55
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
66
use crate::imp_prelude::*;
7-
use crate::{is_aligned, StrideShape};
7+
use crate::is_aligned;
8+
use crate::shape_builder::{Strides, StrideShape};
89

910
impl<A, D> RawArrayView<A, D>
1011
where
@@ -69,11 +70,15 @@ where
6970
{
7071
let shape = shape.into();
7172
let dim = shape.dim;
72-
let strides = shape.strides;
7373
if cfg!(debug_assertions) {
7474
assert!(!ptr.is_null(), "The pointer must be non-null.");
75-
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
75+
if let Strides::Custom(strides) = &shape.strides {
76+
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
77+
} else {
78+
dimension::size_of_shape_checked(&dim).unwrap();
79+
}
7680
}
81+
let strides = shape.strides.strides_for_dim(&dim);
7782
RawArrayView::new_(ptr, dim, strides)
7883
}
7984

@@ -205,11 +210,15 @@ where
205210
{
206211
let shape = shape.into();
207212
let dim = shape.dim;
208-
let strides = shape.strides;
209213
if cfg!(debug_assertions) {
210214
assert!(!ptr.is_null(), "The pointer must be non-null.");
211-
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
215+
if let Strides::Custom(strides) = &shape.strides {
216+
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
217+
} else {
218+
dimension::size_of_shape_checked(&dim).unwrap();
219+
}
212220
}
221+
let strides = shape.strides.strides_for_dim(&dim);
213222
RawArrayViewMut::new_(ptr, dim, strides)
214223
}
215224

src/impl_views/constructors.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,8 @@ where
5353

5454
fn from_shape_impl(shape: StrideShape<D>, xs: &'a [A]) -> Result<Self, ShapeError> {
5555
let dim = shape.dim;
56-
let strides = shape.strides;
57-
if shape.custom {
58-
dimension::can_index_slice(xs, &dim, &strides)?;
59-
} else {
60-
dimension::can_index_slice_not_custom::<A, _>(xs, &dim)?;
61-
}
56+
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
57+
let strides = shape.strides.strides_for_dim(&dim);
6258
unsafe { Ok(Self::new_(xs.as_ptr(), dim, strides)) }
6359
}
6460

@@ -149,12 +145,8 @@ where
149145

150146
fn from_shape_impl(shape: StrideShape<D>, xs: &'a mut [A]) -> Result<Self, ShapeError> {
151147
let dim = shape.dim;
152-
let strides = shape.strides;
153-
if shape.custom {
154-
dimension::can_index_slice(xs, &dim, &strides)?;
155-
} else {
156-
dimension::can_index_slice_not_custom::<A, _>(xs, &dim)?;
157-
}
148+
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
149+
let strides = shape.strides.strides_for_dim(&dim);
158150
unsafe { Ok(Self::new_(xs.as_mut_ptr(), dim, strides)) }
159151
}
160152

src/lib.rs

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ pub use crate::linalg_traits::{LinalgScalar, NdFloat};
138138
pub use crate::stacking::{concatenate, stack, stack_new_axis};
139139

140140
pub use crate::impl_views::IndexLonger;
141-
pub use crate::shape_builder::ShapeBuilder;
141+
pub use crate::shape_builder::{Shape, StrideShape, ShapeBuilder};
142142

143143
#[macro_use]
144144
mod macro_utils;
@@ -1595,24 +1595,8 @@ mod impl_raw_views;
15951595
// Copy-on-write array methods
15961596
mod impl_cow;
15971597

1598-
/// A contiguous array shape of n dimensions.
1599-
///
1600-
/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default).
1601-
#[derive(Copy, Clone, Debug)]
1602-
pub struct Shape<D> {
1603-
dim: D,
1604-
is_c: bool,
1605-
}
1606-
1607-
/// An array shape of n dimensions in c-order, f-order or custom strides.
1608-
#[derive(Copy, Clone, Debug)]
1609-
pub struct StrideShape<D> {
1610-
dim: D,
1611-
strides: D,
1612-
custom: bool,
1613-
}
1614-
16151598
/// Returns `true` if the pointer is aligned.
16161599
pub(crate) fn is_aligned<T>(ptr: *const T) -> bool {
16171600
(ptr as usize) % ::std::mem::align_of::<T>() == 0
16181601
}
1602+

src/shape_builder.rs

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,66 @@
11
use crate::dimension::IntoDimension;
22
use crate::Dimension;
3-
use crate::{Shape, StrideShape};
3+
4+
/// A contiguous array shape of n dimensions.
5+
///
6+
/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default).
7+
#[derive(Copy, Clone, Debug)]
8+
pub struct Shape<D> {
9+
/// Shape (axis lengths)
10+
pub(crate) dim: D,
11+
/// Strides can only be C or F here
12+
pub(crate) strides: Strides<Contiguous>,
13+
}
14+
15+
#[derive(Copy, Clone, Debug)]
16+
pub(crate) enum Contiguous { }
17+
18+
impl<D> Shape<D> {
19+
pub(crate) fn is_c(&self) -> bool {
20+
matches!(self.strides, Strides::C)
21+
}
22+
}
23+
24+
25+
/// An array shape of n dimensions in c-order, f-order or custom strides.
26+
#[derive(Copy, Clone, Debug)]
27+
pub struct StrideShape<D> {
28+
pub(crate) dim: D,
29+
pub(crate) strides: Strides<D>,
30+
}
31+
32+
/// Stride description
33+
#[derive(Copy, Clone, Debug)]
34+
pub(crate) enum Strides<D> {
35+
/// Row-major ("C"-order)
36+
C,
37+
/// Column-major ("F"-order)
38+
F,
39+
/// Custom strides
40+
Custom(D)
41+
}
42+
43+
impl<D> Strides<D> {
44+
/// Return strides for `dim` (computed from dimension if c/f, else return the custom stride)
45+
pub(crate) fn strides_for_dim(self, dim: &D) -> D
46+
where D: Dimension
47+
{
48+
match self {
49+
Strides::C => dim.default_strides(),
50+
Strides::F => dim.fortran_strides(),
51+
Strides::Custom(c) => {
52+
debug_assert_eq!(c.ndim(), dim.ndim(),
53+
"Custom strides given with {} dimensions, expected {}",
54+
c.ndim(), dim.ndim());
55+
c
56+
}
57+
}
58+
}
59+
60+
pub(crate) fn is_custom(&self) -> bool {
61+
matches!(*self, Strides::Custom(_))
62+
}
63+
}
464

565
/// A trait for `Shape` and `D where D: Dimension` that allows
666
/// customizing the memory layout (strides) of an array shape.
@@ -34,36 +94,18 @@ where
3494
{
3595
fn from(value: T) -> Self {
3696
let shape = value.into_shape();
37-
let d = shape.dim;
38-
let st = if shape.is_c {
39-
d.default_strides()
97+
let st = if shape.is_c() {
98+
Strides::C
4099
} else {
41-
d.fortran_strides()
100+
Strides::F
42101
};
43102
StrideShape {
44103
strides: st,
45-
dim: d,
46-
custom: false,
104+
dim: shape.dim,
47105
}
48106
}
49107
}
50108

51-
/*
52-
impl<D> From<Shape<D>> for StrideShape<D>
53-
where D: Dimension
54-
{
55-
fn from(shape: Shape<D>) -> Self {
56-
let d = shape.dim;
57-
let st = if shape.is_c { d.default_strides() } else { d.fortran_strides() };
58-
StrideShape {
59-
strides: st,
60-
dim: d,
61-
custom: false,
62-
}
63-
}
64-
}
65-
*/
66-
67109
impl<T> ShapeBuilder for T
68110
where
69111
T: IntoDimension,
@@ -73,7 +115,7 @@ where
73115
fn into_shape(self) -> Shape<Self::Dim> {
74116
Shape {
75117
dim: self.into_dimension(),
76-
is_c: true,
118+
strides: Strides::C,
77119
}
78120
}
79121
fn f(self) -> Shape<Self::Dim> {
@@ -93,21 +135,24 @@ where
93135
{
94136
type Dim = D;
95137
type Strides = D;
138+
96139
fn into_shape(self) -> Shape<D> {
97140
self
98141
}
142+
99143
fn f(self) -> Self {
100144
self.set_f(true)
101145
}
146+
102147
fn set_f(mut self, is_f: bool) -> Self {
103-
self.is_c = !is_f;
148+
self.strides = if !is_f { Strides::C } else { Strides::F };
104149
self
105150
}
151+
106152
fn strides(self, st: D) -> StrideShape<D> {
107153
StrideShape {
108154
dim: self.dim,
109-
strides: st,
110-
custom: true,
155+
strides: Strides::Custom(st),
111156
}
112157
}
113158
}

tests/array-construct.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ fn test_from_fn_f3() {
148148
fn deny_wraparound_from_vec() {
149149
let five = vec![0; 5];
150150
let five_large = Array::from_shape_vec((3, 7, 29, 36760123, 823996703), five.clone());
151+
println!("{:?}", five_large);
151152
assert!(five_large.is_err());
152153
let six = Array::from_shape_vec(6, five.clone());
153154
assert!(six.is_err());

0 commit comments

Comments
 (0)