Skip to content

Commit 1e03aed

Browse files
committed
YQL-19967: Properly handle optional args for curried function
Follows up e13229863b189b9ad804f0c6772204399430179e Follows up 5b56b1b4416cf5b7abcae727a61e9402e10af067 commit_hash:c14af47d9a4960f42df3fad23da4b03721d27f7c (cherry picked from commit 896fcc7)
1 parent 4d47c26 commit 1e03aed

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

yql/essentials/minikql/comp_nodes/mkql_udf.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -322,17 +322,30 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
322322
<< TruncateTypeDiff(diff)).c_str());
323323
}
324324

325+
const auto callableFuncType = AS_TYPE(TCallableType, funcInfo.FunctionType);
326+
const auto callableNodeType = AS_TYPE(TCallableType, callable.GetType()->GetReturnType());
327+
328+
const auto callableType = runConfigNodeType->IsVoid()
329+
? callableNodeType : callableFuncType;
330+
const auto runConfigType = runConfigNodeType->IsVoid()
331+
? runConfigFuncType : runConfigNodeType;
332+
325333
// If so, check the following invariants:
326334
// * The first argument of the head function in the sequence
327335
// of the curried functions has to be the same as the
328336
// run config type.
337+
// * All other arguments of the head function in the sequence
338+
// of the curried function have to be optional.
329339
// * The type of the resulting callable has to be the same
330340
// as the function type.
331-
const auto firstArgType = runConfigNodeType->IsVoid()
332-
? callable.GetType()->GetArgumentType(0)
333-
: funcInfo.FunctionType->GetArgumentType(0);
334-
const auto runConfigType = runConfigNodeType->IsVoid()
335-
? runConfigFuncType : runConfigNodeType;
341+
if (callableType->GetArgumentsCount() - callableType->GetOptionalArgumentsCount() != 1U) {
342+
UdfTerminate((TStringBuilder() << pos
343+
<< " Udf Function '"
344+
<< funcName
345+
<< "' wrapper has more than one required argument: "
346+
<< PrintNode(callableType)).c_str());
347+
}
348+
const auto firstArgType = callableType->GetArgumentType(0);
336349
if (!runConfigType->IsSameType(*firstArgType)) {
337350
TString diff = TStringBuilder()
338351
<< "type mismatch, expected run config type: "
@@ -345,18 +358,18 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
345358
<< "' "
346359
<< TruncateTypeDiff(diff)).c_str());
347360
}
348-
const auto callableFuncType = runConfigNodeType->IsVoid()
349-
? funcInfo.FunctionType
350-
: funcInfo.FunctionType->GetReturnType();
351-
const auto callableNodeType = runConfigNodeType->IsVoid()
352-
? AS_TYPE(TCallableType, callable.GetType()->GetReturnType())->GetReturnType()
353-
: callable.GetType()->GetReturnType();
354-
if (!callableNodeType->IsSameType(*callableFuncType)) {
361+
const auto closureFuncType = runConfigNodeType->IsVoid()
362+
? callableFuncType
363+
: AS_TYPE(TCallableType, callableFuncType)->GetReturnType();
364+
const auto closureNodeType = runConfigNodeType->IsVoid()
365+
? AS_TYPE(TCallableType, callableNodeType)->GetReturnType()
366+
: callableNodeType;
367+
if (!closureNodeType->IsSameType(*closureFuncType)) {
355368
TString diff = TStringBuilder()
356369
<< "type mismatch, expected return type: "
357-
<< PrintNode(callableNodeType, true)
370+
<< PrintNode(closureNodeType, true)
358371
<< ", actual: "
359-
<< PrintNode(callableFuncType, true);
372+
<< PrintNode(closureFuncType, true);
360373
UdfTerminate((TStringBuilder() << pos
361374
<< " Udf Function '"
362375
<< funcName

0 commit comments

Comments
 (0)