Skip to content

Commit 06e5be5

Browse files
feat: Type check DML (#1727)
1 parent f559f0a commit 06e5be5

File tree

8 files changed

+497
-26
lines changed

8 files changed

+497
-26
lines changed

crates/planner/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ license-file = "LICENSE"
88
derive_more.workspace = true
99
thiserror.workspace = true
1010
spacetimedb-lib.workspace = true
11+
spacetimedb-primitives.workspace = true
1112
spacetimedb-sats.workspace = true
1213
spacetimedb-schema.workspace = true
1314
spacetimedb-sql-parser.workspace = true
1415

1516
[dev-dependencies]
1617
spacetimedb-lib.workspace = true
17-
spacetimedb-primitives.workspace = true

crates/planner/src/logical/bind.rs

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,122 @@ pub trait SchemaView {
2525
fn schema(&self, name: &str, case_sensitive: bool) -> Option<Arc<TableSchema>>;
2626
}
2727

28+
pub trait TypeChecker {
29+
type Ast;
30+
type Set;
31+
32+
fn type_ast(ctx: &mut TyCtx, ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<RelExpr>;
33+
34+
fn type_set(ctx: &mut TyCtx, ast: Self::Set, tx: &impl SchemaView) -> TypingResult<RelExpr>;
35+
36+
fn type_from(ctx: &mut TyCtx, from: SqlFrom<Self::Ast>, tx: &impl SchemaView) -> TypingResult<(RelExpr, Vars)> {
37+
match from {
38+
SqlFrom::Expr(expr, None) => Self::type_rel(ctx, expr, tx),
39+
SqlFrom::Expr(expr, Some(alias)) => {
40+
let (expr, _) = Self::type_rel(ctx, expr, tx)?;
41+
let ty = expr.ty_id();
42+
Ok((expr, vec![(alias.name, ty)].into()))
43+
}
44+
SqlFrom::Join(r, alias, joins) => {
45+
let (mut vars, mut args, mut exprs) = (Vars::default(), Vec::new(), Vec::new());
46+
47+
let (r, _) = Self::type_rel(ctx, r, tx)?;
48+
let ty = r.ty_id();
49+
50+
args.push(r);
51+
vars.push((alias.name, ty));
52+
53+
for join in joins {
54+
let (r, _) = Self::type_rel(ctx, join.expr, tx)?;
55+
let ty = r.ty_id();
56+
57+
args.push(r);
58+
vars.push((join.alias.name, ty));
59+
60+
if let Some(on) = join.on {
61+
exprs.push(type_expr(ctx, &vars, on, Some(TyId::BOOL))?);
62+
}
63+
}
64+
let types = vars.iter().map(|(_, ty)| *ty).collect();
65+
let ty = Type::Tup(types);
66+
let input = RelExpr::Join(args.into(), ctx.add(ty));
67+
Ok((RelExpr::select(input, vars.clone(), exprs), vars))
68+
}
69+
}
70+
}
71+
72+
fn type_rel(ctx: &mut TyCtx, expr: ast::RelExpr<Self::Ast>, tx: &impl SchemaView) -> TypingResult<(RelExpr, Vars)> {
73+
match expr {
74+
ast::RelExpr::Var(var) => {
75+
let schema = tx
76+
.schema(&var.name, var.case_sensitive)
77+
.ok_or_else(|| Unresolved::table(&var.name))
78+
.map_err(TypingError::from)?;
79+
let mut types = Vec::new();
80+
for ColumnSchema { col_name, col_type, .. } in schema.columns() {
81+
let ty = Type::Alg(col_type.clone());
82+
let id = ctx.add(ty);
83+
types.push((col_name.to_string(), id));
84+
}
85+
let ty = Type::Var(types.into_boxed_slice());
86+
let id = ctx.add(ty);
87+
Ok((RelExpr::RelVar(schema, id), vec![(var.name, id)].into()))
88+
}
89+
ast::RelExpr::Ast(ast) => Ok((Self::type_ast(ctx, *ast, tx)?, Vars::default())),
90+
}
91+
}
92+
}
93+
94+
/// Type checker for subscriptions
95+
struct SubChecker;
96+
97+
impl TypeChecker for SubChecker {
98+
type Ast = SqlAst;
99+
type Set = SqlAst;
100+
101+
fn type_ast(ctx: &mut TyCtx, ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<RelExpr> {
102+
Self::type_set(ctx, ast, tx)
103+
}
104+
105+
fn type_set(ctx: &mut TyCtx, ast: Self::Set, tx: &impl SchemaView) -> TypingResult<RelExpr> {
106+
match ast {
107+
SqlAst::Union(a, b) => {
108+
let a = type_ast(ctx, *a, tx)?;
109+
let b = type_ast(ctx, *b, tx)?;
110+
assert_eq_types(a.ty_id().try_with_ctx(ctx)?, b.ty_id().try_with_ctx(ctx)?)?;
111+
Ok(RelExpr::Union(Box::new(a), Box::new(b)))
112+
}
113+
SqlAst::Minus(a, b) => {
114+
let a = type_ast(ctx, *a, tx)?;
115+
let b = type_ast(ctx, *b, tx)?;
116+
assert_eq_types(a.ty_id().try_with_ctx(ctx)?, b.ty_id().try_with_ctx(ctx)?)?;
117+
Ok(RelExpr::Minus(Box::new(a), Box::new(b)))
118+
}
119+
SqlAst::Select(SqlSelect {
120+
project,
121+
from,
122+
filter: None,
123+
}) => {
124+
let (arg, vars) = type_from(ctx, from, tx)?;
125+
type_proj(ctx, project, arg, vars)
126+
}
127+
SqlAst::Select(SqlSelect {
128+
project,
129+
from,
130+
filter: Some(expr),
131+
}) => {
132+
let (from, vars) = type_from(ctx, from, tx)?;
133+
let arg = type_select(ctx, expr, from, vars.clone())?;
134+
type_proj(ctx, project, arg, vars.clone())
135+
}
136+
}
137+
}
138+
}
139+
28140
/// Parse and type check a subscription query
29141
pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<RelExpr> {
30142
let mut ctx = TyCtx::default();
31-
let expr = type_ast(&mut ctx, parse_subscription(sql)?, tx)?;
143+
let expr = SubChecker::type_ast(&mut ctx, parse_subscription(sql)?, tx)?;
32144
expect_table_type(&ctx, expr)
33145
}
34146

@@ -128,13 +240,13 @@ fn type_rel(ctx: &mut TyCtx, expr: ast::RelExpr<SqlAst>, tx: &impl SchemaView) -
128240
}
129241

130242
/// Type check and lower a [SqlExpr]
131-
fn type_select(ctx: &mut TyCtx, expr: SqlExpr, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
243+
pub(crate) fn type_select(ctx: &mut TyCtx, expr: SqlExpr, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
132244
let exprs = vec![type_expr(ctx, &vars, expr, Some(TyId::BOOL))?];
133245
Ok(RelExpr::select(input, vars, exprs))
134246
}
135247

136248
/// Type check and lower a [ast::Project]
137-
fn type_proj(ctx: &mut TyCtx, proj: ast::Project, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
249+
pub(crate) fn type_proj(ctx: &mut TyCtx, proj: ast::Project, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
138250
match proj {
139251
ast::Project::Star(None) => Ok(input),
140252
ast::Project::Star(Some(var)) => {
@@ -167,7 +279,7 @@ fn type_proj(ctx: &mut TyCtx, proj: ast::Project, input: RelExpr, vars: Vars) ->
167279
}
168280

169281
/// Type check and lower a [SqlExpr] into a logical [Expr].
170-
fn type_expr(ctx: &TyCtx, vars: &Vars, expr: SqlExpr, expected: Option<TyId>) -> TypingResult<Expr> {
282+
pub(crate) fn type_expr(ctx: &TyCtx, vars: &Vars, expr: SqlExpr, expected: Option<TyId>) -> TypingResult<Expr> {
171283
match (expr, expected) {
172284
(SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(TyId::BOOL)) => Ok(Expr::bool(v)),
173285
(SqlExpr::Lit(SqlLiteral::Bool(_)), Some(id)) => {
@@ -195,7 +307,7 @@ fn type_expr(ctx: &TyCtx, vars: &Vars, expr: SqlExpr, expected: Option<TyId>) ->
195307
}
196308

197309
/// Parses a source text literal as a particular type
198-
fn parse(ctx: &TyCtx, v: String, id: TyId) -> TypingResult<AlgebraicValue> {
310+
pub(crate) fn parse(ctx: &TyCtx, v: String, id: TyId) -> TypingResult<AlgebraicValue> {
199311
let err = |v, ty| TypingError::from(ConstraintViolation::lit(v, ty));
200312
match ctx.try_resolve(id)? {
201313
ty @ Type::Alg(AlgebraicType::I8) => v
@@ -260,7 +372,7 @@ fn parse(ctx: &TyCtx, v: String, id: TyId) -> TypingResult<AlgebraicValue> {
260372
}
261373

262374
/// Returns a type constraint violation for an unexpected type
263-
fn unexpected_type(expected: TypeWithCtx<'_>, inferred: TypeWithCtx<'_>) -> TypingError {
375+
pub(crate) fn unexpected_type(expected: TypeWithCtx<'_>, inferred: TypeWithCtx<'_>) -> TypingError {
264376
ConstraintViolation::eq(expected, inferred).into()
265377
}
266378

@@ -282,7 +394,7 @@ fn expect_op_type(ctx: &TyCtx, op: BinOp, expr: Expr) -> TypingResult<Expr> {
282394
}
283395
}
284396

285-
fn assert_eq_types(a: TypeWithCtx<'_>, b: TypeWithCtx<'_>) -> TypingResult<()> {
397+
pub(crate) fn assert_eq_types(a: TypeWithCtx<'_>, b: TypeWithCtx<'_>) -> TypingResult<()> {
286398
if a == b {
287399
Ok(())
288400
} else {

crates/planner/src/logical/errors.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
use spacetimedb_sql_parser::{ast::BinOp, parser::errors::SqlParseError};
22
use thiserror::Error;
33

4-
use super::ty::{InvalidTyId, TypeWithCtx};
4+
use super::{
5+
stmt::InvalidVar,
6+
ty::{InvalidTyId, TypeWithCtx},
7+
};
58

69
#[derive(Error, Debug)]
710
pub enum ConstraintViolation {
811
#[error("(expected) {expected} != {inferred} (inferred)")]
912
Eq { expected: String, inferred: String },
10-
#[error("{ty} is not a numeric type")]
13+
#[error("`{ty}` is not a numeric type")]
1114
Num { ty: String },
12-
#[error("{ty} cannot be interpreted as a byte array")]
15+
#[error("`{ty}` cannot be interpreted as a byte array")]
1316
Hex { ty: String },
14-
#[error("{expr} cannot be parsed as type {ty}")]
17+
#[error("`{expr}` cannot be parsed as type `{ty}`")]
1518
Lit { expr: String, ty: String },
16-
#[error("The binary operator {op} does not support type {ty}")]
19+
#[error("The binary operator `{op}` does not support type `{ty}`")]
1720
Bin { op: BinOp, ty: String },
1821
}
1922

@@ -52,11 +55,11 @@ impl ConstraintViolation {
5255

5356
#[derive(Error, Debug)]
5457
pub enum Unresolved {
55-
#[error("Cannot resolve {0}")]
58+
#[error("Cannot resolve `{0}`")]
5659
Var(String),
57-
#[error("Cannot resolve table {0}")]
60+
#[error("Cannot resolve table `{0}`")]
5861
Table(String),
59-
#[error("Cannot resolve field {1} in {0}")]
62+
#[error("Cannot resolve field `{1}` in `{0}`")]
6063
Field(String, String),
6164
#[error("Cannot resolve type for literal expression")]
6265
Literal,
@@ -87,6 +90,19 @@ pub enum Unsupported {
8790
ProjectExpr,
8891
#[error("Unqualified column projections are not supported")]
8992
UnqualifiedProjectExpr,
93+
#[error("ORDER BY is not supported")]
94+
OrderBy,
95+
#[error("LIMIT is not supported")]
96+
Limit,
97+
}
98+
99+
// TODO: It might be better to return the missing/extra fields
100+
#[derive(Error, Debug)]
101+
#[error("Inserting a row with {values} values into `{table}` which has {fields} fields")]
102+
pub struct InsertError {
103+
pub table: String,
104+
pub values: usize,
105+
pub fields: usize,
90106
}
91107

92108
#[derive(Error, Debug)]
@@ -100,5 +116,9 @@ pub enum TypingError {
100116
#[error(transparent)]
101117
InvalidTyId(#[from] InvalidTyId),
102118
#[error(transparent)]
119+
InvalidVar(#[from] InvalidVar),
120+
#[error(transparent)]
121+
Insert(#[from] InsertError),
122+
#[error(transparent)]
103123
ParseError(#[from] SqlParseError),
104124
}

crates/planner/src/logical/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
pub mod bind;
22
pub mod errors;
33
pub mod expr;
4+
pub mod stmt;
45
pub mod ty;
56

67
/// Asserts that `$ty` is `$size` bytes in `static_assert_size($ty, $size)`.

0 commit comments

Comments
 (0)