Skip to content

Commit 4d47c26

Browse files
committed
YQL-19967: Properly handle signatures with run config (vice versa)
commit_hash:5b56b1b4416cf5b7abcae727a61e9402e10af067 (cherry picked from commit c6a21d1)
1 parent 75b1be1 commit 4d47c26

File tree

2 files changed

+101
-32
lines changed

2 files changed

+101
-32
lines changed

yql/essentials/minikql/comp_nodes/mkql_udf.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
308308
// It's only legal, when the compiled UDF declares its
309309
// signature using run config at compilation phase, but then
310310
// omits it in favor to function currying at execution phase.
311-
if (!runConfigFuncType->IsVoid()) {
311+
// And vice versa for the forward compatibility.
312+
if (!runConfigNodeType->IsVoid() && !runConfigFuncType->IsVoid()) {
312313
TString diff = TStringBuilder()
313314
<< "run config type mismatch, expected: "
314315
<< PrintNode((runConfigNodeType), true)
@@ -327,11 +328,15 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
327328
// run config type.
328329
// * The type of the resulting callable has to be the same
329330
// as the function type.
330-
const auto firstArgType = funcInfo.FunctionType->GetArgumentType(0);
331-
if (!runConfigNodeType->IsSameType(*firstArgType)) {
331+
const auto firstArgType = runConfigNodeType->IsVoid()
332+
? callable.GetType()->GetArgumentType(0)
333+
: funcInfo.FunctionType->GetArgumentType(0);
334+
const auto runConfigType = runConfigNodeType->IsVoid()
335+
? runConfigFuncType : runConfigNodeType;
336+
if (!runConfigType->IsSameType(*firstArgType)) {
332337
TString diff = TStringBuilder()
333338
<< "type mismatch, expected run config type: "
334-
<< PrintNode(runConfigNodeType, true)
339+
<< PrintNode(runConfigType, true)
335340
<< ", actual: "
336341
<< PrintNode(firstArgType, true);
337342
UdfTerminate((TStringBuilder() << pos
@@ -340,8 +345,12 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
340345
<< "' "
341346
<< TruncateTypeDiff(diff)).c_str());
342347
}
343-
const auto callableFuncType = funcInfo.FunctionType->GetReturnType();
344-
const auto callableNodeType = callable.GetType()->GetReturnType();
348+
const auto callableFuncType = runConfigNodeType->IsVoid()
349+
? funcInfo.FunctionType
350+
: funcInfo.FunctionType->GetReturnType();
351+
const auto callableNodeType = runConfigNodeType->IsVoid()
352+
? AS_TYPE(TCallableType, callable.GetType()->GetReturnType())->GetReturnType()
353+
: callable.GetType()->GetReturnType();
345354
if (!callableNodeType->IsSameType(*callableFuncType)) {
346355
TString diff = TStringBuilder()
347356
<< "type mismatch, expected return type: "
@@ -357,7 +366,9 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
357366

358367
const auto runConfigCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
359368
const auto runConfigArgs = funcInfo.FunctionType->GetArgumentsCount();
360-
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType, userType);
369+
return runConfigNodeType->IsVoid()
370+
? CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType)
371+
: CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType, userType);
361372
}
362373
MKQL_ENSURE(funcInfo.FunctionType->IsConvertableTo(*callable.GetType()->GetReturnType(), true),
363374
"Function '" << funcName << "' type mismatch, expected return type: " << PrintNode(callable.GetType()->GetReturnType(), true) <<

yql/essentials/minikql/comp_nodes/ut/mkql_udf_ut.cpp

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,31 @@
55
namespace NKikimr {
66
namespace NMiniKQL {
77

8+
class TImpl : public NYql::NUdf::TBoxedValue {
9+
public:
10+
explicit TImpl(NYql::NUdf::TSourcePosition pos,
11+
const std::string_view upvalue)
12+
: Pos_(pos)
13+
, Upvalue_(upvalue)
14+
{}
15+
16+
NYql::NUdf::TUnboxedValue Run(const NYql::NUdf::IValueBuilder* valueBuilder,
17+
const NYql::NUdf::TUnboxedValuePod* args)
18+
const override try {
19+
TStringStream concat;
20+
concat << Upvalue_ << " " << args[0].AsStringRef();
21+
return valueBuilder->NewString(NYql::NUdf::TStringRef(concat.Data(),
22+
concat.Size()));
23+
} catch (const std::exception& e) {
24+
UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
25+
}
26+
27+
28+
private:
29+
const NYql::NUdf::TSourcePosition Pos_;
30+
const TString Upvalue_;
31+
};
32+
833
// Class, implementing the closure with run config.
934
class TRunConfig : public NYql::NUdf::TBoxedValue {
1035
public:
@@ -35,6 +60,15 @@ class TRunConfig : public NYql::NUdf::TBoxedValue {
3560
return true;
3661
}
3762

63+
NYql::NUdf::TUnboxedValue Run(const NYql::NUdf::IValueBuilder*,
64+
const NYql::NUdf::TUnboxedValuePod* args)
65+
const final try {
66+
const std::string_view upvalue(args[0].AsStringRef());
67+
return NYql::NUdf::TUnboxedValuePod(new TImpl(Pos_, upvalue));
68+
} catch (const std::exception& e) {
69+
UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
70+
}
71+
3872
private:
3973
const NYql::NUdf::TSourcePosition Pos_;
4074
};
@@ -81,31 +115,6 @@ class TCurrying : public NYql::NUdf::TBoxedValue {
81115
}
82116

83117
private:
84-
class TImpl : public NYql::NUdf::TBoxedValue {
85-
public:
86-
explicit TImpl(NYql::NUdf::TSourcePosition pos,
87-
const std::string_view upvalue)
88-
: Pos_(pos)
89-
, Upvalue_(upvalue)
90-
{}
91-
92-
NYql::NUdf::TUnboxedValue Run(const NYql::NUdf::IValueBuilder* valueBuilder,
93-
const NYql::NUdf::TUnboxedValuePod* args)
94-
const override try {
95-
TStringStream concat;
96-
concat << Upvalue_ << " " << args[0].AsStringRef();
97-
return valueBuilder->NewString(NYql::NUdf::TStringRef(concat.Data(),
98-
concat.Size()));
99-
} catch (const std::exception& e) {
100-
UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data());
101-
}
102-
103-
104-
private:
105-
const NYql::NUdf::TSourcePosition Pos_;
106-
const TString Upvalue_;
107-
};
108-
109118
const NYql::NUdf::TSourcePosition Pos_;
110119
};
111120

@@ -166,6 +175,55 @@ Y_UNIT_TEST_SUITE(TMiniKQLUdfTest) {
166175
UNIT_ASSERT_STRINGS_EQUAL(TStringBuf(result.AsStringRef()), "Canary is alive");
167176
UNIT_ASSERT(!iterator.Next(result));
168177
}
178+
179+
Y_UNIT_TEST_LLVM(CurryingToRunconfig) {
180+
// Create the test setup, using TCurrying implementation
181+
// for TestModule.Test UDF.
182+
TVector<TUdfModuleInfo> compileModules;
183+
compileModules.emplace_back(
184+
TUdfModuleInfo{"", "TestModule", new TCurryingUTModule()}
185+
);
186+
TSetup<LLVM> compileSetup(GetTestFactory(), std::move(compileModules));
187+
TProgramBuilder& pb = *compileSetup.PgmBuilder;
188+
189+
// Build the graph on the setup with TRunConfig implementation.
190+
const auto strType = pb.NewDataType(NUdf::TDataType<char*>::Id);
191+
const auto upvalue = pb.NewDataLiteral<NUdf::EDataSlot::String>("Canary");
192+
const auto optional = pb.NewOptional(pb.NewDataLiteral(true));
193+
const auto value = pb.NewDataLiteral<NUdf::EDataSlot::String>("is alive");
194+
const auto userType = pb.NewTupleType({
195+
pb.NewTupleType({strType}),
196+
pb.NewEmptyStructType(),
197+
pb.NewEmptyTupleType()});
198+
const auto udf = pb.Udf("TestModule.Test", pb.NewVoid(), userType);
199+
const auto closure = pb.Apply(udf, {upvalue, optional});
200+
201+
const auto list = pb.NewList(strType, {value});
202+
const auto pgmReturn = pb.Map(list, [&pb, closure](const TRuntimeNode item) {
203+
return pb.Apply(closure, {item});
204+
});
205+
206+
// Create the test setup, using TRunConfig implementation
207+
// for TestModule.Test UDF.
208+
TVector<TUdfModuleInfo> runModules;
209+
runModules.emplace_back(
210+
TUdfModuleInfo{"", "TestModule", new TRunConfigUTModule()}
211+
);
212+
TSetup<LLVM> runSetup(GetTestFactory(), std::move(runModules));
213+
// Move the graph from the one setup to another as a
214+
// serialized bytecode sequence.
215+
const auto bytecode = SerializeRuntimeNode(pgmReturn, *compileSetup.Env);
216+
const auto root = DeserializeRuntimeNode(bytecode, *runSetup.Env);
217+
218+
// Run the graph on the setup with TCurrying implementation.
219+
const auto graph = runSetup.BuildGraph(root);
220+
const auto iterator = graph->GetValue().GetListIterator();
221+
222+
NUdf::TUnboxedValue result;
223+
UNIT_ASSERT(iterator.Next(result));
224+
UNIT_ASSERT_STRINGS_EQUAL(TStringBuf(result.AsStringRef()), "Canary is alive");
225+
UNIT_ASSERT(!iterator.Next(result));
226+
}
169227
} // Y_UNIT_TEST_SUITE
170228

171229
} // namespace NMiniKQL

0 commit comments

Comments
 (0)