@@ -28,6 +28,7 @@ use itertools::Itertools;
28
28
use serde:: { Deserialize , Serialize } ;
29
29
30
30
use crate :: error:: Result ;
31
+ use crate :: expr:: visitors:: bound_predicate_visitor:: visit as visit_bound;
31
32
use crate :: expr:: visitors:: predicate_visitor:: visit;
32
33
use crate :: expr:: visitors:: rewrite_not:: RewriteNotVisitor ;
33
34
use crate :: expr:: { Bind , BoundReference , PredicateOperator , Reference } ;
@@ -711,6 +712,63 @@ impl BoundPredicate {
711
712
pub ( crate ) fn and ( self , other : BoundPredicate ) -> BoundPredicate {
712
713
BoundPredicate :: And ( LogicalExpression :: new ( [ Box :: new ( self ) , Box :: new ( other) ] ) )
713
714
}
715
+
716
+ pub ( crate ) fn or ( self , other : BoundPredicate ) -> BoundPredicate {
717
+ BoundPredicate :: Or ( LogicalExpression :: new ( [ Box :: new ( self ) , Box :: new ( other) ] ) )
718
+ }
719
+
720
+ pub ( crate ) fn negate ( self ) -> BoundPredicate {
721
+ match self {
722
+ BoundPredicate :: AlwaysTrue => BoundPredicate :: AlwaysFalse ,
723
+ BoundPredicate :: AlwaysFalse => BoundPredicate :: AlwaysTrue ,
724
+ BoundPredicate :: And ( expr) => BoundPredicate :: Or ( LogicalExpression :: new (
725
+ expr. inputs . map ( |expr| Box :: new ( expr. negate ( ) ) ) ,
726
+ ) ) ,
727
+ BoundPredicate :: Or ( expr) => BoundPredicate :: And ( LogicalExpression :: new (
728
+ expr. inputs . map ( |expr| Box :: new ( expr. negate ( ) ) ) ,
729
+ ) ) ,
730
+ BoundPredicate :: Not ( expr) => {
731
+ let LogicalExpression { inputs : [ input_0] } = expr;
732
+ * input_0
733
+ }
734
+ BoundPredicate :: Unary ( expr) => {
735
+ BoundPredicate :: Unary ( UnaryExpression :: new ( expr. op . negate ( ) , expr. term ) )
736
+ }
737
+ BoundPredicate :: Binary ( expr) => BoundPredicate :: Binary ( BinaryExpression :: new (
738
+ expr. op . negate ( ) ,
739
+ expr. term ,
740
+ expr. literal ,
741
+ ) ) ,
742
+ BoundPredicate :: Set ( expr) => BoundPredicate :: Set ( SetExpression :: new (
743
+ expr. op . negate ( ) ,
744
+ expr. term ,
745
+ expr. literals ,
746
+ ) ) ,
747
+ }
748
+ }
749
+
750
+ /// Simplifies the expression by removing `NOT` predicates,
751
+ /// directly negating the inner expressions instead. The transformation
752
+ /// applies logical laws (such as De Morgan's laws) to
753
+ /// recursively negate and simplify inner expressions within `NOT`
754
+ /// predicates.
755
+ ///
756
+ /// # Example
757
+ ///
758
+ /// ```rust
759
+ /// use std::ops::Not;
760
+ ///
761
+ /// use iceberg::expr::{Bind, BoundPredicate, Reference};
762
+ /// use iceberg::spec::Datum;
763
+ ///
764
+ /// // This would need to be bound first, but the concept is:
765
+ /// // let expression = bound_predicate.not();
766
+ /// // let result = expression.rewrite_not();
767
+ /// ```
768
+ pub fn rewrite_not ( self ) -> BoundPredicate {
769
+ visit_bound ( & mut RewriteNotVisitor :: new ( ) , & self )
770
+ . expect ( "RewriteNotVisitor guarantees always success" )
771
+ }
714
772
}
715
773
716
774
impl Display for BoundPredicate {
@@ -1447,4 +1505,186 @@ mod tests {
1447
1505
assert_eq ! ( & format!( "{bound_expr}" ) , r#"True"# ) ;
1448
1506
test_bound_predicate_serialize_diserialize ( bound_expr) ;
1449
1507
}
1508
+
1509
+ #[ test]
1510
+ fn test_bound_predicate_rewrite_not_binary ( ) {
1511
+ let schema = table_schema_simple ( ) ;
1512
+
1513
+ // Test NOT elimination on binary predicates: NOT(bar < 10) => bar >= 10
1514
+ let predicate = Reference :: new ( "bar" ) . less_than ( Datum :: int ( 10 ) ) . not ( ) ;
1515
+ let bound_predicate = predicate. bind ( schema. clone ( ) , true ) . unwrap ( ) ;
1516
+ let result = bound_predicate. rewrite_not ( ) ;
1517
+
1518
+ // The result should be bar >= 10
1519
+ let expected_predicate = Reference :: new ( "bar" ) . greater_than_or_equal_to ( Datum :: int ( 10 ) ) ;
1520
+ let expected_bound = expected_predicate. bind ( schema, true ) . unwrap ( ) ;
1521
+
1522
+ assert_eq ! ( result, expected_bound) ;
1523
+ assert_eq ! ( & format!( "{result}" ) , "bar >= 10" ) ;
1524
+ }
1525
+
1526
+ #[ test]
1527
+ fn test_bound_predicate_rewrite_not_unary ( ) {
1528
+ let schema = table_schema_simple ( ) ;
1529
+
1530
+ // Test NOT elimination on unary predicates: NOT(foo IS NULL) => foo IS NOT NULL
1531
+ let predicate = Reference :: new ( "foo" ) . is_null ( ) . not ( ) ;
1532
+ let bound_predicate = predicate. bind ( schema. clone ( ) , true ) . unwrap ( ) ;
1533
+ let result = bound_predicate. rewrite_not ( ) ;
1534
+
1535
+ // The result should be foo IS NOT NULL
1536
+ let expected_predicate = Reference :: new ( "foo" ) . is_not_null ( ) ;
1537
+ let expected_bound = expected_predicate. bind ( schema, true ) . unwrap ( ) ;
1538
+
1539
+ assert_eq ! ( result, expected_bound) ;
1540
+ assert_eq ! ( & format!( "{result}" ) , "foo IS NOT NULL" ) ;
1541
+ }
1542
+
1543
+ #[ test]
1544
+ fn test_bound_predicate_rewrite_not_set ( ) {
1545
+ let schema = table_schema_simple ( ) ;
1546
+
1547
+ // Test NOT elimination on set predicates: NOT(bar IN (10, 20)) => bar NOT IN (10, 20)
1548
+ let predicate = Reference :: new ( "bar" )
1549
+ . is_in ( [ Datum :: int ( 10 ) , Datum :: int ( 20 ) ] )
1550
+ . not ( ) ;
1551
+ let bound_predicate = predicate. bind ( schema. clone ( ) , true ) . unwrap ( ) ;
1552
+ let result = bound_predicate. rewrite_not ( ) ;
1553
+
1554
+ // The result should be bar NOT IN (10, 20)
1555
+ let expected_predicate = Reference :: new ( "bar" ) . is_not_in ( [ Datum :: int ( 10 ) , Datum :: int ( 20 ) ] ) ;
1556
+ let expected_bound = expected_predicate. bind ( schema, true ) . unwrap ( ) ;
1557
+
1558
+ assert_eq ! ( result, expected_bound) ;
1559
+ // Note: HashSet order may vary, so we check that it contains the expected format
1560
+ let result_str = format ! ( "{result}" ) ;
1561
+ assert ! (
1562
+ result_str. contains( "bar NOT IN" )
1563
+ && result_str. contains( "10" )
1564
+ && result_str. contains( "20" )
1565
+ ) ;
1566
+ }
1567
+
1568
+ #[ test]
1569
+ fn test_bound_predicate_rewrite_not_and_demorgan ( ) {
1570
+ let schema = table_schema_simple ( ) ;
1571
+
1572
+ // Test De Morgan's law: NOT(A AND B) = (NOT A) OR (NOT B)
1573
+ // NOT((bar < 10) AND (foo IS NULL)) => (bar >= 10) OR (foo IS NOT NULL)
1574
+ let predicate = Reference :: new ( "bar" )
1575
+ . less_than ( Datum :: int ( 10 ) )
1576
+ . and ( Reference :: new ( "foo" ) . is_null ( ) )
1577
+ . not ( ) ;
1578
+
1579
+ let bound_predicate = predicate. bind ( schema. clone ( ) , true ) . unwrap ( ) ;
1580
+ let result = bound_predicate. rewrite_not ( ) ;
1581
+
1582
+ // Expected: (bar >= 10) OR (foo IS NOT NULL)
1583
+ let expected_predicate = Reference :: new ( "bar" )
1584
+ . greater_than_or_equal_to ( Datum :: int ( 10 ) )
1585
+ . or ( Reference :: new ( "foo" ) . is_not_null ( ) ) ;
1586
+
1587
+ let expected_bound = expected_predicate. bind ( schema, true ) . unwrap ( ) ;
1588
+
1589
+ assert_eq ! ( result, expected_bound) ;
1590
+ assert_eq ! ( & format!( "{result}" ) , "(bar >= 10) OR (foo IS NOT NULL)" ) ;
1591
+ }
1592
+
1593
+ #[ test]
1594
+ fn test_bound_predicate_rewrite_not_or_demorgan ( ) {
1595
+ let schema = table_schema_simple ( ) ;
1596
+
1597
+ // Test De Morgan's law: NOT(A OR B) = (NOT A) AND (NOT B)
1598
+ // NOT((bar < 10) OR (foo IS NULL)) => (bar >= 10) AND (foo IS NOT NULL)
1599
+ let predicate = Reference :: new ( "bar" )
1600
+ . less_than ( Datum :: int ( 10 ) )
1601
+ . or ( Reference :: new ( "foo" ) . is_null ( ) )
1602
+ . not ( ) ;
1603
+
1604
+ let bound_predicate = predicate. bind ( schema. clone ( ) , true ) . unwrap ( ) ;
1605
+ let result = bound_predicate. rewrite_not ( ) ;
1606
+
1607
+ // Expected: (bar >= 10) AND (foo IS NOT NULL)
1608
+ let expected_predicate = Reference :: new ( "bar" )
1609
+ . greater_than_or_equal_to ( Datum :: int ( 10 ) )
1610
+ . and ( Reference :: new ( "foo" ) . is_not_null ( ) ) ;
1611
+
1612
+ let expected_bound = expected_predicate. bind ( schema, true ) . unwrap ( ) ;
1613
+
1614
+ assert_eq ! ( result, expected_bound) ;
1615
+ assert_eq ! ( & format!( "{result}" ) , "(bar >= 10) AND (foo IS NOT NULL)" ) ;
1616
+ }
1617
+
1618
+ #[ test]
1619
+ fn test_bound_predicate_rewrite_not_double_negative ( ) {
1620
+ let schema = table_schema_simple ( ) ;
1621
+
1622
+ // Test double negative elimination: NOT(NOT(bar < 10)) => bar < 10
1623
+ let predicate = Reference :: new ( "bar" ) . less_than ( Datum :: int ( 10 ) ) . not ( ) . not ( ) ;
1624
+ let bound_predicate = predicate. bind ( schema. clone ( ) , true ) . unwrap ( ) ;
1625
+ let result = bound_predicate. rewrite_not ( ) ;
1626
+
1627
+ // The result should be bar < 10 (original predicate)
1628
+ let expected_predicate = Reference :: new ( "bar" ) . less_than ( Datum :: int ( 10 ) ) ;
1629
+ let expected_bound = expected_predicate. bind ( schema, true ) . unwrap ( ) ;
1630
+
1631
+ assert_eq ! ( result, expected_bound) ;
1632
+ assert_eq ! ( & format!( "{result}" ) , "bar < 10" ) ;
1633
+ }
1634
+
1635
+ #[ test]
1636
+ fn test_bound_predicate_rewrite_not_always_true_false ( ) {
1637
+ let schema = table_schema_simple ( ) ;
1638
+
1639
+ // Test NOT(AlwaysTrue) => AlwaysFalse
1640
+ let predicate = Reference :: new ( "bar" ) . is_not_null ( ) . not ( ) ; // This becomes NOT(AlwaysTrue) since bar is required
1641
+ let bound_predicate = predicate. bind ( schema. clone ( ) , true ) . unwrap ( ) ;
1642
+ let result = bound_predicate. rewrite_not ( ) ;
1643
+
1644
+ assert_eq ! ( result, BoundPredicate :: AlwaysFalse ) ;
1645
+ assert_eq ! ( & format!( "{result}" ) , "False" ) ;
1646
+
1647
+ // Test NOT(AlwaysFalse) => AlwaysTrue
1648
+ let predicate2 = Reference :: new ( "bar" ) . is_null ( ) . not ( ) ; // This becomes NOT(AlwaysFalse) since bar is required
1649
+ let bound_predicate2 = predicate2. bind ( schema, true ) . unwrap ( ) ;
1650
+ let result2 = bound_predicate2. rewrite_not ( ) ;
1651
+
1652
+ assert_eq ! ( result2, BoundPredicate :: AlwaysTrue ) ;
1653
+ assert_eq ! ( & format!( "{result2}" ) , "True" ) ;
1654
+ }
1655
+
1656
+ #[ test]
1657
+ fn test_bound_predicate_rewrite_not_complex_nested ( ) {
1658
+ let schema = table_schema_simple ( ) ;
1659
+
1660
+ // Test complex nested expression:
1661
+ // NOT(NOT((bar >= 10) AND (foo IS NOT NULL)) OR (bar < 5))
1662
+ // Should become: ((bar >= 10) AND (foo IS NOT NULL)) AND (bar >= 5)
1663
+ let inner_predicate = Reference :: new ( "bar" )
1664
+ . greater_than_or_equal_to ( Datum :: int ( 10 ) )
1665
+ . and ( Reference :: new ( "foo" ) . is_not_null ( ) )
1666
+ . not ( ) ;
1667
+
1668
+ let complex_predicate = inner_predicate
1669
+ . or ( Reference :: new ( "bar" ) . less_than ( Datum :: int ( 5 ) ) )
1670
+ . not ( ) ;
1671
+
1672
+ let bound_predicate = complex_predicate. bind ( schema. clone ( ) , true ) . unwrap ( ) ;
1673
+ let result = bound_predicate. rewrite_not ( ) ;
1674
+
1675
+ // Expected: ((bar >= 10) AND (foo IS NOT NULL)) AND (bar >= 5)
1676
+ // This is because NOT(NOT(A) OR B) = A AND NOT(B)
1677
+ let expected_predicate = Reference :: new ( "bar" )
1678
+ . greater_than_or_equal_to ( Datum :: int ( 10 ) )
1679
+ . and ( Reference :: new ( "foo" ) . is_not_null ( ) )
1680
+ . and ( Reference :: new ( "bar" ) . greater_than_or_equal_to ( Datum :: int ( 5 ) ) ) ;
1681
+
1682
+ let expected_bound = expected_predicate. bind ( schema, true ) . unwrap ( ) ;
1683
+
1684
+ assert_eq ! ( result, expected_bound) ;
1685
+ assert_eq ! (
1686
+ & format!( "{result}" ) ,
1687
+ "((bar >= 10) AND (foo IS NOT NULL)) AND (bar >= 5)"
1688
+ ) ;
1689
+ }
1450
1690
}
0 commit comments