10
10
#include < yql/essentials/minikql/mkql_utils.h>
11
11
#include < yql/essentials/utils/yql_panic.h>
12
12
13
+ #include < library/cpp/containers/stack_array/stack_array.h>
14
+
13
15
namespace NKikimr {
14
16
namespace NMiniKQL {
15
17
@@ -135,13 +137,15 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
135
137
TString&& typeConfig,
136
138
NUdf::TSourcePosition pos,
137
139
IComputationNode* runConfigNode,
140
+ ui32 runConfigArgs,
138
141
const TCallableType* callableType,
139
142
TType* userType)
140
143
: TBaseComputation(mutables, EValueRepresentation::Boxed)
141
144
, FunctionName(std::move(functionName))
142
145
, TypeConfig(std::move(typeConfig))
143
146
, Pos(pos)
144
147
, RunConfigNode(runConfigNode)
148
+ , RunConfigArgs(runConfigArgs)
145
149
, CallableType(callableType)
146
150
, UserType(userType)
147
151
, UdfIndex(mutables.CurValueIndex++)
@@ -154,15 +158,17 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
154
158
if (!udf.HasValue ()) {
155
159
MakeUdf (ctx, udf);
156
160
}
157
- const auto runConfig = RunConfigNode->GetValue (ctx);
158
- auto callable = udf.Run (ctx.Builder , &runConfig);
161
+ NStackArray::TStackArray<NUdf::TUnboxedValue> args (ALLOC_ON_STACK (NUdf::TUnboxedValue, RunConfigArgs));
162
+ args[0 ] = RunConfigNode->GetValue (ctx);
163
+ auto callable = udf.Run (ctx.Builder , args.data ());
159
164
Wrap (callable);
160
165
return callable;
161
166
}
162
167
#ifndef MKQL_DISABLE_CODEGEN
163
168
void DoGenerateGetValue (const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
164
169
auto & context = ctx.Codegen .GetContext ();
165
170
171
+ const auto indexType = Type::getInt32Ty (context);
166
172
const auto valueType = Type::getInt128Ty (context);
167
173
168
174
const auto udfPtr = GetElementPtrInst::CreateInBounds (valueType, ctx.GetMutables (), {ConstantInt::get (Type::getInt32Ty (context), UdfIndex)}, " udf_ptr" , block);
@@ -185,13 +191,24 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
185
191
186
192
block = main;
187
193
188
- GetNodeValue (pointer, RunConfigNode, ctx, block);
189
- const auto conf = new LoadInst (valueType, pointer, " conf" , block);
194
+ const auto argsType = ArrayType::get (valueType, RunConfigArgs);
195
+ const auto args = new AllocaInst (argsType, 0U , " args" , block);
196
+ Value* runConfigValue;
197
+ for (ui32 i = 0 ; i < RunConfigArgs; i++) {
198
+ const auto argIndex = ConstantInt::get (indexType, i);
199
+ const auto argSlot = GetElementPtrInst::CreateInBounds (valueType, args, {argIndex}, " arg" , block);
200
+ if (i == 0 ) {
201
+ GetNodeValue (argSlot, RunConfigNode, ctx, block);
202
+ runConfigValue = new LoadInst (valueType, argSlot, " runconfig" , block);
203
+ } else {
204
+ new StoreInst (ConstantInt::get (valueType, 0U ), argSlot, block);
205
+ }
206
+ }
190
207
const auto udf = new LoadInst (valueType, udfPtr, " udf" , block);
191
208
192
- CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen , block, ctx.GetBuilder (), pointer );
209
+ CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen , block, ctx.GetBuilder (), args );
193
210
194
- ValueUnRef (RunConfigNode->GetRepresentation (), conf , ctx, block);
211
+ ValueUnRef (RunConfigNode->GetRepresentation (), runConfigValue , ctx, block);
195
212
196
213
const auto wrap = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr<&TUdfWrapper::Wrap>());
197
214
const auto funType = FunctionType::get (Type::getVoidTy (context), {self->getType (), pointer->getType ()}, false );
@@ -231,6 +248,7 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
231
248
const TString TypeConfig;
232
249
const NUdf::TSourcePosition Pos;
233
250
IComputationNode* const RunConfigNode;
251
+ const ui32 RunConfigArgs;
234
252
const TCallableType* CallableType;
235
253
TType* const UserType;
236
254
const ui32 UdfIndex;
@@ -298,6 +316,65 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
298
316
<< status.GetError ()).c_str ());
299
317
}
300
318
319
+ const auto runConfigFuncType = funcInfo.RunConfigType ;
320
+ const auto runConfigNodeType = runCfgNode.GetStaticType ();
321
+
322
+ if (!runConfigFuncType->IsSameType (*runConfigNodeType)) {
323
+ // It's only legal, when the compiled UDF declares its
324
+ // signature using run config at compilation phase, but then
325
+ // omits it in favor to function currying at execution phase.
326
+ if (!runConfigFuncType->IsVoid ()) {
327
+ TString diff = TStringBuilder ()
328
+ << " run config type mismatch, expected: "
329
+ << PrintNode ((runConfigNodeType), true )
330
+ << " , actual: "
331
+ << PrintNode (runConfigFuncType, true );
332
+ UdfTerminate ((TStringBuilder () << pos
333
+ << " UDF Function '"
334
+ << funcName
335
+ << " ' "
336
+ << TruncateTypeDiff (diff)).c_str ());
337
+ }
338
+
339
+ // If so, check the following invariants:
340
+ // * The first argument of the head function in the sequence
341
+ // of the curried functions has to be the same as the
342
+ // run config type.
343
+ // * The type of the resulting callable has to be the same
344
+ // as the function type.
345
+ const auto firstArgType = funcInfo.FunctionType ->GetArgumentType (0 );
346
+ if (!runConfigNodeType->IsSameType (*firstArgType)) {
347
+ TString diff = TStringBuilder ()
348
+ << " type mismatch, expected run config type: "
349
+ << PrintNode (runConfigNodeType, true )
350
+ << " , actual: "
351
+ << PrintNode (firstArgType, true );
352
+ UdfTerminate ((TStringBuilder () << pos
353
+ << " Udf Function '"
354
+ << funcName
355
+ << " ' "
356
+ << TruncateTypeDiff (diff)).c_str ());
357
+ }
358
+ const auto callableFuncType = funcInfo.FunctionType ->GetReturnType ();
359
+ const auto callableNodeType = callable.GetType ()->GetReturnType ();
360
+ if (!callableNodeType->IsSameType (*callableFuncType)) {
361
+ TString diff = TStringBuilder ()
362
+ << " type mismatch, expected return type: "
363
+ << PrintNode (callableNodeType, true )
364
+ << " , actual: "
365
+ << PrintNode (callableFuncType, true );
366
+ UdfTerminate ((TStringBuilder () << pos
367
+ << " Udf Function '"
368
+ << funcName
369
+ << " ' "
370
+ << TruncateTypeDiff (diff)).c_str ());
371
+ }
372
+
373
+ const auto runConfigCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
374
+ const auto runConfigArgs = funcInfo.FunctionType ->GetArgumentsCount ();
375
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType , userType);
376
+ }
377
+
301
378
if (!funcInfo.FunctionType ->IsConvertableTo (*callable.GetType ()->GetReturnType (), true )) {
302
379
TString diff = TStringBuilder () << " type mismatch, expected return type: " << PrintNode (callable.GetType ()->GetReturnType (), true ) <<
303
380
" , actual:" << PrintNode (funcInfo.FunctionType , true );
@@ -308,14 +385,7 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
308
385
UdfTerminate ((TStringBuilder () << pos << " UDF implementation is not set for function " << funcName).c_str ());
309
386
}
310
387
311
- const auto runConfigType = funcInfo.RunConfigType ;
312
- if (!runConfigType->IsSameType (*runCfgNode.GetStaticType ())) {
313
- TString diff = TStringBuilder () << " run config type mismatch, expected: " << PrintNode (runCfgNode.GetStaticType (), true ) <<
314
- " , actual:" << PrintNode (runConfigType, true );
315
- UdfTerminate ((TStringBuilder () << pos << " UDF Function '" << funcName << " ' " << TruncateTypeDiff (diff)).c_str ());
316
- }
317
-
318
- if (runConfigType->IsVoid ()) {
388
+ if (runConfigFuncType->IsVoid ()) {
319
389
if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName ) {
320
390
return new TUdfRunCodegeneratorNode (
321
391
ctx.Mutables , std::move (funcName), std::move (typeConfig), pos, funcInfo.FunctionType , userType,
@@ -326,7 +396,7 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
326
396
}
327
397
328
398
const auto runCfgCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
329
- return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, funcInfo.FunctionType , userType);
399
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, 1U , funcInfo.FunctionType , userType);
330
400
}
331
401
332
402
IComputationNode* WrapScriptUdf (TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -375,7 +445,7 @@ IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFacto
375
445
const auto funcTypeInfo = static_cast <TCallableType*>(callableResultType);
376
446
377
447
const auto programCompNode = LocateNode (ctx.NodeLocator , *programNode.GetNode ());
378
- return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, programCompNode, funcTypeInfo, userType);
448
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, programCompNode, 1U , funcTypeInfo, userType);
379
449
}
380
450
381
451
}
0 commit comments