Skip to content

Commit c2b609e

Browse files
committed
Add sum([]) optimization
1 parent edb1b5a commit c2b609e

File tree

5 files changed

+184
-141
lines changed

5 files changed

+184
-141
lines changed

expr_test.go

Lines changed: 0 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -901,147 +901,6 @@ func TestExpr(t *testing.T) {
901901
`all(1..3, {# > 0})`,
902902
true,
903903
},
904-
{
905-
`all(1..3, {# > 0}) && all(1..3, {# < 4})`,
906-
true,
907-
},
908-
{
909-
`all(1..3, {# > 2}) && all(1..3, {# < 4})`,
910-
false,
911-
},
912-
{
913-
`all(1..3, {# > 0}) && all(1..3, {# < 2})`,
914-
false,
915-
},
916-
{
917-
`all(1..3, {# > 2}) && all(1..3, {# < 2})`,
918-
false,
919-
},
920-
{
921-
`all(1..3, {# > 0}) || all(1..3, {# < 4})`,
922-
true,
923-
},
924-
{
925-
`all(1..3, {# > 0}) || all(1..3, {# != 2})`,
926-
true,
927-
},
928-
{
929-
`all(1..3, {# != 3}) || all(1..3, {# < 4})`,
930-
true,
931-
},
932-
{
933-
`all(1..3, {# != 3}) || all(1..3, {# != 2})`,
934-
false,
935-
},
936-
{
937-
`none(1..3, {# == 0})`,
938-
true,
939-
},
940-
{
941-
`none(1..3, {# == 0}) && none(1..3, {# == 4})`,
942-
true,
943-
},
944-
{
945-
`none(1..3, {# == 0}) && none(1..3, {# == 3})`,
946-
false,
947-
},
948-
{
949-
`none(1..3, {# == 1}) && none(1..3, {# == 4})`,
950-
false,
951-
},
952-
{
953-
`none(1..3, {# == 1}) && none(1..3, {# == 3})`,
954-
false,
955-
},
956-
{
957-
`none(1..3, {# == 0}) || none(1..3, {# == 4})`,
958-
true,
959-
},
960-
{
961-
`none(1..3, {# == 0}) || none(1..3, {# == 3})`,
962-
true,
963-
},
964-
{
965-
`none(1..3, {# == 1}) || none(1..3, {# == 4})`,
966-
true,
967-
},
968-
{
969-
`none(1..3, {# == 1}) || none(1..3, {# == 3})`,
970-
false,
971-
},
972-
{
973-
`any([1,1,0,1], {# == 0})`,
974-
true,
975-
},
976-
{
977-
`any(1..3, {# == 1}) && any(1..3, {# == 2})`,
978-
true,
979-
},
980-
{
981-
`any(1..3, {# == 0}) && any(1..3, {# == 2})`,
982-
false,
983-
},
984-
{
985-
`any(1..3, {# == 1}) && any(1..3, {# == 4})`,
986-
false,
987-
},
988-
{
989-
`any(1..3, {# == 0}) && any(1..3, {# == 4})`,
990-
false,
991-
},
992-
{
993-
`any(1..3, {# == 1}) || any(1..3, {# == 2})`,
994-
true,
995-
},
996-
{
997-
`any(1..3, {# == 0}) || any(1..3, {# == 2})`,
998-
true,
999-
},
1000-
{
1001-
`any(1..3, {# == 1}) || any(1..3, {# == 4})`,
1002-
true,
1003-
},
1004-
{
1005-
`any(1..3, {# == 0}) || any(1..3, {# == 4})`,
1006-
false,
1007-
},
1008-
{
1009-
`one([1,1,0,1], {# == 0}) and not one([1,0,0,1], {# == 0})`,
1010-
true,
1011-
},
1012-
{
1013-
`one(1..3, {# == 1}) and one(1..3, {# == 2})`,
1014-
true,
1015-
},
1016-
{
1017-
`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`,
1018-
false,
1019-
},
1020-
{
1021-
`one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`,
1022-
false,
1023-
},
1024-
{
1025-
`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`,
1026-
false,
1027-
},
1028-
{
1029-
`one(1..3, {# == 1}) or one(1..3, {# == 2})`,
1030-
true,
1031-
},
1032-
{
1033-
`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`,
1034-
true,
1035-
},
1036-
{
1037-
`one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`,
1038-
true,
1039-
},
1040-
{
1041-
`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`,
1042-
false,
1043-
},
1044-
1045904
{
1046905
`count(1..30, {# % 3 == 0})`,
1047906
10,

optimizer/optimizer.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ func Optimize(node *Node, config *conf.Config) error {
3737
Walk(node, &filterLast{})
3838
Walk(node, &filterFirst{})
3939
Walk(node, &predicateCombination{})
40+
Walk(node, &sumArray{})
4041
Walk(node, &sumMap{})
4142
return nil
4243
}

optimizer/optimizer_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,79 @@ import (
1717
"github.com/expr-lang/expr/parser"
1818
)
1919

20+
func TestOptimize(t *testing.T) {
21+
env := map[string]any{
22+
"a": 1,
23+
"b": 2,
24+
"c": 3,
25+
}
26+
27+
tests := []struct {
28+
expr string
29+
want any
30+
}{
31+
{`1 + 2`, 3},
32+
{`sum([])`, 0},
33+
{`sum([a])`, 1},
34+
{`sum([a, b])`, 3},
35+
{`sum([a, b, c])`, 6},
36+
{`sum([a, b, c, 4])`, 10},
37+
{`all(1..3, {# > 0}) && all(1..3, {# < 4})`, true},
38+
{`all(1..3, {# > 2}) && all(1..3, {# < 4})`, false},
39+
{`all(1..3, {# > 0}) && all(1..3, {# < 2})`, false},
40+
{`all(1..3, {# > 2}) && all(1..3, {# < 2})`, false},
41+
{`all(1..3, {# > 0}) || all(1..3, {# < 4})`, true},
42+
{`all(1..3, {# > 0}) || all(1..3, {# != 2})`, true},
43+
{`all(1..3, {# != 3}) || all(1..3, {# < 4})`, true},
44+
{`all(1..3, {# != 3}) || all(1..3, {# != 2})`, false},
45+
{`none(1..3, {# == 0})`, true},
46+
{`none(1..3, {# == 0}) && none(1..3, {# == 4})`, true},
47+
{`none(1..3, {# == 0}) && none(1..3, {# == 3})`, false},
48+
{`none(1..3, {# == 1}) && none(1..3, {# == 4})`, false},
49+
{`none(1..3, {# == 1}) && none(1..3, {# == 3})`, false},
50+
{`none(1..3, {# == 0}) || none(1..3, {# == 4})`, true},
51+
{`none(1..3, {# == 0}) || none(1..3, {# == 3})`, true},
52+
{`none(1..3, {# == 1}) || none(1..3, {# == 4})`, true},
53+
{`none(1..3, {# == 1}) || none(1..3, {# == 3})`, false},
54+
{`any([1, 1, 0, 1], {# == 0})`, true},
55+
{`any(1..3, {# == 1}) && any(1..3, {# == 2})`, true},
56+
{`any(1..3, {# == 0}) && any(1..3, {# == 2})`, false},
57+
{`any(1..3, {# == 1}) && any(1..3, {# == 4})`, false},
58+
{`any(1..3, {# == 0}) && any(1..3, {# == 4})`, false},
59+
{`any(1..3, {# == 1}) || any(1..3, {# == 2})`, true},
60+
{`any(1..3, {# == 0}) || any(1..3, {# == 2})`, true},
61+
{`any(1..3, {# == 1}) || any(1..3, {# == 4})`, true},
62+
{`any(1..3, {# == 0}) || any(1..3, {# == 4})`, false},
63+
{`one([1, 1, 0, 1], {# == 0}) and not one([1, 0, 0, 1], {# == 0})`, true},
64+
{`one(1..3, {# == 1}) and one(1..3, {# == 2})`, true},
65+
{`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`, false},
66+
{`one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`, false},
67+
{`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`, false},
68+
{`one(1..3, {# == 1}) or one(1..3, {# == 2})`, true},
69+
{`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`, true},
70+
{`one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`, true},
71+
{`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`, false},
72+
}
73+
74+
for _, tt := range tests {
75+
t.Run(tt.expr, func(t *testing.T) {
76+
program, err := expr.Compile(tt.expr, expr.Env(env))
77+
require.NoError(t, err)
78+
79+
output, err := expr.Run(program, env)
80+
require.NoError(t, err)
81+
assert.Equal(t, tt.want, output)
82+
83+
unoptimizedProgram, err := expr.Compile(tt.expr, expr.Env(env), expr.Optimize(false))
84+
require.NoError(t, err)
85+
86+
unoptimizedOutput, err := expr.Run(unoptimizedProgram, env)
87+
require.NoError(t, err)
88+
assert.Equal(t, tt.want, unoptimizedOutput)
89+
})
90+
}
91+
}
92+
2093
func TestOptimize_constant_folding(t *testing.T) {
2194
tree, err := parser.Parse(`[1,2,3][5*5-25]`)
2295
require.NoError(t, err)

optimizer/sum_array.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package optimizer
2+
3+
import (
4+
"fmt"
5+
6+
. "github.com/expr-lang/expr/ast"
7+
)
8+
9+
type sumArray struct{}
10+
11+
func (*sumArray) Visit(node *Node) {
12+
if sumBuiltin, ok := (*node).(*BuiltinNode); ok &&
13+
sumBuiltin.Name == "sum" &&
14+
len(sumBuiltin.Arguments) == 1 {
15+
if array, ok := sumBuiltin.Arguments[0].(*ArrayNode); ok &&
16+
len(array.Nodes) >= 2 {
17+
Patch(node, sumArrayFold(array))
18+
}
19+
}
20+
}
21+
22+
func sumArrayFold(array *ArrayNode) *BinaryNode {
23+
if len(array.Nodes) > 2 {
24+
return &BinaryNode{
25+
Operator: "+",
26+
Left: array.Nodes[0],
27+
Right: sumArrayFold(&ArrayNode{Nodes: array.Nodes[1:]}),
28+
}
29+
} else if len(array.Nodes) == 2 {
30+
return &BinaryNode{
31+
Operator: "+",
32+
Left: array.Nodes[0],
33+
Right: array.Nodes[1],
34+
}
35+
}
36+
panic(fmt.Errorf("sumArrayFold: invalid array length %d", len(array.Nodes)))
37+
}

optimizer/sum_array_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package optimizer_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/expr-lang/expr"
10+
"github.com/expr-lang/expr/ast"
11+
"github.com/expr-lang/expr/optimizer"
12+
"github.com/expr-lang/expr/parser"
13+
"github.com/expr-lang/expr/vm"
14+
)
15+
16+
func BenchmarkSumArray(b *testing.B) {
17+
env := map[string]any{
18+
"a": 1,
19+
"b": 2,
20+
"c": 3,
21+
"d": 4,
22+
}
23+
24+
program, err := expr.Compile(`sum([a, b, c, d])`, expr.Env(env))
25+
require.NoError(b, err)
26+
27+
var out any
28+
b.ResetTimer()
29+
for n := 0; n < b.N; n++ {
30+
out, err = vm.Run(program, env)
31+
}
32+
b.StopTimer()
33+
34+
require.NoError(b, err)
35+
require.Equal(b, 10, out)
36+
37+
}
38+
39+
func TestOptimize_sum_array(t *testing.T) {
40+
tree, err := parser.Parse(`sum([a, b])`)
41+
require.NoError(t, err)
42+
43+
err = optimizer.Optimize(&tree.Node, nil)
44+
require.NoError(t, err)
45+
46+
expected := &ast.BinaryNode{
47+
Operator: "+",
48+
Left: &ast.IdentifierNode{Value: "a"},
49+
Right: &ast.IdentifierNode{Value: "b"},
50+
}
51+
52+
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
53+
}
54+
55+
func TestOptimize_sum_array_3(t *testing.T) {
56+
tree, err := parser.Parse(`sum([a, b, c])`)
57+
require.NoError(t, err)
58+
59+
err = optimizer.Optimize(&tree.Node, nil)
60+
require.NoError(t, err)
61+
62+
expected := &ast.BinaryNode{
63+
Operator: "+",
64+
Left: &ast.IdentifierNode{Value: "a"},
65+
Right: &ast.BinaryNode{
66+
Operator: "+",
67+
Left: &ast.IdentifierNode{Value: "b"},
68+
Right: &ast.IdentifierNode{Value: "c"},
69+
},
70+
}
71+
72+
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
73+
}

0 commit comments

Comments
 (0)