Skip to content

Commit 5eb7bf0

Browse files
committed
Implemented LEFT JOIN support.
1 parent 34201c6 commit 5eb7bf0

File tree

13 files changed

+245
-57
lines changed

13 files changed

+245
-57
lines changed

crates/execution/src/iter.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,8 @@ pub struct HashJoinIter<'a> {
792792
rhs_ptr: usize,
793793
/// The lhs probe field
794794
lhs_field: &'a TupleField,
795+
/// Is the join outer
796+
outer: bool,
795797
}
796798

797799
impl<'a> HashJoinIter<'a> {
@@ -820,6 +822,7 @@ impl<'a> HashJoinIter<'a> {
820822
lhs_tuple: None,
821823
rhs_ptr: 0,
822824
lhs_field: &join.lhs_field,
825+
outer: join.outer,
823826
})
824827
}
825828
}
@@ -839,11 +842,15 @@ impl<'a> Iterator for HashJoinIter<'a> {
839842
})
840843
.or_else(|| {
841844
self.lhs.find_map(|tuple| {
842-
self.rhs.get(&tuple.project(self.lhs_field)).and_then(|ptrs| {
845+
if let Some(ptrs) = self.rhs.get(&tuple.project(self.lhs_field)) {
843846
self.rhs_ptr = 1;
844847
self.lhs_tuple = Some(tuple.clone());
845848
ptrs.first().map(|ptr| (tuple, ptr.clone()))
846-
})
849+
} else {
850+
if self.outer {
851+
Some((tuple, Row::Null))
852+
} else { None }
853+
}
847854
})
848855
})
849856
}

crates/execution/src/lib.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ pub trait DeltaStore {
117117

118118
#[derive(Clone)]
119119
pub enum Row<'a> {
120+
Null,
120121
Ptr(RowRef<'a>),
121122
Ref(&'a ProductValue),
122123
}
@@ -128,6 +129,7 @@ impl PartialEq for Row<'_> {
128129
(Self::Ref(x), Self::Ref(y)) => x == y,
129130
(Self::Ptr(x), Self::Ref(y)) => x == *y,
130131
(Self::Ref(x), Self::Ptr(y)) => y == *x,
132+
(Self::Null, _) | (_, Self::Null) => false,
131133
}
132134
}
133135
}
@@ -137,6 +139,7 @@ impl Eq for Row<'_> {}
137139
impl Hash for Row<'_> {
138140
fn hash<H: Hasher>(&self, state: &mut H) {
139141
match self {
142+
Self::Null => AlgebraicValue::unit().hash(state),
140143
Self::Ptr(x) => x.hash(state),
141144
Self::Ref(x) => x.hash(state),
142145
}
@@ -146,34 +149,39 @@ impl Hash for Row<'_> {
146149
impl Row<'_> {
147150
pub fn to_product_value(&self) -> ProductValue {
148151
match self {
152+
Self::Null => ProductValue { elements: Box::new([]) },
149153
Self::Ptr(ptr) => ptr.to_product_value(),
150154
Self::Ref(val) => (*val).clone(),
151155
}
152156
}
153157
}
154158

155159
impl_serialize!(['a] Row<'a>, (self, ser) => match self {
160+
Self::Null => AlgebraicValue::unit().serialize(ser),
156161
Self::Ptr(row) => row.serialize(ser),
157162
Self::Ref(row) => row.serialize(ser),
158163
});
159164

160165
impl ToBsatn for Row<'_> {
161166
fn static_bsatn_size(&self) -> Option<u16> {
162167
match self {
168+
Self::Null => self.to_product_value().static_bsatn_size(),
163169
Self::Ptr(ptr) => ptr.static_bsatn_size(),
164170
Self::Ref(val) => val.static_bsatn_size(),
165171
}
166172
}
167173

168174
fn to_bsatn_extend(&self, buf: &mut Vec<u8>) -> std::result::Result<(), EncodeError> {
169175
match self {
176+
Self::Null => self.to_product_value().to_bsatn_extend(buf),
170177
Self::Ptr(ptr) => ptr.to_bsatn_extend(buf),
171178
Self::Ref(val) => val.to_bsatn_extend(buf),
172179
}
173180
}
174181

175182
fn to_bsatn_vec(&self) -> std::result::Result<Vec<u8>, EncodeError> {
176183
match self {
184+
Self::Null => self.to_product_value().to_bsatn_vec(),
177185
Self::Ptr(ptr) => ptr.to_bsatn_vec(),
178186
Self::Ref(val) => val.to_bsatn_vec(),
179187
}
@@ -183,6 +191,7 @@ impl ToBsatn for Row<'_> {
183191
impl ProjectField for Row<'_> {
184192
fn project(&self, field: &TupleField) -> AlgebraicValue {
185193
match self {
194+
Self::Null => AlgebraicValue::unit(),
186195
Self::Ptr(ptr) => ptr.project(field),
187196
Self::Ref(val) => val.project(field),
188197
}
@@ -208,7 +217,7 @@ impl ProjectField for Tuple<'_> {
208217
.label_pos
209218
.and_then(|i| ptrs.get(i))
210219
.map(|ptr| ptr.project(field))
211-
.unwrap(),
220+
.unwrap_or(AlgebraicValue::unit()),
212221
}
213222
}
214223
}

crates/execution/src/pipelined.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ impl From<PhysicalPlan> for PipelinedExecutor {
257257
lhs_field,
258258
rhs_field,
259259
unique,
260+
outer,
260261
},
261262
semijoin,
262263
) => Self::HashJoin(BlockingHashJoin {
@@ -265,6 +266,7 @@ impl From<PhysicalPlan> for PipelinedExecutor {
265266
lhs_field,
266267
rhs_field,
267268
unique,
269+
outer,
268270
semijoin,
269271
}),
270272
PhysicalPlan::NLJoin(lhs, rhs) => Self::NLJoin(BlockingNLJoin {
@@ -1088,6 +1090,7 @@ pub struct BlockingHashJoin {
10881090
pub lhs_field: TupleField,
10891091
pub rhs_field: TupleField,
10901092
pub unique: bool,
1093+
pub outer: bool,
10911094
pub semijoin: Semi,
10921095
}
10931096

@@ -1106,12 +1109,18 @@ impl BlockingHashJoin {
11061109
let mut n = 0;
11071110
let mut bytes_scanned = 0;
11081111
match self {
1112+
Self {
1113+
outer: true,
1114+
semijoin: Semi::Lhs | Semi::Rhs,
1115+
..
1116+
} => unreachable!("Outer semijoin is not possible"),
11091117
Self {
11101118
lhs,
11111119
rhs,
11121120
lhs_field,
11131121
rhs_field,
11141122
unique: true,
1123+
outer: false,
11151124
semijoin: Semi::Lhs,
11161125
} => {
11171126
let mut rhs_table = HashSet::new();
@@ -1137,6 +1146,7 @@ impl BlockingHashJoin {
11371146
lhs_field,
11381147
rhs_field,
11391148
unique: true,
1149+
outer: false,
11401150
semijoin: Semi::Rhs,
11411151
} => {
11421152
let mut rhs_table = HashMap::new();
@@ -1162,6 +1172,7 @@ impl BlockingHashJoin {
11621172
lhs_field,
11631173
rhs_field,
11641174
unique: true,
1175+
outer,
11651176
semijoin: Semi::All,
11661177
} => {
11671178
let mut rhs_table = HashMap::new();
@@ -1177,6 +1188,8 @@ impl BlockingHashJoin {
11771188
n += 1;
11781189
if let Some(v) = rhs_table.get(&project(&u, lhs_field, &mut bytes_scanned)) {
11791190
f(u.clone().join(v.clone()))?;
1191+
} else if *outer {
1192+
f(u.clone().append(Row::Null))?;
11801193
}
11811194
Ok(())
11821195
})?;
@@ -1187,6 +1200,7 @@ impl BlockingHashJoin {
11871200
lhs_field,
11881201
rhs_field,
11891202
unique: false,
1203+
outer: false,
11901204
semijoin: Semi::Lhs,
11911205
} => {
11921206
let mut rhs_table = HashMap::new();
@@ -1214,6 +1228,7 @@ impl BlockingHashJoin {
12141228
lhs_field,
12151229
rhs_field,
12161230
unique: false,
1231+
outer: false,
12171232
semijoin: Semi::Rhs,
12181233
} => {
12191234
let mut rhs_table: HashMap<AlgebraicValue, Vec<_>> = HashMap::new();
@@ -1243,6 +1258,7 @@ impl BlockingHashJoin {
12431258
lhs_field,
12441259
rhs_field,
12451260
unique: false,
1261+
outer,
12461262
semijoin: Semi::All,
12471263
} => {
12481264
let mut rhs_table: HashMap<AlgebraicValue, Vec<_>> = HashMap::new();
@@ -1262,6 +1278,8 @@ impl BlockingHashJoin {
12621278
for v in rhs_tuples {
12631279
f(u.clone().join(v.clone()))?;
12641280
}
1281+
} else if *outer {
1282+
f(u.clone().append(Row::Null))?;
12651283
}
12661284
Ok(())
12671285
})?;

crates/expr/src/check.rs

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::collections::HashMap;
22
use std::ops::{Deref, DerefMut};
33
use std::sync::Arc;
44

5+
use crate::ast::{CrossJoin, InnerJoin, OuterJoin};
56
use crate::expr::LeftDeepJoin;
67
use crate::expr::{Expr, ProjectList, ProjectName, Relvar};
78
use spacetimedb_lib::identity::AuthCtx;
@@ -78,34 +79,56 @@ pub trait TypeChecker {
7879
delta: None,
7980
});
8081

81-
for SqlJoin {
82-
var: SqlIdent(name),
83-
alias: SqlIdent(alias),
84-
on,
85-
} in joins
86-
{
82+
for jn in joins {
8783
// Check for duplicate aliases
88-
if vars.contains_key(&alias) {
89-
return Err(DuplicateName(alias.into_string()).into());
84+
match jn {
85+
SqlJoin::Cross(CrossJoin { alias: SqlIdent(alias), .. })
86+
| SqlJoin::Inner(InnerJoin { alias: SqlIdent(alias), .. })
87+
| SqlJoin::Left(OuterJoin { alias: SqlIdent(alias), .. })
88+
if vars.contains_key(&alias) => {
89+
return Err(DuplicateName(alias.into_string()).into());
90+
}
91+
SqlJoin::Cross(_) => (),
92+
SqlJoin::Inner(_) => (),
93+
SqlJoin::Left(_) => (),
9094
}
9195

9296
let lhs = Box::new(join);
93-
let rhs = Relvar {
94-
schema: Self::type_relvar(tx, &name)?,
95-
alias,
96-
delta: None,
97+
let rhs = match &jn {
98+
SqlJoin::Cross(CrossJoin { var: SqlIdent(name), alias: SqlIdent(alias), .. })
99+
| SqlJoin::Inner(InnerJoin { var: SqlIdent(name), alias: SqlIdent(alias), .. })
100+
| SqlJoin::Left(OuterJoin { var: SqlIdent(name), alias: SqlIdent(alias), .. }) => {
101+
Relvar {
102+
schema: Self::type_relvar(tx, &name)?,
103+
alias: alias.clone(),
104+
delta: None,
105+
}
106+
}
97107
};
98108

99109
vars.insert(rhs.alias.clone(), rhs.schema.clone());
100110

101-
if let Some(on) = on {
102-
if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
103-
if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
104-
join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b);
105-
continue;
111+
match jn {
112+
SqlJoin::Cross(_) => (),
113+
SqlJoin::Inner(InnerJoin { on: Some(on), .. }) => {
114+
if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
115+
if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
116+
join = RelExpr::InnerEqJoin(LeftDeepJoin { lhs, rhs }, a, b);
117+
continue;
118+
}
119+
}
120+
unreachable!("Unreachability guaranteed by parser")
121+
}
122+
SqlJoin::Inner(_) => (),
123+
SqlJoin::Left(OuterJoin { on, .. }) => {
124+
if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
125+
if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
126+
join = RelExpr::LeftOuterEqJoin(LeftDeepJoin { lhs, rhs }, a, b);
127+
continue;
128+
}
106129
}
130+
unreachable!("Unreachability guaranteed by parser")
107131
}
108-
unreachable!("Unreachability guaranteed by parser")
109132
}
110133

111134
join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs });

crates/expr/src/expr.rs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,10 @@ pub enum RelExpr {
197197
Select(Box<RelExpr>, Expr),
198198
/// A left deep binary cross product
199199
LeftDeepJoin(LeftDeepJoin),
200-
/// A left deep binary equi-join
201-
EqJoin(LeftDeepJoin, FieldProject, FieldProject),
200+
/// A left deep binary inner equi-join
201+
InnerEqJoin(LeftDeepJoin, FieldProject, FieldProject),
202+
/// A left deep binary left outer equi-join
203+
LeftOuterEqJoin(LeftDeepJoin, FieldProject, FieldProject),
202204
}
203205

204206
/// A table reference
@@ -219,7 +221,8 @@ impl RelExpr {
219221
match self {
220222
Self::Select(lhs, _)
221223
| Self::LeftDeepJoin(LeftDeepJoin { lhs, .. })
222-
| Self::EqJoin(LeftDeepJoin { lhs, .. }, ..) => {
224+
| Self::InnerEqJoin(LeftDeepJoin { lhs, .. }, ..)
225+
| Self::LeftOuterEqJoin(LeftDeepJoin { lhs, .. }, ..) => {
223226
lhs.visit(f);
224227
}
225228
Self::RelVar(..) => {}
@@ -232,7 +235,8 @@ impl RelExpr {
232235
match self {
233236
Self::Select(lhs, _)
234237
| Self::LeftDeepJoin(LeftDeepJoin { lhs, .. })
235-
| Self::EqJoin(LeftDeepJoin { lhs, .. }, ..) => {
238+
| Self::InnerEqJoin(LeftDeepJoin { lhs, .. }, ..)
239+
| Self::LeftOuterEqJoin(LeftDeepJoin { lhs, .. }, ..) => {
236240
lhs.visit_mut(f);
237241
}
238242
Self::RelVar(..) => {}
@@ -243,7 +247,11 @@ impl RelExpr {
243247
pub fn nfields(&self) -> usize {
244248
match self {
245249
Self::RelVar(..) => 1,
246-
Self::LeftDeepJoin(join) | Self::EqJoin(join, ..) => join.lhs.nfields() + 1,
250+
Self::LeftDeepJoin(join)
251+
| Self::InnerEqJoin(join, ..)
252+
| Self::LeftOuterEqJoin(join, ..) => {
253+
join.lhs.nfields() + 1
254+
}
247255
Self::Select(input, _) => input.nfields(),
248256
}
249257
}
@@ -252,7 +260,9 @@ impl RelExpr {
252260
pub fn has_field(&self, field: &str) -> bool {
253261
match self {
254262
Self::RelVar(Relvar { alias, .. }) => alias.as_ref() == field,
255-
Self::LeftDeepJoin(join) | Self::EqJoin(join, ..) => {
263+
Self::LeftDeepJoin(join)
264+
| Self::InnerEqJoin(join, ..)
265+
| Self::LeftOuterEqJoin(join, ..) => {
256266
join.rhs.alias.as_ref() == field || join.lhs.has_field(field)
257267
}
258268
Self::Select(input, _) => input.has_field(field),
@@ -264,10 +274,12 @@ impl RelExpr {
264274
match self {
265275
Self::RelVar(relvar) if relvar.alias.as_ref() == alias => Some(&relvar.schema),
266276
Self::Select(input, _) => input.find_table_schema(alias),
267-
Self::EqJoin(LeftDeepJoin { rhs, .. }, ..) if rhs.alias.as_ref() == alias => Some(&rhs.schema),
268-
Self::EqJoin(LeftDeepJoin { lhs, .. }, ..) => lhs.find_table_schema(alias),
269277
Self::LeftDeepJoin(LeftDeepJoin { rhs, .. }) if rhs.alias.as_ref() == alias => Some(&rhs.schema),
270278
Self::LeftDeepJoin(LeftDeepJoin { lhs, .. }) => lhs.find_table_schema(alias),
279+
Self::InnerEqJoin(LeftDeepJoin { rhs, .. }, ..) if rhs.alias.as_ref() == alias => Some(&rhs.schema),
280+
Self::InnerEqJoin(LeftDeepJoin { lhs, .. }, ..) => lhs.find_table_schema(alias),
281+
Self::LeftOuterEqJoin(LeftDeepJoin { rhs, .. }, ..) if rhs.alias.as_ref() == alias => Some(&rhs.schema),
282+
Self::LeftOuterEqJoin(LeftDeepJoin { lhs, .. }, ..) => lhs.find_table_schema(alias),
271283
_ => None,
272284
}
273285
}

0 commit comments

Comments
 (0)