Skip to content

Commit 0fc0b0f

Browse files
authored
feat: add UnboundPredicate::negate() (#228)
Issue: #150
1 parent b248fd6 commit 0fc0b0f

File tree

5 files changed

+284
-20
lines changed

5 files changed

+284
-20
lines changed

crates/iceberg/src/expr/mod.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub use predicate::*;
3030
/// The discriminant of this enum is used for determining the type of the operator, see
3131
/// [`PredicateOperator::is_unary`], [`PredicateOperator::is_binary`], [`PredicateOperator::is_set`]
3232
#[allow(missing_docs)]
33-
#[derive(Debug, Clone, Copy)]
33+
#[derive(Debug, Clone, Copy, PartialEq)]
3434
#[repr(u16)]
3535
pub enum PredicateOperator {
3636
// Unary operators
@@ -112,6 +112,39 @@ impl PredicateOperator {
112112
pub fn is_set(self) -> bool {
113113
(self as u16) > (PredicateOperator::NotStartsWith as u16)
114114
}
115+
116+
/// Returns the predicate that is the inverse of self
117+
///
118+
/// # Example
119+
///
120+
/// ```rust
121+
/// use iceberg::expr::PredicateOperator;
122+
/// assert!(PredicateOperator::IsNull.negate() == PredicateOperator::NotNull);
123+
/// assert!(PredicateOperator::IsNan.negate() == PredicateOperator::NotNan);
124+
/// assert!(PredicateOperator::LessThan.negate() == PredicateOperator::GreaterThanOrEq);
125+
/// assert!(PredicateOperator::GreaterThan.negate() == PredicateOperator::LessThanOrEq);
126+
/// assert!(PredicateOperator::Eq.negate() == PredicateOperator::NotEq);
127+
/// assert!(PredicateOperator::In.negate() == PredicateOperator::NotIn);
128+
/// assert!(PredicateOperator::StartsWith.negate() == PredicateOperator::NotStartsWith);
129+
/// ```
130+
pub fn negate(self) -> PredicateOperator {
131+
match self {
132+
PredicateOperator::IsNull => PredicateOperator::NotNull,
133+
PredicateOperator::NotNull => PredicateOperator::IsNull,
134+
PredicateOperator::IsNan => PredicateOperator::NotNan,
135+
PredicateOperator::NotNan => PredicateOperator::IsNan,
136+
PredicateOperator::LessThan => PredicateOperator::GreaterThanOrEq,
137+
PredicateOperator::LessThanOrEq => PredicateOperator::GreaterThan,
138+
PredicateOperator::GreaterThan => PredicateOperator::LessThanOrEq,
139+
PredicateOperator::GreaterThanOrEq => PredicateOperator::LessThan,
140+
PredicateOperator::Eq => PredicateOperator::NotEq,
141+
PredicateOperator::NotEq => PredicateOperator::Eq,
142+
PredicateOperator::In => PredicateOperator::NotIn,
143+
PredicateOperator::NotIn => PredicateOperator::In,
144+
PredicateOperator::StartsWith => PredicateOperator::NotStartsWith,
145+
PredicateOperator::NotStartsWith => PredicateOperator::StartsWith,
146+
}
147+
}
115148
}
116149

117150
#[cfg(test)]

crates/iceberg/src/expr/predicate.rs

Lines changed: 157 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
2222
use crate::expr::{BoundReference, PredicateOperator, Reference};
2323
use crate::spec::Datum;
24+
use itertools::Itertools;
2425
use std::collections::HashSet;
2526
use std::fmt::{Debug, Display, Formatter};
2627
use std::ops::Not;
2728

2829
/// Logical expression, such as `AND`, `OR`, `NOT`.
30+
#[derive(PartialEq)]
2931
pub struct LogicalExpression<T, const N: usize> {
3032
inputs: [Box<T>; N],
3133
}
@@ -54,6 +56,7 @@ impl<T, const N: usize> LogicalExpression<T, N> {
5456
}
5557

5658
/// Unary predicate, for example, `a IS NULL`.
59+
#[derive(PartialEq)]
5760
pub struct UnaryExpression<T> {
5861
/// Operator of this predicate, must be single operand operator.
5962
op: PredicateOperator,
@@ -84,6 +87,7 @@ impl<T> UnaryExpression<T> {
8487
}
8588

8689
/// Binary predicate, for example, `a > 10`.
90+
#[derive(PartialEq)]
8791
pub struct BinaryExpression<T> {
8892
/// Operator of this predicate, must be binary operator, such as `=`, `>`, `<`, etc.
8993
op: PredicateOperator,
@@ -117,6 +121,7 @@ impl<T: Display> Display for BinaryExpression<T> {
117121
}
118122

119123
/// Set predicates, for example, `a in (1, 2, 3)`.
124+
#[derive(PartialEq)]
120125
pub struct SetExpression<T> {
121126
/// Operator of this predicate, must be set operator, such as `IN`, `NOT IN`, etc.
122127
op: PredicateOperator,
@@ -136,8 +141,22 @@ impl<T: Debug> Debug for SetExpression<T> {
136141
}
137142
}
138143

144+
impl<T: Debug> SetExpression<T> {
145+
pub(crate) fn new(op: PredicateOperator, term: T, literals: HashSet<Datum>) -> Self {
146+
Self { op, term, literals }
147+
}
148+
}
149+
150+
impl<T: Display + Debug> Display for SetExpression<T> {
151+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
152+
let mut literal_strs = self.literals.iter().map(|l| format!("{}", l));
153+
154+
write!(f, "{} {} ({})", self.term, self.op, literal_strs.join(", "))
155+
}
156+
}
157+
139158
/// Unbound predicate expression before binding to a schema.
140-
#[derive(Debug)]
159+
#[derive(Debug, PartialEq)]
141160
pub enum Predicate {
142161
/// And predicate, for example, `a > 10 AND b < 20`.
143162
And(LogicalExpression<Predicate, 2>),
@@ -166,23 +185,13 @@ impl Display for Predicate {
166185
write!(f, "NOT ({})", expr.inputs()[0])
167186
}
168187
Predicate::Unary(expr) => {
169-
write!(f, "{}", expr.term)
188+
write!(f, "{}", expr)
170189
}
171190
Predicate::Binary(expr) => {
172-
write!(f, "{} {} {}", expr.term, expr.op, expr.literal)
191+
write!(f, "{}", expr)
173192
}
174193
Predicate::Set(expr) => {
175-
write!(
176-
f,
177-
"{} {} ({})",
178-
expr.term,
179-
expr.op,
180-
expr.literals
181-
.iter()
182-
.map(|l| format!("{:?}", l))
183-
.collect::<Vec<String>>()
184-
.join(", ")
185-
)
194+
write!(f, "{}", expr)
186195
}
187196
}
188197
}
@@ -230,6 +239,54 @@ impl Predicate {
230239
pub fn or(self, other: Predicate) -> Predicate {
231240
Predicate::Or(LogicalExpression::new([Box::new(self), Box::new(other)]))
232241
}
242+
243+
/// Returns a predicate representing the negation ('NOT') of this one,
244+
/// by using inverse predicates rather than wrapping in a `NOT`.
245+
/// Used for `NOT` elimination.
246+
///
247+
/// # Example
248+
///
249+
/// ```rust
250+
/// use std::ops::Bound::Unbounded;
251+
/// use iceberg::expr::BoundPredicate::Unary;
252+
/// use iceberg::expr::{LogicalExpression, Predicate, Reference};
253+
/// use iceberg::spec::Datum;
254+
/// let expr1 = Reference::new("a").less_than(Datum::long(10));
255+
/// let expr2 = Reference::new("b").less_than(Datum::long(5)).and(Reference::new("c").less_than(Datum::long(10)));
256+
///
257+
/// let result = expr1.negate();
258+
/// assert_eq!(&format!("{result}"), "a >= 10");
259+
///
260+
/// let result = expr2.negate();
261+
/// assert_eq!(&format!("{result}"), "(b >= 5) OR (c >= 10)");
262+
/// ```
263+
pub fn negate(self) -> Predicate {
264+
match self {
265+
Predicate::And(expr) => Predicate::Or(LogicalExpression::new(
266+
expr.inputs.map(|expr| Box::new(expr.negate())),
267+
)),
268+
Predicate::Or(expr) => Predicate::And(LogicalExpression::new(
269+
expr.inputs.map(|expr| Box::new(expr.negate())),
270+
)),
271+
Predicate::Not(expr) => {
272+
let LogicalExpression { inputs: [input_0] } = expr;
273+
*input_0
274+
}
275+
Predicate::Unary(expr) => {
276+
Predicate::Unary(UnaryExpression::new(expr.op.negate(), expr.term))
277+
}
278+
Predicate::Binary(expr) => Predicate::Binary(BinaryExpression::new(
279+
expr.op.negate(),
280+
expr.term,
281+
expr.literal,
282+
)),
283+
Predicate::Set(expr) => Predicate::Set(SetExpression::new(
284+
expr.op.negate(),
285+
expr.term,
286+
expr.literals,
287+
)),
288+
}
289+
}
233290
}
234291

235292
impl Not for Predicate {
@@ -271,6 +328,91 @@ pub enum BoundPredicate {
271328
Unary(UnaryExpression<BoundReference>),
272329
/// Binary expression, for example, `a > 10`.
273330
Binary(BinaryExpression<BoundReference>),
274-
/// Set predicates, for example, `a in (1, 2, 3)`.
331+
/// Set predicates, for example, `a IN (1, 2, 3)`.
275332
Set(SetExpression<BoundReference>),
276333
}
334+
335+
#[cfg(test)]
336+
mod tests {
337+
use crate::expr::Reference;
338+
use crate::spec::Datum;
339+
use std::collections::HashSet;
340+
use std::ops::Not;
341+
342+
#[test]
343+
fn test_predicate_negate_and() {
344+
let expression = Reference::new("b")
345+
.less_than(Datum::long(5))
346+
.and(Reference::new("c").less_than(Datum::long(10)));
347+
348+
let expected = Reference::new("b")
349+
.greater_than_or_equal_to(Datum::long(5))
350+
.or(Reference::new("c").greater_than_or_equal_to(Datum::long(10)));
351+
352+
let result = expression.negate();
353+
354+
assert_eq!(result, expected);
355+
}
356+
357+
#[test]
358+
fn test_predicate_negate_or() {
359+
let expression = Reference::new("b")
360+
.greater_than_or_equal_to(Datum::long(5))
361+
.or(Reference::new("c").greater_than_or_equal_to(Datum::long(10)));
362+
363+
let expected = Reference::new("b")
364+
.less_than(Datum::long(5))
365+
.and(Reference::new("c").less_than(Datum::long(10)));
366+
367+
let result = expression.negate();
368+
369+
assert_eq!(result, expected);
370+
}
371+
372+
#[test]
373+
fn test_predicate_negate_not() {
374+
let expression = Reference::new("b")
375+
.greater_than_or_equal_to(Datum::long(5))
376+
.not();
377+
378+
let expected = Reference::new("b").greater_than_or_equal_to(Datum::long(5));
379+
380+
let result = expression.negate();
381+
382+
assert_eq!(result, expected);
383+
}
384+
385+
#[test]
386+
fn test_predicate_negate_unary() {
387+
let expression = Reference::new("b").is_not_null();
388+
389+
let expected = Reference::new("b").is_null();
390+
391+
let result = expression.negate();
392+
393+
assert_eq!(result, expected);
394+
}
395+
396+
#[test]
397+
fn test_predicate_negate_binary() {
398+
let expression = Reference::new("a").less_than(Datum::long(5));
399+
400+
let expected = Reference::new("a").greater_than_or_equal_to(Datum::long(5));
401+
402+
let result = expression.negate();
403+
404+
assert_eq!(result, expected);
405+
}
406+
407+
#[test]
408+
fn test_predicate_negate_set() {
409+
let expression = Reference::new("a").is_in(HashSet::from([Datum::long(5), Datum::long(6)]));
410+
411+
let expected =
412+
Reference::new("a").is_not_in(HashSet::from([Datum::long(5), Datum::long(6)]));
413+
414+
let result = expression.negate();
415+
416+
assert_eq!(result, expected);
417+
}
418+
}

0 commit comments

Comments
 (0)