Skip to content

Commit 617236c

Browse files
authored
Support of CREATE OPERATOR for PG extensions (#7869)
1 parent 5b55710 commit 617236c

File tree

5 files changed

+343
-16
lines changed

5 files changed

+343
-16
lines changed

ydb/library/yql/parser/pg_catalog/catalog.cpp

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,16 +1957,24 @@ struct TCatalog : public IExtensionSqlBuilder {
19571957
void CreateProc(const TProcDesc& desc) final {
19581958
Y_ENSURE(desc.ExtensionIndex);
19591959
TProcDesc newDesc = desc;
1960+
newDesc.Name = to_lower(newDesc.Name);
19601961
newDesc.ProcId = 16000 + State->Procs.size();
19611962
State->Procs[newDesc.ProcId] = newDesc;
19621963
State->ProcByName[newDesc.Name].push_back(newDesc.ProcId);
19631964
}
19641965

19651966
void PrepareType(ui32 extensionIndex, const TString& name) final {
19661967
Y_ENSURE(extensionIndex);
1967-
Y_ENSURE(!State->TypeByName.contains(name));
1968+
auto lowerName = to_lower(name);
1969+
if (auto idPtr = State->TypeByName.FindPtr(lowerName)) {
1970+
auto typePtr = State->Types.FindPtr(*idPtr);
1971+
Y_ENSURE(typePtr);
1972+
Y_ENSURE(!typePtr->ExtensionIndex || typePtr->ExtensionIndex == extensionIndex);
1973+
return;
1974+
}
1975+
19681976
TTypeDesc newDesc;
1969-
newDesc.Name = name;
1977+
newDesc.Name = lowerName;
19701978
newDesc.TypeId = 16000 + State->Types.size();
19711979
newDesc.ExtensionIndex = extensionIndex;
19721980
newDesc.ArrayTypeId = newDesc.TypeId + 1;
@@ -1990,6 +1998,7 @@ struct TCatalog : public IExtensionSqlBuilder {
19901998
}
19911999

19922000
void UpdateType(const TTypeDesc& desc) final {
2001+
Y_ENSURE(desc.ExtensionIndex);
19932002
auto byIdPtr = State->Types.FindPtr(desc.TypeId);
19942003
Y_ENSURE(byIdPtr);
19952004
Y_ENSURE(byIdPtr->Name == desc.Name);
@@ -2091,6 +2100,60 @@ struct TCatalog : public IExtensionSqlBuilder {
20912100
}
20922101
}
20932102

2103+
void PrepareOper(ui32 extensionIndex, const TString& name, const TVector<ui32>& args) final {
2104+
Y_ENSURE(args.size() >= 1 && args.size() <= 2);
2105+
Y_ENSURE(extensionIndex);
2106+
auto lowerName = to_lower(name);
2107+
auto operIdPtr = State->OperatorsByName.FindPtr(lowerName);
2108+
if (operIdPtr) {
2109+
for (const auto& id : *operIdPtr) {
2110+
const auto& d = State->Operators.FindPtr(id);
2111+
Y_ENSURE(d);
2112+
if (d->LeftType == args[0] && (args.size() == 1 || d->RightType == args[1])) {
2113+
Y_ENSURE(!d->ExtensionIndex || d->ExtensionIndex == extensionIndex);
2114+
return;
2115+
}
2116+
}
2117+
}
2118+
2119+
if (!operIdPtr) {
2120+
operIdPtr = &State->OperatorsByName[lowerName];
2121+
}
2122+
2123+
TOperDesc desc;
2124+
desc.Name = name;
2125+
desc.LeftType = args[0];
2126+
if (args.size() == 1) {
2127+
desc.Kind = EOperKind::LeftUnary;
2128+
} else {
2129+
desc.RightType = args[1];
2130+
}
2131+
2132+
auto id = 16000 + State->Operators.size();
2133+
desc.OperId = id;
2134+
desc.ExtensionIndex = extensionIndex;
2135+
Y_ENSURE(State->Operators.emplace(id, desc).second);
2136+
operIdPtr->push_back(id);
2137+
}
2138+
2139+
void UpdateOper(const TOperDesc& desc) final {
2140+
Y_ENSURE(desc.ExtensionIndex);
2141+
const auto& d = State->Operators.FindPtr(desc.OperId);
2142+
Y_ENSURE(d);
2143+
Y_ENSURE(d->Name == desc.Name);
2144+
Y_ENSURE(d->ExtensionIndex == desc.ExtensionIndex);
2145+
Y_ENSURE(d->LeftType == desc.LeftType);
2146+
Y_ENSURE(d->RightType == desc.RightType);
2147+
Y_ENSURE(d->Kind == desc.Kind);
2148+
d->ProcId = desc.ProcId;
2149+
d->ComId = desc.ComId;
2150+
d->NegateId = desc.NegateId;
2151+
d->ResultType = desc.ResultType;
2152+
auto procPtr = State->Procs.FindPtr(desc.ProcId);
2153+
Y_ENSURE(procPtr);
2154+
State->AllowedProcs.insert(procPtr->Name);
2155+
}
2156+
20942157
static const TCatalog& Instance() {
20952158
return *Singleton<TCatalog>();
20962159
}
@@ -3689,6 +3752,32 @@ TString ExportExtensions(const TMaybe<TSet<ui32>>& filter) {
36893752
protoCast->SetCoercionCode((ui32)desc.CoercionCode);
36903753
}
36913754

3755+
TVector<ui32> extOpers;
3756+
for (const auto& o : catalog.State->Operators) {
3757+
const auto& desc = o.second;
3758+
if (!desc.ExtensionIndex) {
3759+
continue;
3760+
}
3761+
3762+
extOpers.push_back(o.first);
3763+
}
3764+
3765+
Sort(extOpers);
3766+
for (const auto o : extOpers) {
3767+
const auto& desc = *catalog.State->Operators.FindPtr(o);
3768+
auto protoOper = proto.AddOper();
3769+
protoOper->SetOperId(o);
3770+
protoOper->SetName(desc.Name);
3771+
protoOper->SetExtensionIndex(desc.ExtensionIndex);
3772+
protoOper->SetLeftType(desc.LeftType);
3773+
protoOper->SetRightType(desc.RightType);
3774+
protoOper->SetKind((ui32)desc.Kind);
3775+
protoOper->SetProcId(desc.ProcId);
3776+
protoOper->SetResultType(desc.ResultType);
3777+
protoOper->SetComId(desc.ComId);
3778+
protoOper->SetNegateId(desc.NegateId);
3779+
}
3780+
36923781
return proto.SerializeAsString();
36933782
}
36943783

@@ -3826,6 +3915,22 @@ void ImportExtensions(const TString& exported, bool typesOnly, IExtensionLoader*
38263915
Y_ENSURE(catalog.State->CastsByDir.insert(std::make_pair(std::make_pair(desc.SourceId, desc.TargetId), id)).second);
38273916
}
38283917

3918+
for (const auto& protoOper : proto.GetOper()) {
3919+
TOperDesc desc;
3920+
desc.OperId = protoOper.GetOperId();
3921+
desc.Name = protoOper.GetName();
3922+
desc.ExtensionIndex = protoOper.GetExtensionIndex();
3923+
desc.LeftType = protoOper.GetLeftType();
3924+
desc.RightType = protoOper.GetRightType();
3925+
desc.Kind = (EOperKind)protoOper.GetKind();
3926+
desc.ProcId = protoOper.GetProcId();
3927+
desc.ResultType = protoOper.GetResultType();
3928+
desc.ComId = protoOper.GetComId();
3929+
desc.NegateId = protoOper.GetNegateId();
3930+
Y_ENSURE(catalog.State->Operators.emplace(desc.OperId, desc).second);
3931+
catalog.State->OperatorsByName[desc.Name].push_back(desc.OperId);
3932+
}
3933+
38293934
if (!typesOnly && loader) {
38303935
for (ui32 extensionIndex = 1; extensionIndex <= catalog.State->Extensions.size(); ++extensionIndex) {
38313936
const auto& e = catalog.State->Extensions[extensionIndex - 1];

ydb/library/yql/parser/pg_catalog/catalog.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct TOperDesc {
4444
ui32 ProcId = 0;
4545
ui32 ComId = 0;
4646
ui32 NegateId = 0;
47+
ui32 ExtensionIndex = 0;
4748
};
4849

4950
enum class EProcKind : char {
@@ -401,6 +402,10 @@ class IExtensionSqlBuilder {
401402
const TVector<TMaybe<TString>>& data) = 0; // row based layout
402403

403404
virtual void CreateCast(const TCastDesc& desc) = 0;
405+
406+
virtual void PrepareOper(ui32 extensionIndex, const TString& name, const TVector<ui32>& args) = 0;
407+
408+
virtual void UpdateOper(const TOperDesc& desc) = 0;
404409
};
405410

406411
class IExtensionSqlParser {

ydb/library/yql/parser/pg_catalog/proto/pg_catalog.proto

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,24 @@ message TPgCast {
6464
optional uint32 CoercionCode = 7;
6565
}
6666

67+
message TPgOper {
68+
optional uint32 OperId = 1;
69+
optional string Name = 2;
70+
optional uint32 ExtensionIndex = 3;
71+
optional uint32 LeftType = 4;
72+
optional uint32 RightType = 5;
73+
optional uint32 Kind = 6;
74+
optional uint32 ProcId = 7;
75+
optional uint32 ResultType = 8;
76+
optional uint32 ComId = 9;
77+
optional uint32 NegateId = 10;
78+
}
79+
6780
message TPgCatalog {
6881
repeated TPgExtension Extension = 1;
6982
repeated TPgType Type = 2;
7083
repeated TPgProc Proc = 3;
7184
repeated TPgTable Table = 4;
7285
repeated TPgCast Cast = 5;
86+
repeated TPgOper Oper = 6;
7387
}

ydb/library/yql/sql/pg/pg_sql.cpp

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5202,6 +5202,8 @@ class TExtensionHandler : public IPGParseEvents {
52025202
switch (value->kind) {
52035203
case OBJECT_TYPE:
52045204
return ParseDefineType(value);
5205+
case OBJECT_OPERATOR:
5206+
return ParseDefineOperator(value);
52055207
default:
52065208
return false;
52075209
}
@@ -5215,15 +5217,13 @@ class TExtensionHandler : public IPGParseEvents {
52155217

52165218
auto nameNode = ListNodeNth(value->defnames, 0);
52175219
auto name = to_lower(TString(StrVal(nameNode)));
5218-
if (!NPg::HasType(name)) {
5219-
Builder.PrepareType(ExtensionIndex, name);
5220-
}
5220+
Builder.PrepareType(ExtensionIndex, name);
52215221

52225222
NPg::TTypeDesc desc = NPg::LookupType(name);
52235223

52245224
for (int i = 0; i < ListLength(value->definition); ++i) {
52255225
auto node = LIST_CAST_NTH(DefElem, value->definition, i);
5226-
TString defnameStr(node->defname);
5226+
auto defnameStr = to_lower(TString(node->defname));
52275227
if (defnameStr == "internallength") {
52285228
if (NodeTag(node->arg) == T_Integer) {
52295229
desc.TypeLen = IntVal(node->arg);
@@ -5363,6 +5363,106 @@ class TExtensionHandler : public IPGParseEvents {
53635363
return true;
53645364
}
53655365

5366+
[[nodiscard]]
5367+
bool ParseDefineOperator(const DefineStmt* value) {
5368+
if (ListLength(value->defnames) != 1) {
5369+
return false;
5370+
}
5371+
5372+
auto nameNode = ListNodeNth(value->defnames, 0);
5373+
auto name = to_lower(TString(StrVal(nameNode)));
5374+
TString procedureName;
5375+
TString commutator;
5376+
TString negator;
5377+
ui32 leftType = 0;
5378+
ui32 rightType = 0;
5379+
for (int i = 0; i < ListLength(value->definition); ++i) {
5380+
auto node = LIST_CAST_NTH(DefElem, value->definition, i);
5381+
auto defnameStr = to_lower(TString(node->defname));
5382+
if (defnameStr == "leftarg") {
5383+
if (NodeTag(node->arg) != T_TypeName) {
5384+
return false;
5385+
}
5386+
5387+
TString value;
5388+
if (!ParseTypeName(CAST_NODE_EXT(PG_TypeName, T_TypeName, node->arg), value)) {
5389+
return false;
5390+
}
5391+
5392+
leftType = NPg::LookupType(value).TypeId;
5393+
} else if (defnameStr == "rightarg") {
5394+
if (NodeTag(node->arg) != T_TypeName) {
5395+
return false;
5396+
}
5397+
5398+
TString value;
5399+
if (!ParseTypeName(CAST_NODE_EXT(PG_TypeName, T_TypeName, node->arg), value)) {
5400+
return false;
5401+
}
5402+
5403+
rightType = NPg::LookupType(value).TypeId;
5404+
} else if (defnameStr == "procedure") {
5405+
if (NodeTag(node->arg) != T_TypeName) {
5406+
return false;
5407+
}
5408+
5409+
TString value;
5410+
if (!ParseTypeName(CAST_NODE_EXT(PG_TypeName, T_TypeName, node->arg), value)) {
5411+
return false;
5412+
}
5413+
5414+
procedureName = value;
5415+
} else if (defnameStr == "commutator") {
5416+
if (NodeTag(node->arg) != T_String) {
5417+
return false;
5418+
}
5419+
5420+
commutator = StrVal(node->arg);
5421+
} else if (defnameStr == "negator") {
5422+
if (NodeTag(node->arg) != T_String) {
5423+
return false;
5424+
}
5425+
5426+
negator = StrVal(node->arg);
5427+
}
5428+
}
5429+
5430+
if (!leftType) {
5431+
return false;
5432+
}
5433+
5434+
if (procedureName.empty()) {
5435+
return false;
5436+
}
5437+
5438+
TVector<ui32> args;
5439+
args.push_back(leftType);
5440+
if (rightType) {
5441+
args.push_back(rightType);
5442+
}
5443+
5444+
Builder.PrepareOper(ExtensionIndex, name, args);
5445+
auto desc = NPg::LookupOper(name, args);
5446+
if (!commutator.empty()) {
5447+
TVector<ui32> commArgs;
5448+
commArgs.push_back(rightType);
5449+
commArgs.push_back(leftType);
5450+
Builder.PrepareOper(ExtensionIndex, commutator, commArgs);
5451+
desc.ComId = NPg::LookupOper(commutator, commArgs).OperId;
5452+
}
5453+
5454+
if (!negator.empty()) {
5455+
Builder.PrepareOper(ExtensionIndex, negator, args);
5456+
desc.NegateId = NPg::LookupOper(negator, args).OperId;
5457+
}
5458+
5459+
const auto& procDesc = NPg::LookupProc(procedureName, args);
5460+
desc.ProcId = procDesc.ProcId;
5461+
desc.ResultType = procDesc.ResultType;
5462+
Builder.UpdateOper(desc);
5463+
return true;
5464+
}
5465+
53665466
[[nodiscard]]
53675467
bool ParseCreateFunctionStmt(const CreateFunctionStmt* value) {
53685468
NYql::NPg::TProcDesc desc;
@@ -5384,10 +5484,7 @@ class TExtensionHandler : public IPGParseEvents {
53845484
return false;
53855485
}
53865486

5387-
if (!NPg::HasType(resultTypeStr)) {
5388-
Builder.PrepareType(ExtensionIndex, resultTypeStr);
5389-
}
5390-
5487+
Builder.PrepareType(ExtensionIndex, resultTypeStr);
53915488
desc.ResultType = NPg::LookupType(resultTypeStr).TypeId;
53925489
} else {
53935490
desc.ResultType = NPg::LookupType("record").TypeId;
@@ -5466,10 +5563,7 @@ class TExtensionHandler : public IPGParseEvents {
54665563
return false;
54675564
}
54685565

5469-
if (!NPg::HasType(argTypeStr)) {
5470-
Builder.PrepareType(ExtensionIndex, argTypeStr);
5471-
}
5472-
5566+
Builder.PrepareType(ExtensionIndex, argTypeStr);
54735567
auto argTypeId = NPg::LookupType(argTypeStr).TypeId;
54745568
if (node->mode == FUNC_PARAM_IN || node->mode == FUNC_PARAM_DEFAULT) {
54755569
desc.ArgTypes.push_back(argTypeId);

0 commit comments

Comments
 (0)