@@ -37,12 +37,14 @@ using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePoli
37
37
TString&& typeConfig,
38
38
NUdf::TSourcePosition pos,
39
39
const TCallableType* callableType,
40
+ const TCallableType* functionType,
40
41
TType* userType)
41
42
: TBaseComputation(mutables, EValueRepresentation::Boxed)
42
43
, FunctionName(std::move(functionName))
43
44
, TypeConfig(std::move(typeConfig))
44
45
, Pos(pos)
45
46
, CallableType(callableType)
47
+ , FunctionType(functionType)
46
48
, UserType(userType)
47
49
{
48
50
this ->Stateless = false ;
@@ -58,16 +60,55 @@ using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePoli
58
60
MKQL_ENSURE (status.IsOk (), status.GetError ());
59
61
MKQL_ENSURE (funcInfo.Implementation , " UDF implementation is not set for function " << FunctionName);
60
62
NUdf::TUnboxedValue udf (NUdf::TUnboxedValuePod (funcInfo.Implementation .Release ()));
61
- TValidate<TValidatePolicy,TValidateMode>::WrapCallable (CallableType, udf, TStringBuilder () << " FunctionWithConfig<" << FunctionName << " >" );
63
+ TValidate<TValidatePolicy,TValidateMode>::WrapCallable (FunctionType, udf, TStringBuilder () << " FunctionWithConfig<" << FunctionName << " >" );
64
+ ExtendArgs (udf, CallableType, funcInfo.FunctionType );
62
65
return udf.Release ();
63
66
}
64
67
private:
68
+ // xXX: This class implements the wrapper to properly handle
69
+ // the case when the signature of the emitted callable (i.e.
70
+ // callable type) requires less arguments than the actual
71
+ // function (i.e. function type). It wraps the unboxed value
72
+ // with the resolved UDF to introduce the bridge in the
73
+ // Run chain, preparing the valid argument vector for the
74
+ // chosen UDF implementation.
75
+ class TExtendedArgsWrapper : public NUdf ::TBoxedValue {
76
+ public:
77
+ TExtendedArgsWrapper (NUdf::TUnboxedValue&& callable, size_t usedArgs, size_t requiredArgs)
78
+ : Callable_(callable)
79
+ , UsedArgs_(usedArgs)
80
+ , RequiredArgs_(requiredArgs)
81
+ {};
82
+
83
+ private:
84
+ NUdf::TUnboxedValue Run (const NUdf::IValueBuilder* valueBuilder, const NUdf::TUnboxedValuePod* args) const final {
85
+ NStackArray::TStackArray<NUdf::TUnboxedValue> values (ALLOC_ON_STACK (NUdf::TUnboxedValue, RequiredArgs_));
86
+ for (size_t i = 0 ; i < UsedArgs_; i++) {
87
+ values[i] = args[i];
88
+ }
89
+ return Callable_.Run (valueBuilder, values.data ());
90
+ }
91
+
92
+ const NUdf::TUnboxedValue Callable_;
93
+ const size_t UsedArgs_;
94
+ const size_t RequiredArgs_;
95
+ };
96
+
97
+ void ExtendArgs (NUdf::TUnboxedValue& callable, const TCallableType* callableType, const TCallableType* functionType) const {
98
+ const auto callableArgc = callableType->GetArgumentsCount ();
99
+ const auto functionArgc = functionType->GetArgumentsCount ();
100
+ if (callableArgc < functionArgc) {
101
+ callable = NUdf::TUnboxedValuePod (new TExtendedArgsWrapper (std::move (callable), callableArgc, functionArgc));
102
+ }
103
+ }
104
+
65
105
void RegisterDependencies () const final {}
66
106
67
107
const TString FunctionName;
68
108
const TString TypeConfig;
69
109
const NUdf::TSourcePosition Pos;
70
110
const TCallableType *const CallableType;
111
+ const TCallableType *const FunctionType;
71
112
TType *const UserType;
72
113
};
73
114
@@ -83,12 +124,13 @@ class TUdfRunCodegeneratorNode: public TSimpleUdfWrapper<TValidateErrorPolicyNon
83
124
TString&& typeConfig,
84
125
NUdf::TSourcePosition pos,
85
126
const TCallableType* callableType,
127
+ const TCallableType* functionType,
86
128
TType* userType,
87
129
TString&& moduleIRUniqID,
88
130
TString&& moduleIR,
89
131
TString&& fuctioNameIR,
90
132
NUdf::TUniquePtr<NUdf::IBoxedValue>&& impl)
91
- : TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, userType)
133
+ : TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, functionType, userType)
92
134
, ModuleIRUniqID(std::move(moduleIRUniqID))
93
135
, ModuleIR(std::move(moduleIR))
94
136
, IRFunctionName(std::move(fuctioNameIR))
@@ -131,15 +173,15 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
131
173
NUdf::TSourcePosition pos,
132
174
IComputationNode* runConfigNode,
133
175
ui32 runConfigArgs,
134
- const TCallableType* callableType ,
176
+ const TCallableType* functionType ,
135
177
TType* userType)
136
178
: TBaseComputation(mutables, EValueRepresentation::Boxed)
137
179
, FunctionName(std::move(functionName))
138
180
, TypeConfig(std::move(typeConfig))
139
181
, Pos(pos)
140
182
, RunConfigNode(runConfigNode)
141
183
, RunConfigArgs(runConfigArgs)
142
- , CallableType(callableType )
184
+ , FunctionType(functionType )
143
185
, UserType(userType)
144
186
, UdfIndex(mutables.CurValueIndex++)
145
187
{
@@ -225,7 +267,7 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
225
267
226
268
void Wrap (NUdf::TUnboxedValue& callable) const {
227
269
MKQL_ENSURE (bool (callable), " Returned empty value in function: " << FunctionName);
228
- TValidate<TValidatePolicy,TValidateMode>::WrapCallable (CallableType , callable, TStringBuilder () << " FunctionWithConfig<" << FunctionName << " >" );
270
+ TValidate<TValidatePolicy,TValidateMode>::WrapCallable (FunctionType , callable, TStringBuilder () << " FunctionWithConfig<" << FunctionName << " >" );
229
271
}
230
272
231
273
void RegisterDependencies () const final {
@@ -237,7 +279,7 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
237
279
const NUdf::TSourcePosition Pos;
238
280
IComputationNode* const RunConfigNode;
239
281
const ui32 RunConfigArgs;
240
- const TCallableType* CallableType ;
282
+ const TCallableType* FunctionType ;
241
283
TType* const UserType;
242
284
const ui32 UdfIndex;
243
285
};
@@ -301,6 +343,8 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
301
343
302
344
MKQL_ENSURE (status.IsOk (), status.GetError ());
303
345
346
+ const auto callableFuncType = AS_TYPE (TCallableType, funcInfo.FunctionType );
347
+ const auto callableNodeType = AS_TYPE (TCallableType, callable.GetType ()->GetReturnType ());
304
348
const auto runConfigFuncType = funcInfo.RunConfigType ;
305
349
const auto runConfigNodeType = runCfgNode.GetStaticType ();
306
350
@@ -322,9 +366,6 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
322
366
<< TruncateTypeDiff (diff)).c_str ());
323
367
}
324
368
325
- const auto callableFuncType = AS_TYPE (TCallableType, funcInfo.FunctionType );
326
- const auto callableNodeType = AS_TYPE (TCallableType, callable.GetType ()->GetReturnType ());
327
-
328
369
const auto callableType = runConfigNodeType->IsVoid ()
329
370
? callableNodeType : callableFuncType;
330
371
const auto runConfigType = runConfigNodeType->IsVoid ()
@@ -380,26 +421,26 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
380
421
const auto runConfigCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
381
422
const auto runConfigArgs = funcInfo.FunctionType ->GetArgumentsCount ();
382
423
return runConfigNodeType->IsVoid ()
383
- ? CreateUdfWrapper<true >(ctx, std::move (funcName), std::move (typeConfig), pos, funcInfo. FunctionType , userType)
384
- : CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo. FunctionType , userType);
424
+ ? CreateUdfWrapper<true >(ctx, std::move (funcName), std::move (typeConfig), pos, callableNodeType, callableFuncType , userType)
425
+ : CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runConfigCompNode, runConfigArgs, callableFuncType , userType);
385
426
}
386
- MKQL_ENSURE (funcInfo. FunctionType ->IsConvertableTo (*callable. GetType ()-> GetReturnType () , true ),
387
- " Function '" << funcName << " ' type mismatch, expected return type: " << PrintNode (callable. GetType ()-> GetReturnType () , true ) <<
388
- " , actual:" << PrintNode (funcInfo. FunctionType , true ));
427
+ MKQL_ENSURE (callableFuncType ->IsConvertableTo (*callableNodeType , true ),
428
+ " Function '" << funcName << " ' type mismatch, expected return type: " << PrintNode (callableNodeType , true ) <<
429
+ " , actual:" << PrintNode (callableFuncType , true ));
389
430
MKQL_ENSURE (funcInfo.Implementation , " UDF implementation is not set for function " << funcName);
390
431
391
432
if (runConfigFuncType->IsVoid ()) {
392
433
if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName ) {
393
434
return new TUdfRunCodegeneratorNode (
394
- ctx.Mutables , std::move (funcName), std::move (typeConfig), pos, funcInfo. FunctionType , userType,
435
+ ctx.Mutables , std::move (funcName), std::move (typeConfig), pos, callableNodeType, callableFuncType , userType,
395
436
std::move (funcInfo.ModuleIRUniqID ), std::move (funcInfo.ModuleIR ), std::move (funcInfo.IRFunctionName ), std::move (funcInfo.Implementation )
396
437
);
397
438
}
398
- return CreateUdfWrapper<true >(ctx, std::move (funcName), std::move (typeConfig), pos, funcInfo. FunctionType , userType);
439
+ return CreateUdfWrapper<true >(ctx, std::move (funcName), std::move (typeConfig), pos, callableNodeType, callableFuncType , userType);
399
440
}
400
441
401
442
const auto runCfgCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
402
- return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, 1U , funcInfo. FunctionType , userType);
443
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, 1U , callableFuncType , userType);
403
444
}
404
445
405
446
IComputationNode* WrapScriptUdf (TCallable& callable, const TComputationNodeFactoryContext& ctx) {
0 commit comments