Skip to content

Commit 21395e1

Browse files
committed
YQL-19967: Properly handle signatures with run config
commit_hash:e13229863b189b9ad804f0c6772204399430179e
1 parent 048fe25 commit 21395e1

File tree

3 files changed

+259
-16
lines changed

3 files changed

+259
-16
lines changed

yql/essentials/minikql/comp_nodes/mkql_udf.cpp

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <yql/essentials/minikql/mkql_utils.h>
1111
#include <yql/essentials/utils/yql_panic.h>
1212

13+
#include <library/cpp/containers/stack_array/stack_array.h>
14+
1315
namespace NKikimr {
1416
namespace NMiniKQL {
1517

@@ -135,13 +137,15 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
135137
TString&& typeConfig,
136138
NUdf::TSourcePosition pos,
137139
IComputationNode* runConfigNode,
140+
ui32 runConfigArgs,
138141
const TCallableType* callableType,
139142
TType* userType)
140143
: TBaseComputation(mutables, EValueRepresentation::Boxed)
141144
, FunctionName(std::move(functionName))
142145
, TypeConfig(std::move(typeConfig))
143146
, Pos(pos)
144147
, RunConfigNode(runConfigNode)
148+
, RunConfigArgs(runConfigArgs)
145149
, CallableType(callableType)
146150
, UserType(userType)
147151
, UdfIndex(mutables.CurValueIndex++)
@@ -154,15 +158,17 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
154158
if (!udf.HasValue()) {
155159
MakeUdf(ctx, udf);
156160
}
157-
const auto runConfig = RunConfigNode->GetValue(ctx);
158-
auto callable = udf.Run(ctx.Builder, &runConfig);
161+
NStackArray::TStackArray<NUdf::TUnboxedValue> args(ALLOC_ON_STACK(NUdf::TUnboxedValue, RunConfigArgs));
162+
args[0] = RunConfigNode->GetValue(ctx);
163+
auto callable = udf.Run(ctx.Builder, args.data());
159164
Wrap(callable);
160165
return callable;
161166
}
162167
#ifndef MKQL_DISABLE_CODEGEN
163168
void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
164169
auto& context = ctx.Codegen.GetContext();
165170

171+
const auto indexType = Type::getInt32Ty(context);
166172
const auto valueType = Type::getInt128Ty(context);
167173

168174
const auto udfPtr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), UdfIndex)}, "udf_ptr", block);
@@ -185,13 +191,24 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
185191

186192
block = main;
187193

188-
GetNodeValue(pointer, RunConfigNode, ctx, block);
189-
const auto conf = new LoadInst(valueType, pointer, "conf", block);
194+
const auto argsType = ArrayType::get(valueType, RunConfigArgs);
195+
const auto args = new AllocaInst(argsType, 0U, "args", block);
196+
Value* runConfigValue;
197+
for (ui32 i = 0; i < RunConfigArgs; i++) {
198+
const auto argIndex = ConstantInt::get(indexType, i);
199+
const auto argSlot = GetElementPtrInst::CreateInBounds(valueType, args, {argIndex}, "arg", block);
200+
if (i == 0) {
201+
GetNodeValue(argSlot, RunConfigNode, ctx, block);
202+
runConfigValue = new LoadInst(valueType, argSlot, "runconfig", block);
203+
} else {
204+
new StoreInst(ConstantInt::get(valueType, 0U), argSlot, block);
205+
}
206+
}
190207
const auto udf = new LoadInst(valueType, udfPtr, "udf", block);
191208

192-
CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen, block, ctx.GetBuilder(), pointer);
209+
CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen, block, ctx.GetBuilder(), args);
193210

194-
ValueUnRef(RunConfigNode->GetRepresentation(), conf, ctx, block);
211+
ValueUnRef(RunConfigNode->GetRepresentation(), runConfigValue, ctx, block);
195212

196213
const auto wrap = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TUdfWrapper::Wrap>());
197214
const auto funType = FunctionType::get(Type::getVoidTy(context), {self->getType(), pointer->getType()}, false);
@@ -231,6 +248,7 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
231248
const TString TypeConfig;
232249
const NUdf::TSourcePosition Pos;
233250
IComputationNode* const RunConfigNode;
251+
const ui32 RunConfigArgs;
234252
const TCallableType* CallableType;
235253
TType* const UserType;
236254
const ui32 UdfIndex;
@@ -298,6 +316,65 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
298316
<< status.GetError()).c_str());
299317
}
300318

319+
const auto runConfigFuncType = funcInfo.RunConfigType;
320+
const auto runConfigNodeType = runCfgNode.GetStaticType();
321+
322+
if (!runConfigFuncType->IsSameType(*runConfigNodeType)) {
323+
// It's only legal, when the compiled UDF declares its
324+
// signature using run config at compilation phase, but then
325+
// omits it in favor to function currying at execution phase.
326+
if (!runConfigFuncType->IsVoid()) {
327+
TString diff = TStringBuilder()
328+
<< "run config type mismatch, expected: "
329+
<< PrintNode((runConfigNodeType), true)
330+
<< ", actual: "
331+
<< PrintNode(runConfigFuncType, true);
332+
UdfTerminate((TStringBuilder() << pos
333+
<< " UDF Function '"
334+
<< funcName
335+
<< "' "
336+
<< TruncateTypeDiff(diff)).c_str());
337+
}
338+
339+
// If so, check the following invariants:
340+
// * The first argument of the head function in the sequence
341+
// of the curried functions has to be the same as the
342+
// run config type.
343+
// * The type of the resulting callable has to be the same
344+
// as the function type.
345+
const auto firstArgType = funcInfo.FunctionType->GetArgumentType(0);
346+
if (!runConfigNodeType->IsSameType(*firstArgType)) {
347+
TString diff = TStringBuilder()
348+
<< "type mismatch, expected run config type: "
349+
<< PrintNode(runConfigNodeType, true)
350+
<< ", actual: "
351+
<< PrintNode(firstArgType, true);
352+
UdfTerminate((TStringBuilder() << pos
353+
<< " Udf Function '"
354+
<< funcName
355+
<< "' "
356+
<< TruncateTypeDiff(diff)).c_str());
357+
}
358+
const auto callableFuncType = funcInfo.FunctionType->GetReturnType();
359+
const auto callableNodeType = callable.GetType()->GetReturnType();
360+
if (!callableNodeType->IsSameType(*callableFuncType)) {
361+
TString diff = TStringBuilder()
362+
<< "type mismatch, expected return type: "
363+
<< PrintNode(callableNodeType, true)
364+
<< ", actual: "
365+
<< PrintNode(callableFuncType, true);
366+
UdfTerminate((TStringBuilder() << pos
367+
<< " Udf Function '"
368+
<< funcName
369+
<< "' "
370+
<< TruncateTypeDiff(diff)).c_str());
371+
}
372+
373+
const auto runConfigCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
374+
const auto runConfigArgs = funcInfo.FunctionType->GetArgumentsCount();
375+
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType, userType);
376+
}
377+
301378
if (!funcInfo.FunctionType->IsConvertableTo(*callable.GetType()->GetReturnType(), true)) {
302379
TString diff = TStringBuilder() << "type mismatch, expected return type: " << PrintNode(callable.GetType()->GetReturnType(), true) <<
303380
", actual:" << PrintNode(funcInfo.FunctionType, true);
@@ -308,14 +385,7 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
308385
UdfTerminate((TStringBuilder() << pos << " UDF implementation is not set for function " << funcName).c_str());
309386
}
310387

311-
const auto runConfigType = funcInfo.RunConfigType;
312-
if (!runConfigType->IsSameType(*runCfgNode.GetStaticType())) {
313-
TString diff = TStringBuilder() << "run config type mismatch, expected: " << PrintNode(runCfgNode.GetStaticType(), true) <<
314-
", actual:" << PrintNode(runConfigType, true);
315-
UdfTerminate((TStringBuilder() << pos << " UDF Function '" << funcName << "' " << TruncateTypeDiff(diff)).c_str());
316-
}
317-
318-
if (runConfigType->IsVoid()) {
388+
if (runConfigFuncType->IsVoid()) {
319389
if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName) {
320390
return new TUdfRunCodegeneratorNode(
321391
ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType,
@@ -326,7 +396,7 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
326396
}
327397

328398
const auto runCfgCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
329-
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, funcInfo.FunctionType, userType);
399+
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, 1U, funcInfo.FunctionType, userType);
330400
}
331401

332402
IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -375,7 +445,7 @@ IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFacto
375445
const auto funcTypeInfo = static_cast<TCallableType*>(callableResultType);
376446

377447
const auto programCompNode = LocateNode(ctx.NodeLocator, *programNode.GetNode());
378-
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, programCompNode, funcTypeInfo, userType);
448+
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, programCompNode, 1U, funcTypeInfo, userType);
379449
}
380450

381451
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#include "mkql_computation_node_ut.h"
2+
#include <yql/essentials/public/udf/udf_helpers.h>
3+
#include <yql/essentials/minikql/mkql_node_serialization.h>
4+
5+
namespace NKikimr {
6+
namespace NMiniKQL {
7+
8+
// Class, implementing the closure with run config.
9+
class TRunConfig : public NYql::NUdf::TBoxedValue {
10+
public:
11+
explicit TRunConfig(NYql::NUdf::TSourcePosition pos)
12+
: Pos_(pos)
13+
{}
14+
15+
static const NYql::NUdf::TStringRef& Name() {
16+
static auto name = NYql::NUdf::TStringRef::Of("Test");
17+
return name;
18+
}
19+
20+
static bool DeclareSignature(const NYql::NUdf::TStringRef& name,
21+
NYql::NUdf::TType*,
22+
NYql::NUdf::IFunctionTypeInfoBuilder& builder,
23+
bool typesOnly)
24+
{
25+
if (Name() != name) {
26+
return false;
27+
}
28+
29+
builder.RunConfig<char*>().Args(1)->Add<char*>();
30+
builder.Returns<char*>();
31+
if (!typesOnly) {
32+
builder.Implementation(new TRunConfig(builder.GetSourcePosition()));
33+
}
34+
35+
return true;
36+
}
37+
38+
private:
39+
const NYql::NUdf::TSourcePosition Pos_;
40+
};
41+
42+
// Class, implementing the closure with currying.
43+
class TCurrying : public NYql::NUdf::TBoxedValue {
44+
public:
45+
explicit TCurrying(NYql::NUdf::TSourcePosition pos)
46+
: Pos_(pos)
47+
{}
48+
49+
static const NYql::NUdf::TStringRef& Name() {
50+
static auto name = NYql::NUdf::TStringRef::Of("Test");
51+
return name;
52+
}
53+
54+
static bool DeclareSignature(const NYql::NUdf::TStringRef& name,
55+
NYql::NUdf::TType*,
56+
NYql::NUdf::IFunctionTypeInfoBuilder& builder,
57+
bool typesOnly)
58+
{
59+
if (Name() != name) {
60+
return false;
61+
}
62+
63+
builder.OptionalArgs(1).Args(2)->Add<char*>()
64+
.Add<NYql::NUdf::TOptional<bool>>().Name("NewOptionalArg");
65+
builder.Returns(builder.SimpleSignatureType<char*(char*)>());
66+
if (!typesOnly) {
67+
builder.Implementation(new TCurrying(builder.GetSourcePosition()));
68+
}
69+
70+
return true;
71+
}
72+
73+
NYql::NUdf::TUnboxedValue Run(const NYql::NUdf::IValueBuilder*,
74+
const NYql::NUdf::TUnboxedValuePod* args)
75+
const final try {
76+
const std::string_view upvalue(args[0].AsStringRef());
77+
UNIT_ASSERT(!args[1]);
78+
return NYql::NUdf::TUnboxedValuePod(new TImpl(Pos_, upvalue));
79+
} catch (const std::exception& e) {
80+
UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
81+
}
82+
83+
private:
84+
class TImpl : public NYql::NUdf::TBoxedValue {
85+
public:
86+
explicit TImpl(NYql::NUdf::TSourcePosition pos,
87+
const std::string_view upvalue)
88+
: Pos_(pos)
89+
, Upvalue_(upvalue)
90+
{}
91+
92+
NYql::NUdf::TUnboxedValue Run(const NYql::NUdf::IValueBuilder* valueBuilder,
93+
const NYql::NUdf::TUnboxedValuePod* args)
94+
const override try {
95+
TStringStream concat;
96+
concat << Upvalue_ << " " << args[0].AsStringRef();
97+
return valueBuilder->NewString(NYql::NUdf::TStringRef(concat.Data(),
98+
concat.Size()));
99+
} catch (const std::exception& e) {
100+
UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
101+
}
102+
103+
104+
private:
105+
const NYql::NUdf::TSourcePosition Pos_;
106+
const TString Upvalue_;
107+
};
108+
109+
const NYql::NUdf::TSourcePosition Pos_;
110+
};
111+
112+
113+
// XXX: To properly test the issue described in YQL-19967, there
114+
// should be two modules, registered by the same name, which
115+
// provide the functions with different signatures, exported by
116+
// the same name. Hence, class names for UDFs are different, but
117+
// use the same name in MKQL bytecode.
118+
SIMPLE_MODULE(TRunConfigUTModule, TRunConfig)
119+
SIMPLE_MODULE(TCurryingUTModule, TCurrying)
120+
121+
Y_UNIT_TEST_SUITE(TMiniKQLUdfTest) {
122+
Y_UNIT_TEST_LLVM(RunconfigToCurrying) {
123+
// Create the test setup, using TRunConfig implementation
124+
// for TestModule.Test UDF.
125+
TVector<TUdfModuleInfo> compileModules;
126+
compileModules.emplace_back(
127+
TUdfModuleInfo{"", "TestModule", new TRunConfigUTModule()}
128+
);
129+
TSetup<LLVM> compileSetup(GetTestFactory(), std::move(compileModules));
130+
TProgramBuilder& pb = *compileSetup.PgmBuilder;
131+
132+
// Build the graph on the setup with TRunConfig implementation.
133+
const auto strType = pb.NewDataType(NUdf::TDataType<char*>::Id);
134+
const auto upvalue = pb.NewDataLiteral<NUdf::EDataSlot::String>("Canary");
135+
const auto value = pb.NewDataLiteral<NUdf::EDataSlot::String>("is alive");
136+
const auto userType = pb.NewTupleType({
137+
pb.NewTupleType({strType}),
138+
pb.NewEmptyStructType(),
139+
pb.NewEmptyTupleType()});
140+
const auto udf = pb.Udf("TestModule.Test", upvalue, userType);
141+
142+
const auto list = pb.NewList(strType, {value});
143+
const auto pgmReturn = pb.Map(list, [&pb, udf](const TRuntimeNode item) {
144+
return pb.Apply(udf, {item});
145+
});
146+
147+
// Create the test setup, using TCurrying implementation
148+
// for TestModule.Test UDF.
149+
TVector<TUdfModuleInfo> runModules;
150+
runModules.emplace_back(
151+
TUdfModuleInfo{"", "TestModule", new TCurryingUTModule()}
152+
);
153+
TSetup<LLVM> runSetup(GetTestFactory(), std::move(runModules));
154+
155+
// Move the graph from the one setup to another as a
156+
// serialized bytecode sequence.
157+
const auto bytecode = SerializeRuntimeNode(pgmReturn, *compileSetup.Env);
158+
const auto root = DeserializeRuntimeNode(bytecode, *runSetup.Env);
159+
160+
// Run the graph on the setup with TCurrying implementation.
161+
const auto graph = runSetup.BuildGraph(root);
162+
const auto iterator = graph->GetValue().GetListIterator();
163+
164+
NUdf::TUnboxedValue result;
165+
UNIT_ASSERT(iterator.Next(result));
166+
UNIT_ASSERT_STRINGS_EQUAL(TStringBuf(result.AsStringRef()), "Canary is alive");
167+
UNIT_ASSERT(!iterator.Next(result));
168+
}
169+
} // Y_UNIT_TEST_SUITE
170+
171+
} // namespace NMiniKQL
172+
} // namespace NKikimr

yql/essentials/minikql/comp_nodes/ut/ya.make.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ SET(ORIG_SOURCES
6666
mkql_switch_ut.cpp
6767
mkql_time_order_recover_saveload_ut.cpp
6868
mkql_todict_ut.cpp
69+
mkql_udf_ut.cpp
6970
mkql_variant_ut.cpp
7071
mkql_wide_chain_map_ut.cpp
7172
mkql_wide_chopper_ut.cpp

0 commit comments

Comments
 (0)