@@ -272,6 +272,42 @@ impl RewriteRule for FlatteningRules {
272
272
}
273
273
}
274
274
275
+ /// A normalization rule that applies absorption laws for unions and
276
+ /// intersections.
277
+ ///
278
+ /// A union containing `True` always evaluates to `True`, and an
279
+ /// intersection containing `False` always evaluates to `False`.
280
+ #[ derive( Default ) ]
281
+ pub struct AbsorbtionRules ;
282
+
283
+ impl RewriteRule for AbsorbtionRules {
284
+ // Absorption rewrites:
285
+ //
286
+ // - Union(..., True, ...) → True
287
+ // - Intersection(..., False, ...) → False
288
+ fn rewrite ( & self , node : NormalizedSelection ) -> NormalizedSelection {
289
+ use NormalizedSelection :: * ;
290
+
291
+ match node {
292
+ Union ( set) => {
293
+ if set. contains ( & True ) {
294
+ True // Union(..., True, ...) → True
295
+ } else {
296
+ Union ( set)
297
+ }
298
+ }
299
+ Intersection ( set) => {
300
+ if set. contains ( & False ) {
301
+ False // Intersection(..., False, ...) → False
302
+ } else {
303
+ Intersection ( set)
304
+ }
305
+ }
306
+ other => other,
307
+ }
308
+ }
309
+ }
310
+
275
311
impl NormalizedSelection {
276
312
pub fn rewrite_bottom_up ( self , rule : & impl RewriteRule ) -> Self {
277
313
let mapped = self . trav ( |child| child. rewrite_bottom_up ( rule) ) ;
@@ -396,3 +432,31 @@ mod tests {
396
432
}
397
433
}
398
434
}
435
+
436
+ #[ test]
437
+ fn test_absorbtion_rules ( ) {
438
+ use NormalizedSelection :: * ;
439
+
440
+ // Union(True, Any(True)) should absorb to True
441
+ let union_case = {
442
+ let mut set = BTreeSet :: new ( ) ;
443
+ set. insert ( True ) ;
444
+ set. insert ( Any ( Box :: new ( True ) ) ) ;
445
+ Union ( set)
446
+ } ;
447
+
448
+ let rule = AbsorbtionRules ;
449
+ let result = rule. rewrite ( union_case) ;
450
+ assert_eq ! ( result, True ) ;
451
+
452
+ // Intersection(False, All(True)) should absorb to False
453
+ let intersection_case = {
454
+ let mut set = BTreeSet :: new ( ) ;
455
+ set. insert ( False ) ;
456
+ set. insert ( All ( Box :: new ( True ) ) ) ;
457
+ Intersection ( set)
458
+ } ;
459
+
460
+ let result = rule. rewrite ( intersection_case) ;
461
+ assert_eq ! ( result, False ) ;
462
+ }
0 commit comments