diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 4e550b234b3..4994ef69431 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -82,8 +82,6 @@ pub enum DatabaseError { DatabasedOpened(PathBuf, anyhow::Error), } -// FIXME: reduce type size -#[expect(clippy::large_enum_variant)] #[derive(Error, Debug, EnumAsInner)] pub enum DBError { #[error("LibError: {0}")] diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index dc8a75640a6..b158ef35ee0 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -539,8 +539,6 @@ pub enum InitDatabaseError { Other(anyhow::Error), } -// FIXME: reduce type size -#[expect(clippy::large_enum_variant)] #[derive(thiserror::Error, Debug)] pub enum ClientConnectedError { #[error(transparent)] diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index b089db36be6..d12e811ca7a 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -1126,6 +1126,11 @@ pub(crate) mod tests { Ok(()) } + /// Test we are protected against stack overflows when: + /// 1. The query is too large (too many characters) + /// 2. The AST is too deep + /// + /// Exercise the limit [`recursion::MAX_RECURSION_EXPR`] #[test] fn test_large_query_no_panic() -> ResultTest<()> { let db = TestDB::durable()?; @@ -1138,16 +1143,43 @@ pub(crate) mod tests { ) .unwrap(); - let mut query = "select * from test where ".to_string(); - for x in 0..1_000 { - for y in 0..1_000 { - let fragment = format!("((x = {x}) and y = {y}) or"); - query.push_str(&fragment); + let build_query = |total| { + let mut sql = "select * from test where ".to_string(); + for x in 1..total { + let fragment = format!("x = {x} or "); + sql.push_str(&fragment.repeat((total - 1) as usize)); } + sql.push_str("(y = 0)"); + sql + }; + let run = |db: &RelationalDB, sep: char, sql_text: &str| { + run_for_testing(db, sql_text).map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string()) + }; + let sql = build_query(1_000); + assert_eq!( + run(&db, ':', &sql), + Err("SQL query exceeds maximum allowed length".to_string()) + ); + + let sql = build_query(41); // This causes stack overflow without the limit + assert_eq!(run(&db, ',', &sql), Err("Recursion limit exceeded".to_string())); + + let sql = build_query(40); // The max we can with the current limit + assert!(run(&db, ',', &sql).is_ok(), "Expected query to run without panic"); + + // Check no overflow with lot of joins + let mut sql = "SELECT test.* FROM test ".to_string(); + // We could push up to 700 joins without overflow as long we don't have any conditions, + // but here execution become too slow. + // TODO: Move this test to the `Plan` + for i in 0..200 { + sql.push_str(&format!("JOIN test AS m{i} ON test.x = m{i}.y ")); } - query.push_str("((x = 1000) and (y = 1000))"); - assert!(run_for_testing(&db, &query).is_err()); + assert!( + run(&db, ',', &sql).is_ok(), + "Query with many joins and conditions should not overflow" + ); Ok(()) } diff --git a/crates/expr/src/errors.rs b/crates/expr/src/errors.rs index d611ec945d2..c523fb9452e 100644 --- a/crates/expr/src/errors.rs +++ b/crates/expr/src/errors.rs @@ -122,8 +122,6 @@ pub struct DuplicateName(pub String); #[error("`filter!` does not support column projections; Must return table rows")] pub struct FilterReturnType; -// FIXME: reduce type size -#[expect(clippy::large_enum_variant)] #[derive(Error, Debug)] pub enum TypingError { #[error(transparent)] diff --git a/crates/expr/src/lib.rs b/crates/expr/src/lib.rs index 4860bdc882e..2d9b3cdc5ab 100644 --- a/crates/expr/src/lib.rs +++ b/crates/expr/src/lib.rs @@ -19,6 +19,7 @@ use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; use spacetimedb_sats::algebraic_value::ser::ValueSerializer; use spacetimedb_schema::schema::ColumnSchema; use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral}; +use spacetimedb_sql_parser::parser::recursion; pub mod check; pub mod errors; @@ -78,8 +79,14 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T } } -/// Type check and lower a [SqlExpr] into a logical [Expr]. -pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult { +// These types determine the size of each stack frame during type checking. +// Changing their sizes will require updating the recursion limit to avoid stack overflows. +const _: () = assert!(size_of::>() == 64); +const _: () = assert!(size_of::() == 40); + +fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, depth: usize) -> TypingResult { + recursion::guard(depth, recursion::MAX_RECURSION_TYP_EXPR, "expr::type_expr")?; + match (expr, expected) { (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)), (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()), @@ -117,21 +124,21 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra })) } (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => { - let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?; - let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?; + let a = _type_expr(vars, *a, Some(&AlgebraicType::Bool), depth + 1)?; + let b = _type_expr(vars, *b, Some(&AlgebraicType::Bool), depth + 1)?; Ok(Expr::LogOp(op, Box::new(a), Box::new(b))) } (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => { - let b = type_expr(vars, *b, None)?; - let a = type_expr(vars, *a, Some(b.ty()))?; + let b = _type_expr(vars, *b, None, depth + 1)?; + let a = _type_expr(vars, *a, Some(b.ty()), depth + 1)?; if !op_supports_type(op, a.ty()) { return Err(InvalidOp::new(op, a.ty()).into()); } Ok(Expr::BinOp(op, Box::new(a), Box::new(b))) } (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => { - let a = type_expr(vars, *a, None)?; - let b = type_expr(vars, *b, Some(a.ty()))?; + let a = _type_expr(vars, *a, None, depth + 1)?; + let b = _type_expr(vars, *b, Some(a.ty()), depth + 1)?; if !op_supports_type(op, a.ty()) { return Err(InvalidOp::new(op, a.ty()).into()); } @@ -144,6 +151,11 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra } } +/// Type check and lower a [SqlExpr] into a logical [Expr]. +pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult { + _type_expr(vars, expr, expected, 0) +} + /// Is this type compatible with this binary operator? fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool { t.is_bool() diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 39abf447118..50e9bdb4c22 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -450,15 +450,16 @@ pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx) #[cfg(test)] mod tests { - use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType}; - use spacetimedb_schema::def::ModuleDef; - + use super::Statement; + use crate::ast::LogOp; use crate::check::{ test_utils::{build_module_def, SchemaViewer}, - SchemaView, TypingResult, + Relvars, SchemaView, TypingResult, }; - - use super::Statement; + use crate::type_expr; + use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType}; + use spacetimedb_schema::def::ModuleDef; + use spacetimedb_sql_parser::ast::{SqlExpr, SqlLiteral}; fn module_def() -> ModuleDef { build_module_def(vec![ @@ -519,4 +520,27 @@ mod tests { assert!(result.is_err()); } } + + /// Manually build the AST for a recursive query, + /// because we limit the length of the query to prevent stack overflow on parsing. + /// Exercise the limit [`recursion::MAX_RECURSION_TYP_EXPR`] + #[test] + fn typing_recursion() { + let build_query = |total, sep: char| { + let mut expr = SqlExpr::Lit(SqlLiteral::Bool(true)); + for _ in 1..total { + let next = SqlExpr::Log( + Box::new(SqlExpr::Lit(SqlLiteral::Bool(true))), + Box::new(SqlExpr::Lit(SqlLiteral::Bool(false))), + LogOp::And, + ); + expr = SqlExpr::Log(Box::new(expr), Box::new(next), LogOp::And); + } + type_expr(&Relvars::default(), expr, Some(&AlgebraicType::Bool)) + .map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string()) + }; + assert_eq!(build_query(2_501, ','), Err("Recursion limit exceeded".to_string())); + + assert!(build_query(2_500, ',').is_ok()); + } } diff --git a/crates/sql-parser/src/parser/errors.rs b/crates/sql-parser/src/parser/errors.rs index 510a5747414..953a031b8b8 100644 --- a/crates/sql-parser/src/parser/errors.rs +++ b/crates/sql-parser/src/parser/errors.rs @@ -50,7 +50,7 @@ pub enum SqlUnsupported { #[error("Unsupported FROM expression: {0}")] From(TableFactor), #[error("Unsupported set operation: {0}")] - SetOp(SetExpr), + SetOp(Box), #[error("Unsupported INSERT expression: {0}")] Insert(Query), #[error("Unsupported INSERT value: {0}")] @@ -93,14 +93,34 @@ pub enum SqlRequired { JoinAlias, } +#[derive(Error, Debug)] +#[error("Recursion limit exceeded, `{source_}`")] +pub struct RecursionError { + pub(crate) source_: &'static str, +} + #[derive(Error, Debug)] pub enum SqlParseError { #[error(transparent)] - SqlUnsupported(#[from] SqlUnsupported), + SqlUnsupported(#[from] Box), #[error(transparent)] - SubscriptionUnsupported(#[from] SubscriptionUnsupported), + SubscriptionUnsupported(#[from] Box), #[error(transparent)] SqlRequired(#[from] SqlRequired), #[error(transparent)] ParserError(#[from] ParserError), + #[error(transparent)] + Recursion(#[from] RecursionError), +} + +impl From for SqlParseError { + fn from(value: SubscriptionUnsupported) -> Self { + SqlParseError::SubscriptionUnsupported(Box::new(value)) + } +} + +impl From for SqlParseError { + fn from(value: SqlUnsupported) -> Self { + SqlParseError::SqlUnsupported(Box::new(value)) + } } diff --git a/crates/sql-parser/src/parser/mod.rs b/crates/sql-parser/src/parser/mod.rs index 61260823a43..9e6e5642bda 100644 --- a/crates/sql-parser/src/parser/mod.rs +++ b/crates/sql-parser/src/parser/mod.rs @@ -10,6 +10,7 @@ use crate::ast::{ }; pub mod errors; +pub mod recursion; pub mod sql; pub mod sub; @@ -61,11 +62,14 @@ trait RelParser { Ok(SqlJoin { var, alias, - on: Some(parse_expr(Expr::BinaryOp { - left, - op: BinaryOperator::Eq, - right, - })?), + on: Some(parse_expr( + Expr::BinaryOp { + left, + op: BinaryOperator::Eq, + right, + }, + 0, + )?), }) } _ => Err(SqlUnsupported::JoinType.into()), @@ -203,16 +207,22 @@ pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult { } } +// These types determine the size of [`parse_expr`]'s stack frame. +// Changing their sizes will require updating the recursion limit to avoid stack overflows. +const _: () = assert!(size_of::() == 168); +const _: () = assert!(size_of::>() == 40); + /// Parse a scalar expression -pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult { - fn signed_num(sign: impl Into, expr: Expr) -> Result { +fn parse_expr(expr: Expr, depth: usize) -> SqlParseResult { + fn signed_num(sign: impl Into, expr: Expr) -> Result> { match expr { Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))), - expr => Err(SqlUnsupported::Expr(expr)), + expr => Err(SqlUnsupported::Expr(expr).into()), } } + recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?; match expr { - Expr::Nested(expr) => parse_expr(*expr), + Expr::Nested(expr) => parse_expr(*expr, depth + 1), Expr::Value(Value::Placeholder(param)) if ¶m == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)), Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)), Expr::UnaryOp { @@ -238,8 +248,8 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult { op: BinaryOperator::And, right, } => { - let l = parse_expr(*left)?; - let r = parse_expr(*right)?; + let l = parse_expr(*left, depth + 1)?; + let r = parse_expr(*right, depth + 1)?; Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And)) } Expr::BinaryOp { @@ -247,13 +257,13 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult { op: BinaryOperator::Or, right, } => { - let l = parse_expr(*left)?; - let r = parse_expr(*right)?; + let l = parse_expr(*left, depth + 1)?; + let r = parse_expr(*right, depth + 1)?; Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or)) } Expr::BinaryOp { left, op, right } => { - let l = parse_expr(*left)?; - let r = parse_expr(*right)?; + let l = parse_expr(*left, depth + 1)?; + let r = parse_expr(*right, depth + 1)?; Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?)) } _ => Err(SqlUnsupported::Expr(expr).into()), @@ -262,7 +272,7 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult { /// Parse an optional scalar expression pub(crate) fn parse_expr_opt(opt: Option) -> SqlParseResult> { - opt.map(parse_expr).transpose() + opt.map(|expr| parse_expr(expr, 0)).transpose() } /// Parse a scalar binary operator diff --git a/crates/sql-parser/src/parser/recursion.rs b/crates/sql-parser/src/parser/recursion.rs new file mode 100644 index 00000000000..f4e2b1dec65 --- /dev/null +++ b/crates/sql-parser/src/parser/recursion.rs @@ -0,0 +1,26 @@ +//! A utility for guarding against stack overflows in the SQL parser. +//! +//! Different parts of the parser may have different recursion limits, based in the size of the structures they parse. + +use crate::parser::errors::{RecursionError, SqlParseError}; + +/// A conservative limit for recursion depth on `parse_expr`. +pub const MAX_RECURSION_EXPR: usize = 1_600; +/// A conservative limit for recursion depth on `type_expr`. +pub const MAX_RECURSION_TYP_EXPR: usize = 2_500; + +/// A utility for guarding against stack overflows in the SQL parser. +/// +/// **Usage:** +/// ``` +/// use spacetimedb_sql_parser::parser::recursion; +/// let mut depth = 0; +/// assert!(recursion::guard(depth, 10, "test").is_ok()); +/// ``` +pub fn guard(depth: usize, limit: usize, source: &'static str) -> Result<(), SqlParseError> { + if depth > limit { + Err(RecursionError { source_: source }.into()) + } else { + Ok(()) + } +} diff --git a/crates/sql-parser/src/parser/sub.rs b/crates/sql-parser/src/parser/sub.rs index 08a917b0679..6a8ef34a1e8 100644 --- a/crates/sql-parser/src/parser/sub.rs +++ b/crates/sql-parser/src/parser/sub.rs @@ -111,7 +111,7 @@ impl RelParser for SubParser { fn parse_set_op(expr: SetExpr) -> SqlParseResult { match expr { SetExpr::Select(select) => parse_select(*select).map(SqlSelect::qualify_vars), - _ => Err(SqlUnsupported::SetOp(expr).into()), + _ => Err(SqlUnsupported::SetOp(Box::new(expr)).into()), } }