Skip to content

Commit ea6ee56

Browse files
: selection: bugfix reifying strided views
Summary: **fix two bugs** thanks go to pzhan9 for reporting these issues! (see D78298712) (1) one in `reify_view` where views with strides were reified incorrectly. the old code assumed views could always be expressed as contiguous `range(start..start + len)` expressions, even when the view's stride differed from the base. this produced incorrect selections when the view was strided relative to the base layout. the fix is to recognize when a view is layout-aligned; that is, when each of its strides is an integer multiple of the corresponding base stride, including the unitary case; and to reify such views using `Range(start, Some(end), step)` expressions that preserve both shape and layout. previously, only the unitary case (`step` = 1) was handled. this change extends support to non-unitary aligned views, such as `step` = 2, by correctly computing the step factor and corresponding end coordinate. if any dimension is not layout-aligned — that is, if `view_stride` % `base_stride` ≠ 0 — we conservatively fall back to enumerating all selected coordinates explicitly. (2) also fixes a bug in `Shape::select()` where the length of a strided range was computed using truncating division. this now correctly uses ceiling division to ensure all selected indices are included. Differential Revision: D78315005
1 parent 7a98d6e commit ea6ee56

File tree

3 files changed

+99
-8
lines changed

3 files changed

+99
-8
lines changed

ndslice/src/selection.rs

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,16 +1021,22 @@ impl ReifyView for Slice {
10211021
/// matches all coordinates in the given `view`, expressed in the
10221022
/// coordinate system of the provided `base` slice (`self`).
10231023
///
1024-
/// The resulting expression uses nested `range(start..end, ...)`
1025-
/// combinators to represent the rectangular region selected by
1026-
/// the view within the base slice.
1024+
/// The result is a nested sequence of `range(start..end, step)`
1025+
/// combinators that match the rectangular region covered by `view`
1026+
/// in base coordinates. This preserves geometry and layout when
1027+
/// `view` is *layout-aligned* — that is, each of its strides is
1028+
/// a multiple of the corresponding base stride.
10271029
///
1028-
/// Returns [`dsl::false_()`] for empty views.
1030+
/// If any dimension is not layout-aligned, the view is reified
1031+
/// by explicitly enumerating its coordinates.
1032+
///
1033+
/// Returns [`dsl::false_()`] if the view is empty.
10291034
///
10301035
/// # Errors
10311036
///
10321037
/// Returns an error if:
1033-
/// - The view lies outside the bounds of the base slice
1038+
/// - The base is not contiguous and row-major
1039+
/// - The view lies outside the bounds of the base
10341040
///
10351041
/// # Example
10361042
///
@@ -1043,6 +1049,11 @@ impl ReifyView for Slice {
10431049
/// let selection = base.reify_view(view).unwrap();
10441050
/// ```
10451051
fn reify_view(&self, view: &Slice) -> Result<Selection, SliceError> {
1052+
// Precondition: the base is contiguous and row major.
1053+
if !self.is_contiguous() {
1054+
return Err(SliceError::NonContiguous);
1055+
}
1056+
10461057
if view.is_empty() {
10471058
return Ok(dsl::false_());
10481059
}
@@ -1055,8 +1066,21 @@ impl ReifyView for Slice {
10551066

10561067
let origin = self.coordinates(view.offset())?;
10571068
let mut acc = dsl::true_();
1058-
for (&start, &len) in origin.iter().zip(view.sizes()).rev() {
1059-
acc = dsl::range(start..start + len, acc);
1069+
for ((&start, &len), (&view_stride, &base_stride)) in origin
1070+
.iter()
1071+
.zip(view.sizes())
1072+
.zip(view.strides().iter().zip(self.strides()))
1073+
.rev()
1074+
{
1075+
if view_stride % base_stride == 0 {
1076+
// Layout-aligned with base.
1077+
let step = view_stride / base_stride;
1078+
let end = start + step * len;
1079+
acc = dsl::range(crate::shape::Range(start, Some(end), step), acc);
1080+
} else {
1081+
// Irregular layout; fallback to explicit enumeration.
1082+
return Selection::of_ranks(self, &view.iter().collect::<BTreeSet<_>>());
1083+
}
10601084
}
10611085

10621086
Ok(acc)
@@ -1338,6 +1362,7 @@ mod tests {
13381362
use super::Selection;
13391363
use super::dsl::*;
13401364
use super::is_equivalent_true;
1365+
use crate::Range;
13411366
use crate::Slice;
13421367
use crate::assert_structurally_eq;
13431368
use crate::select;
@@ -2134,6 +2159,52 @@ mod tests {
21342159
);
21352160
}
21362161

2162+
#[test]
2163+
#[allow(clippy::identity_op)]
2164+
fn test_reify_view_1d_with_stride() {
2165+
let shape = shape!(x = 7); // 1D shape with 7 elements
2166+
let selected = shape.select("x", Range(0, None, 2)).unwrap();
2167+
let view = selected.slice();
2168+
assert_eq!(view, &Slice::new(0, vec![4], vec![1 * 2]).unwrap());
2169+
2170+
let base = shape.slice();
2171+
let selection = base.reify_view(view).unwrap();
2172+
let expected = range(Range(0, Some(8), 2), true_());
2173+
assert_structurally_eq!(&selection, expected);
2174+
2175+
let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2176+
assert_eq!(flat, vec![0, 2, 4, 6]);
2177+
}
2178+
2179+
#[test]
2180+
#[allow(clippy::identity_op)]
2181+
fn test_reify_view_2d_with_stride() {
2182+
// 4 x 4: x = 4, y = 4.
2183+
let base = shape!(x = 4, y = 4);
2184+
// Step 1: select odd rows (x = 1..4 step 2)
2185+
let shape = base.select("x", Range(1, Some(4), 2)).unwrap();
2186+
// Step 2: then select odd columns (y = 1..4 step 2)
2187+
let shape = shape.select("y", Range(1, Some(4), 2)).unwrap();
2188+
let view = shape.slice();
2189+
assert_eq!(
2190+
view,
2191+
&Slice::new(5, vec![2, 2], vec![4 * 2, 1 * 2]).unwrap()
2192+
);
2193+
2194+
let base = base.slice();
2195+
let selection = base.reify_view(view).unwrap();
2196+
// We use `end = start + step * len` to reify the selection.
2197+
// Note: This may yield `end > original_end` (e.g., 5 instead of 4)
2198+
// when the selection length was computed via ceiling division.
2199+
// This is safe: the resulting range will still select the correct
2200+
// indices (e.g., 1 and 3 for Range(1, Some(5), 2)).
2201+
let expected = range(Range(1, Some(5), 2), range(Range(1, Some(5), 2), true_()));
2202+
assert_structurally_eq!(&selection, expected);
2203+
2204+
let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2205+
assert_eq!(flat, vec![5, 7, 13, 15]);
2206+
}
2207+
21372208
#[test]
21382209
fn test_reify_view_selects_column_across_rows() {
21392210
let shape = shape!(host = 2, gpu = 4); // shape [2, 4]

ndslice/src/shape.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ impl Shape {
133133
}
134134

135135
offset += begin * strides[dim];
136-
sizes[dim] = (end - begin) / stride;
136+
// The # of elems in `begin..end` with step `stride`. This is
137+
// ⌈(end - begin) / stride⌉ — the number of stride steps that
138+
// fit in the half-open interval.
139+
sizes[dim] = (end - begin).div_ceil(stride);
137140
strides[dim] *= stride;
138141

139142
Ok(Self {
@@ -626,4 +629,18 @@ mod tests {
626629
.is_err()
627630
);
628631
}
632+
633+
#[test]
634+
fn test_shape_select_stride_rounding() {
635+
let shape = shape!(x = 10);
636+
// Select x = 0..10 step 3 → expect indices [0, 3, 6, 9]
637+
let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
638+
let slice = sub.slice();
639+
// 10 / 3 = 3.33..., so ceil(10 / 3) = 4
640+
assert_eq!(
641+
slice,
642+
&Slice::new(0, vec![4], vec![3]).unwrap(),
643+
"Expected offset 0, size 4, stride 3"
644+
);
645+
}
629646
}

ndslice/src/slice.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ pub enum SliceError {
3535

3636
#[error("incompatible view: {reason}")]
3737
IncompatibleView { reason: String },
38+
39+
#[error("noncontiguous shape")]
40+
NonContiguous,
3841
}
3942

4043
/// Slice is a compact representation of indices into the flat

0 commit comments

Comments
 (0)