10
10
#include < yql/essentials/minikql/mkql_utils.h>
11
11
#include < yql/essentials/utils/yql_panic.h>
12
12
13
+ #include < library/cpp/containers/stack_array/stack_array.h>
14
+
13
15
namespace NKikimr {
14
16
namespace NMiniKQL {
15
17
16
18
namespace {
17
19
20
+ constexpr size_t TypeDiffLimit = 1000 ;
21
+
22
+ TString TruncateTypeDiff (const TString& s) {
23
+ if (s.size () < TypeDiffLimit) {
24
+ return s;
25
+ }
26
+
27
+ return s.substr (0 ,TypeDiffLimit) + " ..." ;
28
+ }
29
+
18
30
template <class TValidatePolicy , class TValidateMode >
19
31
class TSimpleUdfWrapper : public TMutableComputationNode <TSimpleUdfWrapper<TValidatePolicy,TValidateMode>> {
20
32
using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePolicy,TValidateMode>>;
@@ -118,13 +130,15 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
118
130
TString&& typeConfig,
119
131
NUdf::TSourcePosition pos,
120
132
IComputationNode* runConfigNode,
133
+ ui32 runConfigArgs,
121
134
const TCallableType* callableType,
122
135
TType* userType)
123
136
: TBaseComputation(mutables, EValueRepresentation::Boxed)
124
137
, FunctionName(std::move(functionName))
125
138
, TypeConfig(std::move(typeConfig))
126
139
, Pos(pos)
127
140
, RunConfigNode(runConfigNode)
141
+ , RunConfigArgs(runConfigArgs)
128
142
, CallableType(callableType)
129
143
, UserType(userType)
130
144
, UdfIndex(mutables.CurValueIndex++)
@@ -137,15 +151,17 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
137
151
if (!udf.HasValue ()) {
138
152
MakeUdf (ctx, udf);
139
153
}
140
- const auto runConfig = RunConfigNode->GetValue (ctx);
141
- auto callable = udf.Run (ctx.Builder , &runConfig);
154
+ NStackArray::TStackArray<NUdf::TUnboxedValue> args (ALLOC_ON_STACK (NUdf::TUnboxedValue, RunConfigArgs));
155
+ args[0 ] = RunConfigNode->GetValue (ctx);
156
+ auto callable = udf.Run (ctx.Builder , args.data ());
142
157
Wrap (callable);
143
158
return callable;
144
159
}
145
160
#ifndef MKQL_DISABLE_CODEGEN
146
161
void DoGenerateGetValue (const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
147
162
auto & context = ctx.Codegen .GetContext ();
148
163
164
+ const auto indexType = Type::getInt32Ty (context);
149
165
const auto valueType = Type::getInt128Ty (context);
150
166
151
167
const auto udfPtr = GetElementPtrInst::CreateInBounds (valueType, ctx.GetMutables (), {ConstantInt::get (Type::getInt32Ty (context), UdfIndex)}, " udf_ptr" , block);
@@ -168,13 +184,25 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
168
184
169
185
block = main;
170
186
171
- GetNodeValue (pointer, RunConfigNode, ctx, block);
172
- const auto conf = new LoadInst (valueType, pointer, " conf" , block);
187
+ const auto argsType = ArrayType::get (valueType, RunConfigArgs);
188
+ const auto args = new AllocaInst (argsType, 0U , " args" , block);
189
+ const auto zero = ConstantInt::get (indexType, 0 );
190
+ Value* runConfigValue;
191
+ for (ui32 i = 0 ; i < RunConfigArgs; i++) {
192
+ const auto argIndex = ConstantInt::get (indexType, i);
193
+ const auto argSlot = GetElementPtrInst::CreateInBounds (argsType, args, {zero, argIndex}, " arg" , block);
194
+ if (i == 0 ) {
195
+ GetNodeValue (argSlot, RunConfigNode, ctx, block);
196
+ runConfigValue = new LoadInst (valueType, argSlot, " runconfig" , block);
197
+ } else {
198
+ new StoreInst (ConstantInt::get (valueType, 0U ), argSlot, block);
199
+ }
200
+ }
173
201
const auto udf = new LoadInst (valueType, udfPtr, " udf" , block);
174
202
175
- CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen , block, ctx.GetBuilder (), pointer );
203
+ CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen , block, ctx.GetBuilder (), args );
176
204
177
- ValueUnRef (RunConfigNode->GetRepresentation (), conf , ctx, block);
205
+ ValueUnRef (RunConfigNode->GetRepresentation (), runConfigValue , ctx, block);
178
206
179
207
const auto wrap = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr (&TUdfWrapper::Wrap));
180
208
const auto funType = FunctionType::get (Type::getVoidTy (context), {self->getType (), pointer->getType ()}, false );
@@ -208,6 +236,7 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
208
236
const TString TypeConfig;
209
237
const NUdf::TSourcePosition Pos;
210
238
IComputationNode* const RunConfigNode;
239
+ const ui32 RunConfigArgs;
211
240
const TCallableType* CallableType;
212
241
TType* const UserType;
213
242
const ui32 UdfIndex;
@@ -271,17 +300,71 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
271
300
typeConfig, flags, pos, ctx.SecureParamsProvider , &funcInfo);
272
301
273
302
MKQL_ENSURE (status.IsOk (), status.GetError ());
303
+
304
+ const auto runConfigFuncType = funcInfo.RunConfigType ;
305
+ const auto runConfigNodeType = runCfgNode.GetStaticType ();
306
+
307
+ if (!runConfigFuncType->IsSameType (*runConfigNodeType)) {
308
+ // It's only legal, when the compiled UDF declares its
309
+ // signature using run config at compilation phase, but then
310
+ // omits it in favor to function currying at execution phase.
311
+ if (!runConfigFuncType->IsVoid ()) {
312
+ TString diff = TStringBuilder ()
313
+ << " run config type mismatch, expected: "
314
+ << PrintNode ((runConfigNodeType), true )
315
+ << " , actual: "
316
+ << PrintNode (runConfigFuncType, true );
317
+ UdfTerminate ((TStringBuilder () << pos
318
+ << " UDF Function '"
319
+ << funcName
320
+ << " ' "
321
+ << TruncateTypeDiff (diff)).c_str ());
322
+ }
323
+
324
+ // If so, check the following invariants:
325
+ // * The first argument of the head function in the sequence
326
+ // of the curried functions has to be the same as the
327
+ // run config type.
328
+ // * The type of the resulting callable has to be the same
329
+ // as the function type.
330
+ const auto firstArgType = funcInfo.FunctionType ->GetArgumentType (0 );
331
+ if (!runConfigNodeType->IsSameType (*firstArgType)) {
332
+ TString diff = TStringBuilder ()
333
+ << " type mismatch, expected run config type: "
334
+ << PrintNode (runConfigNodeType, true )
335
+ << " , actual: "
336
+ << PrintNode (firstArgType, true );
337
+ UdfTerminate ((TStringBuilder () << pos
338
+ << " Udf Function '"
339
+ << funcName
340
+ << " ' "
341
+ << TruncateTypeDiff (diff)).c_str ());
342
+ }
343
+ const auto callableFuncType = funcInfo.FunctionType ->GetReturnType ();
344
+ const auto callableNodeType = callable.GetType ()->GetReturnType ();
345
+ if (!callableNodeType->IsSameType (*callableFuncType)) {
346
+ TString diff = TStringBuilder ()
347
+ << " type mismatch, expected return type: "
348
+ << PrintNode (callableNodeType, true )
349
+ << " , actual: "
350
+ << PrintNode (callableFuncType, true );
351
+ UdfTerminate ((TStringBuilder () << pos
352
+ << " Udf Function '"
353
+ << funcName
354
+ << " ' "
355
+ << TruncateTypeDiff (diff)).c_str ());
356
+ }
357
+
358
+ const auto runConfigCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
359
+ const auto runConfigArgs = funcInfo.FunctionType ->GetArgumentsCount ();
360
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType , userType);
361
+ }
274
362
MKQL_ENSURE (funcInfo.FunctionType ->IsConvertableTo (*callable.GetType ()->GetReturnType (), true ),
275
363
" Function '" << funcName << " ' type mismatch, expected return type: " << PrintNode (callable.GetType ()->GetReturnType (), true ) <<
276
364
" , actual:" << PrintNode (funcInfo.FunctionType , true ));
277
365
MKQL_ENSURE (funcInfo.Implementation , " UDF implementation is not set for function " << funcName);
278
366
279
- const auto runConfigType = funcInfo.RunConfigType ;
280
- const bool typesMatch = runConfigType->IsSameType (*runCfgNode.GetStaticType ());
281
- MKQL_ENSURE (typesMatch, " RunConfig '" << funcName << " ' type mismatch, expected: " << PrintNode (runCfgNode.GetStaticType (), true ) <<
282
- " , actual: " << PrintNode (runConfigType, true ));
283
-
284
- if (runConfigType->IsVoid ()) {
367
+ if (runConfigFuncType->IsVoid ()) {
285
368
if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName ) {
286
369
return new TUdfRunCodegeneratorNode (
287
370
ctx.Mutables , std::move (funcName), std::move (typeConfig), pos, funcInfo.FunctionType , userType,
@@ -292,7 +375,7 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
292
375
}
293
376
294
377
const auto runCfgCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
295
- return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, funcInfo.FunctionType , userType);
378
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, 1U , funcInfo.FunctionType , userType);
296
379
}
297
380
298
381
IComputationNode* WrapScriptUdf (TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -331,7 +414,7 @@ IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFacto
331
414
const auto funcTypeInfo = static_cast <TCallableType*>(callableResultType);
332
415
333
416
const auto programCompNode = LocateNode (ctx.NodeLocator , *programNode.GetNode ());
334
- return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, programCompNode, funcTypeInfo, userType);
417
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, programCompNode, 1U , funcTypeInfo, userType);
335
418
}
336
419
337
420
}
0 commit comments