Skip to content

Commit 285fe5f

Browse files
authored
feat: DSL Array Pattern Match & Simplification (#36)
# Add List Pattern Matching to the HIR and Evaluation Engine ## Summary This PR adds list pattern matching capabilities to the language. It introduces two new pattern variants: `EmptyArray` for matching empty arrays, and `ArrayDecomp` for decomposing arrays into head and tail components. ## Changes - Added new pattern variants to the `Pattern` enum in the HIR (engine) & AST (parsing). - Extended the pattern matching logic in `match_pattern` to handle list patterns - Added comprehensive tests demonstrating list pattern matching functionality - Simplified syntax elements as agreed during the design phase ## Implementation Details ```rust // New pattern variants pub enum Pattern { // ... existing variants EmptyArray, ArrayDecomp(Box<Pattern>, Box<Pattern>), } ``` The pattern matching logic was extended to handle: - Empty arrays with the `EmptyArray` pattern - Non-empty arrays with the `ArrayDecomp` pattern, which splits arrays into head and tail ## Testing Added a test that shows a practical application of list pattern matching: a recursive sum function defined as: ``` fn sum(arr: [I64]): I64 = match arr | [] -> 0 \ [head .. tail] -> head + sum(tail) ``` The test verifies that this function correctly computes: - `sum([]) = 0` - `sum([42]) = 42` - `sum([1, 2, 3]) = 6` ## Related Changes As part of this PR, I also simplified some syntax elements: - Changed the composition operator from `->` to `.` - Changed the match arrow from `=>` to `->`
1 parent 2c17bf7 commit 285fe5f

File tree

14 files changed

+435
-92
lines changed

14 files changed

+435
-92
lines changed

optd-dsl/src/analyzer/hir.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ pub enum Pattern {
101101
Struct(Identifier, Vec<Pattern>),
102102
Operator(Operator<Pattern>),
103103
Wildcard,
104+
EmptyArray,
105+
ArrayDecomp(Box<Pattern>, Box<Pattern>),
104106
}
105107

106108
/// Match arm combining pattern and expression

optd-dsl/src/cli/basic.op

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
data LogicalProps(schema_len: I64)
22

3-
data Scalar with
3+
data Scalar =
44
| ColumnRef(idx: Int64)
5-
| Literal with
5+
| Literal =
66
| IntLiteral(value: Int64)
77
| StringLiteral(value: String)
88
| BoolLiteral(value: Bool)
99
\ NullLiteral
10-
| Arithmetic with
10+
| Arithmetic =
1111
| Mult(left: Scalar, right: Scalar)
1212
| Add(left: Scalar, right: Scalar)
1313
| Sub(left: Scalar, right: Scalar)
1414
\ Div(left: Scalar, right: Scalar)
15-
| Predicate with
15+
| Predicate =
1616
| And(children: [Predicate])
1717
| Or(children: [Predicate])
1818
| Not(child: Predicate)
@@ -24,18 +24,18 @@ data Scalar with
2424
| GreaterThanEqual(left: Scalar, right: Scalar)
2525
| IsNull(expr: Scalar)
2626
\ IsNotNull(expr: Scalar)
27-
| Function with
27+
| Function =
2828
| Cast(expr: Scalar, target_type: String)
2929
| Substring(str: Scalar, start: Scalar, length: Scalar)
3030
\ Concat(args: [Scalar])
31-
\ AggregateExpr with
31+
\ AggregateExpr =
3232
| Sum(expr: Scalar)
3333
| Count(expr: Scalar)
3434
| Min(expr: Scalar)
3535
| Max(expr: Scalar)
3636
\ Avg(expr: Scalar)
3737

38-
data Logical with
38+
data Logical =
3939
| Scan(table_name: String)
4040
| Filter(child: Logical, cond: Predicate)
4141
| Project(child: Logical, exprs: [Scalar])
@@ -51,11 +51,11 @@ data Logical with
5151
aggregates: [AggregateExpr]
5252
)
5353

54-
data Physical with
54+
data Physical =
5555
| Scan(table_name: String)
5656
| Filter(child: Physical, cond: Predicate)
5757
| Project(child: Physical, exprs: [Scalar])
58-
| Join with
58+
| Join =
5959
| HashJoin(
6060
build_side: Physical,
6161
probe_side: Physical,
@@ -84,7 +84,7 @@ data Physical with
8484
order_by: [(Scalar, SortOrder)]
8585
)
8686

87-
data JoinType with
87+
data JoinType =
8888
| Inner
8989
| Left
9090
| Right
@@ -96,19 +96,19 @@ fn (expr: Scalar) apply_children(f: Scalar => Scalar) = ()
9696

9797
fn (pred: Predicate) remap(map: {I64 : I64)}) =
9898
match predicate
99-
| ColumnRef(idx) => ColumnRef(map(idx))
100-
\ _ => predicate -> apply_children(child => rewrite_column_refs(child, map))
99+
| ColumnRef(idx) -> ColumnRef(map(idx))
100+
\ _ -> predicate.apply_children(child -> rewrite_column_refs(child, map))
101101

102102
[rule]
103103
fn (expr: Logical) join_commute = match expr
104-
\ Join(left, right, Inner, cond) ->
104+
\ Join(left, right, Inner, cond) =>
105105
let
106106
right_indices = 0.right.schema_len,
107107
left_indices = 0..left.schema_len,
108-
remapping = left_indices.map(i => (i, i + right_len)) ++
109-
right_indices.map(i => (left_len + i, i)).to_map,
108+
remapping = left_indices.map(i -> (i, i + right_len)) ++
109+
right_indices.map(i -> (left_len + i, i)).to_map,
110110
in
111111
Project(
112112
Join(right, left, Inner, cond.remap(remapping)),
113-
right_indices.map(i => ColumnRef(i)).to_array
113+
right_indices.map(i -> ColumnRef(i)).to_array
114114
)

optd-dsl/src/engine/eval/binary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub(super) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value {
9494
}
9595

9696
// Any other combination of value types or operations is not supported
97-
_ => panic!("Invalid binary operation"),
97+
_ => panic!("Invalid binary operation: {:?} {:?} {:?}", left, op, right),
9898
}
9999
}
100100

optd-dsl/src/engine/eval/expr.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,4 +665,71 @@ mod tests {
665665
assert_eq!(values.len(), 1);
666666
assert!(matches!(&values[0].0, Literal(Int64(30)))); // 10 + 20 = 30 (since 10 < 20)
667667
}
668+
669+
#[test]
670+
fn test_recursive_list_sum() {
671+
let context = Context::new(HashMap::new());
672+
673+
// Define a recursive sum function using pattern matching
674+
// sum([]) = 0
675+
// sum([x .. xs]) = x + sum(xs)
676+
let sum_function = Value(Function(Closure(
677+
vec!["arr".to_string()],
678+
Box::new(PatternMatch(
679+
Box::new(Ref("arr".to_string())),
680+
vec![
681+
// Base case: empty array returns 0
682+
MatchArm {
683+
pattern: Pattern::EmptyArray,
684+
expr: CoreVal(int_val(0)),
685+
},
686+
// Recursive case: add head + sum(tail)
687+
MatchArm {
688+
pattern: Pattern::ArrayDecomp(
689+
Box::new(Bind("head".to_string(), Box::new(Wildcard))),
690+
Box::new(Bind("tail".to_string(), Box::new(Wildcard))),
691+
),
692+
expr: Binary(
693+
Box::new(Ref("head".to_string())),
694+
BinOp::Add,
695+
Box::new(Call(
696+
Box::new(Ref("sum".to_string())),
697+
vec![Ref("tail".to_string())],
698+
)),
699+
),
700+
},
701+
],
702+
)),
703+
)));
704+
705+
// Bind the recursive function in the context
706+
let mut test_context = context.clone();
707+
test_context.bind("sum".to_string(), sum_function);
708+
709+
// Test arrays
710+
let empty_array = Value(CoreData::Array(vec![]));
711+
let array_123 = Value(CoreData::Array(vec![int_val(1), int_val(2), int_val(3)]));
712+
let array_42 = Value(CoreData::Array(vec![int_val(42)]));
713+
714+
// Test 1: Sum of empty array should be 0
715+
let call_empty = Call(Box::new(Ref("sum".to_string())), vec![CoreVal(empty_array)]);
716+
717+
let result = collect_stream_values(call_empty.evaluate(test_context.clone()));
718+
assert_eq!(result.len(), 1);
719+
assert!(matches!(&result[0].0, Literal(Int64(n)) if *n == 0));
720+
721+
// Test 2: Sum of [1, 2, 3] should be 6
722+
let call_123 = Call(Box::new(Ref("sum".to_string())), vec![CoreVal(array_123)]);
723+
724+
let result = collect_stream_values(call_123.evaluate(test_context.clone()));
725+
assert_eq!(result.len(), 1);
726+
assert!(matches!(&result[0].0, Literal(Int64(n)) if *n == 6));
727+
728+
// Test 3: Sum of [42] should be 42
729+
let call_42 = Call(Box::new(Ref("sum".to_string())), vec![CoreVal(array_42)]);
730+
731+
let result = collect_stream_values(call_42.evaluate(test_context));
732+
assert_eq!(result.len(), 1);
733+
assert!(matches!(&result[0].0, Literal(Int64(n)) if *n == 42));
734+
}
668735
}

optd-dsl/src/engine/eval/match.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,36 @@ async fn match_pattern(value: Value, pattern: Pattern, context: Context) -> Vec<
101101
_ => vec![],
102102
},
103103

104+
// Empty list pattern: match if value is an empty array
105+
(EmptyArray, CoreData::Array(arr)) if arr.is_empty() => vec![context],
106+
107+
// List decomposition pattern: match first element and rest of the array
108+
(ArrayDecomp(head_pattern, tail_pattern), CoreData::Array(arr)) => {
109+
if arr.is_empty() {
110+
return vec![];
111+
}
112+
113+
// Split array into head and tail
114+
let head = arr[0].clone();
115+
let tail = Value(CoreData::Array(arr[1..].to_vec()));
116+
117+
// Match head against head pattern
118+
let head_contexts = match_pattern(head, (**head_pattern).clone(), context).await;
119+
if head_contexts.is_empty() {
120+
return vec![];
121+
}
122+
123+
// For each successful head match, try to match tail
124+
let mut result_contexts = Vec::new();
125+
for head_ctx in head_contexts {
126+
let tail_contexts =
127+
match_pattern(tail.clone(), (**tail_pattern).clone(), head_ctx).await;
128+
result_contexts.extend(tail_contexts);
129+
}
130+
131+
result_contexts
132+
}
133+
104134
// Struct pattern: match name and recursively match fields
105135
(Struct(pat_name, field_patterns), CoreData::Struct(val_name, field_values)) => {
106136
if pat_name != val_name || field_patterns.len() != field_values.len() {

optd-dsl/src/lexer/lex.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ fn lexer() -> impl Parser<char, Vec<(Token, Span)>, Error = Simple<char, Span>>
5454
("false", Token::Bool(false)),
5555
("Unit", Token::TUnit),
5656
("data", Token::Data),
57-
("with", Token::With),
58-
("as", Token::As),
5957
("in", Token::In),
6058
("let", Token::Let),
6159
("match", Token::Match),

optd-dsl/src/lexer/tokens.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ pub enum Token {
1212
// Other keywords
1313
Fn,
1414
Data,
15-
With,
16-
As,
1715
In,
1816
Let,
1917
Match,
@@ -85,8 +83,6 @@ impl std::fmt::Display for Token {
8583
// Other keywords
8684
Token::Fn => write!(f, "fn"),
8785
Token::Data => write!(f, "data"),
88-
Token::With => write!(f, "with"),
89-
Token::As => write!(f, "as"),
9086
Token::In => write!(f, "in"),
9187
Token::Let => write!(f, "let"),
9288
Token::Match => write!(f, "match"),

optd-dsl/src/parser/adt.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub fn adt_parser() -> impl Parser<Token, Spanned<Adt>, Error = Simple<Token, Sp
2525
.map_with_span(Spanned::new);
2626

2727
let with_sum_parser = type_ident
28-
.then_ignore(just(Token::With))
28+
.then_ignore(just(Token::Eq))
2929
.then(
3030
just(Token::Vertical)
3131
.ignore_then(inner_adt_parser.clone())
@@ -96,7 +96,7 @@ mod tests {
9696

9797
#[test]
9898
fn test_enum_adt() {
99-
let input = "data JoinType with
99+
let input = "data JoinType =
100100
| Inner
101101
\\ Outer";
102102
let (result, errors) = parse_adt(input);
@@ -132,7 +132,7 @@ mod tests {
132132

133133
#[test]
134134
fn test_enum_with_struct_variants() {
135-
let input = "data Shape with
135+
let input = "data Shape =
136136
| Circle(center: Point, radius: F64)
137137
| Rectangle(topLeft: Point, width: F64, height: F64)
138138
\\ Triangle(p1: Point, p2: Point, p3: Point)";
@@ -175,8 +175,8 @@ mod tests {
175175

176176
#[test]
177177
fn test_nested_enum() {
178-
let input = "data Expression with
179-
| Literal with
178+
let input = "data Expression =
179+
| Literal =
180180
| IntLiteral(value: I64)
181181
| BoolLiteral(value: Bool)
182182
\\ StringLiteral(value: String)
@@ -234,15 +234,15 @@ mod tests {
234234

235235
#[test]
236236
fn test_double_nested_enum() {
237-
let input = "data Menu with
238-
| File with
239-
| New with
237+
let input = "data Menu =
238+
| File =
239+
| New =
240240
| Document
241241
| Project
242242
\\ Template
243243
| Open
244244
\\ Save
245-
| Edit with
245+
| Edit =
246246
| Cut
247247
| Copy
248248
\\ Paste

optd-dsl/src/parser/ast.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ pub enum Pattern {
145145
Literal(Literal),
146146
/// Wildcard pattern: matches any value
147147
Wildcard,
148+
/// Empty array pattern: matches an empty array
149+
EmptyArray,
150+
/// Array decomposition pattern: matches an array with head and rest elements
151+
ArrayDecomp(Spanned<Pattern>, Spanned<Pattern>),
148152
}
149153

150154
/// Represents a single arm in a pattern match expression
@@ -212,8 +216,6 @@ pub enum UnaryOp {
212216
pub enum PostfixOp {
213217
/// Function or method call with arguments
214218
Call(Vec<Spanned<Expr>>),
215-
/// Function composition operator
216-
Compose(Identifier),
217219
/// Member/field access
218220
Member(Identifier),
219221
}

0 commit comments

Comments
 (0)