Skip to content

Add a recursion limit to the evaluation of type_expr & parse_expr #2935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions crates/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ pub enum DatabaseError {
DatabasedOpened(PathBuf, anyhow::Error),
}

// FIXME: reduce type size
#[expect(clippy::large_enum_variant)]
#[derive(Error, Debug, EnumAsInner)]
pub enum DBError {
#[error("LibError: {0}")]
Expand Down
2 changes: 0 additions & 2 deletions crates/core/src/host/module_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,6 @@ pub enum InitDatabaseError {
Other(anyhow::Error),
}

// FIXME: reduce type size
#[expect(clippy::large_enum_variant)]
#[derive(thiserror::Error, Debug)]
pub enum ClientConnectedError {
#[error(transparent)]
Expand Down
46 changes: 39 additions & 7 deletions crates/core/src/sql/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,11 @@ pub(crate) mod tests {
Ok(())
}

/// Test we are protected against stack overflows when:
/// 1. The query is too large (too many characters)
/// 2. The AST is too deep
///
/// Exercise the limit [`recursion::MAX_RECURSION_EXPR`]
#[test]
fn test_large_query_no_panic() -> ResultTest<()> {
let db = TestDB::durable()?;
Expand All @@ -1138,16 +1143,43 @@ pub(crate) mod tests {
)
.unwrap();

let mut query = "select * from test where ".to_string();
for x in 0..1_000 {
for y in 0..1_000 {
let fragment = format!("((x = {x}) and y = {y}) or");
query.push_str(&fragment);
let build_query = |total| {
let mut sql = "select * from test where ".to_string();
for x in 1..total {
let fragment = format!("x = {x} or ");
sql.push_str(&fragment.repeat((total - 1) as usize));
}
sql.push_str("(y = 0)");
sql
};
let run = |db: &RelationalDB, sep: char, sql_text: &str| {
run_for_testing(db, sql_text).map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string())
};
let sql = build_query(1_000);
assert_eq!(
run(&db, ':', &sql),
Err("SQL query exceeds maximum allowed length".to_string())
);

let sql = build_query(41); // This causes stack overflow without the limit
assert_eq!(run(&db, ',', &sql), Err("Recursion limit exceeded".to_string()));

let sql = build_query(40); // The max we can with the current limit
assert!(run(&db, ',', &sql).is_ok(), "Expected query to run without panic");

// Check no overflow with lot of joins
let mut sql = "SELECT test.* FROM test ".to_string();
// We could push up to 700 joins without overflow as long we don't have any conditions,
// but here execution become too slow.
// TODO: Move this test to the `Plan`
for i in 0..200 {
sql.push_str(&format!("JOIN test AS m{i} ON test.x = m{i}.y "));
}
query.push_str("((x = 1000) and (y = 1000))");

assert!(run_for_testing(&db, &query).is_err());
assert!(
run(&db, ',', &sql).is_ok(),
"Query with many joins and conditions should not overflow"
);
Ok(())
}

Expand Down
2 changes: 0 additions & 2 deletions crates/expr/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ pub struct DuplicateName(pub String);
#[error("`filter!` does not support column projections; Must return table rows")]
pub struct FilterReturnType;

// FIXME: reduce type size
#[expect(clippy::large_enum_variant)]
#[derive(Error, Debug)]
pub enum TypingError {
#[error(transparent)]
Expand Down
28 changes: 20 additions & 8 deletions crates/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
use spacetimedb_schema::schema::ColumnSchema;
use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
use spacetimedb_sql_parser::parser::recursion;

pub mod check;
pub mod errors;
Expand Down Expand Up @@ -78,8 +79,14 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T
}
}

/// Type check and lower a [SqlExpr] into a logical [Expr].
pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
// These types determine the size of each stack frame during type checking.
// Changing their sizes will require updating the recursion limit to avoid stack overflows.
const _: () = assert!(size_of::<TypingResult<Expr>>() == 64);
const _: () = assert!(size_of::<SqlExpr>() == 40);

fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, depth: usize) -> TypingResult<Expr> {
recursion::guard(depth, recursion::MAX_RECURSION_TYP_EXPR, "expr::type_expr")?;

match (expr, expected) {
(SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
(SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
Expand Down Expand Up @@ -117,21 +124,21 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra
}))
}
(SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?;
let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?;
let a = _type_expr(vars, *a, Some(&AlgebraicType::Bool), depth + 1)?;
let b = _type_expr(vars, *b, Some(&AlgebraicType::Bool), depth + 1)?;
Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
}
(SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => {
let b = type_expr(vars, *b, None)?;
let a = type_expr(vars, *a, Some(b.ty()))?;
let b = _type_expr(vars, *b, None, depth + 1)?;
let a = _type_expr(vars, *a, Some(b.ty()), depth + 1)?;
if !op_supports_type(op, a.ty()) {
return Err(InvalidOp::new(op, a.ty()).into());
}
Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
}
(SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => {
let a = type_expr(vars, *a, None)?;
let b = type_expr(vars, *b, Some(a.ty()))?;
let a = _type_expr(vars, *a, None, depth + 1)?;
let b = _type_expr(vars, *b, Some(a.ty()), depth + 1)?;
if !op_supports_type(op, a.ty()) {
return Err(InvalidOp::new(op, a.ty()).into());
}
Expand All @@ -144,6 +151,11 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra
}
}

/// Type check and lower a [SqlExpr] into a logical [Expr].
pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
_type_expr(vars, expr, expected, 0)
}

/// Is this type compatible with this binary operator?
fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
t.is_bool()
Expand Down
36 changes: 30 additions & 6 deletions crates/expr/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,15 +450,16 @@ pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx)

#[cfg(test)]
mod tests {
use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType};
use spacetimedb_schema::def::ModuleDef;

use super::Statement;
use crate::ast::LogOp;
use crate::check::{
test_utils::{build_module_def, SchemaViewer},
SchemaView, TypingResult,
Relvars, SchemaView, TypingResult,
};

use super::Statement;
use crate::type_expr;
use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType};
use spacetimedb_schema::def::ModuleDef;
use spacetimedb_sql_parser::ast::{SqlExpr, SqlLiteral};

fn module_def() -> ModuleDef {
build_module_def(vec![
Expand Down Expand Up @@ -519,4 +520,27 @@ mod tests {
assert!(result.is_err());
}
}

/// Manually build the AST for a recursive query,
/// because we limit the length of the query to prevent stack overflow on parsing.
/// Exercise the limit [`recursion::MAX_RECURSION_TYP_EXPR`]
#[test]
fn typing_recursion() {
let build_query = |total, sep: char| {
let mut expr = SqlExpr::Lit(SqlLiteral::Bool(true));
for _ in 1..total {
let next = SqlExpr::Log(
Box::new(SqlExpr::Lit(SqlLiteral::Bool(true))),
Box::new(SqlExpr::Lit(SqlLiteral::Bool(false))),
LogOp::And,
);
expr = SqlExpr::Log(Box::new(expr), Box::new(next), LogOp::And);
}
type_expr(&Relvars::default(), expr, Some(&AlgebraicType::Bool))
.map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string())
};
assert_eq!(build_query(2_501, ','), Err("Recursion limit exceeded".to_string()));

assert!(build_query(2_500, ',').is_ok());
}
}
26 changes: 23 additions & 3 deletions crates/sql-parser/src/parser/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub enum SqlUnsupported {
#[error("Unsupported FROM expression: {0}")]
From(TableFactor),
#[error("Unsupported set operation: {0}")]
SetOp(SetExpr),
SetOp(Box<SetExpr>),
#[error("Unsupported INSERT expression: {0}")]
Insert(Query),
#[error("Unsupported INSERT value: {0}")]
Expand Down Expand Up @@ -93,14 +93,34 @@ pub enum SqlRequired {
JoinAlias,
}

#[derive(Error, Debug)]
#[error("Recursion limit exceeded, `{source_}`")]
pub struct RecursionError {
pub(crate) source_: &'static str,
}

#[derive(Error, Debug)]
pub enum SqlParseError {
#[error(transparent)]
SqlUnsupported(#[from] SqlUnsupported),
SqlUnsupported(#[from] Box<SqlUnsupported>),
#[error(transparent)]
SubscriptionUnsupported(#[from] SubscriptionUnsupported),
SubscriptionUnsupported(#[from] Box<SubscriptionUnsupported>),
#[error(transparent)]
SqlRequired(#[from] SqlRequired),
#[error(transparent)]
ParserError(#[from] ParserError),
#[error(transparent)]
Recursion(#[from] RecursionError),
}

impl From<SubscriptionUnsupported> for SqlParseError {
fn from(value: SubscriptionUnsupported) -> Self {
SqlParseError::SubscriptionUnsupported(Box::new(value))
}
}

impl From<SqlUnsupported> for SqlParseError {
fn from(value: SqlUnsupported) -> Self {
SqlParseError::SqlUnsupported(Box::new(value))
}
}
42 changes: 26 additions & 16 deletions crates/sql-parser/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::ast::{
};

pub mod errors;
pub mod recursion;
pub mod sql;
pub mod sub;

Expand Down Expand Up @@ -61,11 +62,14 @@ trait RelParser {
Ok(SqlJoin {
var,
alias,
on: Some(parse_expr(Expr::BinaryOp {
left,
op: BinaryOperator::Eq,
right,
})?),
on: Some(parse_expr(
Expr::BinaryOp {
left,
op: BinaryOperator::Eq,
right,
},
0,
)?),
})
}
_ => Err(SqlUnsupported::JoinType.into()),
Expand Down Expand Up @@ -203,16 +207,22 @@ pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
}
}

// These types determine the size of [`parse_expr`]'s stack frame.
// Changing their sizes will require updating the recursion limit to avoid stack overflows.
const _: () = assert!(size_of::<Expr>() == 168);
const _: () = assert!(size_of::<SqlParseResult<SqlExpr>>() == 40);

/// Parse a scalar expression
pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult<SqlExpr> {
fn signed_num(sign: impl Into<String>, expr: Expr) -> Result<SqlExpr, SqlUnsupported> {
fn parse_expr(expr: Expr, depth: usize) -> SqlParseResult<SqlExpr> {
fn signed_num(sign: impl Into<String>, expr: Expr) -> Result<SqlExpr, Box<SqlUnsupported>> {
match expr {
Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))),
expr => Err(SqlUnsupported::Expr(expr)),
expr => Err(SqlUnsupported::Expr(expr).into()),
}
}
recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?;
match expr {
Expr::Nested(expr) => parse_expr(*expr),
Expr::Nested(expr) => parse_expr(*expr, depth + 1),
Expr::Value(Value::Placeholder(param)) if &param == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)),
Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
Expr::UnaryOp {
Expand All @@ -238,22 +248,22 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult<SqlExpr> {
op: BinaryOperator::And,
right,
} => {
let l = parse_expr(*left)?;
let r = parse_expr(*right)?;
let l = parse_expr(*left, depth + 1)?;
let r = parse_expr(*right, depth + 1)?;
Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And))
}
Expr::BinaryOp {
left,
op: BinaryOperator::Or,
right,
} => {
let l = parse_expr(*left)?;
let r = parse_expr(*right)?;
let l = parse_expr(*left, depth + 1)?;
let r = parse_expr(*right, depth + 1)?;
Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or))
}
Expr::BinaryOp { left, op, right } => {
let l = parse_expr(*left)?;
let r = parse_expr(*right)?;
let l = parse_expr(*left, depth + 1)?;
let r = parse_expr(*right, depth + 1)?;
Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
}
_ => Err(SqlUnsupported::Expr(expr).into()),
Expand All @@ -262,7 +272,7 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult<SqlExpr> {

/// Parse an optional scalar expression
pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
opt.map(parse_expr).transpose()
opt.map(|expr| parse_expr(expr, 0)).transpose()
}

/// Parse a scalar binary operator
Expand Down
26 changes: 26 additions & 0 deletions crates/sql-parser/src/parser/recursion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//! A utility for guarding against stack overflows in the SQL parser.
//!
//! Different parts of the parser may have different recursion limits, based in the size of the structures they parse.

use crate::parser::errors::{RecursionError, SqlParseError};

/// A conservative limit for recursion depth on `parse_expr`.
pub const MAX_RECURSION_EXPR: usize = 1_600;
/// A conservative limit for recursion depth on `type_expr`.
pub const MAX_RECURSION_TYP_EXPR: usize = 2_500;

/// A utility for guarding against stack overflows in the SQL parser.
///
/// **Usage:**
/// ```
/// use spacetimedb_sql_parser::parser::recursion;
/// let mut depth = 0;
/// assert!(recursion::guard(depth, 10, "test").is_ok());
/// ```
pub fn guard(depth: usize, limit: usize, source: &'static str) -> Result<(), SqlParseError> {
if depth > limit {
Err(RecursionError { source_: source }.into())
} else {
Ok(())
}
}
2 changes: 1 addition & 1 deletion crates/sql-parser/src/parser/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl RelParser for SubParser {
fn parse_set_op(expr: SetExpr) -> SqlParseResult<SqlSelect> {
match expr {
SetExpr::Select(select) => parse_select(*select).map(SqlSelect::qualify_vars),
_ => Err(SqlUnsupported::SetOp(expr).into()),
_ => Err(SqlUnsupported::SetOp(Box::new(expr)).into()),
}
}

Expand Down
Loading