Skip to content

Commit d0bb904

Browse files
authored
[backport/stable-25-1] Properly handle signatures with run config and more optional arguments (#18954)
2 parents 76514f2 + 0490569 commit d0bb904

File tree

3 files changed

+473
-45
lines changed

3 files changed

+473
-45
lines changed

yql/essentials/minikql/comp_nodes/mkql_udf.cpp

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePoli
3737
TString&& typeConfig,
3838
NUdf::TSourcePosition pos,
3939
const TCallableType* callableType,
40+
const TCallableType* functionType,
4041
TType* userType)
4142
: TBaseComputation(mutables, EValueRepresentation::Boxed)
4243
, FunctionName(std::move(functionName))
4344
, TypeConfig(std::move(typeConfig))
4445
, Pos(pos)
4546
, CallableType(callableType)
47+
, FunctionType(functionType)
4648
, UserType(userType)
4749
{
4850
this->Stateless = false;
@@ -58,16 +60,55 @@ using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePoli
5860
MKQL_ENSURE(status.IsOk(), status.GetError());
5961
MKQL_ENSURE(funcInfo.Implementation, "UDF implementation is not set for function " << FunctionName);
6062
NUdf::TUnboxedValue udf(NUdf::TUnboxedValuePod(funcInfo.Implementation.Release()));
61-
TValidate<TValidatePolicy,TValidateMode>::WrapCallable(CallableType, udf, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
63+
TValidate<TValidatePolicy,TValidateMode>::WrapCallable(FunctionType, udf, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
64+
ExtendArgs(udf, CallableType, funcInfo.FunctionType);
6265
return udf.Release();
6366
}
6467
private:
68+
// xXX: This class implements the wrapper to properly handle
69+
// the case when the signature of the emitted callable (i.e.
70+
// callable type) requires less arguments than the actual
71+
// function (i.e. function type). It wraps the unboxed value
72+
// with the resolved UDF to introduce the bridge in the
73+
// Run chain, preparing the valid argument vector for the
74+
// chosen UDF implementation.
75+
class TExtendedArgsWrapper: public NUdf::TBoxedValue {
76+
public:
77+
TExtendedArgsWrapper(NUdf::TUnboxedValue&& callable, size_t usedArgs, size_t requiredArgs)
78+
: Callable_(callable)
79+
, UsedArgs_(usedArgs)
80+
, RequiredArgs_(requiredArgs)
81+
{};
82+
83+
private:
84+
NUdf::TUnboxedValue Run(const NUdf::IValueBuilder* valueBuilder, const NUdf::TUnboxedValuePod* args) const final {
85+
NStackArray::TStackArray<NUdf::TUnboxedValue> values(ALLOC_ON_STACK(NUdf::TUnboxedValue, RequiredArgs_));
86+
for (size_t i = 0; i < UsedArgs_; i++) {
87+
values[i] = args[i];
88+
}
89+
return Callable_.Run(valueBuilder, values.data());
90+
}
91+
92+
const NUdf::TUnboxedValue Callable_;
93+
const size_t UsedArgs_;
94+
const size_t RequiredArgs_;
95+
};
96+
97+
void ExtendArgs(NUdf::TUnboxedValue& callable, const TCallableType* callableType, const TCallableType* functionType) const {
98+
const auto callableArgc = callableType->GetArgumentsCount();
99+
const auto functionArgc = functionType->GetArgumentsCount();
100+
if (callableArgc < functionArgc) {
101+
callable = NUdf::TUnboxedValuePod(new TExtendedArgsWrapper(std::move(callable), callableArgc, functionArgc));
102+
}
103+
}
104+
65105
void RegisterDependencies() const final {}
66106

67107
const TString FunctionName;
68108
const TString TypeConfig;
69109
const NUdf::TSourcePosition Pos;
70110
const TCallableType *const CallableType;
111+
const TCallableType *const FunctionType;
71112
TType *const UserType;
72113
};
73114

@@ -83,12 +124,13 @@ class TUdfRunCodegeneratorNode: public TSimpleUdfWrapper<TValidateErrorPolicyNon
83124
TString&& typeConfig,
84125
NUdf::TSourcePosition pos,
85126
const TCallableType* callableType,
127+
const TCallableType* functionType,
86128
TType* userType,
87129
TString&& moduleIRUniqID,
88130
TString&& moduleIR,
89131
TString&& fuctioNameIR,
90132
NUdf::TUniquePtr<NUdf::IBoxedValue>&& impl)
91-
: TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, userType)
133+
: TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, functionType, userType)
92134
, ModuleIRUniqID(std::move(moduleIRUniqID))
93135
, ModuleIR(std::move(moduleIR))
94136
, IRFunctionName(std::move(fuctioNameIR))
@@ -301,14 +343,17 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
301343

302344
MKQL_ENSURE(status.IsOk(), status.GetError());
303345

346+
const auto callableFuncType = AS_TYPE(TCallableType, funcInfo.FunctionType);
347+
const auto callableNodeType = AS_TYPE(TCallableType, callable.GetType()->GetReturnType());
304348
const auto runConfigFuncType = funcInfo.RunConfigType;
305349
const auto runConfigNodeType = runCfgNode.GetStaticType();
306350

307351
if (!runConfigFuncType->IsSameType(*runConfigNodeType)) {
308352
// It's only legal, when the compiled UDF declares its
309353
// signature using run config at compilation phase, but then
310354
// omits it in favor to function currying at execution phase.
311-
if (!runConfigFuncType->IsVoid()) {
355+
// And vice versa for the forward compatibility.
356+
if (!runConfigNodeType->IsVoid() && !runConfigFuncType->IsVoid()) {
312357
TString diff = TStringBuilder()
313358
<< "run config type mismatch, expected: "
314359
<< PrintNode((runConfigNodeType), true)
@@ -321,17 +366,31 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
321366
<< TruncateTypeDiff(diff)).c_str());
322367
}
323368

369+
const auto callableType = runConfigNodeType->IsVoid()
370+
? callableNodeType : callableFuncType;
371+
const auto runConfigType = runConfigNodeType->IsVoid()
372+
? runConfigFuncType : runConfigNodeType;
373+
324374
// If so, check the following invariants:
325375
// * The first argument of the head function in the sequence
326376
// of the curried functions has to be the same as the
327377
// run config type.
378+
// * All other arguments of the head function in the sequence
379+
// of the curried function have to be optional.
328380
// * The type of the resulting callable has to be the same
329381
// as the function type.
330-
const auto firstArgType = funcInfo.FunctionType->GetArgumentType(0);
331-
if (!runConfigNodeType->IsSameType(*firstArgType)) {
382+
if (callableType->GetArgumentsCount() - callableType->GetOptionalArgumentsCount() != 1U) {
383+
UdfTerminate((TStringBuilder() << pos
384+
<< " Udf Function '"
385+
<< funcName
386+
<< "' wrapper has more than one required argument: "
387+
<< PrintNode(callableType)).c_str());
388+
}
389+
const auto firstArgType = callableType->GetArgumentType(0);
390+
if (!runConfigType->IsSameType(*firstArgType)) {
332391
TString diff = TStringBuilder()
333392
<< "type mismatch, expected run config type: "
334-
<< PrintNode(runConfigNodeType, true)
393+
<< PrintNode(runConfigType, true)
335394
<< ", actual: "
336395
<< PrintNode(firstArgType, true);
337396
UdfTerminate((TStringBuilder() << pos
@@ -340,14 +399,18 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
340399
<< "' "
341400
<< TruncateTypeDiff(diff)).c_str());
342401
}
343-
const auto callableFuncType = funcInfo.FunctionType->GetReturnType();
344-
const auto callableNodeType = callable.GetType()->GetReturnType();
345-
if (!callableNodeType->IsSameType(*callableFuncType)) {
402+
const auto closureFuncType = runConfigNodeType->IsVoid()
403+
? callableFuncType
404+
: AS_TYPE(TCallableType, callableFuncType)->GetReturnType();
405+
const auto closureNodeType = runConfigNodeType->IsVoid()
406+
? AS_TYPE(TCallableType, callableNodeType)->GetReturnType()
407+
: callableNodeType;
408+
if (!closureNodeType->IsConvertableTo(*closureFuncType)) {
346409
TString diff = TStringBuilder()
347410
<< "type mismatch, expected return type: "
348-
<< PrintNode(callableNodeType, true)
411+
<< PrintNode(closureNodeType, true)
349412
<< ", actual: "
350-
<< PrintNode(callableFuncType, true);
413+
<< PrintNode(closureFuncType, true);
351414
UdfTerminate((TStringBuilder() << pos
352415
<< " Udf Function '"
353416
<< funcName
@@ -357,25 +420,27 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
357420

358421
const auto runConfigCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
359422
const auto runConfigArgs = funcInfo.FunctionType->GetArgumentsCount();
360-
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType, userType);
423+
return runConfigNodeType->IsVoid()
424+
? CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, callableNodeType, callableFuncType, userType)
425+
: CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, callableNodeType, userType);
361426
}
362-
MKQL_ENSURE(funcInfo.FunctionType->IsConvertableTo(*callable.GetType()->GetReturnType(), true),
363-
"Function '" << funcName << "' type mismatch, expected return type: " << PrintNode(callable.GetType()->GetReturnType(), true) <<
364-
", actual:" << PrintNode(funcInfo.FunctionType, true));
427+
MKQL_ENSURE(callableFuncType->IsConvertableTo(*callableNodeType, true),
428+
"Function '" << funcName << "' type mismatch, expected return type: " << PrintNode(callableNodeType, true) <<
429+
", actual:" << PrintNode(callableFuncType, true));
365430
MKQL_ENSURE(funcInfo.Implementation, "UDF implementation is not set for function " << funcName);
366431

367432
if (runConfigFuncType->IsVoid()) {
368433
if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName) {
369434
return new TUdfRunCodegeneratorNode(
370-
ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType,
435+
ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, callableNodeType, callableFuncType, userType,
371436
std::move(funcInfo.ModuleIRUniqID), std::move(funcInfo.ModuleIR), std::move(funcInfo.IRFunctionName), std::move(funcInfo.Implementation)
372437
);
373438
}
374-
return CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType);
439+
return CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, callableNodeType, callableFuncType, userType);
375440
}
376441

377442
const auto runCfgCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
378-
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, 1U, funcInfo.FunctionType, userType);
443+
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, 1U, callableNodeType, userType);
379444
}
380445

381446
IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {

yql/essentials/minikql/comp_nodes/ut/mkql_computation_node_ut.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ struct TSetup {
118118
Reset();
119119
Explorer.Walk(pgm.GetNode(), *Env);
120120
TComputationPatternOpts opts(Alloc.Ref(), *Env, NodeFactory,
121-
FunctionRegistry.Get(), NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception,
121+
FunctionRegistry.Get(), NUdf::EValidateMode::Greedy, NUdf::EValidatePolicy::Exception,
122122
UseLLVM ? "" : "OFF", graphPerProcess, StatsRegistry.Get(), nullptr, nullptr);
123123
Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
124124
auto graph = Pattern->Clone(opts.ToComputationOptions(*RandomProvider, *TimeProvider));

0 commit comments

Comments
 (0)