@@ -37,12 +37,14 @@ using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePoli
37
37
TString&& typeConfig,
38
38
NUdf::TSourcePosition pos,
39
39
const TCallableType* callableType,
40
+ const TCallableType* functionType,
40
41
TType* userType)
41
42
: TBaseComputation(mutables, EValueRepresentation::Boxed)
42
43
, FunctionName(std::move(functionName))
43
44
, TypeConfig(std::move(typeConfig))
44
45
, Pos(pos)
45
46
, CallableType(callableType)
47
+ , FunctionType(functionType)
46
48
, UserType(userType)
47
49
{
48
50
this ->Stateless = false ;
@@ -58,16 +60,55 @@ using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePoli
58
60
MKQL_ENSURE (status.IsOk (), status.GetError ());
59
61
MKQL_ENSURE (funcInfo.Implementation , " UDF implementation is not set for function " << FunctionName);
60
62
NUdf::TUnboxedValue udf (NUdf::TUnboxedValuePod (funcInfo.Implementation .Release ()));
61
- TValidate<TValidatePolicy,TValidateMode>::WrapCallable (CallableType, udf, TStringBuilder () << " FunctionWithConfig<" << FunctionName << " >" );
63
+ TValidate<TValidatePolicy,TValidateMode>::WrapCallable (FunctionType, udf, TStringBuilder () << " FunctionWithConfig<" << FunctionName << " >" );
64
+ ExtendArgs (udf, CallableType, funcInfo.FunctionType );
62
65
return udf.Release ();
63
66
}
64
67
private:
68
+ // xXX: This class implements the wrapper to properly handle
69
+ // the case when the signature of the emitted callable (i.e.
70
+ // callable type) requires less arguments than the actual
71
+ // function (i.e. function type). It wraps the unboxed value
72
+ // with the resolved UDF to introduce the bridge in the
73
+ // Run chain, preparing the valid argument vector for the
74
+ // chosen UDF implementation.
75
+ class TExtendedArgsWrapper : public NUdf ::TBoxedValue {
76
+ public:
77
+ TExtendedArgsWrapper (NUdf::TUnboxedValue&& callable, size_t usedArgs, size_t requiredArgs)
78
+ : Callable_(callable)
79
+ , UsedArgs_(usedArgs)
80
+ , RequiredArgs_(requiredArgs)
81
+ {};
82
+
83
+ private:
84
+ NUdf::TUnboxedValue Run (const NUdf::IValueBuilder* valueBuilder, const NUdf::TUnboxedValuePod* args) const final {
85
+ NStackArray::TStackArray<NUdf::TUnboxedValue> values (ALLOC_ON_STACK (NUdf::TUnboxedValue, RequiredArgs_));
86
+ for (size_t i = 0 ; i < UsedArgs_; i++) {
87
+ values[i] = args[i];
88
+ }
89
+ return Callable_.Run (valueBuilder, values.data ());
90
+ }
91
+
92
+ const NUdf::TUnboxedValue Callable_;
93
+ const size_t UsedArgs_;
94
+ const size_t RequiredArgs_;
95
+ };
96
+
97
+ void ExtendArgs (NUdf::TUnboxedValue& callable, const TCallableType* callableType, const TCallableType* functionType) const {
98
+ const auto callableArgc = callableType->GetArgumentsCount ();
99
+ const auto functionArgc = functionType->GetArgumentsCount ();
100
+ if (callableArgc < functionArgc) {
101
+ callable = NUdf::TUnboxedValuePod (new TExtendedArgsWrapper (std::move (callable), callableArgc, functionArgc));
102
+ }
103
+ }
104
+
65
105
void RegisterDependencies () const final {}
66
106
67
107
const TString FunctionName;
68
108
const TString TypeConfig;
69
109
const NUdf::TSourcePosition Pos;
70
110
const TCallableType *const CallableType;
111
+ const TCallableType *const FunctionType;
71
112
TType *const UserType;
72
113
};
73
114
@@ -83,12 +124,13 @@ class TUdfRunCodegeneratorNode: public TSimpleUdfWrapper<TValidateErrorPolicyNon
83
124
TString&& typeConfig,
84
125
NUdf::TSourcePosition pos,
85
126
const TCallableType* callableType,
127
+ const TCallableType* functionType,
86
128
TType* userType,
87
129
TString&& moduleIRUniqID,
88
130
TString&& moduleIR,
89
131
TString&& fuctioNameIR,
90
132
NUdf::TUniquePtr<NUdf::IBoxedValue>&& impl)
91
- : TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, userType)
133
+ : TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, functionType, userType)
92
134
, ModuleIRUniqID(std::move(moduleIRUniqID))
93
135
, ModuleIR(std::move(moduleIR))
94
136
, IRFunctionName(std::move(fuctioNameIR))
@@ -301,14 +343,17 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
301
343
302
344
MKQL_ENSURE (status.IsOk (), status.GetError ());
303
345
346
+ const auto callableFuncType = AS_TYPE (TCallableType, funcInfo.FunctionType );
347
+ const auto callableNodeType = AS_TYPE (TCallableType, callable.GetType ()->GetReturnType ());
304
348
const auto runConfigFuncType = funcInfo.RunConfigType ;
305
349
const auto runConfigNodeType = runCfgNode.GetStaticType ();
306
350
307
351
if (!runConfigFuncType->IsSameType (*runConfigNodeType)) {
308
352
// It's only legal, when the compiled UDF declares its
309
353
// signature using run config at compilation phase, but then
310
354
// omits it in favor to function currying at execution phase.
311
- if (!runConfigFuncType->IsVoid ()) {
355
+ // And vice versa for the forward compatibility.
356
+ if (!runConfigNodeType->IsVoid () && !runConfigFuncType->IsVoid ()) {
312
357
TString diff = TStringBuilder ()
313
358
<< " run config type mismatch, expected: "
314
359
<< PrintNode ((runConfigNodeType), true )
@@ -321,17 +366,31 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
321
366
<< TruncateTypeDiff (diff)).c_str ());
322
367
}
323
368
369
+ const auto callableType = runConfigNodeType->IsVoid ()
370
+ ? callableNodeType : callableFuncType;
371
+ const auto runConfigType = runConfigNodeType->IsVoid ()
372
+ ? runConfigFuncType : runConfigNodeType;
373
+
324
374
// If so, check the following invariants:
325
375
// * The first argument of the head function in the sequence
326
376
// of the curried functions has to be the same as the
327
377
// run config type.
378
+ // * All other arguments of the head function in the sequence
379
+ // of the curried function have to be optional.
328
380
// * The type of the resulting callable has to be the same
329
381
// as the function type.
330
- const auto firstArgType = funcInfo.FunctionType ->GetArgumentType (0 );
331
- if (!runConfigNodeType->IsSameType (*firstArgType)) {
382
+ if (callableType->GetArgumentsCount () - callableType->GetOptionalArgumentsCount () != 1U ) {
383
+ UdfTerminate ((TStringBuilder () << pos
384
+ << " Udf Function '"
385
+ << funcName
386
+ << " ' wrapper has more than one required argument: "
387
+ << PrintNode (callableType)).c_str ());
388
+ }
389
+ const auto firstArgType = callableType->GetArgumentType (0 );
390
+ if (!runConfigType->IsSameType (*firstArgType)) {
332
391
TString diff = TStringBuilder ()
333
392
<< " type mismatch, expected run config type: "
334
- << PrintNode (runConfigNodeType , true )
393
+ << PrintNode (runConfigType , true )
335
394
<< " , actual: "
336
395
<< PrintNode (firstArgType, true );
337
396
UdfTerminate ((TStringBuilder () << pos
@@ -340,14 +399,18 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
340
399
<< " ' "
341
400
<< TruncateTypeDiff (diff)).c_str ());
342
401
}
343
- const auto callableFuncType = funcInfo.FunctionType ->GetReturnType ();
344
- const auto callableNodeType = callable.GetType ()->GetReturnType ();
345
- if (!callableNodeType->IsSameType (*callableFuncType)) {
402
+ const auto closureFuncType = runConfigNodeType->IsVoid ()
403
+ ? callableFuncType
404
+ : AS_TYPE (TCallableType, callableFuncType)->GetReturnType ();
405
+ const auto closureNodeType = runConfigNodeType->IsVoid ()
406
+ ? AS_TYPE (TCallableType, callableNodeType)->GetReturnType ()
407
+ : callableNodeType;
408
+ if (!closureNodeType->IsConvertableTo (*closureFuncType)) {
346
409
TString diff = TStringBuilder ()
347
410
<< " type mismatch, expected return type: "
348
- << PrintNode (callableNodeType , true )
411
+ << PrintNode (closureNodeType , true )
349
412
<< " , actual: "
350
- << PrintNode (callableFuncType , true );
413
+ << PrintNode (closureFuncType , true );
351
414
UdfTerminate ((TStringBuilder () << pos
352
415
<< " Udf Function '"
353
416
<< funcName
@@ -357,25 +420,27 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
357
420
358
421
const auto runConfigCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
359
422
const auto runConfigArgs = funcInfo.FunctionType ->GetArgumentsCount ();
360
- return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType , userType);
423
+ return runConfigNodeType->IsVoid ()
424
+ ? CreateUdfWrapper<true >(ctx, std::move (funcName), std::move (typeConfig), pos, callableNodeType, callableFuncType, userType)
425
+ : CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runConfigCompNode, runConfigArgs, callableNodeType, userType);
361
426
}
362
- MKQL_ENSURE (funcInfo. FunctionType ->IsConvertableTo (*callable. GetType ()-> GetReturnType () , true ),
363
- " Function '" << funcName << " ' type mismatch, expected return type: " << PrintNode (callable. GetType ()-> GetReturnType () , true ) <<
364
- " , actual:" << PrintNode (funcInfo. FunctionType , true ));
427
+ MKQL_ENSURE (callableFuncType ->IsConvertableTo (*callableNodeType , true ),
428
+ " Function '" << funcName << " ' type mismatch, expected return type: " << PrintNode (callableNodeType , true ) <<
429
+ " , actual:" << PrintNode (callableFuncType , true ));
365
430
MKQL_ENSURE (funcInfo.Implementation , " UDF implementation is not set for function " << funcName);
366
431
367
432
if (runConfigFuncType->IsVoid ()) {
368
433
if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName ) {
369
434
return new TUdfRunCodegeneratorNode (
370
- ctx.Mutables , std::move (funcName), std::move (typeConfig), pos, funcInfo. FunctionType , userType,
435
+ ctx.Mutables , std::move (funcName), std::move (typeConfig), pos, callableNodeType, callableFuncType , userType,
371
436
std::move (funcInfo.ModuleIRUniqID ), std::move (funcInfo.ModuleIR ), std::move (funcInfo.IRFunctionName ), std::move (funcInfo.Implementation )
372
437
);
373
438
}
374
- return CreateUdfWrapper<true >(ctx, std::move (funcName), std::move (typeConfig), pos, funcInfo. FunctionType , userType);
439
+ return CreateUdfWrapper<true >(ctx, std::move (funcName), std::move (typeConfig), pos, callableNodeType, callableFuncType , userType);
375
440
}
376
441
377
442
const auto runCfgCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
378
- return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, 1U , funcInfo. FunctionType , userType);
443
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, 1U , callableNodeType , userType);
379
444
}
380
445
381
446
IComputationNode* WrapScriptUdf (TCallable& callable, const TComputationNodeFactoryContext& ctx) {
0 commit comments