Skip to content

Commit 004d978

Browse files
committed
Add a recursion limit to the evaluation of type_expr & parse_expr
1 parent ccc4d06 commit 004d978

File tree

9 files changed

+127
-38
lines changed

9 files changed

+127
-38
lines changed

crates/core/src/sql/execute.rs

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,9 @@ pub(crate) mod tests {
11261126
Ok(())
11271127
}
11281128

1129+
/// Test we are protected against recursion when:
1130+
/// 1. The query is too large
1131+
/// 2. The AST is too deep
11291132
#[test]
11301133
fn test_large_query_no_panic() -> ResultTest<()> {
11311134
let db = TestDB::durable()?;
@@ -1138,16 +1141,46 @@ pub(crate) mod tests {
11381141
)
11391142
.unwrap();
11401143

1141-
let mut query = "select * from test where ".to_string();
1142-
for x in 0..1_000 {
1143-
for y in 0..1_000 {
1144-
let fragment = format!("((x = {x}) and y = {y}) or");
1145-
query.push_str(&fragment);
1144+
let build_query = |total| {
1145+
let mut sql = "select * from test where ".to_string();
1146+
for x in 0..total {
1147+
for y in 0..total {
1148+
let fragment = format!("((x = {x}) and (y = {y})) or ");
1149+
sql.push_str(&fragment);
1150+
}
11461151
}
1152+
sql.push_str("((x = 1000) and (y = 1000))");
1153+
sql
1154+
};
1155+
let run = |db: &RelationalDB, sep: char, sql_text: &str| {
1156+
run_for_testing(db, sql_text).map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string())
1157+
};
1158+
let sql = build_query(1_000);
1159+
assert_eq!(
1160+
run(&db, ':', &sql),
1161+
Err("SQL query exceeds maximum allowed length".to_string())
1162+
);
1163+
1164+
// Exercise the limit [recursion::MAX_RECURSION_EXPR] && [recursion::MAX_RECURSION_TYP_EXPR]
1165+
let sql = build_query(8);
1166+
assert_eq!(run(&db, ',', &sql), Err("Recursion limit exceeded".to_string()));
1167+
1168+
let sql = build_query(7);
1169+
assert!(run(&db, ',', &sql).is_ok(), "Expected query to run without panic");
1170+
1171+
// Check no overflow with lot of joins
1172+
let mut sql = "SELECT test.* FROM test ".to_string();
1173+
// We could pust up to 700 joins without overflow as long we don't have any conditions,
1174+
// but here execution become too slow.
1175+
// TODO: Move this test to the `Plan`
1176+
for i in 0..200 {
1177+
sql.push_str(&format!("JOIN test AS m{i} ON test.x = m{i}.y "));
11471178
}
1148-
query.push_str("((x = 1000) and (y = 1000))");
11491179

1150-
assert!(run_for_testing(&db, &query).is_err());
1180+
assert!(
1181+
run(&db, ',', &sql).is_ok(),
1182+
"Query with many joins and conditions should not overflow"
1183+
);
11511184
Ok(())
11521185
}
11531186

crates/expr/src/check.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ pub trait TypeChecker {
9999
vars.insert(rhs.alias.clone(), rhs.schema.clone());
100100

101101
if let Some(on) = on {
102-
if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
102+
if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool), &mut 0)? {
103103
if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
104104
join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b);
105105
continue;

crates/expr/src/lib.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
1919
use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
2020
use spacetimedb_schema::schema::ColumnSchema;
2121
use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
22+
use spacetimedb_sql_parser::parser::recursion;
2223

2324
pub mod check;
2425
pub mod errors;
@@ -30,7 +31,7 @@ pub mod statement;
3031
pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult<RelExpr> {
3132
Ok(RelExpr::Select(
3233
Box::new(input),
33-
type_expr(vars, expr, Some(&AlgebraicType::Bool))?,
34+
type_expr(vars, expr, Some(&AlgebraicType::Bool), &mut 0)?,
3435
))
3536
}
3637

@@ -68,7 +69,7 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T
6869
return Err(DuplicateName(alias.into_string()).into());
6970
}
7071

71-
if let Expr::Field(p) = type_expr(vars, expr.into(), None)? {
72+
if let Expr::Field(p) = type_expr(vars, expr.into(), None, &mut 0)? {
7273
projections.push((alias, p));
7374
}
7475
}
@@ -79,7 +80,14 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T
7980
}
8081

8182
/// Type check and lower a [SqlExpr] into a logical [Expr].
82-
pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
83+
pub(crate) fn type_expr(
84+
vars: &Relvars,
85+
expr: SqlExpr,
86+
expected: Option<&AlgebraicType>,
87+
depth: &mut usize,
88+
) -> TypingResult<Expr> {
89+
recursion::guard(depth, recursion::MAX_RECURSION_TYP_EXPR, "expr::type_expr")?;
90+
8391
match (expr, expected) {
8492
(SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
8593
(SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
@@ -117,21 +125,21 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra
117125
}))
118126
}
119127
(SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
120-
let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?;
121-
let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?;
128+
let a = type_expr(vars, *a, Some(&AlgebraicType::Bool), depth)?;
129+
let b = type_expr(vars, *b, Some(&AlgebraicType::Bool), depth)?;
122130
Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
123131
}
124132
(SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => {
125-
let b = type_expr(vars, *b, None)?;
126-
let a = type_expr(vars, *a, Some(b.ty()))?;
133+
let b = type_expr(vars, *b, None, depth)?;
134+
let a = type_expr(vars, *a, Some(b.ty()), depth)?;
127135
if !op_supports_type(op, a.ty()) {
128136
return Err(InvalidOp::new(op, a.ty()).into());
129137
}
130138
Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
131139
}
132140
(SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => {
133-
let a = type_expr(vars, *a, None)?;
134-
let b = type_expr(vars, *b, Some(a.ty()))?;
141+
let a = type_expr(vars, *a, None, depth)?;
142+
let b = type_expr(vars, *b, Some(a.ty()), depth)?;
135143
if !op_supports_type(op, a.ty()) {
136144
return Err(InvalidOp::new(op, a.ty()).into());
137145
}

crates/expr/src/statement.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ pub fn type_delete(delete: SqlDelete, tx: &impl SchemaView) -> TypingResult<Tabl
162162
let mut vars = Relvars::default();
163163
vars.insert(table_name.clone(), from.clone());
164164
let expr = filter
165-
.map(|expr| type_expr(&vars, expr, Some(&AlgebraicType::Bool)))
165+
.map(|expr| type_expr(&vars, expr, Some(&AlgebraicType::Bool), &mut 0))
166166
.transpose()?;
167167
Ok(TableDelete {
168168
table: from,
@@ -216,7 +216,7 @@ pub fn type_update(update: SqlUpdate, tx: &impl SchemaView) -> TypingResult<Tabl
216216
vars.insert(table_name.clone(), schema.clone());
217217
let values = values.into_boxed_slice();
218218
let filter = filter
219-
.map(|expr| type_expr(&vars, expr, Some(&AlgebraicType::Bool)))
219+
.map(|expr| type_expr(&vars, expr, Some(&AlgebraicType::Bool), &mut 0))
220220
.transpose()?;
221221
Ok(TableUpdate {
222222
table: schema,

crates/sql-parser/src/parser/errors.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ pub enum SqlRequired {
9393
JoinAlias,
9494
}
9595

96+
#[derive(Error, Debug)]
97+
#[error("Recursion limit exceeded, `{message}` limit: {limit}")]
98+
pub struct RecursionError {
99+
pub(crate) limit: usize,
100+
pub(crate) message: String,
101+
}
102+
96103
#[derive(Error, Debug)]
97104
pub enum SqlParseError {
98105
#[error(transparent)]
@@ -103,4 +110,6 @@ pub enum SqlParseError {
103110
SqlRequired(#[from] SqlRequired),
104111
#[error(transparent)]
105112
ParserError(#[from] ParserError),
113+
#[error(transparent)]
114+
Recursion(#[from] RecursionError),
106115
}

crates/sql-parser/src/parser/mod.rs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::ast::{
1010
};
1111

1212
pub mod errors;
13+
pub mod recursion;
1314
pub mod sql;
1415
pub mod sub;
1516

@@ -61,11 +62,14 @@ trait RelParser {
6162
Ok(SqlJoin {
6263
var,
6364
alias,
64-
on: Some(parse_expr(Expr::BinaryOp {
65-
left,
66-
op: BinaryOperator::Eq,
67-
right,
68-
})?),
65+
on: Some(parse_expr(
66+
Expr::BinaryOp {
67+
left,
68+
op: BinaryOperator::Eq,
69+
right,
70+
},
71+
&mut 0,
72+
)?),
6973
})
7074
}
7175
_ => Err(SqlUnsupported::JoinType.into()),
@@ -204,15 +208,16 @@ pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
204208
}
205209

206210
/// Parse a scalar expression
207-
pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult<SqlExpr> {
211+
pub(crate) fn parse_expr(expr: Expr, depth: &mut usize) -> SqlParseResult<SqlExpr> {
208212
fn signed_num(sign: impl Into<String>, expr: Expr) -> Result<SqlExpr, SqlUnsupported> {
209213
match expr {
210214
Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))),
211215
expr => Err(SqlUnsupported::Expr(expr)),
212216
}
213217
}
218+
recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?;
214219
match expr {
215-
Expr::Nested(expr) => parse_expr(*expr),
220+
Expr::Nested(expr) => parse_expr(*expr, depth),
216221
Expr::Value(Value::Placeholder(param)) if &param == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)),
217222
Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
218223
Expr::UnaryOp {
@@ -238,31 +243,31 @@ pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult<SqlExpr> {
238243
op: BinaryOperator::And,
239244
right,
240245
} => {
241-
let l = parse_expr(*left)?;
242-
let r = parse_expr(*right)?;
246+
let l = parse_expr(*left, depth)?;
247+
let r = parse_expr(*right, depth)?;
243248
Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And))
244249
}
245250
Expr::BinaryOp {
246251
left,
247252
op: BinaryOperator::Or,
248253
right,
249254
} => {
250-
let l = parse_expr(*left)?;
251-
let r = parse_expr(*right)?;
255+
let l = parse_expr(*left, depth)?;
256+
let r = parse_expr(*right, depth)?;
252257
Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or))
253258
}
254259
Expr::BinaryOp { left, op, right } => {
255-
let l = parse_expr(*left)?;
256-
let r = parse_expr(*right)?;
260+
let l = parse_expr(*left, depth)?;
261+
let r = parse_expr(*right, depth)?;
257262
Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
258263
}
259264
_ => Err(SqlUnsupported::Expr(expr).into()),
260265
}
261266
}
262267

263268
/// Parse an optional scalar expression
264-
pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
265-
opt.map(parse_expr).transpose()
269+
pub(crate) fn parse_expr_opt(opt: Option<Expr>, depth: &mut usize) -> SqlParseResult<Option<SqlExpr>> {
270+
opt.map(|expr| parse_expr(expr, depth)).transpose()
266271
}
267272

268273
/// Parse a scalar binary operator
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//! A utility for guarding against excessive recursion depth in the SQL parser.
2+
//!
3+
//! Different parts of the parser may have different recursion limits.
4+
//!
5+
//! Removing one could allow the others to be higher, but depending on how the `SQL` is structured, it could lead to a `stack overflow`
6+
//! if is not guarded against, so is incorrect to assume that a limit is sufficient for the next part of the parser.
7+
use crate::parser::errors::{RecursionError, SqlParseError};
8+
use std::fmt::Display;
9+
10+
/// A conservative limit for recursion depth on `parse_expr`.
11+
pub const MAX_RECURSION_EXPR: usize = 700;
12+
/// A conservative limit for recursion depth on `type_expr`.
13+
pub const MAX_RECURSION_TYP_EXPR: usize = 5_000;
14+
15+
/// A utility for guarding against excessive recursion depth.
16+
///
17+
/// **Usage:**
18+
/// ```
19+
/// use spacetimedb_sql_parser::parser::recursion;
20+
/// let mut depth = 0;
21+
/// assert!(recursion::guard(&mut depth, 10, "test").is_ok());
22+
/// ```
23+
pub fn guard(depth: &mut usize, limit: usize, msg: impl Display) -> Result<(), SqlParseError> {
24+
*depth += 1;
25+
if *depth > limit {
26+
Err(RecursionError {
27+
limit,
28+
message: msg.to_string(),
29+
}
30+
.into())
31+
} else {
32+
Ok(())
33+
}
34+
}

crates/sql-parser/src/parser/sql.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ fn parse_statement(stmt: Statement) -> SqlParseResult<SqlAst> {
202202
} if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate {
203203
table: parse_ident(name)?,
204204
assignments: parse_assignments(assignments)?,
205-
filter: parse_expr_opt(selection)?,
205+
filter: parse_expr_opt(selection, &mut 0)?,
206206
})),
207207
Statement::Delete {
208208
tables,
@@ -297,7 +297,7 @@ fn parse_delete(mut from: Vec<TableWithJoins>, selection: Option<Expr>) -> SqlPa
297297
joins,
298298
} if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete {
299299
table: parse_ident(name)?,
300-
filter: parse_expr_opt(selection)?,
300+
filter: parse_expr_opt(selection, &mut 0)?,
301301
}),
302302
t => Err(SqlUnsupported::DeleteTable(t).into()),
303303
}
@@ -395,7 +395,7 @@ fn parse_select(select: Select, limit: Option<Box<str>>) -> SqlParseResult<SqlSe
395395
Ok(SqlSelect {
396396
project: parse_projection(projection)?,
397397
from: SqlParser::parse_from(from)?,
398-
filter: parse_expr_opt(selection)?,
398+
filter: parse_expr_opt(selection, &mut 0)?,
399399
limit,
400400
})
401401
}

crates/sql-parser/src/parser/sub.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ fn parse_select(select: Select) -> SqlParseResult<SqlSelect> {
142142
{
143143
Ok(SqlSelect {
144144
from: SubParser::parse_from(from)?,
145-
filter: parse_expr_opt(selection)?,
145+
filter: parse_expr_opt(selection, &mut 0)?,
146146
project: parse_projection(projection)?,
147147
})
148148
}

0 commit comments

Comments
 (0)