Skip to content

Commit 3ecd4d1

Browse files
igormunkinlll-phill-lll
authored andcommitted
YQL-19967: Properly handle signatures with run config
commit_hash:e13229863b189b9ad804f0c6772204399430179e (cherry picked from commit 21395e1) (cherry picked from commit 8b2a785)
1 parent 0cd9040 commit 3ecd4d1

File tree

3 files changed

+269
-14
lines changed

3 files changed

+269
-14
lines changed

yql/essentials/minikql/comp_nodes/mkql_udf.cpp

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,23 @@
1010
#include <yql/essentials/minikql/mkql_utils.h>
1111
#include <yql/essentials/utils/yql_panic.h>
1212

13+
#include <library/cpp/containers/stack_array/stack_array.h>
14+
1315
namespace NKikimr {
1416
namespace NMiniKQL {
1517

1618
namespace {
1719

20+
constexpr size_t TypeDiffLimit = 1000;
21+
22+
TString TruncateTypeDiff(const TString& s) {
23+
if (s.size() < TypeDiffLimit) {
24+
return s;
25+
}
26+
27+
return s.substr(0,TypeDiffLimit) + "...";
28+
}
29+
1830
template<class TValidatePolicy, class TValidateMode>
1931
class TSimpleUdfWrapper: public TMutableComputationNode<TSimpleUdfWrapper<TValidatePolicy,TValidateMode>> {
2032
using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePolicy,TValidateMode>>;
@@ -118,13 +130,15 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
118130
TString&& typeConfig,
119131
NUdf::TSourcePosition pos,
120132
IComputationNode* runConfigNode,
133+
ui32 runConfigArgs,
121134
const TCallableType* callableType,
122135
TType* userType)
123136
: TBaseComputation(mutables, EValueRepresentation::Boxed)
124137
, FunctionName(std::move(functionName))
125138
, TypeConfig(std::move(typeConfig))
126139
, Pos(pos)
127140
, RunConfigNode(runConfigNode)
141+
, RunConfigArgs(runConfigArgs)
128142
, CallableType(callableType)
129143
, UserType(userType)
130144
, UdfIndex(mutables.CurValueIndex++)
@@ -137,15 +151,17 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
137151
if (!udf.HasValue()) {
138152
MakeUdf(ctx, udf);
139153
}
140-
const auto runConfig = RunConfigNode->GetValue(ctx);
141-
auto callable = udf.Run(ctx.Builder, &runConfig);
154+
NStackArray::TStackArray<NUdf::TUnboxedValue> args(ALLOC_ON_STACK(NUdf::TUnboxedValue, RunConfigArgs));
155+
args[0] = RunConfigNode->GetValue(ctx);
156+
auto callable = udf.Run(ctx.Builder, args.data());
142157
Wrap(callable);
143158
return callable;
144159
}
145160
#ifndef MKQL_DISABLE_CODEGEN
146161
void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
147162
auto& context = ctx.Codegen.GetContext();
148163

164+
const auto indexType = Type::getInt32Ty(context);
149165
const auto valueType = Type::getInt128Ty(context);
150166

151167
const auto udfPtr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), UdfIndex)}, "udf_ptr", block);
@@ -168,13 +184,24 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
168184

169185
block = main;
170186

171-
GetNodeValue(pointer, RunConfigNode, ctx, block);
172-
const auto conf = new LoadInst(valueType, pointer, "conf", block);
187+
const auto argsType = ArrayType::get(valueType, RunConfigArgs);
188+
const auto args = new AllocaInst(argsType, 0U, "args", block);
189+
Value* runConfigValue;
190+
for (ui32 i = 0; i < RunConfigArgs; i++) {
191+
const auto argIndex = ConstantInt::get(indexType, i);
192+
const auto argSlot = GetElementPtrInst::CreateInBounds(valueType, args, {argIndex}, "arg", block);
193+
if (i == 0) {
194+
GetNodeValue(argSlot, RunConfigNode, ctx, block);
195+
runConfigValue = new LoadInst(valueType, argSlot, "runconfig", block);
196+
} else {
197+
new StoreInst(ConstantInt::get(valueType, 0U), argSlot, block);
198+
}
199+
}
173200
const auto udf = new LoadInst(valueType, udfPtr, "udf", block);
174201

175-
CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen, block, ctx.GetBuilder(), pointer);
202+
CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen, block, ctx.GetBuilder(), args);
176203

177-
ValueUnRef(RunConfigNode->GetRepresentation(), conf, ctx, block);
204+
ValueUnRef(RunConfigNode->GetRepresentation(), runConfigValue, ctx, block);
178205

179206
const auto wrap = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TUdfWrapper::Wrap));
180207
const auto funType = FunctionType::get(Type::getVoidTy(context), {self->getType(), pointer->getType()}, false);
@@ -208,6 +235,7 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
208235
const TString TypeConfig;
209236
const NUdf::TSourcePosition Pos;
210237
IComputationNode* const RunConfigNode;
238+
const ui32 RunConfigArgs;
211239
const TCallableType* CallableType;
212240
TType* const UserType;
213241
const ui32 UdfIndex;
@@ -271,17 +299,71 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
271299
typeConfig, flags, pos, ctx.SecureParamsProvider, &funcInfo);
272300

273301
MKQL_ENSURE(status.IsOk(), status.GetError());
302+
303+
const auto runConfigFuncType = funcInfo.RunConfigType;
304+
const auto runConfigNodeType = runCfgNode.GetStaticType();
305+
306+
if (!runConfigFuncType->IsSameType(*runConfigNodeType)) {
307+
// It's only legal, when the compiled UDF declares its
308+
// signature using run config at compilation phase, but then
309+
// omits it in favor to function currying at execution phase.
310+
if (!runConfigFuncType->IsVoid()) {
311+
TString diff = TStringBuilder()
312+
<< "run config type mismatch, expected: "
313+
<< PrintNode((runConfigNodeType), true)
314+
<< ", actual: "
315+
<< PrintNode(runConfigFuncType, true);
316+
UdfTerminate((TStringBuilder() << pos
317+
<< " UDF Function '"
318+
<< funcName
319+
<< "' "
320+
<< TruncateTypeDiff(diff)).c_str());
321+
}
322+
323+
// If so, check the following invariants:
324+
// * The first argument of the head function in the sequence
325+
// of the curried functions has to be the same as the
326+
// run config type.
327+
// * The type of the resulting callable has to be the same
328+
// as the function type.
329+
const auto firstArgType = funcInfo.FunctionType->GetArgumentType(0);
330+
if (!runConfigNodeType->IsSameType(*firstArgType)) {
331+
TString diff = TStringBuilder()
332+
<< "type mismatch, expected run config type: "
333+
<< PrintNode(runConfigNodeType, true)
334+
<< ", actual: "
335+
<< PrintNode(firstArgType, true);
336+
UdfTerminate((TStringBuilder() << pos
337+
<< " Udf Function '"
338+
<< funcName
339+
<< "' "
340+
<< TruncateTypeDiff(diff)).c_str());
341+
}
342+
const auto callableFuncType = funcInfo.FunctionType->GetReturnType();
343+
const auto callableNodeType = callable.GetType()->GetReturnType();
344+
if (!callableNodeType->IsSameType(*callableFuncType)) {
345+
TString diff = TStringBuilder()
346+
<< "type mismatch, expected return type: "
347+
<< PrintNode(callableNodeType, true)
348+
<< ", actual: "
349+
<< PrintNode(callableFuncType, true);
350+
UdfTerminate((TStringBuilder() << pos
351+
<< " Udf Function '"
352+
<< funcName
353+
<< "' "
354+
<< TruncateTypeDiff(diff)).c_str());
355+
}
356+
357+
const auto runConfigCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
358+
const auto runConfigArgs = funcInfo.FunctionType->GetArgumentsCount();
359+
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType, userType);
360+
}
274361
MKQL_ENSURE(funcInfo.FunctionType->IsConvertableTo(*callable.GetType()->GetReturnType(), true),
275362
"Function '" << funcName << "' type mismatch, expected return type: " << PrintNode(callable.GetType()->GetReturnType(), true) <<
276363
", actual:" << PrintNode(funcInfo.FunctionType, true));
277364
MKQL_ENSURE(funcInfo.Implementation, "UDF implementation is not set for function " << funcName);
278365

279-
const auto runConfigType = funcInfo.RunConfigType;
280-
const bool typesMatch = runConfigType->IsSameType(*runCfgNode.GetStaticType());
281-
MKQL_ENSURE(typesMatch, "RunConfig '" << funcName << "' type mismatch, expected: " << PrintNode(runCfgNode.GetStaticType(), true) <<
282-
", actual: " << PrintNode(runConfigType, true));
283-
284-
if (runConfigType->IsVoid()) {
366+
if (runConfigFuncType->IsVoid()) {
285367
if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName) {
286368
return new TUdfRunCodegeneratorNode(
287369
ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType,
@@ -292,7 +374,7 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
292374
}
293375

294376
const auto runCfgCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
295-
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, funcInfo.FunctionType, userType);
377+
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, 1U, funcInfo.FunctionType, userType);
296378
}
297379

298380
IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -331,7 +413,7 @@ IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFacto
331413
const auto funcTypeInfo = static_cast<TCallableType*>(callableResultType);
332414

333415
const auto programCompNode = LocateNode(ctx.NodeLocator, *programNode.GetNode());
334-
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, programCompNode, funcTypeInfo, userType);
416+
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, programCompNode, 1U, funcTypeInfo, userType);
335417
}
336418

337419
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#include "mkql_computation_node_ut.h"
2+
#include <yql/essentials/public/udf/udf_helpers.h>
3+
#include <yql/essentials/minikql/mkql_node_serialization.h>
4+
5+
namespace NKikimr {
6+
namespace NMiniKQL {
7+
8+
// Class, implementing the closure with run config.
9+
class TRunConfig : public NYql::NUdf::TBoxedValue {
10+
public:
11+
explicit TRunConfig(NYql::NUdf::TSourcePosition pos)
12+
: Pos_(pos)
13+
{}
14+
15+
static const NYql::NUdf::TStringRef& Name() {
16+
static auto name = NYql::NUdf::TStringRef::Of("Test");
17+
return name;
18+
}
19+
20+
static bool DeclareSignature(const NYql::NUdf::TStringRef& name,
21+
NYql::NUdf::TType*,
22+
NYql::NUdf::IFunctionTypeInfoBuilder& builder,
23+
bool typesOnly)
24+
{
25+
if (Name() != name) {
26+
return false;
27+
}
28+
29+
builder.RunConfig<char*>().Args(1)->Add<char*>();
30+
builder.Returns<char*>();
31+
if (!typesOnly) {
32+
builder.Implementation(new TRunConfig(builder.GetSourcePosition()));
33+
}
34+
35+
return true;
36+
}
37+
38+
private:
39+
const NYql::NUdf::TSourcePosition Pos_;
40+
};
41+
42+
// Class, implementing the closure with currying.
43+
class TCurrying : public NYql::NUdf::TBoxedValue {
44+
public:
45+
explicit TCurrying(NYql::NUdf::TSourcePosition pos)
46+
: Pos_(pos)
47+
{}
48+
49+
static const NYql::NUdf::TStringRef& Name() {
50+
static auto name = NYql::NUdf::TStringRef::Of("Test");
51+
return name;
52+
}
53+
54+
static bool DeclareSignature(const NYql::NUdf::TStringRef& name,
55+
NYql::NUdf::TType*,
56+
NYql::NUdf::IFunctionTypeInfoBuilder& builder,
57+
bool typesOnly)
58+
{
59+
if (Name() != name) {
60+
return false;
61+
}
62+
63+
builder.OptionalArgs(1).Args(2)->Add<char*>()
64+
.Add<NYql::NUdf::TOptional<bool>>().Name("NewOptionalArg");
65+
builder.Returns(builder.SimpleSignatureType<char*(char*)>());
66+
if (!typesOnly) {
67+
builder.Implementation(new TCurrying(builder.GetSourcePosition()));
68+
}
69+
70+
return true;
71+
}
72+
73+
NYql::NUdf::TUnboxedValue Run(const NYql::NUdf::IValueBuilder*,
74+
const NYql::NUdf::TUnboxedValuePod* args)
75+
const final try {
76+
const std::string_view upvalue(args[0].AsStringRef());
77+
UNIT_ASSERT(!args[1]);
78+
return NYql::NUdf::TUnboxedValuePod(new TImpl(Pos_, upvalue));
79+
} catch (const std::exception& e) {
80+
UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
81+
}
82+
83+
private:
84+
class TImpl : public NYql::NUdf::TBoxedValue {
85+
public:
86+
explicit TImpl(NYql::NUdf::TSourcePosition pos,
87+
const std::string_view upvalue)
88+
: Pos_(pos)
89+
, Upvalue_(upvalue)
90+
{}
91+
92+
NYql::NUdf::TUnboxedValue Run(const NYql::NUdf::IValueBuilder* valueBuilder,
93+
const NYql::NUdf::TUnboxedValuePod* args)
94+
const override try {
95+
TStringStream concat;
96+
concat << Upvalue_ << " " << args[0].AsStringRef();
97+
return valueBuilder->NewString(NYql::NUdf::TStringRef(concat.Data(),
98+
concat.Size()));
99+
} catch (const std::exception& e) {
100+
UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
101+
}
102+
103+
104+
private:
105+
const NYql::NUdf::TSourcePosition Pos_;
106+
const TString Upvalue_;
107+
};
108+
109+
const NYql::NUdf::TSourcePosition Pos_;
110+
};
111+
112+
113+
// XXX: To properly test the issue described in YQL-19967, there
114+
// should be two modules, registered by the same name, which
115+
// provide the functions with different signatures, exported by
116+
// the same name. Hence, class names for UDFs are different, but
117+
// use the same name in MKQL bytecode.
118+
SIMPLE_MODULE(TRunConfigUTModule, TRunConfig)
119+
SIMPLE_MODULE(TCurryingUTModule, TCurrying)
120+
121+
Y_UNIT_TEST_SUITE(TMiniKQLUdfTest) {
122+
Y_UNIT_TEST_LLVM(RunconfigToCurrying) {
123+
// Create the test setup, using TRunConfig implementation
124+
// for TestModule.Test UDF.
125+
TVector<TUdfModuleInfo> compileModules;
126+
compileModules.emplace_back(
127+
TUdfModuleInfo{"", "TestModule", new TRunConfigUTModule()}
128+
);
129+
TSetup<LLVM> compileSetup(GetTestFactory(), std::move(compileModules));
130+
TProgramBuilder& pb = *compileSetup.PgmBuilder;
131+
132+
// Build the graph on the setup with TRunConfig implementation.
133+
const auto strType = pb.NewDataType(NUdf::TDataType<char*>::Id);
134+
const auto upvalue = pb.NewDataLiteral<NUdf::EDataSlot::String>("Canary");
135+
const auto value = pb.NewDataLiteral<NUdf::EDataSlot::String>("is alive");
136+
const auto userType = pb.NewTupleType({
137+
pb.NewTupleType({strType}),
138+
pb.NewEmptyStructType(),
139+
pb.NewEmptyTupleType()});
140+
const auto udf = pb.Udf("TestModule.Test", upvalue, userType);
141+
142+
const auto list = pb.NewList(strType, {value});
143+
const auto pgmReturn = pb.Map(list, [&pb, udf](const TRuntimeNode item) {
144+
return pb.Apply(udf, {item});
145+
});
146+
147+
// Create the test setup, using TCurrying implementation
148+
// for TestModule.Test UDF.
149+
TVector<TUdfModuleInfo> runModules;
150+
runModules.emplace_back(
151+
TUdfModuleInfo{"", "TestModule", new TCurryingUTModule()}
152+
);
153+
TSetup<LLVM> runSetup(GetTestFactory(), std::move(runModules));
154+
155+
// Move the graph from the one setup to another as a
156+
// serialized bytecode sequence.
157+
const auto bytecode = SerializeRuntimeNode(pgmReturn, *compileSetup.Env);
158+
const auto root = DeserializeRuntimeNode(bytecode, *runSetup.Env);
159+
160+
// Run the graph on the setup with TCurrying implementation.
161+
const auto graph = runSetup.BuildGraph(root);
162+
const auto iterator = graph->GetValue().GetListIterator();
163+
164+
NUdf::TUnboxedValue result;
165+
UNIT_ASSERT(iterator.Next(result));
166+
UNIT_ASSERT_STRINGS_EQUAL(TStringBuf(result.AsStringRef()), "Canary is alive");
167+
UNIT_ASSERT(!iterator.Next(result));
168+
}
169+
} // Y_UNIT_TEST_SUITE
170+
171+
} // namespace NMiniKQL
172+
} // namespace NKikimr

yql/essentials/minikql/comp_nodes/ut/ya.make.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ SET(ORIG_SOURCES
5959
mkql_switch_ut.cpp
6060
mkql_time_order_recover_saveload_ut.cpp
6161
mkql_todict_ut.cpp
62+
mkql_udf_ut.cpp
6263
mkql_variant_ut.cpp
6364
mkql_wide_chain_map_ut.cpp
6465
mkql_wide_chopper_ut.cpp

0 commit comments

Comments
 (0)