Skip to content

Commit 73daf9d

Browse files
: selection: routing: handle evals of selections over 0-dim slices (#507)
Summary: handle 0d slices by canonically embedding them as 1d slices of extent 1, enabling uniform evaluation and routing logic. adds test coverage for `eval` and `next_steps`. Differential Revision: D78168758
1 parent 21e99bc commit 73daf9d

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

ndslice/src/selection.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,20 @@ impl Selection {
565565
opts: &EvalOpts,
566566
slice: &'a Slice,
567567
) -> Result<Box<dyn Iterator<Item = usize> + 'a>, ShapeError> {
568-
Ok(Self::validate(self, opts, slice)?.eval_rec(slice, vec![0; slice.num_dim()], 0))
568+
// Canonically embed 0D as 1D (extent 1).
569+
if slice.num_dim() == 0 {
570+
let slice = Slice::new(slice.offset(), vec![1], vec![1]).unwrap();
571+
return Ok(Box::new(
572+
self.validate(opts, &slice)?
573+
.eval_rec(&slice, vec![0; 1], 0)
574+
.collect::<Vec<_>>()
575+
.into_iter(),
576+
));
577+
}
578+
579+
Ok(self
580+
.validate(opts, slice)?
581+
.eval_rec(slice, vec![0; slice.num_dim()], 0))
569582
}
570583

571584
fn eval_rec<'a>(
@@ -1831,6 +1844,17 @@ mod tests {
18311844
assert_matches!(res.as_slice(), [i, j] if *i < *j && *i < 8 && *j < 8);
18321845
}
18331846

1847+
#[test]
1848+
fn test_eval_zero_dim_slice() {
1849+
let slice_0d = Slice::new(1, vec![], vec![]).unwrap();
1850+
assert_eq!(eval(true_(), &slice_0d), vec![1]);
1851+
assert_eq!(eval(false_(), &slice_0d), vec![]);
1852+
assert_eq!(eval(all(true_()), &slice_0d), vec![1]);
1853+
assert_eq!(eval(all(false_()), &slice_0d), vec![]);
1854+
assert_eq!(eval(union(true_(), true_()), &slice_0d), vec![1]);
1855+
assert_eq!(eval(intersection(true_(), false_()), &slice_0d), vec![]);
1856+
}
1857+
18341858
#[test]
18351859
fn test_selection_10() {
18361860
let slice = &test_slice();

ndslice/src/selection/routing.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,14 @@ impl RoutingFrame {
378378
_chooser: &mut dyn FnMut(&Choice) -> usize,
379379
f: &mut dyn FnMut(RoutingStep) -> ControlFlow<()>,
380380
) -> ControlFlow<()> {
381+
if self.slice.num_dim() == 0 {
382+
// Canonically embed 0D as 1D (extent 1).
383+
let embedded = Slice::new(self.slice.offset(), vec![1], vec![1]).unwrap();
384+
let mut this = self.clone();
385+
this.slice = Arc::new(embedded);
386+
this.here = vec![0];
387+
return this.next_steps(_chooser, f);
388+
}
381389
let selection = self
382390
.selection
383391
.clone()
@@ -1590,4 +1598,48 @@ mod tests {
15901598
"Expected panic due to overdelivery, but no panic occurred"
15911599
);
15921600
}
1601+
1602+
#[test]
1603+
fn test_next_steps_zero_dim_slice() {
1604+
use std::ops::ControlFlow;
1605+
1606+
use crate::selection::dsl::*;
1607+
1608+
let slice = Slice::new(42, vec![], vec![]).unwrap();
1609+
let selection = true_();
1610+
let frame = RoutingFrame::root(selection, slice.clone());
1611+
1612+
let mut steps = vec![];
1613+
let _ = frame.next_steps(
1614+
&mut |_| panic!("Unexpected Choice in 0D test"),
1615+
&mut |step| {
1616+
steps.push(step);
1617+
ControlFlow::Continue(())
1618+
},
1619+
);
1620+
1621+
assert_eq!(steps.len(), 1);
1622+
let step = steps[0].as_forward().unwrap();
1623+
assert_eq!(step.here, vec![0]);
1624+
assert!(step.deliver_here());
1625+
assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1626+
1627+
let selection = all(false_());
1628+
let frame = RoutingFrame::root(selection, slice);
1629+
1630+
let mut steps = vec![];
1631+
let _ = frame.next_steps(
1632+
&mut |_| panic!("Unexpected Choice in 0D test"),
1633+
&mut |step| {
1634+
steps.push(step);
1635+
ControlFlow::Continue(())
1636+
},
1637+
);
1638+
1639+
assert_eq!(steps.len(), 1);
1640+
let step = steps[0].as_forward().unwrap();
1641+
assert_eq!(step.here, vec![0]);
1642+
assert!(!step.deliver_here());
1643+
assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1644+
}
15931645
}

0 commit comments

Comments
 (0)