Skip to content

Commit 9d4faa5

Browse files
authored
[fix](Nereids) cse extract wrong expression from lambda expressions (#49166)
### What problem does this PR solve? Related PR: #33087 Problem Summary: CSE should not extract ArrayItemSlot and ArrayItemReference. Because they could not be computed out of Lambda expression. NOTICE: currently, we could not extract common expression from Lambda, because ArrayItemSlot in Lambda are not same.
1 parent 9e45efb commit 9d4faa5

File tree

4 files changed

+73
-19
lines changed

4 files changed

+73
-19
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
package org.apache.doris.nereids.processor.post;
1919

20+
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
21+
import org.apache.doris.nereids.trees.expressions.ArrayItemReference.ArrayItemSlot;
2022
import org.apache.doris.nereids.trees.expressions.Expression;
23+
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
2124
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
2225

2326
import java.util.HashMap;
@@ -28,29 +31,38 @@
2831
/**
2932
* collect common expr
3033
*/
31-
public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, Void> {
34+
public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, Boolean> {
3235
public final Map<Integer, Set<Expression>> commonExprByDepth = new HashMap<>();
3336
private final Map<Integer, Set<Expression>> expressionsByDepth = new HashMap<>();
3437

38+
public int collect(Expression expr) {
39+
return expr.accept(this, expr instanceof Lambda);
40+
}
41+
3542
@Override
36-
public Integer visit(Expression expr, Void context) {
43+
public Integer visit(Expression expr, Boolean inLambda) {
3744
if (expr.children().isEmpty()) {
3845
return 0;
3946
}
4047
return collectCommonExpressionByDepth(
4148
expr.children()
4249
.stream()
43-
.map(child -> child.accept(this, context))
50+
.map(child -> child.accept(this, inLambda == null || inLambda || child instanceof Lambda))
4451
.reduce(Math::max)
4552
.map(m -> m + 1)
4653
.orElse(1),
47-
expr
54+
expr,
55+
inLambda == null || inLambda
4856
);
4957
}
5058

51-
private int collectCommonExpressionByDepth(int depth, Expression expr) {
59+
private int collectCommonExpressionByDepth(int depth, Expression expr, boolean inLambda) {
5260
Set<Expression> expressions = getExpressionsFromDepthMap(depth, expressionsByDepth);
53-
if (expressions.contains(expr)) {
61+
// ArrayItemSlot and ArrayItemReference could not be common expressions
62+
// TODO: could not extract common expression when expression contains same lambda expression
63+
// because ArrayItemSlot in Lambda are not same.
64+
if (expressions.contains(expr)
65+
&& !(inLambda && expr.containsType(ArrayItemSlot.class, ArrayItemReference.class))) {
5466
Set<Expression> commonExpression = getExpressionsFromDepthMap(depth, commonExprByDepth);
5567
commonExpression.add(expr);
5668
}

fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ private List<List<NamedExpression>> computeMultiLayerProjections(
6666
List<List<NamedExpression>> multiLayers = Lists.newArrayList();
6767
CommonSubExpressionCollector collector = new CommonSubExpressionCollector();
6868
for (Expression expr : projects) {
69-
expr.accept(collector, null);
69+
collector.collect(expr);
7070
}
7171
// use linkedHashMap to make projects order stable
7272
Map<Expression, Alias> aliasMap = new LinkedHashMap<>();

fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,24 @@
2020
import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector;
2121
import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt;
2222
import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
23+
import org.apache.doris.nereids.trees.expressions.Add;
2324
import org.apache.doris.nereids.trees.expressions.Alias;
25+
import org.apache.doris.nereids.trees.expressions.And;
26+
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
27+
import org.apache.doris.nereids.trees.expressions.ExprId;
2428
import org.apache.doris.nereids.trees.expressions.Expression;
2529
import org.apache.doris.nereids.trees.expressions.NamedExpression;
2630
import org.apache.doris.nereids.trees.expressions.Slot;
2731
import org.apache.doris.nereids.trees.expressions.SlotReference;
32+
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
33+
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
34+
import org.apache.doris.nereids.trees.expressions.literal.Literal;
2835
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
36+
import org.apache.doris.nereids.types.ArrayType;
2937
import org.apache.doris.nereids.types.IntegerType;
3038

39+
import com.google.common.collect.ImmutableList;
40+
import com.google.common.collect.Lists;
3141
import org.junit.jupiter.api.Assertions;
3242
import org.junit.jupiter.api.Test;
3343

@@ -37,27 +47,40 @@
3747
import java.util.List;
3848
import java.util.Map;
3949
import java.util.Set;
40-
import java.util.stream.Collectors;
4150

4251
public class CommonSubExpressionTest extends ExpressionRewriteTestHelper {
4352
@Test
4453
public void testExtractCommonExpr() {
4554
List<NamedExpression> exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a");
46-
CommonSubExpressionCollector collector =
47-
new CommonSubExpressionCollector();
55+
CommonSubExpressionCollector collector = new CommonSubExpressionCollector();
4856
exprs.forEach(expr -> collector.visit(expr, null));
49-
System.out.println(collector.commonExprByDepth);
5057
Assertions.assertEquals(2, collector.commonExprByDepth.size());
51-
List<Expression> l1 = collector.commonExprByDepth.get(Integer.valueOf(1))
52-
.stream().collect(Collectors.toList());
53-
List<Expression> l2 = collector.commonExprByDepth.get(Integer.valueOf(2))
54-
.stream().collect(Collectors.toList());
58+
List<Expression> l1 = new ArrayList<>(collector.commonExprByDepth.get(1));
59+
List<Expression> l2 = new ArrayList<>(collector.commonExprByDepth.get(2));
5560
Assertions.assertEquals(1, l1.size());
5661
assertExpression(l1.get(0), "a+b");
5762
Assertions.assertEquals(1, l2.size());
5863
assertExpression(l2.get(0), "a+b+1");
5964
}
6065

66+
@Test
67+
void testLambdaExpression() {
68+
ArrayItemReference ref = new ArrayItemReference("x", new SlotReference(new ExprId(1), "y",
69+
ArrayType.of(IntegerType.INSTANCE), true, ImmutableList.of()));
70+
Expression add = new Add(ref.toSlot(), Literal.of(1));
71+
Expression and = new And(add, add);
72+
ArrayMap arrayMap = new ArrayMap(new Lambda(ImmutableList.of("x"), and, ImmutableList.of(ref)));
73+
List<NamedExpression> exprs = Lists.newArrayList(
74+
new Alias(new ExprId(10000), arrayMap, "c1"),
75+
new Alias(new ExprId(10001), arrayMap, "c2")
76+
);
77+
CommonSubExpressionCollector collector = new CommonSubExpressionCollector();
78+
exprs.forEach(expr -> collector.visit(expr, false));
79+
Assertions.assertEquals(1, collector.commonExprByDepth.size());
80+
Assertions.assertEquals(1, collector.commonExprByDepth.get(4).size());
81+
Assertions.assertEquals(arrayMap, collector.commonExprByDepth.get(4).iterator().next());
82+
}
83+
6184
@Test
6285
public void testMultiLayers() throws Exception {
6386
List<NamedExpression> exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a");
@@ -68,15 +91,14 @@ public void testMultiLayers() throws Exception {
6891
computeMultLayerProjectionsMethod.setAccessible(true);
6992
List<List<NamedExpression>> multiLayers = (List<List<NamedExpression>>) computeMultLayerProjectionsMethod
7093
.invoke(opt, inputSlots, exprs);
71-
System.out.println(multiLayers);
7294
Assertions.assertEquals(3, multiLayers.size());
7395
List<NamedExpression> l0 = multiLayers.get(0);
7496
Assertions.assertEquals(2, l0.size());
7597
Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a")));
76-
Assertions.assertTrue(l0.get(1) instanceof Alias);
98+
Assertions.assertInstanceOf(Alias.class, l0.get(1));
7799
assertExpression(l0.get(1).child(0), "a+b");
78-
Assertions.assertEquals(multiLayers.get(1).size(), 3);
79-
Assertions.assertEquals(multiLayers.get(2).size(), 5);
100+
Assertions.assertEquals(3, multiLayers.get(1).size());
101+
Assertions.assertEquals(5, multiLayers.get(2).size());
80102
List<NamedExpression> l2 = multiLayers.get(2);
81103
for (int i = 0; i < 5; i++) {
82104
Assertions.assertEquals(exprs.get(i).getExprId().asInt(), l2.get(i).getExprId().asInt());

regression-test/suites/nereids_rules_p0/cse/cse.groovy

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,26 @@ suite("cse") {
6767
"""
6868
contains("l1([k1#0, d#1, i#2, (k1 >= i)#9, cast(d as TEXT)#10, unix_timestamp(cast(d as TEXT)#10, '%Y%m%d') AS `unix_timestamp(cast(d as TEXT), '%Y%m%d')`#11])")
6969
}
70+
71+
// cse should not extract expression use for lambda, such as ArrayItemSlot and ArrayItemReference
72+
sql """
73+
drop table if exists array_cse;
74+
"""
75+
sql """
76+
create table array_cse(c1 int, c2 array<varchar(255)>) PROPERTIES ("replication_allocation" = "tag.location.default: 1");
77+
"""
78+
sql """
79+
insert into array_cse values(1, [1,2,3]);
80+
"""
81+
sql """
82+
sync
83+
"""
84+
sql """
85+
SELECT array_map(x-> if(left(x, 5) = '12345', x, left(x, 5)), c2) FROM array_cse;
86+
"""
87+
sql """
88+
SELECT c0, c0 FROM (SELECT ARRAY_MAP(x-> if(left(x, 5), x, left(x, 5)), `c2`) as `c0` FROM array_cse) t
89+
"""
7090

7191
}
7292

0 commit comments

Comments
 (0)