diff --git a/ndslice/src/selection.rs b/ndslice/src/selection.rs index 3f895ec5..039aa31b 100644 --- a/ndslice/src/selection.rs +++ b/ndslice/src/selection.rs @@ -1021,16 +1021,22 @@ impl ReifyView for Slice { /// matches all coordinates in the given `view`, expressed in the /// coordinate system of the provided `base` slice (`self`). /// - /// The resulting expression uses nested `range(start..end, ...)` - /// combinators to represent the rectangular region selected by - /// the view within the base slice. + /// The result is a nested sequence of `range(start..end, step)` + /// combinators that match the rectangular region covered by `view` + /// in base coordinates. This preserves geometry and layout when + /// `view` is *layout-aligned* — that is, each of its strides is + /// a multiple of the corresponding base stride. /// - /// Returns [`dsl::false_()`] for empty views. + /// If any dimension is not layout-aligned, the view is reified + /// by explicitly enumerating its coordinates. + /// + /// Returns [`dsl::false_()`] if the view is empty. /// /// # Errors /// /// Returns an error if: - /// - The view lies outside the bounds of the base slice + /// - The base is not contiguous and row-major + /// - The view lies outside the bounds of the base /// /// # Example /// @@ -1043,6 +1049,11 @@ impl ReifyView for Slice { /// let selection = base.reify_view(view).unwrap(); /// ``` fn reify_view(&self, view: &Slice) -> Result { + // Precondition: the base is contiguous and row major. + if !self.is_contiguous() { + return Err(SliceError::NonContiguous); + } + if view.is_empty() { return Ok(dsl::false_()); } @@ -1055,8 +1066,21 @@ impl ReifyView for Slice { let origin = self.coordinates(view.offset())?; let mut acc = dsl::true_(); - for (&start, &len) in origin.iter().zip(view.sizes()).rev() { - acc = dsl::range(start..start + len, acc); + for ((&start, &len), (&view_stride, &base_stride)) in origin + .iter() + .zip(view.sizes()) + .zip(view.strides().iter().zip(self.strides())) + .rev() + { + if view_stride % base_stride == 0 { + // Layout-aligned with base. + let step = view_stride / base_stride; + let end = start + step * len; + acc = dsl::range(crate::shape::Range(start, Some(end), step), acc); + } else { + // Irregular layout; fallback to explicit enumeration. + return Selection::of_ranks(self, &view.iter().collect::>()); + } } Ok(acc) @@ -1338,6 +1362,7 @@ mod tests { use super::Selection; use super::dsl::*; use super::is_equivalent_true; + use crate::Range; use crate::Slice; use crate::assert_structurally_eq; use crate::select; @@ -2134,6 +2159,55 @@ mod tests { ); } + #[test] + #[allow(clippy::identity_op)] + fn test_reify_view_1d_with_stride() { + let shape = shape!(x = 7); // 1D shape with 7 elements + let selected = shape.select("x", Range(0, None, 2)).unwrap(); + let view = selected.slice(); + assert_eq!(view, &Slice::new(0, vec![4], vec![1 * 2]).unwrap()); + + let base = shape.slice(); + let selection = base.reify_view(view).unwrap(); + // Note: ceil(7 / 2) = 4, hence end = 0 + 2 × 4 = 8. See the + // more detailed explanation in + // `test_reify_view_2d_with_stride`. + let expected = range(Range(0, Some(8), 2), true_()); + assert_structurally_eq!(&selection, expected); + + let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect(); + assert_eq!(flat, vec![0, 2, 4, 6]); + } + + #[test] + #[allow(clippy::identity_op)] + fn test_reify_view_2d_with_stride() { + // 4 x 4: x = 4, y = 4. + let base = shape!(x = 4, y = 4); + // Step 1: select odd rows (x = 1..4 step 2) + let shape = base.select("x", Range(1, Some(4), 2)).unwrap(); + // Step 2: then select odd columns (y = 1..4 step 2) + let shape = shape.select("y", Range(1, Some(4), 2)).unwrap(); + let view = shape.slice(); + assert_eq!( + view, + &Slice::new(5, vec![2, 2], vec![4 * 2, 1 * 2]).unwrap() + ); + + let base = base.slice(); + let selection = base.reify_view(view).unwrap(); + // We use `end = start + step * len` to reify the selection. + // Note: This may yield `end > original_end` (e.g., 5 instead of 4) + // when the selection length was computed via ceiling division. + // This is safe: the resulting range will still select the correct + // indices (e.g., 1 and 3 for Range(1, Some(5), 2)). + let expected = range(Range(1, Some(5), 2), range(Range(1, Some(5), 2), true_())); + assert_structurally_eq!(&selection, expected); + + let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect(); + assert_eq!(flat, vec![5, 7, 13, 15]); + } + #[test] fn test_reify_view_selects_column_across_rows() { let shape = shape!(host = 2, gpu = 4); // shape [2, 4] diff --git a/ndslice/src/shape.rs b/ndslice/src/shape.rs index 38fcac86..c8cc44a9 100644 --- a/ndslice/src/shape.rs +++ b/ndslice/src/shape.rs @@ -80,37 +80,36 @@ impl Shape { Ok(Self { labels, slice }) } - /// Restrict this shape along a named dimension using a [`Range`]. The - /// provided range must be nonempty. - // - /// A shape defines a "strided view" where a strided view is a - /// triple (`offset, `sizes`, `strides`). Each coordinate maps to - /// a flat memory index using the formula: - /// ``` text - /// index = offset + ∑ i_k * strides[k] + /// Restrict this shape along a named dimension using a [`Range`]. + /// The provided range must be nonempty. + /// + /// A shape defines a **strided view**; a triple (`offset, + /// `sizes`, `strides`). Each coordinate maps to a flat memory + /// index using the formula: + /// ```text + /// index = offset + ∑ iₖ × strides[k] /// ``` - /// where `i_k` is the coordinate in dimension `k`. + /// where `iₖ` is the coordinate in dimension `k`. /// /// The `select(dim, range)` operation restricts the view to a /// subrange along a single dimension. It refines the shape by /// updating the `offset`, `sizes[dim]`, and `strides[dim]` to /// describe a logically reindexed subregion: - /// /// ```text - /// offset += begin x strides[dim] - /// sizes[dim] = (end - begin) / step - /// strides[dim] *= step + /// offset += begin × strides[dim] + /// sizes[dim] = ⎡(end - begin) / step⎤ + /// strides[dim] ×= step /// ``` /// /// This transformation preserves the strided layout and avoids /// copying data. After `select`, the view behaves as if indexing /// starts at zero in the selected dimension, with a new length - /// and stride. From the user's perspective, nothing changes — + /// and stride. From the user's perspective, nothing changes; /// indexing remains zero-based, and the resulting shape can be /// used like any other. The transformation is internal: the /// view's offset and stride absorb the selection logic. /// - /// `select` is composable — it can be applied repeatedly, even on + /// `select` is composable, it can be applied repeatedly, even on /// the same dimension, to refine the view incrementally. pub fn select>(&self, label: &str, range: R) -> Result { let dim = self.dim(label)?; @@ -133,7 +132,10 @@ impl Shape { } offset += begin * strides[dim]; - sizes[dim] = (end - begin) / stride; + // The # of elems in `begin..end` with step `stride`. This is + // ⌈(end - begin) / stride⌉ — the number of stride steps that + // fit in the half-open interval. + sizes[dim] = (end - begin).div_ceil(stride); strides[dim] *= stride; Ok(Self { @@ -626,4 +628,18 @@ mod tests { .is_err() ); } + + #[test] + fn test_shape_select_stride_rounding() { + let shape = shape!(x = 10); + // Select x = 0..10 step 3 → expect indices [0, 3, 6, 9] + let sub = shape.select("x", Range(0, Some(10), 3)).unwrap(); + let slice = sub.slice(); + // 10 / 3 = 3.33..., so ceil(10 / 3) = 4 + assert_eq!( + slice, + &Slice::new(0, vec![4], vec![3]).unwrap(), + "Expected offset 0, size 4, stride 3" + ); + } } diff --git a/ndslice/src/slice.rs b/ndslice/src/slice.rs index 27e742d0..e2da82f8 100644 --- a/ndslice/src/slice.rs +++ b/ndslice/src/slice.rs @@ -35,6 +35,9 @@ pub enum SliceError { #[error("incompatible view: {reason}")] IncompatibleView { reason: String }, + + #[error("noncontiguous shape")] + NonContiguous, } /// Slice is a compact representation of indices into the flat