Skip to content

Commit baa09fc

Browse files
AlSchloyliang412
andauthored
feat: DSL parser + AST implementation (#23)
## Problem We want a DSL to be able to write rules & operators in a declarative fashion. This makes the code more maintainable, maximizes compatibility, and speeds up the writing of newer rules. ## Summary of changes Wrote a parser of the OPTD-DSL using Pest. The syntax is highly functional and inspired from Scala. Some snippets below: ## Next steps Semantic analysis, type checking, and low-level IR generation. Also switching to Chumsky as parser, since Pest is not very maintainable. ```scala // Logical Properties Logical Props(schema_len: Int64) // Scalar Operators Scalar ColumnRef(idx: Int64) Scalar Mult(left: Int64, right: Int64) Scalar Add(left: Int64, right: Int64) Scalar And(children: [Scalar]) Scalar Or(children: [Scalar]) Scalar Not(child: Scalar) // Logical Operators Logical Scan(table_name: String) derive { schema_len = 5 // <call catalog> } Logical Filter(child: Logical, cond: Scalar) derive { schema_len = input.schema_len } Logical Project(child: Logical, exprs: [Scalar]) derive { schema_len = exprs.len } Logical Join( left: Logical, right: Logical, typ: String, cond: Scalar ) derive { schema_len = left.schema_len + right.schema_len } Logical Sort(child: Logical, keys: [(Scalar, String)]) derive { schema_len = input.schema_len } Logical Aggregate(child: Logical, group_keys: [Scalar], aggs: [(Scalar, String)]) derive { schema_len = group_keys.len + aggs.len } // Rules def rewrite_column_refs(predicate: Scalar, map: Map[Int64, Int64]): Scalar = match predicate case ColumnRef(idx) => ColumnRef(map(idx)), case other @ _ => predicate.apply_children(child => rewrite_column_refs(child, map)) @rule(Logical) def join_commute(expr: Logical): Logical = match expr case Join("Inner", left, right, cond) => val left_len = left.schema_len; val right_len = right.schema_len; val right_indices = 0..right_len; val left_indices = 0..left_len; val remapping = (left_indices.map(i => (i, i + right_len)) ++ right_indices.map(i => (left_len + i, i))).to_map(); Project( Join("Inner", right, left, rewrite_column_refs(cond, remapping)), left_indices.map(i => ColumnRef(i + right_len)) ++ right_indices.map(i => ColumnRef(i) ) ) def has_refs_in_range(cond: Scalar, from: Int64, to: Int64): Bool = match predicate case ColumnRef(idx) => from <= idx && idx < to, case _ => predicate.children.any(child => has_refs_in_range(child, from, to)) @rule(Logical) def join_associate(expr: Logical): Logical = match expr case op @ Join("Inner", Join("Inner", a, b, cond_inner), c, cond_outer) => val a_len = a.schema_len; if !has_refs_in_range(cond_outer, 0, a_len) then val remap_inner = (a.schema_len..op.schema_len).map(i => (i, i - a_len)).to_map(); Join( "Inner", a, Join("Inner", b, c, rewrite_column_refs(cond_outer, remap_inner), cond_inner) ) else fail("") @rule(Scalar) def conjunctive_normal_form(expr: Scalar): Scalar = fail("unimplemented") def with_optional_filter(key: String, old: Scalar, grouped: Map[String, [Scalar]]): Scalar = match grouped(key) case Some(conds) => Filter(conds, old), case _ => old @rule(Logical) def filter_pushdown_join(expr: Logical): Logical = match expr case op @ Filter(Join(join_type, left, right, join_cond), cond) => val cnf = conjunctive_normal_form(cond); val grouped = cnf.children.groupBy(cond => { if has_refs_in_range(cond, 0, left.schema_len) && !has_refs_in_range(cond, left.schema_len, op.schema_len) then "left" else if !has_refs_in_range(cond, 0, left.schema_len) && has_refs_in_range(cond, left.schema_len, op.schema_len) then "right" else "remain" }); with_optional_filter("remain", Join(join_type, with_optional_filter("left", grouped), with_optional_filter("right", grouped), join_cond ) ) ``` --------- Signed-off-by: Yuchen Liang <yuchenl3@andrew.cmu.edu> Co-authored-by: Yuchen Liang <yuchenl3@andrew.cmu.edu> Co-authored-by: Yuchen Liang <70461588+yliang412@users.noreply.github.com>
1 parent a94829a commit baa09fc

File tree

15 files changed

+2100
-1
lines changed

15 files changed

+2100
-1
lines changed

Cargo.lock

Lines changed: 53 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

optd-core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ serde = { version = "1.0", features = ["derive"] }
1515
serde_json = { version = "1", features = ["raw_value"] }
1616
dotenvy = "0.15"
1717
async-recursion = "1.1.1"
18+
pest = "2.7.15"
19+
pest_derive = "2.7.15"

optd-core/src/cascades/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ mod tests {
147147

148148
#[tokio::test]
149149
async fn test_ingest_partial_logical_plan() -> anyhow::Result<()> {
150-
let memo = SqliteMemo::new("sqlite://memo.db").await?;
150+
let memo = SqliteMemo::new_in_memory().await?;
151151
// select * from t1, t2 where t1.id = t2.id and t2.name = 'Memo' and t2.v1 = 1 + 1
152152
let partial_logical_plan = filter(
153153
join(

optd-core/src/dsl/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod parser;

optd-core/src/dsl/parser/ast.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
use std::collections::HashMap;
2+
3+
/// Types supported by the language
4+
#[derive(Debug, Clone, PartialEq)]
5+
pub enum Type {
6+
Int64,
7+
String,
8+
Bool,
9+
Float64,
10+
Array(Box<Type>), // Array types like [T]
11+
Map(Box<Type>, Box<Type>), // Map types like map[K->V]
12+
Tuple(Vec<Type>), // Tuple types like (T1, T2)
13+
Function(Box<Type>, Box<Type>), // Function types like (T1)->T2
14+
Operator(OperatorKind), // Operator types (scalar/logical)
15+
}
16+
17+
/// Kinds of operators supported in the language
18+
#[derive(Debug, Clone, PartialEq)]
19+
pub enum OperatorKind {
20+
Scalar, // Scalar operators
21+
Logical, // Logical operators with derivable properties
22+
}
23+
24+
/// A field in an operator or properties block
25+
#[derive(Debug, Clone)]
26+
pub struct Field {
27+
pub name: String,
28+
pub ty: Type,
29+
}
30+
31+
/// Logical properties block that must appear exactly once per file
32+
#[derive(Debug, Clone)]
33+
pub struct Properties {
34+
pub fields: Vec<Field>,
35+
}
36+
37+
/// Top-level operator definition
38+
#[derive(Debug, Clone)]
39+
pub enum Operator {
40+
Scalar(ScalarOp),
41+
Logical(LogicalOp),
42+
}
43+
44+
/// Scalar operator definition
45+
#[derive(Debug, Clone)]
46+
pub struct ScalarOp {
47+
pub name: String,
48+
pub fields: Vec<Field>,
49+
}
50+
51+
/// Logical operator definition with derived properties
52+
#[derive(Debug, Clone)]
53+
pub struct LogicalOp {
54+
pub name: String,
55+
pub fields: Vec<Field>,
56+
pub derived_props: HashMap<String, Expr>, // Maps property names to their derivation expressions
57+
}
58+
59+
/// Patterns used in match expressions
60+
#[derive(Debug, Clone)]
61+
pub enum Pattern {
62+
Bind(String, Box<Pattern>), // Binding patterns like x@p or x:p
63+
Constructor(
64+
String, // Constructor name
65+
Vec<Pattern>, // Subpatterns, can be named (x:p) or positional
66+
),
67+
Literal(Literal), // Literal patterns like 42 or "hello"
68+
Wildcard, // Wildcard pattern _
69+
Var(String), // Variable binding pattern
70+
}
71+
72+
/// Literal values
73+
#[derive(Debug, Clone)]
74+
pub enum Literal {
75+
Int64(i64),
76+
String(String),
77+
Bool(bool),
78+
Float64(f64),
79+
Array(Vec<Expr>), // Array literals [e1, e2, ...]
80+
Tuple(Vec<Expr>), // Tuple literals (e1, e2, ...)
81+
}
82+
83+
/// Expressions - the core of the language
84+
#[derive(Debug, Clone)]
85+
pub enum Expr {
86+
Match(Box<Expr>, Vec<MatchArm>), // Pattern matching
87+
If(Box<Expr>, Box<Expr>, Box<Expr>), // If-then-else
88+
Val(String, Box<Expr>, Box<Expr>), // Local binding (val x = e1; e2)
89+
Constructor(String, Vec<Expr>), // Constructor application (currently only operators)
90+
Binary(Box<Expr>, BinOp, Box<Expr>), // Binary operations
91+
Unary(UnaryOp, Box<Expr>), // Unary operations
92+
Call(Box<Expr>, Vec<Expr>), // Function application
93+
Member(Box<Expr>, String), // Field access (e.f)
94+
MemberCall(Box<Expr>, String, Vec<Expr>), // Method call (e.f(args))
95+
ArrayIndex(Box<Expr>, Box<Expr>), // Array indexing (e[i])
96+
Var(String), // Variable reference
97+
Literal(Literal), // Literal values
98+
Fail(String), // Failure with message
99+
Closure(Vec<String>, Box<Expr>), // Anonymous functions v = (x, y) => x + y;
100+
}
101+
102+
/// A case in a match expression
103+
#[derive(Debug, Clone)]
104+
pub struct MatchArm {
105+
pub pattern: Pattern,
106+
pub expr: Expr,
107+
}
108+
109+
/// Binary operators with fixed precedence
110+
#[derive(Debug, Clone)]
111+
pub enum BinOp {
112+
Add, // +
113+
Sub, // -
114+
Mul, // *
115+
Div, // /
116+
Concat, // ++
117+
Eq, // ==
118+
Neq, // !=
119+
Gt, // >
120+
Lt, // <
121+
Ge, // >=
122+
Le, // <=
123+
And, // &&
124+
Or, // ||
125+
Range, // ..
126+
}
127+
128+
/// Unary operators
129+
#[derive(Debug, Clone)]
130+
pub enum UnaryOp {
131+
Neg, // -
132+
Not, // !
133+
}
134+
135+
/// Function definition
136+
#[derive(Debug, Clone)]
137+
pub struct Function {
138+
pub name: String,
139+
pub params: Vec<(String, Type)>, // Parameter name and type pairs
140+
pub return_type: Type,
141+
pub body: Expr,
142+
pub rule_type: Option<OperatorKind>, // Some if this is a rule, indicating what kind
143+
}
144+
145+
/// A complete source file
146+
#[derive(Debug, Clone)]
147+
pub struct File {
148+
pub properties: Properties, // The single logical properties block
149+
pub operators: Vec<Operator>, // All operator definitions
150+
pub functions: Vec<Function>, // All function definitions
151+
}

0 commit comments

Comments
 (0)