Skip to content

Commit 168df93

Browse files
authored
Fixed resolving of pg aggregation over state (#10441)
1 parent b9bc524 commit 168df93

File tree

14 files changed

+177
-89
lines changed

14 files changed

+177
-89
lines changed

ydb/library/yql/core/yql_aggregate_expander.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2523,9 +2523,20 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() {
25232523
.Seal()
25242524
.Build();
25252525

2526+
auto name = TString(originalTrait->ChildPtr(0)->Content());
2527+
if (name.StartsWith("pg_")) {
2528+
auto func = name.substr(3);
2529+
TVector<ui32> argTypes;
2530+
bool needRetype = false;
2531+
auto status = ExtractPgTypesFromMultiLambda(originalTrait->ChildRef(2), argTypes, needRetype, Ctx);
2532+
YQL_ENSURE(status == IGraphTransformer::TStatus::Ok);
2533+
const NPg::TAggregateDesc& aggDesc = NPg::LookupAggregation(TString(func), argTypes);
2534+
name = "pg_" + aggDesc.Name + "#" + ToString(aggDesc.AggId);
2535+
}
2536+
25262537
mergeTraits.push_back(Ctx.Builder(Node->Pos())
25272538
.Callable(many ? "AggApplyManyState" : "AggApplyState")
2528-
.Add(0, originalTrait->ChildPtr(0))
2539+
.Atom(0, name)
25292540
.Add(1, extractorTypeNode)
25302541
.Add(2, extractor)
25312542
.Add(3, originalExtractorTypeNode)

ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ struct TAggParams {
427427
ui32 Column_ = 0;
428428
TType* StateType_ = nullptr;
429429
TType* ReturnType_ = nullptr;
430+
ui32 Hint_ = 0;
430431
};
431432

432433
struct TKeyParams {
@@ -1723,15 +1724,18 @@ std::unique_ptr<IPreparedBlockAggregator<TAggregator>> PrepareBlockAggregator(co
17231724
std::optional<ui32> filterColumn,
17241725
const std::vector<ui32>& argsColumns,
17251726
const TTypeEnvironment& env,
1726-
TType* returnType);
1727+
TType* returnType,
1728+
ui32 hint);
17271729

17281730
template <>
17291731
std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineAll>> PrepareBlockAggregator<IBlockAggregatorCombineAll>(const IBlockAggregatorFactory& factory,
17301732
TTupleType* tupleType,
17311733
std::optional<ui32> filterColumn,
17321734
const std::vector<ui32>& argsColumns,
17331735
const TTypeEnvironment& env,
1734-
TType* returnType) {
1736+
TType* returnType,
1737+
ui32 hint) {
1738+
Y_UNUSED(hint);
17351739
MKQL_ENSURE(!returnType, "Unexpected return type");
17361740
return factory.PrepareCombineAll(tupleType, filterColumn, argsColumns, env);
17371741
}
@@ -1742,7 +1746,9 @@ std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineKeys>> PrepareBl
17421746
std::optional<ui32> filterColumn,
17431747
const std::vector<ui32>& argsColumns,
17441748
const TTypeEnvironment& env,
1745-
TType* returnType) {
1749+
TType* returnType,
1750+
ui32 hint) {
1751+
Y_UNUSED(hint);
17461752
MKQL_ENSURE(!filterColumn, "Unexpected filter column");
17471753
MKQL_ENSURE(!returnType, "Unexpected return type");
17481754
return factory.PrepareCombineKeys(tupleType, argsColumns, env);
@@ -1754,10 +1760,11 @@ std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorFinalizeKeys>> PrepareB
17541760
std::optional<ui32> filterColumn,
17551761
const std::vector<ui32>& argsColumns,
17561762
const TTypeEnvironment& env,
1757-
TType* returnType) {
1763+
TType* returnType,
1764+
ui32 hint) {
17581765
MKQL_ENSURE(!filterColumn, "Unexpected filter column");
17591766
MKQL_ENSURE(returnType, "Missing return type");
1760-
return factory.PrepareFinalizeKeys(tupleType, argsColumns, env, returnType);
1767+
return factory.PrepareFinalizeKeys(tupleType, argsColumns, env, returnType, hint);
17611768
}
17621769

17631770
template <typename TAggregator>
@@ -1802,9 +1809,13 @@ ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<
18021809
p.Column_ = argColumns[0];
18031810
p.StateType_ = AS_TYPE(TBlockType, tupleType->GetElementType(p.Column_))->GetItemType();
18041811
p.ReturnType_ = returnTypes[i + keysCount];
1812+
TStringBuf left, right;
1813+
if (TStringBuf(name).TrySplit('#', left, right)) {
1814+
p.Hint_ = FromString<ui32>(right);
1815+
}
18051816
}
18061817

1807-
p.Prepared_ = PrepareBlockAggregator<TAggregator>(GetBlockAggregatorFactory(name), unwrappedTupleType, filterColumn, argColumns, env, p.ReturnType_);
1818+
p.Prepared_ = PrepareBlockAggregator<TAggregator>(GetBlockAggregatorFactory(name), unwrappedTupleType, filterColumn, argColumns, env, p.ReturnType_, p.Hint_);
18081819

18091820
totalStateSize += p.Prepared_->StateSize;
18101821
aggsParams.emplace_back(std::move(p));

ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,11 +360,13 @@ class TBlockCountAllFactory : public IBlockAggregatorFactory {
360360
TTupleType* tupleType,
361361
const std::vector<ui32>& argsColumns,
362362
const TTypeEnvironment& env,
363-
TType* returnType) const final {
363+
TType* returnType,
364+
ui32 hint) const final {
364365
Y_UNUSED(tupleType);
365366
Y_UNUSED(argsColumns);
366367
Y_UNUSED(env);
367368
Y_UNUSED(returnType);
369+
Y_UNUSED(hint);
368370
return PrepareCountAll<TFinalizeKeysTag>(std::optional<ui32>(), argsColumns[0]);
369371
}
370372
};
@@ -395,11 +397,13 @@ class TBlockCountFactory : public IBlockAggregatorFactory {
395397
TTupleType* tupleType,
396398
const std::vector<ui32>& argsColumns,
397399
const TTypeEnvironment& env,
398-
TType* returnType) const final {
400+
TType* returnType,
401+
ui32 hint) const final {
399402
Y_UNUSED(tupleType);
400403
Y_UNUSED(argsColumns);
401404
Y_UNUSED(env);
402405
Y_UNUSED(returnType);
406+
Y_UNUSED(hint);
403407
return PrepareCount<TFinalizeKeysTag>(std::optional<ui32>(), argsColumns[0]);
404408
}
405409
};

ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ struct TAggregatorFactories {
2727

2828
const IBlockAggregatorFactory& GetBlockAggregatorFactory(TStringBuf name) {
2929
const auto& f = Singleton<TAggregatorFactories>()->Factories;
30+
TStringBuf left, right;
31+
if (name.TrySplit('#', left, right)) {
32+
name = left;
33+
}
34+
3035
auto it = f.find(name);
3136
if (it == f.end()) {
3237
throw yexception() << "Unsupported block aggregation function: " << name;

ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ class IBlockAggregatorFactory {
116116
TTupleType* tupleType,
117117
const std::vector<ui32>& argsColumns,
118118
const TTypeEnvironment& env,
119-
TType* returnType) const = 0;
119+
TType* returnType,
120+
ui32 hint) const = 0;
120121
};
121122

122123
const IBlockAggregatorFactory& GetBlockAggregatorFactory(TStringBuf name);

ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1005,9 +1005,11 @@ class TBlockMinMaxFactory : public IBlockAggregatorFactory {
10051005
TTupleType* tupleType,
10061006
const std::vector<ui32>& argsColumns,
10071007
const TTypeEnvironment& env,
1008-
TType* returnType) const final {
1008+
TType* returnType,
1009+
ui32 hint) const final {
10091010
Y_UNUSED(env);
10101011
Y_UNUSED(returnType);
1012+
Y_UNUSED(hint);
10111013
return PrepareMinMax<TFinalizeKeysTag, IsMin>(tupleType, std::optional<ui32>(), argsColumns[0]);
10121014
}
10131015
};

ydb/library/yql/minikql/comp_nodes/mkql_block_agg_some.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,11 @@ class TBlockSomeFactory : public IBlockAggregatorFactory {
269269
TTupleType* tupleType,
270270
const std::vector<ui32>& argsColumns,
271271
const TTypeEnvironment& env,
272-
TType* returnType) const override {
272+
TType* returnType,
273+
ui32 hint) const override {
273274
Y_UNUSED(env);
274275
Y_UNUSED(returnType);
276+
Y_UNUSED(hint);
275277
return PrepareSome<TFinalizeKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0]);
276278
}
277279
};

ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,11 @@ class TBlockSumFactory : public IBlockAggregatorFactory {
700700
TTupleType* tupleType,
701701
const std::vector<ui32>& argsColumns,
702702
const TTypeEnvironment& env,
703-
TType* returnType) const final
703+
TType* returnType,
704+
ui32 hint) const final
704705
{
705706
Y_UNUSED(returnType);
707+
Y_UNUSED(hint);
706708
return PrepareSum<TFinalizeKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0], env);
707709
}
708710
};
@@ -853,8 +855,10 @@ class TBlockAvgFactory : public IBlockAggregatorFactory {
853855
TTupleType* tupleType,
854856
const std::vector<ui32>& argsColumns,
855857
const TTypeEnvironment& env,
856-
TType* returnType) const final {
858+
TType* returnType,
859+
ui32 hint) const final {
857860
Y_UNUSED(returnType);
861+
Y_UNUSED(hint);
858862
return PrepareAvg<TFinalizeKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0], env);
859863
}
860864
};

ydb/library/yql/parser/pg_catalog/catalog.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3430,15 +3430,27 @@ const TAggregateDesc& LookupAggregation(const TString& name, const TVector<ui32>
34303430
}
34313431

34323432
const TAggregateDesc& LookupAggregation(const TString& name, ui32 stateType, ui32 resultType) {
3433+
TStringBuf realName = name;
3434+
TMaybe<ui32> aggId;
3435+
TStringBuf left, right;
3436+
if (realName.TrySplit('#', left, right)) {
3437+
aggId = FromString<ui32>(right);
3438+
realName = left;
3439+
}
3440+
34333441
const auto& catalog = TCatalog::Instance();
3434-
auto aggIdPtr = catalog.State->AggregationsByName.FindPtr(to_lower(name));
3442+
auto aggIdPtr = catalog.State->AggregationsByName.FindPtr(to_lower(TString(realName)));
34353443
if (!aggIdPtr) {
34363444
throw yexception() << "No such aggregate: " << name;
34373445
}
34383446

34393447
for (const auto& id : *aggIdPtr) {
34403448
const auto& d = catalog.State->Aggregations.FindPtr(id);
34413449
Y_ENSURE(d);
3450+
if (aggId && d->AggId != *aggId) {
3451+
continue;
3452+
}
3453+
34423454
if (!ValidateAggregateArgs(*d, stateType, resultType)) {
34433455
continue;
34443456
}

ydb/library/yql/parser/pg_wrapper/arrow.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,14 @@ TExecs::TExecs()
6060
#undef RegisterExec
6161
}
6262

63-
const NPg::TAggregateDesc& ResolveAggregation(const TString& name, NKikimr::NMiniKQL::TTupleType* tupleType, const std::vector<ui32>& argsColumns, NKikimr::NMiniKQL::TType* returnType) {
63+
const NPg::TAggregateDesc& ResolveAggregation(const TString& name, NKikimr::NMiniKQL::TTupleType* tupleType,
64+
const std::vector<ui32>& argsColumns, NKikimr::NMiniKQL::TType* returnType, ui32 hint) {
6465
using namespace NKikimr::NMiniKQL;
6566
if (returnType) {
6667
MKQL_ENSURE(argsColumns.size() == 1, "Expected one column");
6768
TType* stateType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]))->GetItemType();
6869
TType* returnItemType = AS_TYPE(TBlockType, returnType)->GetItemType();
69-
return NPg::LookupAggregation(name, AS_TYPE(TPgType, stateType)->GetTypeId(), AS_TYPE(TPgType, returnItemType)->GetTypeId());
70+
return NPg::LookupAggregation(name + "#" + ToString(hint), AS_TYPE(TPgType, stateType)->GetTypeId(), AS_TYPE(TPgType, returnItemType)->GetTypeId());
7071
} else {
7172
TVector<ui32> argTypeIds;
7273
for (const auto col : argsColumns) {

0 commit comments

Comments
 (0)