Skip to content

Commit 3151c0a

Browse files
authored
Support of CREATE AGGREGATE for PG extensions (#7948)
1 parent b524a82 commit 3151c0a

File tree

5 files changed

+352
-0
lines changed

5 files changed

+352
-0
lines changed

ydb/library/yql/parser/pg_catalog/catalog.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,8 @@ class TAggregationsParser : public TParser {
883883
LastAggregation.InitValue = value;
884884
} else if (key == "aggfinalextra") {
885885
LastAggregation.FinalExtra = (value == "t");;
886+
} else if (key == "aggnumdirectargs") {
887+
LastAggregation.NumDirectArgs = FromString<ui32>(value);
886888
}
887889
}
888890

@@ -2154,6 +2156,33 @@ struct TCatalog : public IExtensionSqlBuilder {
21542156
State->AllowedProcs.insert(procPtr->Name);
21552157
}
21562158

2159+
void CreateAggregate(const TAggregateDesc& desc) final {
2160+
Y_ENSURE(desc.ExtensionIndex);
2161+
auto id = 16000 + State->Aggregations.size();
2162+
auto newDesc = desc;
2163+
newDesc.Name = to_lower(newDesc.Name);
2164+
newDesc.AggId = id;
2165+
Y_ENSURE(State->Aggregations.emplace(id, newDesc).second);
2166+
State->AggregationsByName[newDesc.Name].push_back(id);
2167+
if (desc.CombineFuncId) {
2168+
State->AllowedProcs.insert(State->Procs.FindPtr(desc.CombineFuncId)->Name);
2169+
}
2170+
2171+
if (desc.DeserializeFuncId) {
2172+
State->AllowedProcs.insert(State->Procs.FindPtr(desc.DeserializeFuncId)->Name);
2173+
}
2174+
2175+
if (desc.SerializeFuncId) {
2176+
State->AllowedProcs.insert(State->Procs.FindPtr(desc.SerializeFuncId)->Name);
2177+
}
2178+
2179+
if (desc.FinalFuncId) {
2180+
State->AllowedProcs.insert(State->Procs.FindPtr(desc.FinalFuncId)->Name);
2181+
}
2182+
2183+
State->AllowedProcs.insert(State->Procs.FindPtr(desc.TransFuncId)->Name);
2184+
}
2185+
21572186
static const TCatalog& Instance() {
21582187
return *Singleton<TCatalog>();
21592188
}
@@ -3778,6 +3807,39 @@ TString ExportExtensions(const TMaybe<TSet<ui32>>& filter) {
37783807
protoOper->SetNegateId(desc.NegateId);
37793808
}
37803809

3810+
TVector<ui32> extAggs;
3811+
for (const auto& a : catalog.State->Aggregations) {
3812+
const auto& desc = a.second;
3813+
if (!desc.ExtensionIndex) {
3814+
continue;
3815+
}
3816+
3817+
extAggs.push_back(a.first);
3818+
}
3819+
3820+
Sort(extAggs);
3821+
for (const auto a : extAggs) {
3822+
const auto& desc = *catalog.State->Aggregations.FindPtr(a);
3823+
auto protoAggregation = proto.AddAggregation();
3824+
protoAggregation->SetAggId(a);
3825+
protoAggregation->SetName(desc.Name);
3826+
protoAggregation->SetExtensionIndex(desc.ExtensionIndex);
3827+
for (const auto argType : desc.ArgTypes) {
3828+
protoAggregation->AddArgType(argType);
3829+
}
3830+
3831+
protoAggregation->SetKind((ui32)desc.Kind);
3832+
protoAggregation->SetTransTypeId(desc.TransTypeId);
3833+
protoAggregation->SetTransFuncId(desc.TransFuncId);
3834+
protoAggregation->SetFinalFuncId(desc.FinalFuncId);
3835+
protoAggregation->SetCombineFuncId(desc.CombineFuncId);
3836+
protoAggregation->SetSerializeFuncId(desc.SerializeFuncId);
3837+
protoAggregation->SetDeserializeFuncId(desc.DeserializeFuncId);
3838+
protoAggregation->SetInitValue(desc.InitValue);
3839+
protoAggregation->SetFinalExtra(desc.FinalExtra);
3840+
protoAggregation->SetNumDirectArgs(desc.NumDirectArgs);
3841+
}
3842+
37813843
return proto.SerializeAsString();
37823844
}
37833845

@@ -3931,6 +3993,30 @@ void ImportExtensions(const TString& exported, bool typesOnly, IExtensionLoader*
39313993
catalog.State->OperatorsByName[desc.Name].push_back(desc.OperId);
39323994
}
39333995

3996+
for (const auto& protoAggregation : proto.GetAggregation()) {
3997+
TAggregateDesc desc;
3998+
desc.AggId = protoAggregation.GetAggId();
3999+
desc.Name = protoAggregation.GetName();
4000+
desc.ExtensionIndex = protoAggregation.GetExtensionIndex();
4001+
for (const auto argType : protoAggregation.GetArgType()) {
4002+
desc.ArgTypes.push_back(argType);
4003+
}
4004+
4005+
desc.Kind = (NPg::EAggKind)protoAggregation.GetKind();
4006+
desc.TransTypeId = protoAggregation.GetTransTypeId();
4007+
desc.TransFuncId = protoAggregation.GetTransFuncId();
4008+
desc.FinalFuncId = protoAggregation.GetFinalFuncId();
4009+
desc.CombineFuncId = protoAggregation.GetCombineFuncId();
4010+
desc.SerializeFuncId = protoAggregation.GetSerializeFuncId();
4011+
desc.DeserializeFuncId = protoAggregation.GetDeserializeFuncId();
4012+
desc.InitValue = protoAggregation.GetInitValue();
4013+
desc.FinalExtra = protoAggregation.GetFinalExtra();
4014+
desc.NumDirectArgs = protoAggregation.GetNumDirectArgs();
4015+
4016+
Y_ENSURE(catalog.State->Aggregations.emplace(desc.AggId, desc).second);
4017+
catalog.State->AggregationsByName[desc.Name].push_back(desc.AggId);
4018+
}
4019+
39344020
if (!typesOnly && loader) {
39354021
for (ui32 extensionIndex = 1; extensionIndex <= catalog.State->Extensions.size(); ++extensionIndex) {
39364022
const auto& e = catalog.State->Extensions[extensionIndex - 1];

ydb/library/yql/parser/pg_catalog/catalog.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ struct TAggregateDesc {
177177
ui32 DeserializeFuncId = 0;
178178
TString InitValue;
179179
bool FinalExtra = false;
180+
ui32 NumDirectArgs = 0;
181+
ui32 ExtensionIndex = 0;
180182
};
181183

182184
enum class EAmType {
@@ -406,6 +408,8 @@ class IExtensionSqlBuilder {
406408
virtual void PrepareOper(ui32 extensionIndex, const TString& name, const TVector<ui32>& args) = 0;
407409

408410
virtual void UpdateOper(const TOperDesc& desc) = 0;
411+
412+
virtual void CreateAggregate(const TAggregateDesc& desc) = 0;
409413
};
410414

411415
class IExtensionSqlParser {

ydb/library/yql/parser/pg_catalog/proto/pg_catalog.proto

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,29 @@ message TPgOper {
7777
optional uint32 NegateId = 10;
7878
}
7979

80+
message TPgAggregation {
81+
optional uint32 AggId = 1;
82+
optional string Name = 2;
83+
optional uint32 ExtensionIndex = 3;
84+
repeated uint32 ArgType = 4;
85+
optional uint32 Kind = 5;
86+
optional uint32 TransTypeId = 6;
87+
optional uint32 TransFuncId = 7;
88+
optional uint32 FinalFuncId = 8;
89+
optional uint32 CombineFuncId = 9;
90+
optional uint32 SerializeFuncId = 10;
91+
optional uint32 DeserializeFuncId = 11;
92+
optional string InitValue = 12;
93+
optional bool FinalExtra = 13;
94+
optional uint32 NumDirectArgs = 14;
95+
}
96+
8097
message TPgCatalog {
8198
repeated TPgExtension Extension = 1;
8299
repeated TPgType Type = 2;
83100
repeated TPgProc Proc = 3;
84101
repeated TPgTable Table = 4;
85102
repeated TPgCast Cast = 5;
86103
repeated TPgOper Oper = 6;
104+
repeated TPgAggregation Aggregation = 7;
87105
}

ydb/library/yql/sql/pg/pg_sql.cpp

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5204,6 +5204,8 @@ class TExtensionHandler : public IPGParseEvents {
52045204
return ParseDefineType(value);
52055205
case OBJECT_OPERATOR:
52065206
return ParseDefineOperator(value);
5207+
case OBJECT_AGGREGATE:
5208+
return ParseDefineAggregate(value);
52075209
default:
52085210
return false;
52095211
}
@@ -5463,6 +5465,172 @@ class TExtensionHandler : public IPGParseEvents {
54635465
return true;
54645466
}
54655467

5468+
[[nodiscard]]
5469+
bool ParseDefineAggregate(const DefineStmt* value) {
5470+
if (ListLength(value->defnames) != 1) {
5471+
return false;
5472+
}
5473+
5474+
auto nameNode = ListNodeNth(value->defnames, 0);
5475+
auto name = to_lower(TString(StrVal(nameNode)));
5476+
TString sfunc;
5477+
ui32 stype;
5478+
TString combinefunc;
5479+
TString finalfunc;
5480+
TString serialfunc;
5481+
TString deserialfunc;
5482+
bool hypothetical = false;
5483+
for (int i = 0; i < ListLength(value->definition); ++i) {
5484+
auto node = LIST_CAST_NTH(DefElem, value->definition, i);
5485+
auto defnameStr = to_lower(TString(node->defname));
5486+
if (defnameStr == "sfunc") {
5487+
if (NodeTag(node->arg) != T_TypeName) {
5488+
return false;
5489+
}
5490+
5491+
TString value;
5492+
if (!ParseTypeName(CAST_NODE_EXT(PG_TypeName, T_TypeName, node->arg), value)) {
5493+
return false;
5494+
}
5495+
5496+
sfunc = value;
5497+
} else if (defnameStr == "stype") {
5498+
if (NodeTag(node->arg) != T_TypeName) {
5499+
return false;
5500+
}
5501+
5502+
TString value;
5503+
if (!ParseTypeName(CAST_NODE_EXT(PG_TypeName, T_TypeName, node->arg), value)) {
5504+
return false;
5505+
}
5506+
5507+
stype = NPg::LookupType(value).TypeId;
5508+
} else if (defnameStr == "combinefunc") {
5509+
if (NodeTag(node->arg) != T_TypeName) {
5510+
return false;
5511+
}
5512+
5513+
TString value;
5514+
if (!ParseTypeName(CAST_NODE_EXT(PG_TypeName, T_TypeName, node->arg), value)) {
5515+
return false;
5516+
}
5517+
5518+
combinefunc = value;
5519+
} else if (defnameStr == "finalfunc") {
5520+
if (NodeTag(node->arg) != T_TypeName) {
5521+
return false;
5522+
}
5523+
5524+
TString value;
5525+
if (!ParseTypeName(CAST_NODE_EXT(PG_TypeName, T_TypeName, node->arg), value)) {
5526+
return false;
5527+
}
5528+
5529+
finalfunc = value;
5530+
} else if (defnameStr == "serialfunc") {
5531+
if (NodeTag(node->arg) != T_TypeName) {
5532+
return false;
5533+
}
5534+
5535+
TString value;
5536+
if (!ParseTypeName(CAST_NODE_EXT(PG_TypeName, T_TypeName, node->arg), value)) {
5537+
return false;
5538+
}
5539+
5540+
serialfunc = value;
5541+
} else if (defnameStr == "deserialfunc") {
5542+
if (NodeTag(node->arg) != T_TypeName) {
5543+
return false;
5544+
}
5545+
5546+
TString value;
5547+
if (!ParseTypeName(CAST_NODE_EXT(PG_TypeName, T_TypeName, node->arg), value)) {
5548+
return false;
5549+
}
5550+
5551+
deserialfunc = value;
5552+
} else if (defnameStr == "hypothetical") {
5553+
if (NodeTag(node->arg) != T_Boolean) {
5554+
return false;
5555+
}
5556+
5557+
if (BoolVal(node->arg)) {
5558+
hypothetical = true;
5559+
}
5560+
}
5561+
}
5562+
5563+
if (!sfunc || !stype) {
5564+
return false;
5565+
}
5566+
5567+
NPg::TAggregateDesc desc;
5568+
desc.Name = name;
5569+
desc.ExtensionIndex = ExtensionIndex;
5570+
if (ListLength(value->args) != 2) {
5571+
return false;
5572+
}
5573+
5574+
auto numDirectArgs = intVal(lsecond(value->args));
5575+
if (numDirectArgs >= 0) {
5576+
desc.NumDirectArgs = numDirectArgs;
5577+
desc.Kind = NPg::EAggKind::OrderedSet;
5578+
Y_ENSURE(!hypothetical);
5579+
} else if (hypothetical) {
5580+
desc.Kind = NPg::EAggKind::Hypothetical;
5581+
}
5582+
5583+
auto args = linitial_node(List, value->args);
5584+
for (int i = 0; i < ListLength(args); ++i) {
5585+
auto node = LIST_CAST_NTH(FunctionParameter, args, i);
5586+
if (node->mode == FUNC_PARAM_IN || node->mode == FUNC_PARAM_DEFAULT) {
5587+
if (node->defexpr) {
5588+
return false;
5589+
}
5590+
} else {
5591+
return false;
5592+
}
5593+
5594+
TString argTypeStr;
5595+
if (!ParseTypeName(node->argType, argTypeStr)) {
5596+
return false;
5597+
}
5598+
5599+
Builder.PrepareType(ExtensionIndex, argTypeStr);
5600+
auto argTypeId = NPg::LookupType(argTypeStr).TypeId;
5601+
desc.ArgTypes.push_back(argTypeId);
5602+
}
5603+
5604+
desc.TransTypeId = stype;
5605+
TVector<ui32> stateWithArgs;
5606+
stateWithArgs.push_back(stype);
5607+
stateWithArgs.insert(stateWithArgs.end(), desc.ArgTypes.begin(), desc.ArgTypes.end());
5608+
desc.TransFuncId = NPg::LookupProc(sfunc, stateWithArgs).ProcId;
5609+
if (!finalfunc.empty()) {
5610+
desc.FinalFuncId = NPg::LookupProc(finalfunc, { stype }).ProcId;
5611+
}
5612+
5613+
if (!combinefunc.empty()) {
5614+
desc.CombineFuncId = NPg::LookupProc(combinefunc, { stype, stype }).ProcId;
5615+
}
5616+
5617+
if (!serialfunc.empty()) {
5618+
const auto& procDesc = NPg::LookupProc(serialfunc, { stype });
5619+
Y_ENSURE(procDesc.ResultType == NPg::LookupType("bytea").TypeId);
5620+
desc.SerializeFuncId = procDesc.ProcId;
5621+
}
5622+
5623+
if (!deserialfunc.empty()) {
5624+
Y_ENSURE(!serialfunc.empty());
5625+
const auto& procDesc = NPg::LookupProc(deserialfunc, { NPg::LookupType("bytea").TypeId, stype });
5626+
Y_ENSURE(procDesc.ResultType == stype);
5627+
desc.DeserializeFuncId = procDesc.ProcId;
5628+
}
5629+
5630+
Builder.CreateAggregate(desc);
5631+
return true;
5632+
}
5633+
54665634
[[nodiscard]]
54675635
bool ParseCreateFunctionStmt(const CreateFunctionStmt* value) {
54685636
NYql::NPg::TProcDesc desc;

0 commit comments

Comments
 (0)