20
20
import org .apache .doris .nereids .processor .post .CommonSubExpressionCollector ;
21
21
import org .apache .doris .nereids .processor .post .CommonSubExpressionOpt ;
22
22
import org .apache .doris .nereids .rules .expression .ExpressionRewriteTestHelper ;
23
+ import org .apache .doris .nereids .trees .expressions .Add ;
23
24
import org .apache .doris .nereids .trees .expressions .Alias ;
25
+ import org .apache .doris .nereids .trees .expressions .And ;
26
+ import org .apache .doris .nereids .trees .expressions .ArrayItemReference ;
27
+ import org .apache .doris .nereids .trees .expressions .ExprId ;
24
28
import org .apache .doris .nereids .trees .expressions .Expression ;
25
29
import org .apache .doris .nereids .trees .expressions .NamedExpression ;
26
30
import org .apache .doris .nereids .trees .expressions .Slot ;
27
31
import org .apache .doris .nereids .trees .expressions .SlotReference ;
32
+ import org .apache .doris .nereids .trees .expressions .functions .scalar .ArrayMap ;
33
+ import org .apache .doris .nereids .trees .expressions .functions .scalar .Lambda ;
34
+ import org .apache .doris .nereids .trees .expressions .literal .Literal ;
28
35
import org .apache .doris .nereids .trees .expressions .visitor .DefaultExpressionRewriter ;
36
+ import org .apache .doris .nereids .types .ArrayType ;
29
37
import org .apache .doris .nereids .types .IntegerType ;
30
38
39
+ import com .google .common .collect .ImmutableList ;
40
+ import com .google .common .collect .Lists ;
31
41
import org .junit .jupiter .api .Assertions ;
32
42
import org .junit .jupiter .api .Test ;
33
43
37
47
import java .util .List ;
38
48
import java .util .Map ;
39
49
import java .util .Set ;
40
- import java .util .stream .Collectors ;
41
50
42
51
public class CommonSubExpressionTest extends ExpressionRewriteTestHelper {
43
52
@ Test
44
53
public void testExtractCommonExpr () {
45
54
List <NamedExpression > exprs = parseProjections ("a+b, a+b+1, abs(a+b+1), a" );
46
- CommonSubExpressionCollector collector =
47
- new CommonSubExpressionCollector ();
55
+ CommonSubExpressionCollector collector = new CommonSubExpressionCollector ();
48
56
exprs .forEach (expr -> collector .visit (expr , null ));
49
- System .out .println (collector .commonExprByDepth );
50
57
Assertions .assertEquals (2 , collector .commonExprByDepth .size ());
51
- List <Expression > l1 = collector .commonExprByDepth .get (Integer .valueOf (1 ))
52
- .stream ().collect (Collectors .toList ());
53
- List <Expression > l2 = collector .commonExprByDepth .get (Integer .valueOf (2 ))
54
- .stream ().collect (Collectors .toList ());
58
+ List <Expression > l1 = new ArrayList <>(collector .commonExprByDepth .get (1 ));
59
+ List <Expression > l2 = new ArrayList <>(collector .commonExprByDepth .get (2 ));
55
60
Assertions .assertEquals (1 , l1 .size ());
56
61
assertExpression (l1 .get (0 ), "a+b" );
57
62
Assertions .assertEquals (1 , l2 .size ());
58
63
assertExpression (l2 .get (0 ), "a+b+1" );
59
64
}
60
65
66
+ @ Test
67
+ void testLambdaExpression () {
68
+ ArrayItemReference ref = new ArrayItemReference ("x" , new SlotReference (new ExprId (1 ), "y" ,
69
+ ArrayType .of (IntegerType .INSTANCE ), true , ImmutableList .of ()));
70
+ Expression add = new Add (ref .toSlot (), Literal .of (1 ));
71
+ Expression and = new And (add , add );
72
+ ArrayMap arrayMap = new ArrayMap (new Lambda (ImmutableList .of ("x" ), and , ImmutableList .of (ref )));
73
+ List <NamedExpression > exprs = Lists .newArrayList (
74
+ new Alias (new ExprId (10000 ), arrayMap , "c1" ),
75
+ new Alias (new ExprId (10001 ), arrayMap , "c2" )
76
+ );
77
+ CommonSubExpressionCollector collector = new CommonSubExpressionCollector ();
78
+ exprs .forEach (expr -> collector .visit (expr , false ));
79
+ Assertions .assertEquals (1 , collector .commonExprByDepth .size ());
80
+ Assertions .assertEquals (1 , collector .commonExprByDepth .get (4 ).size ());
81
+ Assertions .assertEquals (arrayMap , collector .commonExprByDepth .get (4 ).iterator ().next ());
82
+ }
83
+
61
84
@ Test
62
85
public void testMultiLayers () throws Exception {
63
86
List <NamedExpression > exprs = parseProjections ("a, a+b, a+b+1, abs(a+b+1), a" );
@@ -68,15 +91,14 @@ public void testMultiLayers() throws Exception {
68
91
computeMultLayerProjectionsMethod .setAccessible (true );
69
92
List <List <NamedExpression >> multiLayers = (List <List <NamedExpression >>) computeMultLayerProjectionsMethod
70
93
.invoke (opt , inputSlots , exprs );
71
- System .out .println (multiLayers );
72
94
Assertions .assertEquals (3 , multiLayers .size ());
73
95
List <NamedExpression > l0 = multiLayers .get (0 );
74
96
Assertions .assertEquals (2 , l0 .size ());
75
97
Assertions .assertTrue (l0 .contains (ExprParser .INSTANCE .parseExpression ("a" )));
76
- Assertions .assertTrue ( l0 .get (1 ) instanceof Alias );
98
+ Assertions .assertInstanceOf ( Alias . class , l0 .get (1 ));
77
99
assertExpression (l0 .get (1 ).child (0 ), "a+b" );
78
- Assertions .assertEquals (multiLayers .get (1 ).size (), 3 );
79
- Assertions .assertEquals (multiLayers .get (2 ).size (), 5 );
100
+ Assertions .assertEquals (3 , multiLayers .get (1 ).size ());
101
+ Assertions .assertEquals (5 , multiLayers .get (2 ).size ());
80
102
List <NamedExpression > l2 = multiLayers .get (2 );
81
103
for (int i = 0 ; i < 5 ; i ++) {
82
104
Assertions .assertEquals (exprs .get (i ).getExprId ().asInt (), l2 .get (i ).getExprId ().asInt ());
0 commit comments