Skip to content

Commit c195930

Browse files
committed
FIX: Error-check array shape before computing strides
Update StrideShape's representation to use enum { C, F, Custom(D) } so that we don't try to compute a C or F stride set until after we have checked the actual array shape. This avoids overflow errors (that panic) in places where we expected to return an error. It was visible as a test failure on 32-bit that long went undetected because tests didn't run on such platforms before (but the bug affected all platforms, given sufficiently large inputs). Also move the Shape, StrideShape types into the shape builder module. It all calls out for a nicer organization of types that makes constructors easier to understand for users, but that's an issue for another time - and it's a breaking change.
1 parent 4549b00 commit c195930

File tree

7 files changed

+121
-76
lines changed

7 files changed

+121
-76
lines changed

src/dimension/mod.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pub use self::dynindeximpl::IxDynImpl;
1919
pub use self::ndindex::NdIndex;
2020
pub use self::remove_axis::RemoveAxis;
2121

22+
use crate::shape_builder::Strides;
23+
2224
use std::isize;
2325
use std::mem;
2426

@@ -114,11 +116,24 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
114116
/// conditions 1 and 2 are sufficient to guarantee that the offset in units of
115117
/// `A` and in units of bytes between the least address and greatest address
116118
/// accessible by moving along all axes does not exceed `isize::MAX`.
117-
pub fn can_index_slice_not_custom<A, D: Dimension>(data: &[A], dim: &D) -> Result<(), ShapeError> {
119+
pub(crate) fn can_index_slice_with_strides<A, D: Dimension>(data: &[A], dim: &D,
120+
strides: &Strides<D>)
121+
-> Result<(), ShapeError>
122+
{
123+
if let Strides::Custom(strides) = strides {
124+
can_index_slice(data, dim, strides)
125+
} else {
126+
can_index_slice_not_custom(data.len(), dim)
127+
}
128+
}
129+
130+
pub(crate) fn can_index_slice_not_custom<D: Dimension>(data_len: usize, dim: &D)
131+
-> Result<(), ShapeError>
132+
{
118133
// Condition 1.
119134
let len = size_of_shape_checked(dim)?;
120135
// Condition 2.
121-
if len > data.len() {
136+
if len > data_len {
122137
return Err(from_kind(ErrorKind::OutOfBounds));
123138
}
124139
Ok(())
@@ -217,7 +232,7 @@ where
217232
/// condition 4 is sufficient to guarantee that the absolute difference in
218233
/// units of `A` and in units of bytes between the least address and greatest
219234
/// address accessible by moving along all axes does not exceed `isize::MAX`.
220-
pub fn can_index_slice<A, D: Dimension>(
235+
pub(crate) fn can_index_slice<A, D: Dimension>(
221236
data: &[A],
222237
dim: &D,
223238
strides: &D,
@@ -771,7 +786,7 @@ mod test {
771786
quickcheck! {
772787
fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec<u8>, dim: Vec<usize>) -> bool {
773788
let dim = IxDyn(&dim);
774-
let result = can_index_slice_not_custom(&data, &dim);
789+
let result = can_index_slice_not_custom(data.len(), &dim);
775790
if dim.size_checked().is_none() {
776791
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
777792
result.is_err()

src/impl_constructors.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ where
358358
{
359359
let shape = shape.into_shape();
360360
let _ = size_of_shape_checked_unwrap!(&shape.dim);
361-
if shape.is_c {
361+
if shape.is_c() {
362362
let v = to_vec_mapped(indices(shape.dim.clone()).into_iter(), f);
363363
unsafe { Self::from_shape_vec_unchecked(shape, v) }
364364
} else {
@@ -411,15 +411,12 @@ where
411411

412412
fn from_shape_vec_impl(shape: StrideShape<D>, v: Vec<A>) -> Result<Self, ShapeError> {
413413
let dim = shape.dim;
414-
let strides = shape.strides;
415-
if shape.custom {
416-
dimension::can_index_slice(&v, &dim, &strides)?;
417-
} else {
418-
dimension::can_index_slice_not_custom::<A, _>(&v, &dim)?;
419-
if dim.size() != v.len() {
420-
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
421-
}
414+
let is_custom = shape.strides.is_custom();
415+
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides)?;
416+
if !is_custom && dim.size() != v.len() {
417+
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
422418
}
419+
let strides = shape.strides.strides_for_dim(&dim);
423420
unsafe { Ok(Self::from_vec_dim_stride_unchecked(dim, strides, v)) }
424421
}
425422

@@ -451,7 +448,9 @@ where
451448
Sh: Into<StrideShape<D>>,
452449
{
453450
let shape = shape.into();
454-
Self::from_vec_dim_stride_unchecked(shape.dim, shape.strides, v)
451+
let dim = shape.dim;
452+
let strides = shape.strides.strides_for_dim(&dim);
453+
Self::from_vec_dim_stride_unchecked(dim, strides, v)
455454
}
456455

457456
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self {

src/impl_raw_views.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use std::ptr::NonNull;
44
use crate::dimension::{self, stride_offset};
55
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
66
use crate::imp_prelude::*;
7-
use crate::{is_aligned, StrideShape};
7+
use crate::is_aligned;
8+
use crate::shape_builder::{Strides, StrideShape};
89

910
impl<A, D> RawArrayView<A, D>
1011
where
@@ -69,11 +70,15 @@ where
6970
{
7071
let shape = shape.into();
7172
let dim = shape.dim;
72-
let strides = shape.strides;
7373
if cfg!(debug_assertions) {
7474
assert!(!ptr.is_null(), "The pointer must be non-null.");
75-
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
75+
if let Strides::Custom(strides) = &shape.strides {
76+
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
77+
} else {
78+
dimension::size_of_shape_checked(&dim).unwrap();
79+
}
7680
}
81+
let strides = shape.strides.strides_for_dim(&dim);
7782
RawArrayView::new_(ptr, dim, strides)
7883
}
7984

@@ -205,11 +210,15 @@ where
205210
{
206211
let shape = shape.into();
207212
let dim = shape.dim;
208-
let strides = shape.strides;
209213
if cfg!(debug_assertions) {
210214
assert!(!ptr.is_null(), "The pointer must be non-null.");
211-
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
215+
if let Strides::Custom(strides) = &shape.strides {
216+
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
217+
} else {
218+
dimension::size_of_shape_checked(&dim).unwrap();
219+
}
212220
}
221+
let strides = shape.strides.strides_for_dim(&dim);
213222
RawArrayViewMut::new_(ptr, dim, strides)
214223
}
215224

src/impl_views/constructors.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,8 @@ where
5353

5454
fn from_shape_impl(shape: StrideShape<D>, xs: &'a [A]) -> Result<Self, ShapeError> {
5555
let dim = shape.dim;
56-
let strides = shape.strides;
57-
if shape.custom {
58-
dimension::can_index_slice(xs, &dim, &strides)?;
59-
} else {
60-
dimension::can_index_slice_not_custom::<A, _>(xs, &dim)?;
61-
}
56+
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
57+
let strides = shape.strides.strides_for_dim(&dim);
6258
unsafe { Ok(Self::new_(xs.as_ptr(), dim, strides)) }
6359
}
6460

@@ -149,12 +145,8 @@ where
149145

150146
fn from_shape_impl(shape: StrideShape<D>, xs: &'a mut [A]) -> Result<Self, ShapeError> {
151147
let dim = shape.dim;
152-
let strides = shape.strides;
153-
if shape.custom {
154-
dimension::can_index_slice(xs, &dim, &strides)?;
155-
} else {
156-
dimension::can_index_slice_not_custom::<A, _>(xs, &dim)?;
157-
}
148+
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
149+
let strides = shape.strides.strides_for_dim(&dim);
158150
unsafe { Ok(Self::new_(xs.as_mut_ptr(), dim, strides)) }
159151
}
160152

src/lib.rs

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ pub use crate::linalg_traits::{LinalgScalar, NdFloat};
138138
pub use crate::stacking::{concatenate, stack, stack_new_axis};
139139

140140
pub use crate::impl_views::IndexLonger;
141-
pub use crate::shape_builder::ShapeBuilder;
141+
pub use crate::shape_builder::{Shape, StrideShape, ShapeBuilder};
142142

143143
#[macro_use]
144144
mod macro_utils;
@@ -1595,24 +1595,8 @@ mod impl_raw_views;
15951595
// Copy-on-write array methods
15961596
mod impl_cow;
15971597

1598-
/// A contiguous array shape of n dimensions.
1599-
///
1600-
/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default).
1601-
#[derive(Copy, Clone, Debug)]
1602-
pub struct Shape<D> {
1603-
dim: D,
1604-
is_c: bool,
1605-
}
1606-
1607-
/// An array shape of n dimensions in c-order, f-order or custom strides.
1608-
#[derive(Copy, Clone, Debug)]
1609-
pub struct StrideShape<D> {
1610-
dim: D,
1611-
strides: D,
1612-
custom: bool,
1613-
}
1614-
16151598
/// Returns `true` if the pointer is aligned.
16161599
pub(crate) fn is_aligned<T>(ptr: *const T) -> bool {
16171600
(ptr as usize) % ::std::mem::align_of::<T>() == 0
16181601
}
1602+

src/shape_builder.rs

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,66 @@
11
use crate::dimension::IntoDimension;
22
use crate::Dimension;
3-
use crate::{Shape, StrideShape};
3+
4+
/// A contiguous array shape of n dimensions.
5+
///
6+
/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default).
7+
#[derive(Copy, Clone, Debug)]
8+
pub struct Shape<D> {
9+
/// Shape (axis lengths)
10+
pub(crate) dim: D,
11+
/// Strides can only be C or F here
12+
pub(crate) strides: Strides<Contiguous>,
13+
}
14+
15+
#[derive(Copy, Clone, Debug)]
16+
pub(crate) enum Contiguous { }
17+
18+
impl<D> Shape<D> {
19+
pub(crate) fn is_c(&self) -> bool {
20+
matches!(self.strides, Strides::C)
21+
}
22+
}
23+
24+
25+
/// An array shape of n dimensions in c-order, f-order or custom strides.
26+
#[derive(Copy, Clone, Debug)]
27+
pub struct StrideShape<D> {
28+
pub(crate) dim: D,
29+
pub(crate) strides: Strides<D>,
30+
}
31+
32+
/// Stride description
33+
#[derive(Copy, Clone, Debug)]
34+
pub(crate) enum Strides<D> {
35+
/// Row-major ("C"-order)
36+
C,
37+
/// Column-major ("F"-order)
38+
F,
39+
/// Custom strides
40+
Custom(D)
41+
}
42+
43+
impl<D> Strides<D> {
44+
/// Return strides for `dim` (computed from dimension if c/f, else return the custom stride)
45+
pub(crate) fn strides_for_dim(self, dim: &D) -> D
46+
where D: Dimension
47+
{
48+
match self {
49+
Strides::C => dim.default_strides(),
50+
Strides::F => dim.fortran_strides(),
51+
Strides::Custom(c) => {
52+
debug_assert_eq!(c.ndim(), dim.ndim(),
53+
"Custom strides given with {} dimensions, expected {}",
54+
c.ndim(), dim.ndim());
55+
c
56+
}
57+
}
58+
}
59+
60+
pub(crate) fn is_custom(&self) -> bool {
61+
matches!(*self, Strides::Custom(_))
62+
}
63+
}
464

565
/// A trait for `Shape` and `D where D: Dimension` that allows
666
/// customizing the memory layout (strides) of an array shape.
@@ -34,36 +94,18 @@ where
3494
{
3595
fn from(value: T) -> Self {
3696
let shape = value.into_shape();
37-
let d = shape.dim;
38-
let st = if shape.is_c {
39-
d.default_strides()
97+
let st = if shape.is_c() {
98+
Strides::C
4099
} else {
41-
d.fortran_strides()
100+
Strides::F
42101
};
43102
StrideShape {
44103
strides: st,
45-
dim: d,
46-
custom: false,
104+
dim: shape.dim,
47105
}
48106
}
49107
}
50108

51-
/*
52-
impl<D> From<Shape<D>> for StrideShape<D>
53-
where D: Dimension
54-
{
55-
fn from(shape: Shape<D>) -> Self {
56-
let d = shape.dim;
57-
let st = if shape.is_c { d.default_strides() } else { d.fortran_strides() };
58-
StrideShape {
59-
strides: st,
60-
dim: d,
61-
custom: false,
62-
}
63-
}
64-
}
65-
*/
66-
67109
impl<T> ShapeBuilder for T
68110
where
69111
T: IntoDimension,
@@ -73,7 +115,7 @@ where
73115
fn into_shape(self) -> Shape<Self::Dim> {
74116
Shape {
75117
dim: self.into_dimension(),
76-
is_c: true,
118+
strides: Strides::C,
77119
}
78120
}
79121
fn f(self) -> Shape<Self::Dim> {
@@ -93,21 +135,24 @@ where
93135
{
94136
type Dim = D;
95137
type Strides = D;
138+
96139
fn into_shape(self) -> Shape<D> {
97140
self
98141
}
142+
99143
fn f(self) -> Self {
100144
self.set_f(true)
101145
}
146+
102147
fn set_f(mut self, is_f: bool) -> Self {
103-
self.is_c = !is_f;
148+
self.strides = if !is_f { Strides::C } else { Strides::F };
104149
self
105150
}
151+
106152
fn strides(self, st: D) -> StrideShape<D> {
107153
StrideShape {
108154
dim: self.dim,
109-
strides: st,
110-
custom: true,
155+
strides: Strides::Custom(st),
111156
}
112157
}
113158
}

tests/array-construct.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ fn test_from_fn_f3() {
148148
fn deny_wraparound_from_vec() {
149149
let five = vec![0; 5];
150150
let five_large = Array::from_shape_vec((3, 7, 29, 36760123, 823996703), five.clone());
151+
println!("{:?}", five_large);
151152
assert!(five_large.is_err());
152153
let six = Array::from_shape_vec(6, five.clone());
153154
assert!(six.is_err());

0 commit comments

Comments
 (0)