diff --git a/conf/config.go b/conf/config.go index 8a6ee70e2..270438b9d 100644 --- a/conf/config.go +++ b/conf/config.go @@ -10,31 +10,43 @@ import ( "github.com/expr-lang/expr/vm/runtime" ) +const ( + // DefaultMemoryBudget represents an upper limit of memory usage + DefaultMemoryBudget uint = 1e6 + + // DefaultMaxNodes represents an upper limit of AST nodes + DefaultMaxNodes uint = 10000 +) + type FunctionsTable map[string]*builtin.Function type Config struct { - EnvObject any - Env nature.Nature - Expect reflect.Kind - ExpectAny bool - Optimize bool - Strict bool - Profile bool - ConstFns map[string]reflect.Value - Visitors []ast.Visitor - Functions FunctionsTable - Builtins FunctionsTable - Disabled map[string]bool // disabled builtins + EnvObject any + Env nature.Nature + Expect reflect.Kind + ExpectAny bool + Optimize bool + Strict bool + Profile bool + MaxNodes uint + MemoryBudget uint + ConstFns map[string]reflect.Value + Visitors []ast.Visitor + Functions FunctionsTable + Builtins FunctionsTable + Disabled map[string]bool // disabled builtins } // CreateNew creates new config with default values. func CreateNew() *Config { c := &Config{ - Optimize: true, - ConstFns: make(map[string]reflect.Value), - Functions: make(map[string]*builtin.Function), - Builtins: make(map[string]*builtin.Function), - Disabled: make(map[string]bool), + Optimize: true, + MaxNodes: DefaultMaxNodes, + MemoryBudget: DefaultMemoryBudget, + ConstFns: make(map[string]reflect.Value), + Functions: make(map[string]*builtin.Function), + Builtins: make(map[string]*builtin.Function), + Disabled: make(map[string]bool), } for _, f := range builtin.Builtins { c.Builtins[f.Name] = f diff --git a/parser/parser.go b/parser/parser.go index 917e0db4f..086554223 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -45,12 +45,47 @@ var predicates = map[string]struct { } type parser struct { - tokens []Token - current Token - pos int - err *file.Error - depth int // predicate call depth - config *conf.Config + tokens []Token + current Token + pos int + err *file.Error + depth int // predicate call depth + config *conf.Config + nodeCount uint // tracks number of AST nodes created +} + +// checkNodeLimit verifies that adding a new node won't exceed configured limits +func (p *parser) checkNodeLimit() error { + p.nodeCount++ + if p.config.MaxNodes > 0 && p.nodeCount > p.config.MaxNodes { + p.error("compilation failed: expression exceeds maximum allowed nodes") + return nil + } + return nil +} + +// createNode handles creation of regular nodes +func (p *parser) createNode(n Node, loc file.Location) Node { + if err := p.checkNodeLimit(); err != nil { + return nil + } + if n == nil || p.err != nil { + return nil + } + n.SetLocation(loc) + return n +} + +// createMemberNode handles creation of member nodes +func (p *parser) createMemberNode(n *MemberNode, loc file.Location) *MemberNode { + if err := p.checkNodeLimit(); err != nil { + return nil + } + if n == nil || p.err != nil { + return nil + } + n.SetLocation(loc) + return n } type Tree struct { @@ -129,6 +164,10 @@ func (p *parser) expect(kind Kind, values ...string) { // parse functions func (p *parser) parseExpression(precedence int) Node { + if p.err != nil { + return nil + } + if precedence == 0 && p.current.Is(Operator, "let") { return p.parseVariableDeclaration() } @@ -190,19 +229,23 @@ func (p *parser) parseExpression(precedence int) Node { nodeRight = p.parseExpression(op.Precedence) } - nodeLeft = &BinaryNode{ + nodeLeft = p.createNode(&BinaryNode{ Operator: opToken.Value, Left: nodeLeft, Right: nodeRight, + }, opToken.Location) + if nodeLeft == nil { + return nil } - nodeLeft.SetLocation(opToken.Location) if negate { - nodeLeft = &UnaryNode{ + nodeLeft = p.createNode(&UnaryNode{ Operator: "not", Node: nodeLeft, + }, notToken.Location) + if nodeLeft == nil { + return nil } - nodeLeft.SetLocation(notToken.Location) } goto next @@ -229,13 +272,11 @@ func (p *parser) parseVariableDeclaration() Node { value := p.parseExpression(0) p.expect(Operator, ";") node := p.parseExpression(0) - let := &VariableDeclaratorNode{ + return p.createNode(&VariableDeclaratorNode{ Name: variableName.Value, Value: value, Expr: node, - } - let.SetLocation(variableName.Location) - return let + }, variableName.Location) } func (p *parser) parseConditionalIf() Node { @@ -272,10 +313,13 @@ func (p *parser) parseConditional(node Node) Node { expr2 = p.parseExpression(0) } - node = &ConditionalNode{ + node = p.createNode(&ConditionalNode{ Cond: node, Exp1: expr1, Exp2: expr2, + }, p.current.Location) + if node == nil { + return nil } } return node @@ -288,11 +332,13 @@ func (p *parser) parsePrimary() Node { if op, ok := operator.Unary[token.Value]; ok { p.next() expr := p.parseExpression(op.Precedence) - node := &UnaryNode{ + node := p.createNode(&UnaryNode{ Operator: token.Value, Node: expr, + }, token.Location) + if node == nil { + return nil } - node.SetLocation(token.Location) return p.parsePostfixExpression(node) } } @@ -314,8 +360,10 @@ func (p *parser) parsePrimary() Node { p.next() } } - node := &PointerNode{Name: name} - node.SetLocation(token.Location) + node := p.createNode(&PointerNode{Name: name}, token.Location) + if node == nil { + return nil + } return p.parsePostfixExpression(node) } } else { @@ -344,23 +392,31 @@ func (p *parser) parseSecondary() Node { p.next() switch token.Value { case "true": - node := &BoolNode{Value: true} - node.SetLocation(token.Location) + node = p.createNode(&BoolNode{Value: true}, token.Location) + if node == nil { + return nil + } return node case "false": - node := &BoolNode{Value: false} - node.SetLocation(token.Location) + node = p.createNode(&BoolNode{Value: false}, token.Location) + if node == nil { + return nil + } return node case "nil": - node := &NilNode{} - node.SetLocation(token.Location) + node = p.createNode(&NilNode{}, token.Location) + if node == nil { + return nil + } return node default: if p.current.Is(Bracket, "(") { node = p.parseCall(token, []Node{}, true) } else { - node = &IdentifierNode{Value: token.Value} - node.SetLocation(token.Location) + node = p.createNode(&IdentifierNode{Value: token.Value}, token.Location) + if node == nil { + return nil + } } } @@ -407,8 +463,10 @@ func (p *parser) parseSecondary() Node { return node case String: p.next() - node = &StringNode{Value: token.Value} - node.SetLocation(token.Location) + node = p.createNode(&StringNode{Value: token.Value}, token.Location) + if node == nil { + return nil + } default: if token.Is(Bracket, "[") { @@ -428,7 +486,7 @@ func (p *parser) toIntegerNode(number int64) Node { p.error("integer literal is too large") return nil } - return &IntegerNode{Value: int(number)} + return p.createNode(&IntegerNode{Value: int(number)}, p.current.Location) } func (p *parser) toFloatNode(number float64) Node { @@ -436,7 +494,7 @@ func (p *parser) toFloatNode(number float64) Node { p.error("float literal is too large") return nil } - return &FloatNode{Value: number} + return p.createNode(&FloatNode{Value: number}, p.current.Location) } func (p *parser) parseCall(token Token, arguments []Node, checkOverrides bool) Node { @@ -478,25 +536,34 @@ func (p *parser) parseCall(token Token, arguments []Node, checkOverrides bool) N p.expect(Bracket, ")") - node = &BuiltinNode{ + node = p.createNode(&BuiltinNode{ Name: token.Value, Arguments: arguments, + }, token.Location) + if node == nil { + return nil } - node.SetLocation(token.Location) } else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] && !isOverridden { - node = &BuiltinNode{ + node = p.createNode(&BuiltinNode{ Name: token.Value, Arguments: p.parseArguments(arguments), + }, token.Location) + if node == nil { + return nil } - node.SetLocation(token.Location) + } else { - callee := &IdentifierNode{Value: token.Value} - callee.SetLocation(token.Location) - node = &CallNode{ + callee := p.createNode(&IdentifierNode{Value: token.Value}, token.Location) + if callee == nil { + return nil + } + node = p.createNode(&CallNode{ Callee: callee, Arguments: p.parseArguments(arguments), + }, token.Location) + if node == nil { + return nil } - node.SetLocation(token.Location) } return node } @@ -534,10 +601,12 @@ func (p *parser) parsePredicate() Node { if expectClosingBracket { p.expect(Bracket, "}") } - predicateNode := &PredicateNode{ + predicateNode := p.createNode(&PredicateNode{ Node: node, + }, startToken.Location) + if predicateNode == nil { + return nil } - predicateNode.SetLocation(startToken.Location) return predicateNode } @@ -558,8 +627,10 @@ func (p *parser) parseArrayExpression(token Token) Node { end: p.expect(Bracket, "]") - node := &ArrayNode{Nodes: nodes} - node.SetLocation(token.Location) + node := p.createNode(&ArrayNode{Nodes: nodes}, token.Location) + if node == nil { + return nil + } return node } @@ -585,8 +656,10 @@ func (p *parser) parseMapExpression(token Token) Node { // * identifier, which is equivalent to a string // * expression, which must be enclosed in parentheses -- (1 + 2) if p.current.Is(Number) || p.current.Is(String) || p.current.Is(Identifier) { - key = &StringNode{Value: p.current.Value} - key.SetLocation(token.Location) + key = p.createNode(&StringNode{Value: p.current.Value}, p.current.Location) + if key == nil { + return nil + } p.next() } else if p.current.Is(Bracket, "(") { key = p.parseExpression(0) @@ -597,16 +670,20 @@ func (p *parser) parseMapExpression(token Token) Node { p.expect(Operator, ":") node := p.parseExpression(0) - pair := &PairNode{Key: key, Value: node} - pair.SetLocation(token.Location) + pair := p.createNode(&PairNode{Key: key, Value: node}, token.Location) + if pair == nil { + return nil + } nodes = append(nodes, pair) } end: p.expect(Bracket, "}") - node := &MapNode{Pairs: nodes} - node.SetLocation(token.Location) + node := p.createNode(&MapNode{Pairs: nodes}, token.Location) + if node == nil { + return nil + } return node } @@ -631,8 +708,10 @@ func (p *parser) parsePostfixExpression(node Node) Node { p.error("expected name") } - property := &StringNode{Value: propertyToken.Value} - property.SetLocation(propertyToken.Location) + property := p.createNode(&StringNode{Value: propertyToken.Value}, propertyToken.Location) + if property == nil { + return nil + } chainNode, isChain := node.(*ChainNode) optional := postfixToken.Value == "?." @@ -641,26 +720,33 @@ func (p *parser) parsePostfixExpression(node Node) Node { node = chainNode.Node } - memberNode := &MemberNode{ + memberNode := p.createMemberNode(&MemberNode{ Node: node, Property: property, Optional: optional, + }, propertyToken.Location) + if memberNode == nil { + return nil } - memberNode.SetLocation(propertyToken.Location) if p.current.Is(Bracket, "(") { memberNode.Method = true - node = &CallNode{ + node = p.createNode(&CallNode{ Callee: memberNode, Arguments: p.parseArguments([]Node{}), + }, propertyToken.Location) + if node == nil { + return nil } - node.SetLocation(propertyToken.Location) } else { node = memberNode } if isChain || optional { - node = &ChainNode{Node: node} + node = p.createNode(&ChainNode{Node: node}, propertyToken.Location) + if node == nil { + return nil + } } } else if postfixToken.Value == "[" { @@ -674,11 +760,13 @@ func (p *parser) parsePostfixExpression(node Node) Node { to = p.parseExpression(0) } - node = &SliceNode{ + node = p.createNode(&SliceNode{ Node: node, To: to, + }, postfixToken.Location) + if node == nil { + return nil } - node.SetLocation(postfixToken.Location) p.expect(Bracket, "]") } else { @@ -692,25 +780,32 @@ func (p *parser) parsePostfixExpression(node Node) Node { to = p.parseExpression(0) } - node = &SliceNode{ + node = p.createNode(&SliceNode{ Node: node, From: from, To: to, + }, postfixToken.Location) + if node == nil { + return nil } - node.SetLocation(postfixToken.Location) p.expect(Bracket, "]") } else { // Slice operator [:] was not found, // it should be just an index node. - node = &MemberNode{ + node = p.createNode(&MemberNode{ Node: node, Property: from, Optional: optional, + }, postfixToken.Location) + if node == nil { + return nil } - node.SetLocation(postfixToken.Location) if optional { - node = &ChainNode{Node: node} + node = p.createNode(&ChainNode{Node: node}, postfixToken.Location) + if node == nil { + return nil + } } p.expect(Bracket, "]") } @@ -722,26 +817,29 @@ func (p *parser) parsePostfixExpression(node Node) Node { } return node } - func (p *parser) parseComparison(left Node, token Token, precedence int) Node { var rootNode Node for { comparator := p.parseExpression(precedence + 1) - cmpNode := &BinaryNode{ + cmpNode := p.createNode(&BinaryNode{ Operator: token.Value, Left: left, Right: comparator, + }, token.Location) + if cmpNode == nil { + return nil } - cmpNode.SetLocation(token.Location) if rootNode == nil { rootNode = cmpNode } else { - rootNode = &BinaryNode{ + rootNode = p.createNode(&BinaryNode{ Operator: "&&", Left: rootNode, Right: cmpNode, + }, token.Location) + if rootNode == nil { + return nil } - rootNode.SetLocation(token.Location) } left = comparator diff --git a/parser/parser_test.go b/parser/parser_test.go index 6bb17e604..efae1f413 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/internal/testify/assert" "github.com/expr-lang/expr/internal/testify/require" @@ -925,3 +926,78 @@ func TestParse_pipe_operator(t *testing.T) { require.NoError(t, err) assert.Equal(t, Dump(expect), Dump(actual.Node)) } + +func TestNodeBudget(t *testing.T) { + tests := []struct { + name string + expr string + maxNodes uint + shouldError bool + }{ + { + name: "simple expression equal to limit", + expr: "a + b", + maxNodes: 3, + shouldError: false, + }, + { + name: "medium expression under limit", + expr: "a + b * c / d", + maxNodes: 20, + shouldError: false, + }, + { + name: "deeply nested expression over limit", + expr: "1 + (2 + (3 + (4 + (5 + (6 + (7 + 8))))))", + maxNodes: 10, + shouldError: true, + }, + { + name: "array expression over limit", + expr: "[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]", + maxNodes: 5, + shouldError: true, + }, + { + name: "disabled node budget", + expr: "1 + (2 + (3 + (4 + (5 + (6 + (7 + 8))))))", + maxNodes: 0, + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := conf.CreateNew() + config.MaxNodes = tt.maxNodes + config.Disabled = make(map[string]bool, 0) + + _, err := parser.ParseWithConfig(tt.expr, config) + hasError := err != nil && strings.Contains(err.Error(), "exceeds maximum allowed nodes") + + if hasError != tt.shouldError { + t.Errorf("ParseWithConfig(%q) error = %v, shouldError %v", tt.expr, err, tt.shouldError) + } + + // Verify error message format when expected + if tt.shouldError && err != nil { + expected := "compilation failed: expression exceeds maximum allowed nodes" + if !strings.Contains(err.Error(), expected) { + t.Errorf("Expected error message to contain %q, got %q", expected, err.Error()) + } + } + }) + } +} + +func TestNodeBudgetDisabled(t *testing.T) { + config := conf.CreateNew() + config.MaxNodes = 0 // Disable node budget + + expr := strings.Repeat("a + ", 1000) + "b" + _, err := parser.ParseWithConfig(expr, config) + + if err != nil && strings.Contains(err.Error(), "exceeds maximum allowed nodes") { + t.Error("Node budget check should be disabled when MaxNodes is 0") + } +} diff --git a/vm/utils.go b/vm/utils.go index fc2f5e7b8..11005137c 100644 --- a/vm/utils.go +++ b/vm/utils.go @@ -11,9 +11,6 @@ type ( ) var ( - // MemoryBudget represents an upper limit of memory usage. - MemoryBudget uint = 1e6 - errorType = reflect.TypeOf((*error)(nil)).Elem() ) diff --git a/vm/vm.go b/vm/vm.go index fa1223b42..62c6511f4 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -11,6 +11,7 @@ import ( "time" "github.com/expr-lang/expr/builtin" + "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" "github.com/expr-lang/expr/internal/deref" "github.com/expr-lang/expr/vm/runtime" @@ -20,11 +21,23 @@ func Run(program *Program, env any) (any, error) { if program == nil { return nil, fmt.Errorf("program is nil") } - vm := VM{} return vm.Run(program, env) } +func RunWithConfig(program *Program, env any, config *conf.Config) (any, error) { + if program == nil { + return nil, fmt.Errorf("program is nil") + } + if config == nil { + return nil, fmt.Errorf("config is nil") + } + vm := VM{ + MemoryBudget: config.MemoryBudget, + } + return vm.Run(program, env) +} + func Debug() *VM { vm := &VM{ debug: true, @@ -38,9 +51,9 @@ type VM struct { Stack []any Scopes []*Scope Variables []any + MemoryBudget uint ip int memory uint - memoryBudget uint debug bool step chan struct{} curr chan int @@ -76,7 +89,9 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { vm.Variables = make([]any, program.variables) } - vm.memoryBudget = MemoryBudget + if vm.MemoryBudget == 0 { + vm.MemoryBudget = conf.DefaultMemoryBudget + } vm.memory = 0 vm.ip = 0 @@ -599,7 +614,7 @@ func (vm *VM) pop() any { func (vm *VM) memGrow(size uint) { vm.memory += size - if vm.memory >= vm.memoryBudget { + if vm.memory >= vm.MemoryBudget { panic("memory budget exceeded") } } diff --git a/vm/vm_test.go b/vm/vm_test.go index d45ef2d6c..aec358ab7 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "reflect" + "strings" "testing" "github.com/expr-lang/expr/internal/testify/require" @@ -17,8 +18,35 @@ import ( ) func TestRun_NilProgram(t *testing.T) { - _, err := vm.Run(nil, nil) - require.Error(t, err) + t.Run("run with nil program", func(t *testing.T) { + newVM, err := vm.Run(nil, nil) + require.Error(t, err) + require.Nil(t, newVM) + }) + t.Run("run with nil program and nil config", func(t *testing.T) { + newVM, err := vm.RunWithConfig(nil, nil, nil) + require.Error(t, err) + require.Nil(t, newVM) + }) + t.Run("run with nil config", func(t *testing.T) { + program, err := expr.Compile("1") + require.Nil(t, err) + newVM, err := vm.RunWithConfig(program, nil, nil) + require.Error(t, err) + require.Nil(t, newVM) + }) + t.Run("run with config", func(t *testing.T) { + program, err := expr.Compile("1") + require.Nil(t, err) + config := conf.New(nil) + env := map[string]any{ + "a": 1, + } + config.MemoryBudget = 100 + newVM, err := vm.RunWithConfig(program, env, config) + require.Nil(t, err) + require.Equal(t, newVM, 1) + }) } func TestRun_ReuseVM(t *testing.T) { @@ -1191,3 +1219,135 @@ func TestVM_DirectBasicOpcodes(t *testing.T) { }) } } + +func TestVM_MemoryBudget(t *testing.T) { + tests := []struct { + name string + expr string + memBudget uint + expectError string + }{ + { + name: "under budget", + expr: "map(1..10, #)", + memBudget: 100, + }, + { + name: "exceeds budget", + expr: "map(1..1000, #)", + memBudget: 10, + expectError: "memory budget exceeded", + }, + { + name: "zero budget uses default", + expr: "map(1..10, #)", + memBudget: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node, err := parser.Parse(tt.expr) + require.NoError(t, err) + + program, err := compiler.Compile(node, nil) + require.NoError(t, err) + + vm := vm.VM{MemoryBudget: tt.memBudget} + out, err := vm.Run(program, nil) + + if tt.expectError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectError) + } else { + require.NoError(t, err) + require.NotNil(t, out) + } + }) + } +} + +// Helper functions for creating deeply nested expressions +func createNestedArithmeticExpr(t *testing.T, depth int) string { + t.Helper() + if depth == 0 { + return "a" + } + return fmt.Sprintf("(%s + %d)", createNestedArithmeticExpr(t, depth-1), depth) +} + +func createNestedMapExpr(t *testing.T, depth int) string { + t.Helper() + if depth == 0 { + return `{"value": 1}` + } + return fmt.Sprintf(`{"nested": %s}`, createNestedMapExpr(t, depth-1)) +} + +func TestVM_Limits(t *testing.T) { + tests := []struct { + name string + expr string + memoryBudget uint + maxNodes uint + env map[string]any + expectError string + }{ + { + name: "nested arithmetic allowed with max nodes and memory budget", + expr: createNestedArithmeticExpr(t, 100), + env: map[string]any{"a": 1}, + maxNodes: 1000, + memoryBudget: 1, // arithmetic expressions not counted towards memory budget + }, + { + name: "nested arithmetic blocked by max nodes", + expr: createNestedArithmeticExpr(t, 10000), + env: map[string]any{"a": 1}, + maxNodes: 100, + memoryBudget: 1, // arithmetic expressions not counted towards memory budget + expectError: "compilation failed: expression exceeds maximum allowed nodes", + }, + { + name: "nested map blocked by memory budget", + expr: createNestedMapExpr(t, 100), + env: map[string]any{}, + maxNodes: 1000, + memoryBudget: 10, // Small memory budget to trigger limit + expectError: "memory budget exceeded", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var options []expr.Option + options = append(options, expr.Env(test.env)) + if test.maxNodes > 0 { + options = append(options, func(c *conf.Config) { + c.MaxNodes = test.maxNodes + }) + } + + program, err := expr.Compile(test.expr, options...) + if err != nil { + if test.expectError != "" && strings.Contains(err.Error(), test.expectError) { + return + } + t.Fatal(err) + } + + testVM := &vm.VM{ + MemoryBudget: test.memoryBudget, + } + + _, err = testVM.Run(program, test.env) + + if test.expectError == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), test.expectError) + } + }) + } +}