Skip to content

Commit d327838

Browse files
selection: bugfix reifying strided views (#533)
Summary: Pull Request resolved: #533 **fix two bugs** thanks go to pzhan9 for reporting these issues! (see D78298712) (1) one in `reify_view` where views with strides were reified incorrectly. the old code assumed views could always be expressed as contiguous `range(start..start + len)` expressions, even when the view's stride differed from the base. this produced incorrect selections when the view was strided relative to the base layout. the fix is to recognize when a view is layout-aligned; that is, when each of its strides is an integer multiple of the corresponding base stride, including the unitary case; and to reify such views using `Range(start, Some(end), step)` expressions that preserve both shape and layout. previously, only the unitary case (`step` = 1) was handled. this change extends support to non-unitary aligned views, such as `step` = 2, by correctly computing the step factor and corresponding end coordinate. if any dimension is not layout-aligned — that is, if `view_stride` % `base_stride` ≠ 0 — we conservatively fall back to enumerating all selected coordinates explicitly. (2) also fixes a bug in `Shape::select()` where the length of a strided range was computed using truncating division (symptom: a strided `select()` like `0..7 step 2` gives sizes: `[3]`, but the selected indices are clearly `[0, 2, 4, 6]` i.e. 4 elements). this now correctly uses ceiling division to ensure all selected indices are included. Reviewed By: pzhan9 Differential Revision: D78315005 fbshipit-source-id: 2c5d9f944eda8f309e0068f538ee9796096a9e52
1 parent a7b299b commit d327838

File tree

3 files changed

+116
-23
lines changed

3 files changed

+116
-23
lines changed

ndslice/src/selection.rs

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,16 +1021,22 @@ impl ReifyView for Slice {
10211021
/// matches all coordinates in the given `view`, expressed in the
10221022
/// coordinate system of the provided `base` slice (`self`).
10231023
///
1024-
/// The resulting expression uses nested `range(start..end, ...)`
1025-
/// combinators to represent the rectangular region selected by
1026-
/// the view within the base slice.
1024+
/// The result is a nested sequence of `range(start..end, step)`
1025+
/// combinators that match the rectangular region covered by `view`
1026+
/// in base coordinates. This preserves geometry and layout when
1027+
/// `view` is *layout-aligned* — that is, each of its strides is
1028+
/// a multiple of the corresponding base stride.
10271029
///
1028-
/// Returns [`dsl::false_()`] for empty views.
1030+
/// If any dimension is not layout-aligned, the view is reified
1031+
/// by explicitly enumerating its coordinates.
1032+
///
1033+
/// Returns [`dsl::false_()`] if the view is empty.
10291034
///
10301035
/// # Errors
10311036
///
10321037
/// Returns an error if:
1033-
/// - The view lies outside the bounds of the base slice
1038+
/// - The base is not contiguous and row-major
1039+
/// - The view lies outside the bounds of the base
10341040
///
10351041
/// # Example
10361042
///
@@ -1043,6 +1049,11 @@ impl ReifyView for Slice {
10431049
/// let selection = base.reify_view(view).unwrap();
10441050
/// ```
10451051
fn reify_view(&self, view: &Slice) -> Result<Selection, SliceError> {
1052+
// Precondition: the base is contiguous and row major.
1053+
if !self.is_contiguous() {
1054+
return Err(SliceError::NonContiguous);
1055+
}
1056+
10461057
if view.is_empty() {
10471058
return Ok(dsl::false_());
10481059
}
@@ -1055,8 +1066,21 @@ impl ReifyView for Slice {
10551066

10561067
let origin = self.coordinates(view.offset())?;
10571068
let mut acc = dsl::true_();
1058-
for (&start, &len) in origin.iter().zip(view.sizes()).rev() {
1059-
acc = dsl::range(start..start + len, acc);
1069+
for ((&start, &len), (&view_stride, &base_stride)) in origin
1070+
.iter()
1071+
.zip(view.sizes())
1072+
.zip(view.strides().iter().zip(self.strides()))
1073+
.rev()
1074+
{
1075+
if view_stride % base_stride == 0 {
1076+
// Layout-aligned with base.
1077+
let step = view_stride / base_stride;
1078+
let end = start + step * len;
1079+
acc = dsl::range(crate::shape::Range(start, Some(end), step), acc);
1080+
} else {
1081+
// Irregular layout; fallback to explicit enumeration.
1082+
return Selection::of_ranks(self, &view.iter().collect::<BTreeSet<_>>());
1083+
}
10601084
}
10611085

10621086
Ok(acc)
@@ -1338,6 +1362,7 @@ mod tests {
13381362
use super::Selection;
13391363
use super::dsl::*;
13401364
use super::is_equivalent_true;
1365+
use crate::Range;
13411366
use crate::Slice;
13421367
use crate::assert_structurally_eq;
13431368
use crate::select;
@@ -2134,6 +2159,55 @@ mod tests {
21342159
);
21352160
}
21362161

2162+
#[test]
2163+
#[allow(clippy::identity_op)]
2164+
fn test_reify_view_1d_with_stride() {
2165+
let shape = shape!(x = 7); // 1D shape with 7 elements
2166+
let selected = shape.select("x", Range(0, None, 2)).unwrap();
2167+
let view = selected.slice();
2168+
assert_eq!(view, &Slice::new(0, vec![4], vec![1 * 2]).unwrap());
2169+
2170+
let base = shape.slice();
2171+
let selection = base.reify_view(view).unwrap();
2172+
// Note: ceil(7 / 2) = 4, hence end = 0 + 2 × 4 = 8. See the
2173+
// more detailed explanation in
2174+
// `test_reify_view_2d_with_stride`.
2175+
let expected = range(Range(0, Some(8), 2), true_());
2176+
assert_structurally_eq!(&selection, expected);
2177+
2178+
let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2179+
assert_eq!(flat, vec![0, 2, 4, 6]);
2180+
}
2181+
2182+
#[test]
2183+
#[allow(clippy::identity_op)]
2184+
fn test_reify_view_2d_with_stride() {
2185+
// 4 x 4: x = 4, y = 4.
2186+
let base = shape!(x = 4, y = 4);
2187+
// Step 1: select odd rows (x = 1..4 step 2)
2188+
let shape = base.select("x", Range(1, Some(4), 2)).unwrap();
2189+
// Step 2: then select odd columns (y = 1..4 step 2)
2190+
let shape = shape.select("y", Range(1, Some(4), 2)).unwrap();
2191+
let view = shape.slice();
2192+
assert_eq!(
2193+
view,
2194+
&Slice::new(5, vec![2, 2], vec![4 * 2, 1 * 2]).unwrap()
2195+
);
2196+
2197+
let base = base.slice();
2198+
let selection = base.reify_view(view).unwrap();
2199+
// We use `end = start + step * len` to reify the selection.
2200+
// Note: This may yield `end > original_end` (e.g., 5 instead of 4)
2201+
// when the selection length was computed via ceiling division.
2202+
// This is safe: the resulting range will still select the correct
2203+
// indices (e.g., 1 and 3 for Range(1, Some(5), 2)).
2204+
let expected = range(Range(1, Some(5), 2), range(Range(1, Some(5), 2), true_()));
2205+
assert_structurally_eq!(&selection, expected);
2206+
2207+
let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2208+
assert_eq!(flat, vec![5, 7, 13, 15]);
2209+
}
2210+
21372211
#[test]
21382212
fn test_reify_view_selects_column_across_rows() {
21392213
let shape = shape!(host = 2, gpu = 4); // shape [2, 4]

ndslice/src/shape.rs

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,37 +80,36 @@ impl Shape {
8080
Ok(Self { labels, slice })
8181
}
8282

83-
/// Restrict this shape along a named dimension using a [`Range`]. The
84-
/// provided range must be nonempty.
85-
//
86-
/// A shape defines a "strided view" where a strided view is a
87-
/// triple (`offset, `sizes`, `strides`). Each coordinate maps to
88-
/// a flat memory index using the formula:
89-
/// ``` text
90-
/// index = offset + ∑ i_k * strides[k]
83+
/// Restrict this shape along a named dimension using a [`Range`].
84+
/// The provided range must be nonempty.
85+
///
86+
/// A shape defines a **strided view**; a triple (`offset,
87+
/// `sizes`, `strides`). Each coordinate maps to a flat memory
88+
/// index using the formula:
89+
/// ```text
90+
/// index = offset + ∑ iₖ × strides[k]
9191
/// ```
92-
/// where `i_k` is the coordinate in dimension `k`.
92+
/// where `iₖ` is the coordinate in dimension `k`.
9393
///
9494
/// The `select(dim, range)` operation restricts the view to a
9595
/// subrange along a single dimension. It refines the shape by
9696
/// updating the `offset`, `sizes[dim]`, and `strides[dim]` to
9797
/// describe a logically reindexed subregion:
98-
///
9998
/// ```text
100-
/// offset += begin x strides[dim]
101-
/// sizes[dim] = (end - begin) / step
102-
/// strides[dim] *= step
99+
/// offset += begin × strides[dim]
100+
/// sizes[dim] = (end - begin) / step
101+
/// strides[dim] ×= step
103102
/// ```
104103
///
105104
/// This transformation preserves the strided layout and avoids
106105
/// copying data. After `select`, the view behaves as if indexing
107106
/// starts at zero in the selected dimension, with a new length
108-
/// and stride. From the user's perspective, nothing changes
107+
/// and stride. From the user's perspective, nothing changes;
109108
/// indexing remains zero-based, and the resulting shape can be
110109
/// used like any other. The transformation is internal: the
111110
/// view's offset and stride absorb the selection logic.
112111
///
113-
/// `select` is composable it can be applied repeatedly, even on
112+
/// `select` is composable, it can be applied repeatedly, even on
114113
/// the same dimension, to refine the view incrementally.
115114
pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
116115
let dim = self.dim(label)?;
@@ -133,7 +132,10 @@ impl Shape {
133132
}
134133

135134
offset += begin * strides[dim];
136-
sizes[dim] = (end - begin) / stride;
135+
// The # of elems in `begin..end` with step `stride`. This is
136+
// ⌈(end - begin) / stride⌉ — the number of stride steps that
137+
// fit in the half-open interval.
138+
sizes[dim] = (end - begin).div_ceil(stride);
137139
strides[dim] *= stride;
138140

139141
Ok(Self {
@@ -626,4 +628,18 @@ mod tests {
626628
.is_err()
627629
);
628630
}
631+
632+
#[test]
633+
fn test_shape_select_stride_rounding() {
634+
let shape = shape!(x = 10);
635+
// Select x = 0..10 step 3 → expect indices [0, 3, 6, 9]
636+
let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
637+
let slice = sub.slice();
638+
// 10 / 3 = 3.33..., so ceil(10 / 3) = 4
639+
assert_eq!(
640+
slice,
641+
&Slice::new(0, vec![4], vec![3]).unwrap(),
642+
"Expected offset 0, size 4, stride 3"
643+
);
644+
}
629645
}

ndslice/src/slice.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ pub enum SliceError {
3535

3636
#[error("incompatible view: {reason}")]
3737
IncompatibleView { reason: String },
38+
39+
#[error("noncontiguous shape")]
40+
NonContiguous,
3841
}
3942

4043
/// Slice is a compact representation of indices into the flat

0 commit comments

Comments
 (0)