Skip to content

Implemented OUTER JOIN support #2892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions crates/execution/src/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,8 @@ pub struct HashJoinIter<'a> {
rhs_ptr: usize,
/// The lhs probe field
lhs_field: &'a TupleField,
/// Is the join outer
outer: bool,
}

impl<'a> HashJoinIter<'a> {
Expand Down Expand Up @@ -820,6 +822,7 @@ impl<'a> HashJoinIter<'a> {
lhs_tuple: None,
rhs_ptr: 0,
lhs_field: &join.lhs_field,
outer: join.outer,
})
}
}
Expand All @@ -839,11 +842,15 @@ impl<'a> Iterator for HashJoinIter<'a> {
})
.or_else(|| {
self.lhs.find_map(|tuple| {
self.rhs.get(&tuple.project(self.lhs_field)).and_then(|ptrs| {
if let Some(ptrs) = self.rhs.get(&tuple.project(self.lhs_field)) {
self.rhs_ptr = 1;
self.lhs_tuple = Some(tuple.clone());
ptrs.first().map(|ptr| (tuple, ptr.clone()))
})
} else {
if self.outer {
Some((tuple, Row::Null))
} else { None }
}
})
})
}
Expand Down
11 changes: 10 additions & 1 deletion crates/execution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ pub trait DeltaStore {

#[derive(Clone)]
pub enum Row<'a> {
Null,
Ptr(RowRef<'a>),
Ref(&'a ProductValue),
}
Expand All @@ -128,6 +129,7 @@ impl PartialEq for Row<'_> {
(Self::Ref(x), Self::Ref(y)) => x == y,
(Self::Ptr(x), Self::Ref(y)) => x == *y,
(Self::Ref(x), Self::Ptr(y)) => y == *x,
(Self::Null, _) | (_, Self::Null) => false,
}
}
}
Expand All @@ -137,6 +139,7 @@ impl Eq for Row<'_> {}
impl Hash for Row<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Self::Null => AlgebraicValue::unit().hash(state),
Self::Ptr(x) => x.hash(state),
Self::Ref(x) => x.hash(state),
}
Expand All @@ -146,34 +149,39 @@ impl Hash for Row<'_> {
impl Row<'_> {
pub fn to_product_value(&self) -> ProductValue {
match self {
Self::Null => ProductValue { elements: Box::new([]) },
Self::Ptr(ptr) => ptr.to_product_value(),
Self::Ref(val) => (*val).clone(),
}
}
}

impl_serialize!(['a] Row<'a>, (self, ser) => match self {
Self::Null => AlgebraicValue::unit().serialize(ser),
Self::Ptr(row) => row.serialize(ser),
Self::Ref(row) => row.serialize(ser),
});

impl ToBsatn for Row<'_> {
fn static_bsatn_size(&self) -> Option<u16> {
match self {
Self::Null => self.to_product_value().static_bsatn_size(),
Self::Ptr(ptr) => ptr.static_bsatn_size(),
Self::Ref(val) => val.static_bsatn_size(),
}
}

fn to_bsatn_extend(&self, buf: &mut Vec<u8>) -> std::result::Result<(), EncodeError> {
match self {
Self::Null => self.to_product_value().to_bsatn_extend(buf),
Self::Ptr(ptr) => ptr.to_bsatn_extend(buf),
Self::Ref(val) => val.to_bsatn_extend(buf),
}
}

fn to_bsatn_vec(&self) -> std::result::Result<Vec<u8>, EncodeError> {
match self {
Self::Null => self.to_product_value().to_bsatn_vec(),
Self::Ptr(ptr) => ptr.to_bsatn_vec(),
Self::Ref(val) => val.to_bsatn_vec(),
}
Expand All @@ -183,6 +191,7 @@ impl ToBsatn for Row<'_> {
impl ProjectField for Row<'_> {
fn project(&self, field: &TupleField) -> AlgebraicValue {
match self {
Self::Null => AlgebraicValue::unit(),
Self::Ptr(ptr) => ptr.project(field),
Self::Ref(val) => val.project(field),
}
Expand All @@ -208,7 +217,7 @@ impl ProjectField for Tuple<'_> {
.label_pos
.and_then(|i| ptrs.get(i))
.map(|ptr| ptr.project(field))
.unwrap(),
.unwrap_or(AlgebraicValue::unit()),
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions crates/execution/src/pipelined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ impl From<PhysicalPlan> for PipelinedExecutor {
lhs_field,
rhs_field,
unique,
outer,
},
semijoin,
) => Self::HashJoin(BlockingHashJoin {
Expand All @@ -265,6 +266,7 @@ impl From<PhysicalPlan> for PipelinedExecutor {
lhs_field,
rhs_field,
unique,
outer,
semijoin,
}),
PhysicalPlan::NLJoin(lhs, rhs) => Self::NLJoin(BlockingNLJoin {
Expand Down Expand Up @@ -1088,6 +1090,7 @@ pub struct BlockingHashJoin {
pub lhs_field: TupleField,
pub rhs_field: TupleField,
pub unique: bool,
pub outer: bool,
pub semijoin: Semi,
}

Expand All @@ -1106,12 +1109,18 @@ impl BlockingHashJoin {
let mut n = 0;
let mut bytes_scanned = 0;
match self {
Self {
outer: true,
semijoin: Semi::Lhs | Semi::Rhs,
..
} => unreachable!("Outer semijoin is not possible"),
Self {
lhs,
rhs,
lhs_field,
rhs_field,
unique: true,
outer: false,
semijoin: Semi::Lhs,
} => {
let mut rhs_table = HashSet::new();
Expand All @@ -1137,6 +1146,7 @@ impl BlockingHashJoin {
lhs_field,
rhs_field,
unique: true,
outer: false,
semijoin: Semi::Rhs,
} => {
let mut rhs_table = HashMap::new();
Expand All @@ -1162,6 +1172,7 @@ impl BlockingHashJoin {
lhs_field,
rhs_field,
unique: true,
outer,
semijoin: Semi::All,
} => {
let mut rhs_table = HashMap::new();
Expand All @@ -1177,6 +1188,8 @@ impl BlockingHashJoin {
n += 1;
if let Some(v) = rhs_table.get(&project(&u, lhs_field, &mut bytes_scanned)) {
f(u.clone().join(v.clone()))?;
} else if *outer {
f(u.clone().append(Row::Null))?;
}
Ok(())
})?;
Expand All @@ -1187,6 +1200,7 @@ impl BlockingHashJoin {
lhs_field,
rhs_field,
unique: false,
outer: false,
semijoin: Semi::Lhs,
} => {
let mut rhs_table = HashMap::new();
Expand Down Expand Up @@ -1214,6 +1228,7 @@ impl BlockingHashJoin {
lhs_field,
rhs_field,
unique: false,
outer: false,
semijoin: Semi::Rhs,
} => {
let mut rhs_table: HashMap<AlgebraicValue, Vec<_>> = HashMap::new();
Expand Down Expand Up @@ -1243,6 +1258,7 @@ impl BlockingHashJoin {
lhs_field,
rhs_field,
unique: false,
outer,
semijoin: Semi::All,
} => {
let mut rhs_table: HashMap<AlgebraicValue, Vec<_>> = HashMap::new();
Expand All @@ -1262,6 +1278,8 @@ impl BlockingHashJoin {
for v in rhs_tuples {
f(u.clone().join(v.clone()))?;
}
} else if *outer {
f(u.clone().append(Row::Null))?;
}
Ok(())
})?;
Expand Down
59 changes: 41 additions & 18 deletions crates/expr/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;

use crate::ast::{CrossJoin, InnerJoin, OuterJoin};
use crate::expr::LeftDeepJoin;
use crate::expr::{Expr, ProjectList, ProjectName, Relvar};
use spacetimedb_lib::identity::AuthCtx;
Expand Down Expand Up @@ -78,34 +79,56 @@ pub trait TypeChecker {
delta: None,
});

for SqlJoin {
var: SqlIdent(name),
alias: SqlIdent(alias),
on,
} in joins
{
for jn in joins {
// Check for duplicate aliases
if vars.contains_key(&alias) {
return Err(DuplicateName(alias.into_string()).into());
match jn {
SqlJoin::Cross(CrossJoin { alias: SqlIdent(alias), .. })
| SqlJoin::Inner(InnerJoin { alias: SqlIdent(alias), .. })
| SqlJoin::Left(OuterJoin { alias: SqlIdent(alias), .. })
if vars.contains_key(&alias) => {
return Err(DuplicateName(alias.into_string()).into());
}
SqlJoin::Cross(_) => (),
SqlJoin::Inner(_) => (),
SqlJoin::Left(_) => (),
}

let lhs = Box::new(join);
let rhs = Relvar {
schema: Self::type_relvar(tx, &name)?,
alias,
delta: None,
let rhs = match &jn {
SqlJoin::Cross(CrossJoin { var: SqlIdent(name), alias: SqlIdent(alias), .. })
| SqlJoin::Inner(InnerJoin { var: SqlIdent(name), alias: SqlIdent(alias), .. })
| SqlJoin::Left(OuterJoin { var: SqlIdent(name), alias: SqlIdent(alias), .. }) => {
Relvar {
schema: Self::type_relvar(tx, &name)?,
alias: alias.clone(),
delta: None,
}
}
};

vars.insert(rhs.alias.clone(), rhs.schema.clone());

if let Some(on) = on {
if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b);
continue;
match jn {
SqlJoin::Cross(_) => (),
SqlJoin::Inner(InnerJoin { on: Some(on), .. }) => {
if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
join = RelExpr::InnerEqJoin(LeftDeepJoin { lhs, rhs }, a, b);
continue;
}
}
unreachable!("Unreachability guaranteed by parser")
}
SqlJoin::Inner(_) => (),
SqlJoin::Left(OuterJoin { on, .. }) => {
if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
join = RelExpr::LeftOuterEqJoin(LeftDeepJoin { lhs, rhs }, a, b);
continue;
}
}
unreachable!("Unreachability guaranteed by parser")
}
unreachable!("Unreachability guaranteed by parser")
}

join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs });
Expand Down
28 changes: 20 additions & 8 deletions crates/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,10 @@ pub enum RelExpr {
Select(Box<RelExpr>, Expr),
/// A left deep binary cross product
LeftDeepJoin(LeftDeepJoin),
/// A left deep binary equi-join
EqJoin(LeftDeepJoin, FieldProject, FieldProject),
/// A left deep binary inner equi-join
InnerEqJoin(LeftDeepJoin, FieldProject, FieldProject),
/// A left deep binary left outer equi-join
LeftOuterEqJoin(LeftDeepJoin, FieldProject, FieldProject),
}

/// A table reference
Expand All @@ -219,7 +221,8 @@ impl RelExpr {
match self {
Self::Select(lhs, _)
| Self::LeftDeepJoin(LeftDeepJoin { lhs, .. })
| Self::EqJoin(LeftDeepJoin { lhs, .. }, ..) => {
| Self::InnerEqJoin(LeftDeepJoin { lhs, .. }, ..)
| Self::LeftOuterEqJoin(LeftDeepJoin { lhs, .. }, ..) => {
lhs.visit(f);
}
Self::RelVar(..) => {}
Expand All @@ -232,7 +235,8 @@ impl RelExpr {
match self {
Self::Select(lhs, _)
| Self::LeftDeepJoin(LeftDeepJoin { lhs, .. })
| Self::EqJoin(LeftDeepJoin { lhs, .. }, ..) => {
| Self::InnerEqJoin(LeftDeepJoin { lhs, .. }, ..)
| Self::LeftOuterEqJoin(LeftDeepJoin { lhs, .. }, ..) => {
lhs.visit_mut(f);
}
Self::RelVar(..) => {}
Expand All @@ -243,7 +247,11 @@ impl RelExpr {
pub fn nfields(&self) -> usize {
match self {
Self::RelVar(..) => 1,
Self::LeftDeepJoin(join) | Self::EqJoin(join, ..) => join.lhs.nfields() + 1,
Self::LeftDeepJoin(join)
| Self::InnerEqJoin(join, ..)
| Self::LeftOuterEqJoin(join, ..) => {
join.lhs.nfields() + 1
}
Self::Select(input, _) => input.nfields(),
}
}
Expand All @@ -252,7 +260,9 @@ impl RelExpr {
pub fn has_field(&self, field: &str) -> bool {
match self {
Self::RelVar(Relvar { alias, .. }) => alias.as_ref() == field,
Self::LeftDeepJoin(join) | Self::EqJoin(join, ..) => {
Self::LeftDeepJoin(join)
| Self::InnerEqJoin(join, ..)
| Self::LeftOuterEqJoin(join, ..) => {
join.rhs.alias.as_ref() == field || join.lhs.has_field(field)
}
Self::Select(input, _) => input.has_field(field),
Expand All @@ -264,10 +274,12 @@ impl RelExpr {
match self {
Self::RelVar(relvar) if relvar.alias.as_ref() == alias => Some(&relvar.schema),
Self::Select(input, _) => input.find_table_schema(alias),
Self::EqJoin(LeftDeepJoin { rhs, .. }, ..) if rhs.alias.as_ref() == alias => Some(&rhs.schema),
Self::EqJoin(LeftDeepJoin { lhs, .. }, ..) => lhs.find_table_schema(alias),
Self::LeftDeepJoin(LeftDeepJoin { rhs, .. }) if rhs.alias.as_ref() == alias => Some(&rhs.schema),
Self::LeftDeepJoin(LeftDeepJoin { lhs, .. }) => lhs.find_table_schema(alias),
Self::InnerEqJoin(LeftDeepJoin { rhs, .. }, ..) if rhs.alias.as_ref() == alias => Some(&rhs.schema),
Self::InnerEqJoin(LeftDeepJoin { lhs, .. }, ..) => lhs.find_table_schema(alias),
Self::LeftOuterEqJoin(LeftDeepJoin { rhs, .. }, ..) if rhs.alias.as_ref() == alias => Some(&rhs.schema),
Self::LeftOuterEqJoin(LeftDeepJoin { lhs, .. }, ..) => lhs.find_table_schema(alias),
_ => None,
}
}
Expand Down
Loading