@@ -322,17 +322,30 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
322
322
<< TruncateTypeDiff (diff)).c_str ());
323
323
}
324
324
325
+ const auto callableFuncType = AS_TYPE (TCallableType, funcInfo.FunctionType );
326
+ const auto callableNodeType = AS_TYPE (TCallableType, callable.GetType ()->GetReturnType ());
327
+
328
+ const auto callableType = runConfigNodeType->IsVoid ()
329
+ ? callableNodeType : callableFuncType;
330
+ const auto runConfigType = runConfigNodeType->IsVoid ()
331
+ ? runConfigFuncType : runConfigNodeType;
332
+
325
333
// If so, check the following invariants:
326
334
// * The first argument of the head function in the sequence
327
335
// of the curried functions has to be the same as the
328
336
// run config type.
337
+ // * All other arguments of the head function in the sequence
338
+ // of the curried function have to be optional.
329
339
// * The type of the resulting callable has to be the same
330
340
// as the function type.
331
- const auto firstArgType = runConfigNodeType->IsVoid ()
332
- ? callable.GetType ()->GetArgumentType (0 )
333
- : funcInfo.FunctionType ->GetArgumentType (0 );
334
- const auto runConfigType = runConfigNodeType->IsVoid ()
335
- ? runConfigFuncType : runConfigNodeType;
341
+ if (callableType->GetArgumentsCount () - callableType->GetOptionalArgumentsCount () != 1U ) {
342
+ UdfTerminate ((TStringBuilder () << pos
343
+ << " Udf Function '"
344
+ << funcName
345
+ << " ' wrapper has more than one required argument: "
346
+ << PrintNode (callableType)).c_str ());
347
+ }
348
+ const auto firstArgType = callableType->GetArgumentType (0 );
336
349
if (!runConfigType->IsSameType (*firstArgType)) {
337
350
TString diff = TStringBuilder ()
338
351
<< " type mismatch, expected run config type: "
@@ -345,18 +358,18 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
345
358
<< " ' "
346
359
<< TruncateTypeDiff (diff)).c_str ());
347
360
}
348
- const auto callableFuncType = runConfigNodeType->IsVoid ()
349
- ? funcInfo. FunctionType
350
- : funcInfo. FunctionType ->GetReturnType ();
351
- const auto callableNodeType = runConfigNodeType->IsVoid ()
352
- ? AS_TYPE (TCallableType, callable. GetType ()-> GetReturnType () )->GetReturnType ()
353
- : callable. GetType ()-> GetReturnType () ;
354
- if (!callableNodeType ->IsSameType (*callableFuncType )) {
361
+ const auto closureFuncType = runConfigNodeType->IsVoid ()
362
+ ? callableFuncType
363
+ : AS_TYPE (TCallableType, callableFuncType) ->GetReturnType ();
364
+ const auto closureNodeType = runConfigNodeType->IsVoid ()
365
+ ? AS_TYPE (TCallableType, callableNodeType )->GetReturnType ()
366
+ : callableNodeType ;
367
+ if (!closureNodeType ->IsSameType (*closureFuncType )) {
355
368
TString diff = TStringBuilder ()
356
369
<< " type mismatch, expected return type: "
357
- << PrintNode (callableNodeType , true )
370
+ << PrintNode (closureNodeType , true )
358
371
<< " , actual: "
359
- << PrintNode (callableFuncType , true );
372
+ << PrintNode (closureFuncType , true );
360
373
UdfTerminate ((TStringBuilder () << pos
361
374
<< " Udf Function '"
362
375
<< funcName
0 commit comments