Skip to content

Commit f009fcf

Browse files
committed
YQL-19967: Introduce TExtendedArgsWrapper helper
commit_hash:8aa01a548ffd87f8f1f6aa6df7eeddb66dad1a27 (cherry picked from commit 2dcd562)
1 parent 1e03aed commit f009fcf

File tree

2 files changed

+314
-17
lines changed

2 files changed

+314
-17
lines changed

yql/essentials/minikql/comp_nodes/mkql_udf.cpp

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePoli
3737
TString&& typeConfig,
3838
NUdf::TSourcePosition pos,
3939
const TCallableType* callableType,
40+
const TCallableType* functionType,
4041
TType* userType)
4142
: TBaseComputation(mutables, EValueRepresentation::Boxed)
4243
, FunctionName(std::move(functionName))
4344
, TypeConfig(std::move(typeConfig))
4445
, Pos(pos)
4546
, CallableType(callableType)
47+
, FunctionType(functionType)
4648
, UserType(userType)
4749
{
4850
this->Stateless = false;
@@ -58,16 +60,55 @@ using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePoli
5860
MKQL_ENSURE(status.IsOk(), status.GetError());
5961
MKQL_ENSURE(funcInfo.Implementation, "UDF implementation is not set for function " << FunctionName);
6062
NUdf::TUnboxedValue udf(NUdf::TUnboxedValuePod(funcInfo.Implementation.Release()));
61-
TValidate<TValidatePolicy,TValidateMode>::WrapCallable(CallableType, udf, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
63+
TValidate<TValidatePolicy,TValidateMode>::WrapCallable(FunctionType, udf, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
64+
ExtendArgs(udf, CallableType, funcInfo.FunctionType);
6265
return udf.Release();
6366
}
6467
private:
68+
// xXX: This class implements the wrapper to properly handle
69+
// the case when the signature of the emitted callable (i.e.
70+
// callable type) requires less arguments than the actual
71+
// function (i.e. function type). It wraps the unboxed value
72+
// with the resolved UDF to introduce the bridge in the
73+
// Run chain, preparing the valid argument vector for the
74+
// chosen UDF implementation.
75+
class TExtendedArgsWrapper: public NUdf::TBoxedValue {
76+
public:
77+
TExtendedArgsWrapper(NUdf::TUnboxedValue&& callable, size_t usedArgs, size_t requiredArgs)
78+
: Callable_(callable)
79+
, UsedArgs_(usedArgs)
80+
, RequiredArgs_(requiredArgs)
81+
{};
82+
83+
private:
84+
NUdf::TUnboxedValue Run(const NUdf::IValueBuilder* valueBuilder, const NUdf::TUnboxedValuePod* args) const final {
85+
NStackArray::TStackArray<NUdf::TUnboxedValue> values(ALLOC_ON_STACK(NUdf::TUnboxedValue, RequiredArgs_));
86+
for (size_t i = 0; i < UsedArgs_; i++) {
87+
values[i] = args[i];
88+
}
89+
return Callable_.Run(valueBuilder, values.data());
90+
}
91+
92+
const NUdf::TUnboxedValue Callable_;
93+
const size_t UsedArgs_;
94+
const size_t RequiredArgs_;
95+
};
96+
97+
void ExtendArgs(NUdf::TUnboxedValue& callable, const TCallableType* callableType, const TCallableType* functionType) const {
98+
const auto callableArgc = callableType->GetArgumentsCount();
99+
const auto functionArgc = functionType->GetArgumentsCount();
100+
if (callableArgc < functionArgc) {
101+
callable = NUdf::TUnboxedValuePod(new TExtendedArgsWrapper(std::move(callable), callableArgc, functionArgc));
102+
}
103+
}
104+
65105
void RegisterDependencies() const final {}
66106

67107
const TString FunctionName;
68108
const TString TypeConfig;
69109
const NUdf::TSourcePosition Pos;
70110
const TCallableType *const CallableType;
111+
const TCallableType *const FunctionType;
71112
TType *const UserType;
72113
};
73114

@@ -83,12 +124,13 @@ class TUdfRunCodegeneratorNode: public TSimpleUdfWrapper<TValidateErrorPolicyNon
83124
TString&& typeConfig,
84125
NUdf::TSourcePosition pos,
85126
const TCallableType* callableType,
127+
const TCallableType* functionType,
86128
TType* userType,
87129
TString&& moduleIRUniqID,
88130
TString&& moduleIR,
89131
TString&& fuctioNameIR,
90132
NUdf::TUniquePtr<NUdf::IBoxedValue>&& impl)
91-
: TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, userType)
133+
: TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, functionType, userType)
92134
, ModuleIRUniqID(std::move(moduleIRUniqID))
93135
, ModuleIR(std::move(moduleIR))
94136
, IRFunctionName(std::move(fuctioNameIR))
@@ -131,15 +173,15 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
131173
NUdf::TSourcePosition pos,
132174
IComputationNode* runConfigNode,
133175
ui32 runConfigArgs,
134-
const TCallableType* callableType,
176+
const TCallableType* functionType,
135177
TType* userType)
136178
: TBaseComputation(mutables, EValueRepresentation::Boxed)
137179
, FunctionName(std::move(functionName))
138180
, TypeConfig(std::move(typeConfig))
139181
, Pos(pos)
140182
, RunConfigNode(runConfigNode)
141183
, RunConfigArgs(runConfigArgs)
142-
, CallableType(callableType)
184+
, FunctionType(functionType)
143185
, UserType(userType)
144186
, UdfIndex(mutables.CurValueIndex++)
145187
{
@@ -225,7 +267,7 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
225267

226268
void Wrap(NUdf::TUnboxedValue& callable) const {
227269
MKQL_ENSURE(bool(callable), "Returned empty value in function: " << FunctionName);
228-
TValidate<TValidatePolicy,TValidateMode>::WrapCallable(CallableType, callable, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
270+
TValidate<TValidatePolicy,TValidateMode>::WrapCallable(FunctionType, callable, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
229271
}
230272

231273
void RegisterDependencies() const final {
@@ -237,7 +279,7 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
237279
const NUdf::TSourcePosition Pos;
238280
IComputationNode* const RunConfigNode;
239281
const ui32 RunConfigArgs;
240-
const TCallableType* CallableType;
282+
const TCallableType* FunctionType;
241283
TType* const UserType;
242284
const ui32 UdfIndex;
243285
};
@@ -301,6 +343,8 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
301343

302344
MKQL_ENSURE(status.IsOk(), status.GetError());
303345

346+
const auto callableFuncType = AS_TYPE(TCallableType, funcInfo.FunctionType);
347+
const auto callableNodeType = AS_TYPE(TCallableType, callable.GetType()->GetReturnType());
304348
const auto runConfigFuncType = funcInfo.RunConfigType;
305349
const auto runConfigNodeType = runCfgNode.GetStaticType();
306350

@@ -322,9 +366,6 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
322366
<< TruncateTypeDiff(diff)).c_str());
323367
}
324368

325-
const auto callableFuncType = AS_TYPE(TCallableType, funcInfo.FunctionType);
326-
const auto callableNodeType = AS_TYPE(TCallableType, callable.GetType()->GetReturnType());
327-
328369
const auto callableType = runConfigNodeType->IsVoid()
329370
? callableNodeType : callableFuncType;
330371
const auto runConfigType = runConfigNodeType->IsVoid()
@@ -380,26 +421,26 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
380421
const auto runConfigCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
381422
const auto runConfigArgs = funcInfo.FunctionType->GetArgumentsCount();
382423
return runConfigNodeType->IsVoid()
383-
? CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType)
384-
: CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType, userType);
424+
? CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, callableNodeType, callableFuncType, userType)
425+
: CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, callableFuncType, userType);
385426
}
386-
MKQL_ENSURE(funcInfo.FunctionType->IsConvertableTo(*callable.GetType()->GetReturnType(), true),
387-
"Function '" << funcName << "' type mismatch, expected return type: " << PrintNode(callable.GetType()->GetReturnType(), true) <<
388-
", actual:" << PrintNode(funcInfo.FunctionType, true));
427+
MKQL_ENSURE(callableFuncType->IsConvertableTo(*callableNodeType, true),
428+
"Function '" << funcName << "' type mismatch, expected return type: " << PrintNode(callableNodeType, true) <<
429+
", actual:" << PrintNode(callableFuncType, true));
389430
MKQL_ENSURE(funcInfo.Implementation, "UDF implementation is not set for function " << funcName);
390431

391432
if (runConfigFuncType->IsVoid()) {
392433
if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName) {
393434
return new TUdfRunCodegeneratorNode(
394-
ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType,
435+
ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, callableNodeType, callableFuncType, userType,
395436
std::move(funcInfo.ModuleIRUniqID), std::move(funcInfo.ModuleIR), std::move(funcInfo.IRFunctionName), std::move(funcInfo.Implementation)
396437
);
397438
}
398-
return CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType);
439+
return CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, callableNodeType, callableFuncType, userType);
399440
}
400441

401442
const auto runCfgCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
402-
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, 1U, funcInfo.FunctionType, userType);
443+
return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, 1U, callableFuncType, userType);
403444
}
404445

405446
IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {

0 commit comments

Comments
 (0)