diff --git a/ndslice/src/selection.rs b/ndslice/src/selection.rs index 3f895ec5..8737abbc 100644 --- a/ndslice/src/selection.rs +++ b/ndslice/src/selection.rs @@ -565,7 +565,20 @@ impl Selection { opts: &EvalOpts, slice: &'a Slice, ) -> Result + 'a>, ShapeError> { - Ok(Self::validate(self, opts, slice)?.eval_rec(slice, vec![0; slice.num_dim()], 0)) + // Canonically embed 0D as 1D (extent 1). + if slice.num_dim() == 0 { + let slice = Slice::new(slice.offset(), vec![1], vec![1]).unwrap(); + return Ok(Box::new( + self.validate(opts, &slice)? + .eval_rec(&slice, vec![0; 1], 0) + .collect::>() + .into_iter(), + )); + } + + Ok(self + .validate(opts, slice)? + .eval_rec(slice, vec![0; slice.num_dim()], 0)) } fn eval_rec<'a>( @@ -865,26 +878,28 @@ impl Selection { } } - // "Pads out" a selection so that if `Selection::True` appears before - // the final dimension, it becomes All(All(...(True))), enough to fill - // the remaining dimensions. - pub(crate) fn promote_terminal_true(self, dim: usize, max_dim: usize) -> Selection { + /// Pads out a terminal selection (e.g., `True`, `False`) with + /// `All(...)` to reach `max_dim` dimensions. + pub(crate) fn promote_terminal(self, dim: usize, max_dim: usize) -> Selection { use crate::selection::dsl::*; match self { Selection::True if dim < max_dim => all(true_()), - Selection::All(inner) => all(inner.promote_terminal_true(dim + 1, max_dim)), - Selection::Range(r, inner) => range(r, inner.promote_terminal_true(dim + 1, max_dim)), + Selection::False if dim < max_dim => all(false_()), + + Selection::All(inner) => all(inner.promote_terminal(dim + 1, max_dim)), + Selection::Range(r, inner) => range(r, inner.promote_terminal(dim + 1, max_dim)), Selection::Intersection(a, b) => intersection( - a.promote_terminal_true(dim, max_dim), - b.promote_terminal_true(dim, max_dim), + a.promote_terminal(dim, max_dim), + b.promote_terminal(dim, max_dim), ), Selection::Union(a, b) => union( - a.promote_terminal_true(dim, max_dim), - b.promote_terminal_true(dim, max_dim), + a.promote_terminal(dim, max_dim), + b.promote_terminal(dim, max_dim), ), - Selection::First(inner) => first(inner.promote_terminal_true(dim + 1, max_dim)), - Selection::Any(inner) => any(inner.promote_terminal_true(dim + 1, max_dim)), + Selection::First(inner) => first(inner.promote_terminal(dim + 1, max_dim)), + Selection::Any(inner) => any(inner.promote_terminal(dim + 1, max_dim)), + other => other, } } @@ -1831,6 +1846,17 @@ mod tests { assert_matches!(res.as_slice(), [i, j] if *i < *j && *i < 8 && *j < 8); } + #[test] + fn test_eval_zero_dim_slice() { + let slice_0d = Slice::new(1, vec![], vec![]).unwrap(); + assert_eq!(eval(true_(), &slice_0d), vec![1]); + assert_eq!(eval(false_(), &slice_0d), vec![]); + assert_eq!(eval(all(true_()), &slice_0d), vec![1]); + assert_eq!(eval(all(false_()), &slice_0d), vec![]); + assert_eq!(eval(union(true_(), true_()), &slice_0d), vec![1]); + assert_eq!(eval(intersection(true_(), false_()), &slice_0d), vec![]); + } + #[test] fn test_selection_10() { let slice = &test_slice(); diff --git a/ndslice/src/selection/routing.rs b/ndslice/src/selection/routing.rs index 7f7db729..0b5c64f7 100644 --- a/ndslice/src/selection/routing.rs +++ b/ndslice/src/selection/routing.rs @@ -378,10 +378,18 @@ impl RoutingFrame { _chooser: &mut dyn FnMut(&Choice) -> usize, f: &mut dyn FnMut(RoutingStep) -> ControlFlow<()>, ) -> ControlFlow<()> { + if self.slice.num_dim() == 0 { + // Canonically embed 0D as 1D (extent 1). + let embedded = Slice::new(self.slice.offset(), vec![1], vec![1]).unwrap(); + let mut this = self.clone(); + this.slice = Arc::new(embedded); + this.here = vec![0]; + return this.next_steps(_chooser, f); + } let selection = self .selection .clone() - .promote_terminal_true(self.dim, self.slice.num_dim()); + .promote_terminal(self.dim, self.slice.num_dim()); match &selection { Selection::True => ControlFlow::Continue(()), Selection::False => ControlFlow::Continue(()), @@ -1590,4 +1598,48 @@ mod tests { "Expected panic due to overdelivery, but no panic occurred" ); } + + #[test] + fn test_next_steps_zero_dim_slice() { + use std::ops::ControlFlow; + + use crate::selection::dsl::*; + + let slice = Slice::new(42, vec![], vec![]).unwrap(); + let selection = true_(); + let frame = RoutingFrame::root(selection, slice.clone()); + + let mut steps = vec![]; + let _ = frame.next_steps( + &mut |_| panic!("Unexpected Choice in 0D test"), + &mut |step| { + steps.push(step); + ControlFlow::Continue(()) + }, + ); + + assert_eq!(steps.len(), 1); + let step = steps[0].as_forward().unwrap(); + assert_eq!(step.here, vec![0]); + assert!(step.deliver_here()); + assert_eq!(step.slice.location(&step.here).unwrap(), 42); + + let selection = false_(); + let frame = RoutingFrame::root(selection, slice); + + let mut steps = vec![]; + let _ = frame.next_steps( + &mut |_| panic!("Unexpected Choice in 0D test"), + &mut |step| { + steps.push(step); + ControlFlow::Continue(()) + }, + ); + + assert_eq!(steps.len(), 1); + let step = steps[0].as_forward().unwrap(); + assert_eq!(step.here, vec![0]); + assert!(!step.deliver_here()); + assert_eq!(step.slice.location(&step.here).unwrap(), 42); + } }