Skip to content

Commit bfb3eb0

Browse files
authored
[KQP] Add new rule to KQP pipeline (#20809)
1 parent 4e90a2b commit bfb3eb0

File tree

6 files changed

+217
-0
lines changed

6 files changed

+217
-0
lines changed

ydb/core/kqp/opt/physical/kqp_opt_phy.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase {
103103
AddHandler(1, &TCoExtractMembers::Match, HNDL(PushExtractMembersToStage<true>));
104104
AddHandler(1, &TCoFlatMapBase::Match, HNDL(BuildFlatmapStage<true>));
105105
AddHandler(1, &TCoCombineByKey::Match, HNDL(PushCombineToStage<true>));
106+
AddHandler(1, &TCoCombineByKey::Match, HNDL(PushCombineToStageDependsOnOtherStage<true>));
106107
AddHandler(1, &TCoPartitionsByKeys::Match, HNDL(BuildPartitionsStage<true>));
107108
AddHandler(1, &TCoFinalizeByKey::Match, HNDL(BuildFinalizeByKeyStage<true>));
108109
AddHandler(1, &TCoShuffleByKeys::Match, HNDL(BuildShuffleStage<true>));
@@ -330,6 +331,15 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase {
330331
return output;
331332
}
332333

334+
template <bool IsGlobal>
335+
TMaybeNode<TExprBase> PushCombineToStageDependsOnOtherStage(TExprBase node, TExprContext& ctx,
336+
IOptimizationContext& optCtx, const TGetParents& getParents)
337+
{
338+
TExprBase output = DqPushCombineToStageDependsOnOtherStage(node, ctx, optCtx, *getParents(), IsGlobal);
339+
DumpAppliedRule("PushCombineToStageDependsOnOtherStage", node.Ptr(), output.Ptr(), ctx);
340+
return output;
341+
}
342+
333343
template <bool IsGlobal>
334344
TMaybeNode<TExprBase> BuildShuffleStage(TExprBase node, TExprContext& ctx,
335345
IOptimizationContext& optCtx, const TGetParents& getParents)

ydb/core/kqp/ut/opt/kqp_agg_ut.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,88 @@ Y_UNIT_TEST_SUITE(KqpAgg) {
244244
}
245245

246246
}
247+
248+
Y_UNIT_TEST(AggWithSqlIn) {
249+
auto settings = TKikimrSettings().SetWithSampleTables(false);
250+
TKikimrRunner kikimr(settings);
251+
252+
auto tableClient = kikimr.GetTableClient();
253+
auto session = tableClient.CreateSession().GetValueSync().GetSession();
254+
255+
auto queryClient = kikimr.GetQueryClient();
256+
auto result = queryClient.GetSession().GetValueSync();
257+
NStatusHelpers::ThrowOnError(result);
258+
auto session2 = result.GetSession();
259+
260+
auto res = session.ExecuteSchemeQuery(R"(
261+
CREATE TABLE `/Root/t1` (
262+
a Int64 NOT NULL,
263+
b Int32,
264+
primary key(a)
265+
)
266+
PARTITION BY HASH(a)
267+
WITH (STORE = COLUMN);
268+
)").GetValueSync();
269+
UNIT_ASSERT(res.IsSuccess());
270+
271+
res = session.ExecuteSchemeQuery(R"(
272+
CREATE TABLE `/Root/t2` (
273+
a Int64 NOT NULL,
274+
b Int32,
275+
primary key(a)
276+
)
277+
PARTITION BY HASH(a)
278+
WITH (STORE = COLUMN);
279+
)").GetValueSync();
280+
UNIT_ASSERT(res.IsSuccess());
281+
282+
res = session.ExecuteSchemeQuery(R"(
283+
CREATE TABLE `/Root/t3` (
284+
a Int64 NOT NULL,
285+
b Int32,
286+
primary key(a)
287+
)
288+
PARTITION BY HASH(a)
289+
WITH (STORE = COLUMN);
290+
)").GetValueSync();
291+
UNIT_ASSERT(res.IsSuccess());
292+
293+
auto insertRes = session2.ExecuteQuery(R"(
294+
INSERT INTO `/Root/t1` (a, b) VALUES (1, 1);
295+
INSERT INTO `/Root/t2` (a, b) VALUES (1, 1);
296+
INSERT INTO `/Root/t3` (a, b) VALUES (1, 1);
297+
INSERT INTO `/Root/t1` (a, b) VALUES (2, 1);
298+
INSERT INTO `/Root/t2` (a, b) VALUES (2, 1);
299+
INSERT INTO `/Root/t3` (a, b) VALUES (2, 1);
300+
INSERT INTO `/Root/t1` (a, b) VALUES (3, 1);
301+
INSERT INTO `/Root/t2` (a, b) VALUES (3, 1);
302+
INSERT INTO `/Root/t3` (a, b) VALUES (3, 1);
303+
)", NYdb::NQuery::TTxControl::NoTx()).GetValueSync();
304+
UNIT_ASSERT(insertRes.IsSuccess());
305+
306+
std::vector<TString> queries = {
307+
R"(
308+
SELECT sum(CASE WHEN (t1.a IN (SELECT t2.a FROM t2)) THEN 1 ELSE 0 END),
309+
sum(CASE WHEN (t1.a IN (SELECT t3.a FROM t3)) THEN 1 ELSE 0 END)
310+
FROM t1;
311+
)",
312+
};
313+
314+
for (ui32 i = 0; i < queries.size(); ++i) {
315+
const auto query = queries[i];
316+
auto result =
317+
session2
318+
.ExecuteQuery(query, NYdb::NQuery::TTxControl::NoTx(), NYdb::NQuery::TExecuteQuerySettings().ExecMode(NQuery::EExecMode::Explain))
319+
.ExtractValueSync();
320+
UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
321+
322+
result = session2.ExecuteQuery(query, NYdb::NQuery::TTxControl::NoTx(), NYdb::NQuery::TExecuteQuerySettings()).ExtractValueSync();
323+
UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS);
324+
325+
TString output = FormatResultSetYson(result.GetResultSet(0));
326+
UNIT_ASSERT_VALUES_EQUAL(FormatResultSetYson(result.GetResultSet(0)), "[[[3];[3]]]");
327+
}
328+
}
247329
}
248330

249331
} // namespace NKikimr::NKqp

ydb/library/yql/dq/opt/dq_opt.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ bool IsDqDependsOnStage(const TExprBase& node, const TDqStageBase& stage) {
140140
});
141141
}
142142

143+
bool IsDqDependsOnOtherStage(const TExprBase& node, const TDqStageBase& stage) {
144+
return !!FindNode(node.Ptr(), [ptr = stage.Raw()](const TExprNode::TPtr& node) {
145+
if (TMaybeNode<TDqStage>(node)) {
146+
return node.Get() != ptr;
147+
}
148+
return false;
149+
});
150+
}
151+
143152
bool IsDqDependsOnStageOutput(const TExprBase& node, const TDqStageBase& stage, ui32 outputIndex) {
144153
return !!FindNode(node.Ptr(), [ptr = stage.Raw(), outputIndex](const TExprNode::TPtr& exprNode) {
145154
if (TDqOutput::Match(exprNode.Get())) {

ydb/library/yql/dq/opt/dq_opt.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ inline bool IsDqCompletePureExpr(const NNodes::TExprBase& node, bool isPrecomput
3838

3939
bool IsDqSelfContainedExpr(const NNodes::TExprBase& node);
4040
bool IsDqDependsOnStage(const NNodes::TExprBase& node, const NNodes::TDqStageBase& stage);
41+
bool IsDqDependsOnOtherStage(const NNodes::TExprBase& node, const NNodes::TDqStageBase& stage);
4142
bool IsDqDependsOnStageOutput(const NNodes::TExprBase& node, const NNodes::TDqStageBase& stage, ui32 outputIndex);
4243

4344
bool CanPushDqExpr(const NNodes::TExprBase& expr, const NNodes::TDqStageBase& stage);

ydb/library/yql/dq/opt/dq_opt_phy.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,8 @@ TExprBase DqBuildLMapOverMuxStage(TExprBase node, TExprContext& ctx, IOptimizati
10751075
return DqBuildLMapOverMuxStageStub<TCoLMap>(node, ctx, optCtx, parentsMap);
10761076
}
10771077

1078+
1079+
10781080
TExprBase DqPushCombineToStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx,
10791081
const TParentsMap& parentsMap, bool allowStageMultiUsage)
10801082
{
@@ -1156,6 +1158,112 @@ TExprBase DqPushCombineToStage(TExprBase node, TExprContext& ctx, IOptimizationC
11561158
return result.Cast();
11571159
}
11581160

1161+
TExprBase DqPushCombineToStageDependsOnOtherStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TParentsMap& parentsMap,
1162+
bool allowStageMultiUsage) {
1163+
Y_UNUSED(optCtx);
1164+
if (!node.Maybe<TCoCombineByKey>().Input().Maybe<TDqCnUnionAll>()) {
1165+
return node;
1166+
}
1167+
1168+
auto combine = node.Cast<TCoCombineByKey>();
1169+
auto dqUnion = combine.Input().Cast<TDqCnUnionAll>();
1170+
1171+
if (!IsSingleConsumerConnection(dqUnion, parentsMap, allowStageMultiUsage)) {
1172+
return node;
1173+
}
1174+
1175+
if (IsDqDependsOnOtherStage(combine.InitHandlerLambda(), dqUnion.Output().Stage()) &&
1176+
IsDqDependsOnOtherStage(combine.UpdateHandlerLambda(), dqUnion.Output().Stage()) &&
1177+
!IsDqDependsOnOtherStage(combine.PreMapLambda(), dqUnion.Output().Stage()) &&
1178+
!IsDqDependsOnOtherStage(combine.KeySelectorLambda(), dqUnion.Output().Stage()) &&
1179+
!IsDqDependsOnOtherStage(combine.FinishHandlerLambda(), dqUnion.Output().Stage())) {
1180+
1181+
// Collect all connections for `UpdateHnadler` and `InitHandler` lambda, we want to replace them.
1182+
TExprNode::TListType connections{dqUnion.Ptr()};
1183+
auto connectionPredicate = [](const TExprNode::TPtr& node) { return !!TMaybeNode<TDqConnection>(node); };
1184+
const auto connectionsInitHandler = FindNodes(combine.InitHandlerLambda().Ptr(), connectionPredicate);
1185+
const auto connectionsUpdateHandler = FindNodes(combine.UpdateHandlerLambda().Ptr(), connectionPredicate);
1186+
1187+
if (connectionsInitHandler.empty() || (connectionsInitHandler.size() != connectionsUpdateHandler.size())) {
1188+
return node;
1189+
}
1190+
1191+
// Check that all collected connections are the same.
1192+
for (ui32 i = 0; i < connectionsInitHandler.size(); ++i) {
1193+
if ((connectionsInitHandler[i].Get() != connectionsUpdateHandler[i].Get()) || (connectionsInitHandler[i].Get() == dqUnion.Raw()) ||
1194+
(connectionsUpdateHandler[i].Get() == dqUnion.Raw())) {
1195+
return node;
1196+
}
1197+
}
1198+
1199+
connections.insert(connections.end(), connectionsInitHandler.begin(), connectionsInitHandler.end());
1200+
1201+
// Arguments for DqStage.
1202+
TVector<TCoArgument> inputArgs;
1203+
for (ui32 i = 0; i < connections.size(); ++i) {
1204+
TCoArgument arg = Build<TCoArgument>(ctx, node.Pos())
1205+
.Name(TStringBuilder() << "input_arg_" << i)
1206+
.Done();
1207+
inputArgs.push_back(arg);
1208+
}
1209+
1210+
// Arguments for the Flatmap's lambdas.
1211+
TVector<TCoArgument> lambdaArgs;
1212+
for (ui32 i = 1; i < connections.size(); ++i) {
1213+
TCoArgument lambdaArg = Build<TCoArgument>(ctx, node.Pos())
1214+
.Name(TStringBuilder() << "lambda_arg_" << i)
1215+
.Done();
1216+
lambdaArgs.push_back(lambdaArg);
1217+
}
1218+
1219+
TNodeOnNodeOwnedMap replaces;
1220+
replaces[connections[0].Get()] = inputArgs.front().Ptr();
1221+
for (ui32 i = 0; i < lambdaArgs.size(); ++i) {
1222+
replaces[connections[i + 1].Get()] = lambdaArgs[i].Ptr();
1223+
}
1224+
1225+
// Replace connections with `lambdaArgs` for the future flatmaps.
1226+
auto newBody = TExprBase(ctx.ReplaceNodes(combine.Ptr(), replaces));
1227+
1228+
// For input args in range [1, n] wrap to `SqueezeToList` and create
1229+
// a flatmap.
1230+
for (ui32 i = 1; i < inputArgs.size(); ++i) {
1231+
auto squeezeToListArg = Build<TCoSqueezeToList>(ctx, node.Pos())
1232+
.Stream(inputArgs[i])
1233+
.Done();
1234+
1235+
newBody = Build<TCoFlatMap>(ctx, node.Pos())
1236+
.Input(squeezeToListArg)
1237+
.Lambda()
1238+
.Args(lambdaArgs[i - 1])
1239+
.Body(newBody)
1240+
.Build()
1241+
.Done();
1242+
}
1243+
1244+
// Build a stage with inputs.
1245+
auto stage = Build<TDqStage>(ctx, node.Pos())
1246+
.Inputs()
1247+
.Add(connections)
1248+
.Build()
1249+
.Program()
1250+
.Args(inputArgs)
1251+
.Body(newBody)
1252+
.Build()
1253+
.Settings(TDqStageSettings().BuildNode(ctx, node.Pos()))
1254+
.Done();
1255+
1256+
return Build<TDqCnUnionAll>(ctx, node.Pos())
1257+
.Output()
1258+
.Stage(stage)
1259+
.Index().Build("0")
1260+
.Build()
1261+
.Done();
1262+
}
1263+
1264+
return node;
1265+
}
1266+
11591267
NNodes::TExprBase DqPushAggregateCombineToStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx,
11601268
const TParentsMap& parentsMap, bool allowStageMultiUsage)
11611269
{
@@ -1169,6 +1277,10 @@ NNodes::TExprBase DqPushAggregateCombineToStage(NNodes::TExprBase node, TExprCon
11691277
return node;
11701278
}
11711279

1280+
if (!IsDqCompletePureExpr(aggCombine.Handlers())) {
1281+
return node;
1282+
}
1283+
11721284
auto lambda = Build<TCoLambda>(ctx, aggCombine.Pos())
11731285
.Args({"stream"})
11741286
.Body<TCoAggregateCombine>()

ydb/library/yql/dq/opt/dq_opt_phy.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ NNodes::TExprBase DqPushFlatmapToStage(NNodes::TExprBase node, TExprContext& ctx
6060
NNodes::TExprBase DqPushCombineToStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx,
6161
const TParentsMap& parentsMap, bool allowStageMultiUsage = true);
6262

63+
NNodes::TExprBase DqPushCombineToStageDependsOnOtherStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx,
64+
const TParentsMap& parentsMap, bool allowStageMultiUsage = true);
65+
6366
NNodes::TExprBase DqPushAggregateCombineToStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx,
6467
const TParentsMap& parentsMap, bool allowStageMultiUsage = true);
6568

0 commit comments

Comments
 (0)