Skip to content

: selection: bugfix reifying strided views #533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 81 additions & 7 deletions ndslice/src/selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1021,16 +1021,22 @@ impl ReifyView for Slice {
/// matches all coordinates in the given `view`, expressed in the
/// coordinate system of the provided `base` slice (`self`).
///
/// The resulting expression uses nested `range(start..end, ...)`
/// combinators to represent the rectangular region selected by
/// the view within the base slice.
/// The result is a nested sequence of `range(start..end, step)`
/// combinators that match the rectangular region covered by `view`
/// in base coordinates. This preserves geometry and layout when
/// `view` is *layout-aligned* — that is, each of its strides is
/// a multiple of the corresponding base stride.
///
/// Returns [`dsl::false_()`] for empty views.
/// If any dimension is not layout-aligned, the view is reified
/// by explicitly enumerating its coordinates.
///
/// Returns [`dsl::false_()`] if the view is empty.
///
/// # Errors
///
/// Returns an error if:
/// - The view lies outside the bounds of the base slice
/// - The base is not contiguous and row-major
/// - The view lies outside the bounds of the base
///
/// # Example
///
Expand All @@ -1043,6 +1049,11 @@ impl ReifyView for Slice {
/// let selection = base.reify_view(view).unwrap();
/// ```
fn reify_view(&self, view: &Slice) -> Result<Selection, SliceError> {
// Precondition: the base is contiguous and row major.
if !self.is_contiguous() {
return Err(SliceError::NonContiguous);
}

if view.is_empty() {
return Ok(dsl::false_());
}
Expand All @@ -1055,8 +1066,21 @@ impl ReifyView for Slice {

let origin = self.coordinates(view.offset())?;
let mut acc = dsl::true_();
for (&start, &len) in origin.iter().zip(view.sizes()).rev() {
acc = dsl::range(start..start + len, acc);
for ((&start, &len), (&view_stride, &base_stride)) in origin
.iter()
.zip(view.sizes())
.zip(view.strides().iter().zip(self.strides()))
.rev()
{
if view_stride % base_stride == 0 {
// Layout-aligned with base.
let step = view_stride / base_stride;
let end = start + step * len;
acc = dsl::range(crate::shape::Range(start, Some(end), step), acc);
} else {
// Irregular layout; fallback to explicit enumeration.
return Selection::of_ranks(self, &view.iter().collect::<BTreeSet<_>>());
}
}

Ok(acc)
Expand Down Expand Up @@ -1338,6 +1362,7 @@ mod tests {
use super::Selection;
use super::dsl::*;
use super::is_equivalent_true;
use crate::Range;
use crate::Slice;
use crate::assert_structurally_eq;
use crate::select;
Expand Down Expand Up @@ -2134,6 +2159,55 @@ mod tests {
);
}

#[test]
#[allow(clippy::identity_op)]
fn test_reify_view_1d_with_stride() {
let shape = shape!(x = 7); // 1D shape with 7 elements
let selected = shape.select("x", Range(0, None, 2)).unwrap();
let view = selected.slice();
assert_eq!(view, &Slice::new(0, vec![4], vec![1 * 2]).unwrap());

let base = shape.slice();
let selection = base.reify_view(view).unwrap();
// Note: ceil(7 / 2) = 4, hence end = 0 + 2 × 4 = 8. See the
// more detailed explanation in
// `test_reify_view_2d_with_stride`.
let expected = range(Range(0, Some(8), 2), true_());
assert_structurally_eq!(&selection, expected);

let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
assert_eq!(flat, vec![0, 2, 4, 6]);
}

#[test]
#[allow(clippy::identity_op)]
fn test_reify_view_2d_with_stride() {
// 4 x 4: x = 4, y = 4.
let base = shape!(x = 4, y = 4);
// Step 1: select odd rows (x = 1..4 step 2)
let shape = base.select("x", Range(1, Some(4), 2)).unwrap();
// Step 2: then select odd columns (y = 1..4 step 2)
let shape = shape.select("y", Range(1, Some(4), 2)).unwrap();
let view = shape.slice();
assert_eq!(
view,
&Slice::new(5, vec![2, 2], vec![4 * 2, 1 * 2]).unwrap()
);

let base = base.slice();
let selection = base.reify_view(view).unwrap();
// We use `end = start + step * len` to reify the selection.
// Note: This may yield `end > original_end` (e.g., 5 instead of 4)
// when the selection length was computed via ceiling division.
// This is safe: the resulting range will still select the correct
// indices (e.g., 1 and 3 for Range(1, Some(5), 2)).
let expected = range(Range(1, Some(5), 2), range(Range(1, Some(5), 2), true_()));
assert_structurally_eq!(&selection, expected);

let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
assert_eq!(flat, vec![5, 7, 13, 15]);
}

#[test]
fn test_reify_view_selects_column_across_rows() {
let shape = shape!(host = 2, gpu = 4); // shape [2, 4]
Expand Down
48 changes: 32 additions & 16 deletions ndslice/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,37 +80,36 @@ impl Shape {
Ok(Self { labels, slice })
}

/// Restrict this shape along a named dimension using a [`Range`]. The
/// provided range must be nonempty.
//
/// A shape defines a "strided view" where a strided view is a
/// triple (`offset, `sizes`, `strides`). Each coordinate maps to
/// a flat memory index using the formula:
/// ``` text
/// index = offset + ∑ i_k * strides[k]
/// Restrict this shape along a named dimension using a [`Range`].
/// The provided range must be nonempty.
///
/// A shape defines a **strided view**; a triple (`offset,
/// `sizes`, `strides`). Each coordinate maps to a flat memory
/// index using the formula:
/// ```text
/// index = offset + ∑ iₖ × strides[k]
/// ```
/// where `i_k` is the coordinate in dimension `k`.
/// where `iₖ` is the coordinate in dimension `k`.
///
/// The `select(dim, range)` operation restricts the view to a
/// subrange along a single dimension. It refines the shape by
/// updating the `offset`, `sizes[dim]`, and `strides[dim]` to
/// describe a logically reindexed subregion:
///
/// ```text
/// offset += begin x strides[dim]
/// sizes[dim] = (end - begin) / step
/// strides[dim] *= step
/// offset += begin × strides[dim]
/// sizes[dim] = (end - begin) / step
/// strides[dim] ×= step
/// ```
///
/// This transformation preserves the strided layout and avoids
/// copying data. After `select`, the view behaves as if indexing
/// starts at zero in the selected dimension, with a new length
/// and stride. From the user's perspective, nothing changes
/// and stride. From the user's perspective, nothing changes;
/// indexing remains zero-based, and the resulting shape can be
/// used like any other. The transformation is internal: the
/// view's offset and stride absorb the selection logic.
///
/// `select` is composable it can be applied repeatedly, even on
/// `select` is composable, it can be applied repeatedly, even on
/// the same dimension, to refine the view incrementally.
pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
let dim = self.dim(label)?;
Expand All @@ -133,7 +132,10 @@ impl Shape {
}

offset += begin * strides[dim];
sizes[dim] = (end - begin) / stride;
// The # of elems in `begin..end` with step `stride`. This is
// ⌈(end - begin) / stride⌉ — the number of stride steps that
// fit in the half-open interval.
sizes[dim] = (end - begin).div_ceil(stride);
strides[dim] *= stride;

Ok(Self {
Expand Down Expand Up @@ -626,4 +628,18 @@ mod tests {
.is_err()
);
}

#[test]
fn test_shape_select_stride_rounding() {
let shape = shape!(x = 10);
// Select x = 0..10 step 3 → expect indices [0, 3, 6, 9]
let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
let slice = sub.slice();
// 10 / 3 = 3.33..., so ceil(10 / 3) = 4
assert_eq!(
slice,
&Slice::new(0, vec![4], vec![3]).unwrap(),
"Expected offset 0, size 4, stride 3"
);
}
}
3 changes: 3 additions & 0 deletions ndslice/src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ pub enum SliceError {

#[error("incompatible view: {reason}")]
IncompatibleView { reason: String },

#[error("noncontiguous shape")]
NonContiguous,
}

/// Slice is a compact representation of indices into the flat
Expand Down
Loading