Skip to content

Commit 9fb8e0a

Browse files
authored
Add INNER and CROSS JOINs to planner and evaluator (#218)
1 parent a279276 commit 9fb8e0a

File tree

4 files changed

+285
-54
lines changed

4 files changed

+285
-54
lines changed

partiql-eval/src/eval.rs

Lines changed: 128 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use thiserror::Error;
66

77
use petgraph::algo::toposort;
88
use petgraph::prelude::StableGraph;
9-
use petgraph::{Directed, Incoming, Outgoing};
9+
use petgraph::{Directed, Outgoing};
1010

1111
use partiql_value::Value::{Boolean, Missing, Null};
1212
use partiql_value::{
@@ -18,7 +18,7 @@ use crate::env::basic::MapBindings;
1818
use crate::env::Bindings;
1919

2020
#[derive(Debug)]
21-
pub struct EvalPlan(pub StableGraph<Box<dyn Evaluable>, (), Directed>);
21+
pub struct EvalPlan(pub StableGraph<Box<dyn Evaluable>, u8, Directed>);
2222

2323
impl Default for EvalPlan {
2424
fn default() -> Self {
@@ -28,7 +28,7 @@ impl Default for EvalPlan {
2828

2929
impl EvalPlan {
3030
fn new() -> Self {
31-
EvalPlan(StableGraph::<Box<dyn Evaluable>, (), Directed>::new())
31+
EvalPlan(StableGraph::<Box<dyn Evaluable>, u8, Directed>::new())
3232
}
3333
}
3434

@@ -50,7 +50,7 @@ pub enum EvaluationError {
5050

5151
pub trait Evaluable: Debug {
5252
fn evaluate(&mut self, ctx: &dyn EvalContext) -> Option<Value>;
53-
fn update_input(&mut self, input: &Value);
53+
fn update_input(&mut self, input: &Value, branch_num: u8);
5454
}
5555

5656
#[derive(Debug)]
@@ -108,11 +108,103 @@ impl Evaluable for EvalScan {
108108
self.output.clone()
109109
}
110110

111-
fn update_input(&mut self, _input: &Value) {
111+
fn update_input(&mut self, _input: &Value, _branch_num: u8) {
112112
todo!("update_input for Scan")
113113
}
114114
}
115115

116+
#[derive(Debug)]
117+
pub enum EvalJoinKind {
118+
Inner,
119+
Left,
120+
Right,
121+
Full,
122+
Cross,
123+
}
124+
125+
#[derive(Debug)]
126+
pub struct EvalJoin {
127+
pub kind: EvalJoinKind,
128+
pub on: Option<Box<dyn EvalExpr>>,
129+
pub input_l: Option<Value>,
130+
pub input_r: Option<Value>,
131+
pub output: Option<Value>,
132+
}
133+
134+
impl EvalJoin {
135+
pub fn new(kind: EvalJoinKind, on: Option<Box<dyn EvalExpr>>) -> Self {
136+
EvalJoin {
137+
kind,
138+
on,
139+
input_l: None,
140+
input_r: None,
141+
output: None,
142+
}
143+
}
144+
}
145+
146+
impl Evaluable for EvalJoin {
147+
fn evaluate(&mut self, ctx: &dyn EvalContext) -> Option<Value> {
148+
// TODO: PartiQL defaults to lateral JOINs (RHS can reference binding tuples defined from the LHS)
149+
// https://partiql.org/assets/PartiQL-Specification.pdf#subsection.5.3. Adding this behavior
150+
// to be spec-compliant may result in changes to the DAG flows.
151+
let output = match self.kind {
152+
EvalJoinKind::Inner => {
153+
let mut result = partiql_bag!();
154+
for binding_tuple_l in self.input_l.clone().unwrap() {
155+
let binding_tuple_l = binding_tuple_l.coerce_to_tuple();
156+
for binding_tuple_r in self.input_r.clone().unwrap() {
157+
let binding_tuple_r = binding_tuple_r.coerce_to_tuple();
158+
let mut new_result = binding_tuple_l.clone();
159+
for pairs in binding_tuple_r.pairs() {
160+
new_result.insert(pairs.0, pairs.1.clone());
161+
}
162+
if let Some(on_condition) = &self.on {
163+
if on_condition.evaluate(&new_result, ctx) == Boolean(true) {
164+
result.push(new_result.into());
165+
}
166+
} else {
167+
result.push(new_result.into());
168+
}
169+
}
170+
}
171+
Some(result.into())
172+
}
173+
EvalJoinKind::Left => {
174+
todo!("Left JOINs")
175+
}
176+
EvalJoinKind::Cross => {
177+
let mut result = partiql_bag!();
178+
for binding_tuple_l in self.input_l.clone().unwrap() {
179+
let binding_tuple_l = binding_tuple_l.coerce_to_tuple();
180+
for binding_tuple_r in self.input_r.clone().unwrap() {
181+
let binding_tuple_r = binding_tuple_r.coerce_to_tuple();
182+
let mut new_result = binding_tuple_l.clone();
183+
for pairs in binding_tuple_r.pairs() {
184+
new_result.insert(pairs.0, pairs.1.clone());
185+
}
186+
result.push(new_result.into());
187+
}
188+
}
189+
Some(result.into())
190+
}
191+
EvalJoinKind::Full | EvalJoinKind::Right => {
192+
todo!("Full and Right Joins are not yet implemented for `partiql-lang-rust`")
193+
}
194+
};
195+
self.output = output;
196+
self.output.clone()
197+
}
198+
199+
fn update_input(&mut self, input: &Value, branch_num: u8) {
200+
match branch_num {
201+
0 => self.input_l = Some(input.clone()),
202+
1 => self.input_r = Some(input.clone()),
203+
_ => panic!("EvalJoin nodes only support `0` and `1` for the `branch_num`"),
204+
};
205+
}
206+
}
207+
116208
#[derive(Debug)]
117209
pub struct EvalUnpivot {
118210
pub expr: Box<dyn EvalExpr>,
@@ -154,7 +246,7 @@ impl Evaluable for EvalUnpivot {
154246
self.output.clone()
155247
}
156248

157-
fn update_input(&mut self, _input: &Value) {
249+
fn update_input(&mut self, _input: &Value, _branch_num: u8) {
158250
todo!()
159251
}
160252
}
@@ -206,7 +298,7 @@ impl Evaluable for EvalFilter {
206298
self.output = Some(Value::Bag(Box::new(out)));
207299
self.output.clone()
208300
}
209-
fn update_input(&mut self, input: &Value) {
301+
fn update_input(&mut self, input: &Value, _branch_num: u8) {
210302
self.input = Some(input.clone())
211303
}
212304
}
@@ -257,7 +349,7 @@ impl Evaluable for EvalProject {
257349
self.output.clone()
258350
}
259351

260-
fn update_input(&mut self, input: &Value) {
352+
fn update_input(&mut self, input: &Value, _branch_num: u8) {
261353
self.input = Some(input.clone());
262354
}
263355
}
@@ -304,7 +396,7 @@ impl Evaluable for EvalProjectValue {
304396
self.output.clone()
305397
}
306398

307-
fn update_input(&mut self, input: &Value) {
399+
fn update_input(&mut self, input: &Value, _branch_num: u8) {
308400
self.input = Some(input.clone());
309401
}
310402
}
@@ -436,7 +528,7 @@ impl Evaluable for EvalDistinct {
436528
self.output.clone()
437529
}
438530

439-
fn update_input(&mut self, input: &Value) {
531+
fn update_input(&mut self, input: &Value, _branch_num: u8) {
440532
self.input = Some(input.clone());
441533
}
442534
}
@@ -451,7 +543,7 @@ impl Evaluable for EvalSink {
451543
fn evaluate(&mut self, _ctx: &dyn EvalContext) -> Option<Value> {
452544
self.input.clone()
453545
}
454-
fn update_input(&mut self, input: &Value) {
546+
fn update_input(&mut self, input: &Value, _branch_num: u8) {
455547
self.input = Some(input.clone());
456548
}
457549
}
@@ -627,39 +719,33 @@ impl Evaluator {
627719
// that all v ∈ V \{v0} are reachable from v0. Note that this is the definition of trees
628720
// without the condition |E| = |V | − 1. Hence, all trees are DAGs.
629721
// Reference: https://link.springer.com/article/10.1007/s00450-009-0061-0
630-
match graph.externals(Incoming).exactly_one() {
631-
Ok(_) => {
632-
let sorted_ops = toposort(&graph, None);
633-
match sorted_ops {
634-
Ok(ops) => {
635-
let mut result = None;
636-
for idx in ops.into_iter() {
637-
let src = graph
638-
.node_weight_mut(idx)
639-
.expect("Error in retrieving node");
640-
result = src.evaluate(&*self.ctx);
641-
642-
let mut ne = graph.neighbors_directed(idx, Outgoing).detach();
643-
while let Some(n) = ne.next_node(&graph) {
644-
let dst =
645-
graph.node_weight_mut(n).expect("Error in retrieving node");
646-
dst.update_input(
647-
&result.clone().expect("Error in retrieving source value"),
648-
);
649-
}
650-
}
651-
let evaluated = Evaluated {
652-
result: result.expect("Error in retrieving eval output"),
653-
};
654-
Ok(evaluated)
722+
let sorted_ops = toposort(&graph, None);
723+
match sorted_ops {
724+
Ok(ops) => {
725+
let mut result = None;
726+
for idx in ops.into_iter() {
727+
let src = graph
728+
.node_weight_mut(idx)
729+
.expect("Error in retrieving node");
730+
result = src.evaluate(&*self.ctx);
731+
732+
let mut ne = graph.neighbors_directed(idx, Outgoing).detach();
733+
while let Some((e, n)) = ne.next(&graph) {
734+
// use the edge weight to store the `branch_num`
735+
let branch_num = *graph
736+
.edge_weight(e)
737+
.expect("Error in retrieving weight for edge");
738+
let dst = graph.node_weight_mut(n).expect("Error in retrieving node");
739+
dst.update_input(
740+
&result.clone().expect("Error in retrieving source value"),
741+
branch_num,
742+
);
655743
}
656-
Err(e) => Err(EvalErr {
657-
errors: vec![EvaluationError::InvalidEvaluationPlan(format!(
658-
"Malformed evaluation plan detected: {:?}",
659-
e
660-
))],
661-
}),
662744
}
745+
let evaluated = Evaluated {
746+
result: result.expect("Error in retrieving eval output"),
747+
};
748+
Ok(evaluated)
663749
}
664750
Err(e) => Err(EvalErr {
665751
errors: vec![EvaluationError::InvalidEvaluationPlan(format!(

partiql-eval/src/lib.rs

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ mod tests {
1717
use partiql_logical as logical;
1818
use partiql_logical::BindingsExpr::{Distinct, Project, ProjectValue};
1919
use partiql_logical::{
20-
BagExpr, BinaryOp, BindingsExpr, ListExpr, LogicalPlan, PathComponent, TupleExpr, ValueExpr,
20+
BagExpr, BinaryOp, BindingsExpr, JoinKind, ListExpr, LogicalPlan, PathComponent, TupleExpr,
21+
ValueExpr,
2122
};
22-
2323
use partiql_value as value;
2424
use partiql_value::{
2525
partiql_bag, partiql_list, partiql_tuple, Bag, BindingsName, List, Tuple, Value,
@@ -85,6 +85,23 @@ mod tests {
8585
)
8686
}
8787

88+
fn join_data() -> MapBindings<Value> {
89+
let customers = partiql_list![
90+
partiql_tuple![("id", 5), ("name", "Joe")],
91+
partiql_tuple![("id", 7), ("name", "Mary")],
92+
];
93+
94+
let orders = partiql_list![
95+
partiql_tuple![("custId", 7), ("productId", 101)],
96+
partiql_tuple![("custId", 7), ("productId", 523)],
97+
];
98+
99+
let mut bindings = MapBindings::default();
100+
bindings.insert("customers", customers.into());
101+
bindings.insert("orders", orders.into());
102+
bindings
103+
}
104+
88105
// Creates the plan: `SELECT <lhs> <op> <rhs> AS result FROM data` where <lhs> comes from data
89106
// Evaluates the plan and asserts the result is a bag of the tuple mapping to `expected_first_elem`
90107
// (i.e. <<{'result': <expected_first_elem>}>>)
@@ -452,6 +469,94 @@ mod tests {
452469
);
453470
}
454471

472+
#[test]
473+
fn select_with_cross_join() {
474+
let mut lg = LogicalPlan::new();
475+
476+
// Example 9 from spec with projected columns from different tables demonstrates a cross join:
477+
// SELECT c.id, c.name, o.custId, o.productId FROM customers AS c, orders AS o
478+
let from_lhs = lg.add_operator(scan("customers", "c"));
479+
let from_rhs = lg.add_operator(scan("orders", "o"));
480+
481+
let project = lg.add_operator(Project(logical::Project {
482+
exprs: HashMap::from([
483+
("id".to_string(), path_var("c", "id")),
484+
("name".to_string(), path_var("c", "name")),
485+
("custId".to_string(), path_var("o", "custId")),
486+
("productId".to_string(), path_var("o", "productId")),
487+
]),
488+
}));
489+
490+
let join = lg.add_operator(BindingsExpr::Join(logical::Join {
491+
kind: JoinKind::Cross,
492+
on: None,
493+
}));
494+
495+
let sink = lg.add_operator(BindingsExpr::Sink);
496+
lg.add_flow_with_branch_num(from_lhs, join, 0);
497+
lg.add_flow_with_branch_num(from_rhs, join, 1);
498+
lg.add_flow_with_branch_num(join, project, 0);
499+
lg.add_flow_with_branch_num(project, sink, 0);
500+
501+
let out = evaluate(lg, join_data());
502+
println!("{:?}", &out);
503+
504+
assert_matches!(out, Value::Bag(bag) => {
505+
let expected = partiql_bag![
506+
partiql_tuple![("custId", 7), ("name", "Joe"), ("id", 5), ("productId", 101)],
507+
partiql_tuple![("custId", 7), ("name", "Joe"), ("id", 5), ("productId", 523)],
508+
partiql_tuple![("custId", 7), ("name", "Mary"), ("id", 7), ("productId", 101)],
509+
partiql_tuple![("custId", 7), ("name", "Mary"), ("id", 7), ("productId", 523)],
510+
];
511+
assert_eq!(*bag, expected);
512+
});
513+
}
514+
515+
#[test]
516+
fn select_with_join_and_on() {
517+
let mut lg = LogicalPlan::new();
518+
519+
// Similar to ex 9 from spec with projected columns from different tables with an inner JOIN and ON condition
520+
// SELECT c.id, c.name, o.custId, o.productId FROM customers AS c, orders AS o ON c.id = o.custId
521+
let from_lhs = lg.add_operator(scan("customers", "c"));
522+
let from_rhs = lg.add_operator(scan("orders", "o"));
523+
524+
let project = lg.add_operator(Project(logical::Project {
525+
exprs: HashMap::from([
526+
("id".to_string(), path_var("c", "id")),
527+
("name".to_string(), path_var("c", "name")),
528+
("custId".to_string(), path_var("o", "custId")),
529+
("productId".to_string(), path_var("o", "productId")),
530+
]),
531+
}));
532+
533+
let join = lg.add_operator(BindingsExpr::Join(logical::Join {
534+
kind: JoinKind::Inner,
535+
on: Some(ValueExpr::BinaryExpr(
536+
BinaryOp::Eq,
537+
Box::new(path_var("c", "id")),
538+
Box::new(path_var("o", "custId")),
539+
)),
540+
}));
541+
542+
let sink = lg.add_operator(BindingsExpr::Sink);
543+
lg.add_flow_with_branch_num(from_lhs, join, 0);
544+
lg.add_flow_with_branch_num(from_rhs, join, 1);
545+
lg.add_flow_with_branch_num(join, project, 0);
546+
lg.add_flow_with_branch_num(project, sink, 0);
547+
548+
let out = evaluate(lg, join_data());
549+
println!("{:?}", &out);
550+
551+
assert_matches!(out, Value::Bag(bag) => {
552+
let expected = partiql_bag![
553+
partiql_tuple![("custId", 7), ("name", "Mary"), ("id", 7), ("productId", 101)],
554+
partiql_tuple![("custId", 7), ("name", "Mary"), ("id", 7), ("productId", 523)],
555+
];
556+
assert_eq!(*bag, expected);
557+
});
558+
}
559+
455560
#[test]
456561
fn select() {
457562
let mut lg = LogicalPlan::new();

0 commit comments

Comments
 (0)