diff --git a/crates/expr/src/errors.rs b/crates/expr/src/errors.rs index c523fb9452e..7a706ebeaf9 100644 --- a/crates/expr/src/errors.rs +++ b/crates/expr/src/errors.rs @@ -114,6 +114,21 @@ impl UnexpectedType { } } +#[derive(Debug, Error)] +#[error("Unexpected array type: expected {expected}")] +pub struct UnexpectedArrayType { + expected: String, + // TODO: inferred +} + +impl UnexpectedArrayType { + pub fn new(expected: &AlgebraicType) -> Self { + Self { + expected: fmt_algebraic_type(expected).to_string(), + } + } +} + #[derive(Debug, Error)] #[error("Duplicate name `{0}`")] pub struct DuplicateName(pub String); @@ -144,6 +159,8 @@ pub enum TypingError { #[error(transparent)] Unexpected(#[from] UnexpectedType), #[error(transparent)] + UnexpectedArray(#[from] UnexpectedArrayType), + #[error(transparent)] Wildcard(#[from] InvalidWildcard), #[error(transparent)] DuplicateName(#[from] DuplicateName), diff --git a/crates/expr/src/expr.rs b/crates/expr/src/expr.rs index 4f4a68592a0..8ce3576a824 100644 --- a/crates/expr/src/expr.rs +++ b/crates/expr/src/expr.rs @@ -322,6 +322,8 @@ pub enum Expr { LogOp(LogOp, Box, Box), /// A typed literal expression Value(AlgebraicValue, AlgebraicType), + /// A typed literal tuple expression + Tuple(Box<[AlgebraicValue]>, AlgebraicType), /// A field projection Field(FieldProject), } @@ -335,7 +337,7 @@ impl Expr { a.visit(f); b.visit(f); } - Self::Value(..) | Self::Field(..) => {} + Self::Value(..) | Self::Tuple(..) | Self::Field(..) => {} } } @@ -347,7 +349,7 @@ impl Expr { a.visit_mut(f); b.visit_mut(f); } - Self::Value(..) | Self::Field(..) => {} + Self::Value(..) | Self::Tuple(..) | Self::Field(..) => {} } } @@ -365,7 +367,7 @@ impl Expr { pub fn ty(&self) -> &AlgebraicType { match self { Self::BinOp(..) | Self::LogOp(..) => &AlgebraicType::Bool, - Self::Value(_, ty) | Self::Field(FieldProject { ty, .. }) => ty, + Self::Value(_, ty) | Self::Tuple(_, ty) | Self::Field(FieldProject { ty, .. }) => ty, } } } diff --git a/crates/expr/src/lib.rs b/crates/expr/src/lib.rs index 4860bdc882e..e91821ce18b 100644 --- a/crates/expr/src/lib.rs +++ b/crates/expr/src/lib.rs @@ -7,14 +7,15 @@ use anyhow::Context; use bigdecimal::BigDecimal; use bigdecimal::ToPrimitive; use check::{Relvars, TypingResult}; -use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved}; +use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedArrayType, UnexpectedType, Unresolved}; use ethnum::i256; use ethnum::u256; use expr::AggType; use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr}; use spacetimedb_lib::ser::Serialize; use spacetimedb_lib::Timestamp; -use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity}; +use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity, ProductType, ProductTypeElement}; +use spacetimedb_sats::{ArrayType, ArrayValue, F32, F64}; use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; use spacetimedb_sats::algebraic_value::ser::ValueSerializer; use spacetimedb_schema::schema::ColumnSchema; @@ -90,6 +91,45 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?, ty.clone(), )), + (SqlExpr::Lit(SqlLiteral::Arr(_)), None) => { + Err(Unresolved::Literal.into()) + }, + (SqlExpr::Lit(SqlLiteral::Arr(_)), Some(ty)) => { + Err(UnexpectedArrayType::new(ty).into()) + }, + (SqlExpr::Tup(_), None) => { + Err(Unresolved::Literal.into()) + } + (SqlExpr::Tup(t), Some(&AlgebraicType::Product(ref pty))) => Ok(Expr::Tuple( + t.iter().zip(pty.elements.iter()).map(|(lit, ty)| { + match (lit, ty) { + (SqlLiteral::Bool(v), ProductTypeElement { + algebraic_type: AlgebraicType::Bool, + .. + }) => Ok(AlgebraicValue::Bool(*v)), + (SqlLiteral::Bool(_), ProductTypeElement { + algebraic_type: ty, + .. + }) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()), + (SqlLiteral::Str(v) | SqlLiteral::Num(v) | SqlLiteral::Hex(v), ProductTypeElement { + algebraic_type: ty, + .. + }) => Ok(parse(&v, ty).map_err(|_| InvalidLiteral::new(v.clone().into_string(), ty))?), + (SqlLiteral::Arr(v), ProductTypeElement { + algebraic_type: AlgebraicType::Array(a), + .. + }) => Ok(parse_array_value(v, a).map_err(|_| InvalidLiteral::new("[…]".into(), &a.elem_ty))?), + (SqlLiteral::Arr(_), ProductTypeElement { + algebraic_type: ty, + .. + }) => Err(UnexpectedArrayType::new(ty).into()), + } + }).collect::>>()?, + AlgebraicType::Product(pty.clone()), + )), + (SqlExpr::Tup(_), Some(ty)) => { + Err(UnexpectedType::new(&AlgebraicType::Product(ProductType::unit()), ty).into()) + } (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => { let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?; let ColumnSchema { col_pos, col_type, .. } = table_type @@ -145,15 +185,19 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra } /// Is this type compatible with this binary operator? -fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool { - t.is_bool() - || t.is_integer() - || t.is_float() - || t.is_string() - || t.is_bytes() - || t.is_identity() - || t.is_connection_id() - || t.is_timestamp() +fn op_supports_type(op: BinOp, ty: &AlgebraicType) -> bool { + match (ty, op) { + (AlgebraicType::Product(_), BinOp::Eq | BinOp::Ne) => true, + _ if ty.is_bool() => true, + _ if ty.is_integer() => true, + _ if ty.is_float() => true, + _ if ty.is_string() => true, + _ if ty.is_bytes() => true, + _ if ty.is_identity() => true, + _ if ty.is_connection_id() => true, + _ if ty.is_timestamp() => true, + _ => false, + } } /// Parse an integer literal into an [AlgebraicValue] @@ -346,6 +390,50 @@ pub(crate) fn parse(value: &str, ty: &AlgebraicType) -> anyhow::Result { + match $elem_ty { + AlgebraicType::Bool => ArrayValue::Bool($arr.iter().map(|x| match x { + SqlLiteral::Bool(b) => Ok(*b), + _ => Err(UnexpectedType::new(&$elem_ty, &AlgebraicType::Bool).into()), + }).collect::>>()?), + AlgebraicType::String => ArrayValue::String($arr.iter().map(|x| match x { + SqlLiteral::Str(b) => Ok(b.clone()), + _ => Err(UnexpectedType::new(&$elem_ty, &AlgebraicType::String).into()), + }).collect::]>>>()?), + $(AlgebraicType::$t => ArrayValue::$t($arr.iter().map(|x| match x { + SqlLiteral::Num(v) | SqlLiteral::Hex(v) => Ok(match parse(v, &$elem_ty).map_err(|_| InvalidLiteral::new(v.clone().into_string(), &$elem_ty))? { + AlgebraicValue::$t(r) => r, + _ => unreachable!(), // guaranteed by `parse()' + }.into()), + SqlLiteral::Str(_) => Err(UnexpectedType::new(&$elem_ty, &AlgebraicType::String).into()), + SqlLiteral::Bool(_) => Err(UnexpectedType::new(&$elem_ty, &AlgebraicType::Bool).into()), + SqlLiteral::Arr(_) => Err(UnexpectedArrayType::new(&$elem_ty).into()), + }).collect::>>()?),)* + _ => { + return Err(UnexpectedArrayType::new(&$elem_ty).into()); + } + } + } +} + +pub(crate) fn parse_array_value(arr: &Box<[SqlLiteral]>, a: &ArrayType) -> anyhow::Result { + Ok(AlgebraicValue::Array(parse_array_number!(arr, *a.elem_ty, + I8, i8, + U8, u8, + I16, i16, + U16, u16, + I32, i32, + U32, u32, + I128, i128, + U128, u128, + /*I256, i256, + U256, u256,*/ // TODO: Boxed + F32, F32, + F64, F64 + ))) +} + /// The source of a statement pub enum StatementSource { Subscription, diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 39abf447118..18dc3929da1 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -21,9 +21,9 @@ use crate::{ use super::{ check::{SchemaView, TypeChecker, TypingResult}, - errors::{InsertFieldsError, InsertValuesError, TypingError, UnexpectedType, Unresolved}, + errors::{InsertFieldsError, InsertValuesError, TypingError, UnexpectedArrayType, UnexpectedType, Unresolved}, expr::Expr, - parse, type_expr, type_proj, type_select, StatementCtx, StatementSource, + parse, parse_array_value, type_expr, type_proj, type_select, StatementCtx, StatementSource, }; pub enum Statement { @@ -140,6 +140,12 @@ pub fn type_insert(insert: SqlInsert, tx: &impl SchemaView) -> TypingResult { values.push(parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?); } + (SqlLiteral::Arr(v), AlgebraicType::Array(a)) => { + values.push(parse_array_value(&v, a).map_err(|_| InvalidLiteral::new("[…]".into(), &a.elem_ty))?); + } + (SqlLiteral::Arr(_), _) => { + return Err(UnexpectedArrayType::new(ty).into()); + } } } rows.push(ProductValue::from(values)); @@ -210,6 +216,15 @@ pub fn type_update(update: SqlUpdate, tx: &impl SchemaView) -> TypingResult { + values.push(( + *col_id, + parse_array_value(&v, a).map_err(|_| InvalidLiteral::new("[…]".into(), &a.elem_ty))?, + )); + } + (SqlLiteral::Arr(_), _) => { + return Err(UnexpectedArrayType::new(ty).into()); + } } } let mut vars = Relvars::default(); @@ -269,6 +284,7 @@ pub fn type_and_rewrite_set(set: SqlSet, tx: &impl SchemaView) -> TypingResult Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::Bool).into()), SqlLiteral::Str(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::String).into()), SqlLiteral::Hex(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::bytes()).into()), + SqlLiteral::Arr(_) => Err(UnexpectedArrayType::new(&AlgebraicType::U64).into()), SqlLiteral::Num(n) => { let table = tx.schema(ST_VAR_NAME).ok_or_else(|| Unresolved::table(ST_VAR_NAME))?; let var_name = AlgebraicValue::String(var_name); diff --git a/crates/physical-plan/src/compile.rs b/crates/physical-plan/src/compile.rs index 56215f5e354..7ed01e737d2 100644 --- a/crates/physical-plan/src/compile.rs +++ b/crates/physical-plan/src/compile.rs @@ -23,6 +23,7 @@ fn compile_expr(expr: Expr, var: &mut impl VarLabel) -> PhysicalExpr { PhysicalExpr::BinOp(op, a, b) } Expr::Value(v, _) => PhysicalExpr::Value(v), + Expr::Tuple(t, _) => PhysicalExpr::Tuple(t), Expr::Field(proj) => PhysicalExpr::Field(compile_field_project(var, proj)), } } diff --git a/crates/physical-plan/src/plan.rs b/crates/physical-plan/src/plan.rs index 6fa9b58297e..ae1aff9e56e 100644 --- a/crates/physical-plan/src/plan.rs +++ b/crates/physical-plan/src/plan.rs @@ -1031,6 +1031,8 @@ pub enum PhysicalExpr { BinOp(BinOp, Box, Box), /// A constant algebraic value Value(AlgebraicValue), + /// A tuple of constant algebraic values + Tuple(Box<[AlgebraicValue]>), /// A field projection expression Field(TupleField), } @@ -1094,6 +1096,7 @@ impl PhysicalExpr { pub fn map(self, f: &impl Fn(Self) -> Self) -> Self { match f(self) { value @ Self::Value(..) => value, + values @ Self::Tuple(..) => values, field @ Self::Field(..) => field, Self::BinOp(op, a, b) => Self::BinOp(op, Box::new(a.map(f)), Box::new(b.map(f))), Self::LogOp(op, exprs) => Self::LogOp(op, exprs.into_iter().map(|expr| expr.map(f)).collect()), @@ -1155,6 +1158,7 @@ impl PhysicalExpr { Cow::Owned(value) } Self::Value(v) => Cow::Borrowed(v), + Self::Tuple(t) => Cow::Owned(AlgebraicValue::Product(ProductValue {elements: t.clone()})), } } @@ -1173,7 +1177,7 @@ impl PhysicalExpr { .collect(), ), Self::BinOp(op, a, b) => Self::BinOp(op, Box::new(a.flatten()), Box::new(b.flatten())), - Self::Field(..) | Self::Value(..) => self, + Self::Field(..) | Self::Value(..) | Self::Tuple(..) => self, } } } diff --git a/crates/sats/src/algebraic_value.rs b/crates/sats/src/algebraic_value.rs index 2666765c6d2..31e670e6cb3 100644 --- a/crates/sats/src/algebraic_value.rs +++ b/crates/sats/src/algebraic_value.rs @@ -85,7 +85,7 @@ pub enum AlgebraicValue { I256(Box), /// A [`u256`] value of type [`AlgebraicType::U256`]. /// - /// We pack these to shrink `AlgebraicValue`. + /// We box these up to shrink `AlgebraicValue`. U256(Box), /// A totally ordered [`F32`] value of type [`AlgebraicType::F32`]. /// @@ -123,6 +123,18 @@ impl From for Packed { } } +macro_rules! impl_from_packed { + ($ty:ty) => { + impl From> for $ty { + fn from(packed: Packed<$ty>) -> Self { + packed.0 + } + } + }; +} +impl_from_packed!(i128); +impl_from_packed!(u128); + #[allow(non_snake_case)] impl AlgebraicValue { /// Extract the value and replace it with a dummy one that is cheap to make. diff --git a/crates/sql-parser/src/ast/mod.rs b/crates/sql-parser/src/ast/mod.rs index 776d4fc5006..acfb1755234 100644 --- a/crates/sql-parser/src/ast/mod.rs +++ b/crates/sql-parser/src/ast/mod.rs @@ -107,6 +107,8 @@ impl Project { pub enum SqlExpr { /// A constant expression Lit(SqlLiteral), + /// A tuple of constant expressions + Tup(Vec), /// Unqualified column ref Var(SqlIdent), /// A parameter prefixed with `:` @@ -123,7 +125,7 @@ impl SqlExpr { pub fn qualify_vars(self, with: SqlIdent) -> Self { match self { Self::Var(name) => Self::Field(with, name), - Self::Lit(..) | Self::Field(..) | Self::Param(..) => self, + Self::Lit(..) | Self::Tup(..) | Self::Field(..) | Self::Param(..) => self, Self::Bin(a, b, op) => Self::Bin( Box::new(a.qualify_vars(with.clone())), Box::new(b.qualify_vars(with)), @@ -149,7 +151,7 @@ impl SqlExpr { /// We need to know in order to hash subscription queries correctly. pub fn has_parameter(&self) -> bool { match self { - Self::Lit(_) | Self::Var(_) | Self::Field(..) => false, + Self::Lit(_) | Self::Tup(_) | Self::Var(_) | Self::Field(..) => false, Self::Param(Parameter::Sender) => true, Self::Bin(a, b, _) | Self::Log(a, b, _) => a.has_parameter() || b.has_parameter(), } @@ -158,7 +160,7 @@ impl SqlExpr { /// Replace the `:sender` parameter with the [Identity] it represents pub fn resolve_sender(self, sender_identity: Identity) -> Self { match self { - Self::Lit(_) | Self::Var(_) | Self::Field(..) => self, + Self::Lit(_) | Self::Tup(_) | Self::Var(_) | Self::Field(..) => self, Self::Param(Parameter::Sender) => { Self::Lit(SqlLiteral::Hex(String::from(sender_identity.to_hex()).into_boxed_str())) } @@ -207,6 +209,8 @@ pub enum SqlLiteral { Num(Box), /// A string value Str(Box), + /// A literal array + Arr(Box<[SqlLiteral]>), } /// Binary infix operators diff --git a/crates/sql-parser/src/parser/mod.rs b/crates/sql-parser/src/parser/mod.rs index 61260823a43..c6f4195ff0d 100644 --- a/crates/sql-parser/src/parser/mod.rs +++ b/crates/sql-parser/src/parser/mod.rs @@ -1,8 +1,8 @@ use errors::{SqlParseError, SqlRequired, SqlUnsupported}; use sqlparser::ast::{ - BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, - ObjectName, Query, SelectItem, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value, - WildcardAdditionalOptions, + Array, BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, + JoinOperator, ObjectName, Query, SelectItem, TableAlias, TableFactor, TableWithJoins, UnaryOperator, + Value, WildcardAdditionalOptions, }; use crate::ast::{ @@ -215,6 +215,12 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult { Expr::Nested(expr) => parse_expr(*expr), Expr::Value(Value::Placeholder(param)) if ¶m == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)), Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)), + Expr::Tuple(ref t) => Ok(SqlExpr::Tup(t.iter().map(|x| { + match x { + Expr::Value(v) => parse_literal(v.clone()), + _ => Err(SqlUnsupported::Expr(expr.clone()).into()), + } + }).collect::>>()?)), Expr::UnaryOp { op: UnaryOperator::Plus, expr, @@ -289,6 +295,16 @@ pub(crate) fn parse_literal(value: Value) -> SqlParseResult { } } +/// Parse a literal array expression +pub(crate) fn parse_literal_array(array: Array) -> SqlParseResult { + Ok(SqlLiteral::Arr(array.elem.into_iter().map(|expr| { + match expr { + Expr::Value(value) => Ok(parse_literal(value)?), + _ => Err(SqlUnsupported::Expr(expr).into()), + } + }).collect::>>()?)) +} + /// Parse an identifier pub(crate) fn parse_ident(ObjectName(parts): ObjectName) -> SqlParseResult { parse_parts(parts) diff --git a/crates/sql-parser/src/parser/sql.rs b/crates/sql-parser/src/parser/sql.rs index a1eb5078726..8270acaa4f2 100644 --- a/crates/sql-parser/src/parser/sql.rs +++ b/crates/sql-parser/src/parser/sql.rs @@ -142,7 +142,7 @@ use crate::ast::{ }; use super::{ - errors::SqlUnsupported, parse_expr_opt, parse_ident, parse_literal, parse_parts, parse_projection, RelParser, + errors::SqlUnsupported, parse_expr_opt, parse_ident, parse_literal, parse_literal_array, parse_parts, parse_projection, RelParser, SqlParseResult, }; @@ -242,10 +242,10 @@ fn parse_values(values: Query) -> SqlParseResult { for row in rows { let mut literals = Vec::new(); for expr in row { - if let Expr::Value(value) = expr { - literals.push(parse_literal(value)?); - } else { - return Err(SqlUnsupported::InsertValue(expr).into()); + match expr { + Expr::Array(array) => literals.push(parse_literal_array(array)?), + Expr::Value(value) => literals.push(parse_literal(value)?), + _ => return Err(SqlUnsupported::InsertValue(expr).into()), } } row_literals.push(literals);