Skip to content

Commit b9167c3

Browse files
selection: normal: implement absorbtion (#389)
Summary: Pull Request resolved: #389 ghstack-source-id: 293656201 implement the absorbtion normalization rules Reviewed By: mariusae Differential Revision: D77597600 fbshipit-source-id: 5e20504f18017f9fdbd02a0dd2ff8d66593c28e8
1 parent 2960002 commit b9167c3

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

ndslice/src/selection/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,9 @@ pub fn structurally_equal(a: &Selection, b: &Selection) -> bool {
359359
/// structure. It is designed to improve over time as additional
360360
/// rewrites (e.g., flattening, simplification) are introduced.
361361
pub fn normalize(sel: &Selection) -> NormalizedSelection {
362-
let rule = normal::FlatteningRules.then(normal::IdentityRules);
362+
let rule = normal::FlatteningRules
363+
.then(normal::IdentityRules)
364+
.then(normal::AbsorbtionRules);
363365
sel.fold::<normal::NormalizedSelection>()
364366
.rewrite_bottom_up(&rule)
365367
}

ndslice/src/selection/normal.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,42 @@ impl RewriteRule for FlatteningRules {
272272
}
273273
}
274274

275+
/// A normalization rule that applies absorption laws for unions and
276+
/// intersections.
277+
///
278+
/// A union containing `True` always evaluates to `True`, and an
279+
/// intersection containing `False` always evaluates to `False`.
280+
#[derive(Default)]
281+
pub struct AbsorbtionRules;
282+
283+
impl RewriteRule for AbsorbtionRules {
284+
// Absorption rewrites:
285+
//
286+
// - Union(..., True, ...) → True
287+
// - Intersection(..., False, ...) → False
288+
fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
289+
use NormalizedSelection::*;
290+
291+
match node {
292+
Union(set) => {
293+
if set.contains(&True) {
294+
True // Union(..., True, ...) → True
295+
} else {
296+
Union(set)
297+
}
298+
}
299+
Intersection(set) => {
300+
if set.contains(&False) {
301+
False // Intersection(..., False, ...) → False
302+
} else {
303+
Intersection(set)
304+
}
305+
}
306+
other => other,
307+
}
308+
}
309+
}
310+
275311
impl NormalizedSelection {
276312
pub fn rewrite_bottom_up(self, rule: &impl RewriteRule) -> Self {
277313
let mapped = self.trav(|child| child.rewrite_bottom_up(rule));
@@ -396,3 +432,31 @@ mod tests {
396432
}
397433
}
398434
}
435+
436+
#[test]
437+
fn test_absorbtion_rules() {
438+
use NormalizedSelection::*;
439+
440+
// Union(True, Any(True)) should absorb to True
441+
let union_case = {
442+
let mut set = BTreeSet::new();
443+
set.insert(True);
444+
set.insert(Any(Box::new(True)));
445+
Union(set)
446+
};
447+
448+
let rule = AbsorbtionRules;
449+
let result = rule.rewrite(union_case);
450+
assert_eq!(result, True);
451+
452+
// Intersection(False, All(True)) should absorb to False
453+
let intersection_case = {
454+
let mut set = BTreeSet::new();
455+
set.insert(False);
456+
set.insert(All(Box::new(True)));
457+
Intersection(set)
458+
};
459+
460+
let result = rule.rewrite(intersection_case);
461+
assert_eq!(result, False);
462+
}

0 commit comments

Comments
 (0)