10
10
#include < yql/essentials/minikql/mkql_utils.h>
11
11
#include < yql/essentials/utils/yql_panic.h>
12
12
13
+ #include < library/cpp/containers/stack_array/stack_array.h>
14
+
13
15
namespace NKikimr {
14
16
namespace NMiniKQL {
15
17
16
18
namespace {
17
19
20
+ constexpr size_t TypeDiffLimit = 1000 ;
21
+
22
+ TString TruncateTypeDiff (const TString& s) {
23
+ if (s.size () < TypeDiffLimit) {
24
+ return s;
25
+ }
26
+
27
+ return s.substr (0 ,TypeDiffLimit) + " ..." ;
28
+ }
29
+
18
30
template <class TValidatePolicy , class TValidateMode >
19
31
class TSimpleUdfWrapper : public TMutableComputationNode <TSimpleUdfWrapper<TValidatePolicy,TValidateMode>> {
20
32
using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePolicy,TValidateMode>>;
@@ -118,13 +130,15 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
118
130
TString&& typeConfig,
119
131
NUdf::TSourcePosition pos,
120
132
IComputationNode* runConfigNode,
133
+ ui32 runConfigArgs,
121
134
const TCallableType* callableType,
122
135
TType* userType)
123
136
: TBaseComputation(mutables, EValueRepresentation::Boxed)
124
137
, FunctionName(std::move(functionName))
125
138
, TypeConfig(std::move(typeConfig))
126
139
, Pos(pos)
127
140
, RunConfigNode(runConfigNode)
141
+ , RunConfigArgs(runConfigArgs)
128
142
, CallableType(callableType)
129
143
, UserType(userType)
130
144
, UdfIndex(mutables.CurValueIndex++)
@@ -137,15 +151,17 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
137
151
if (!udf.HasValue ()) {
138
152
MakeUdf (ctx, udf);
139
153
}
140
- const auto runConfig = RunConfigNode->GetValue (ctx);
141
- auto callable = udf.Run (ctx.Builder , &runConfig);
154
+ NStackArray::TStackArray<NUdf::TUnboxedValue> args (ALLOC_ON_STACK (NUdf::TUnboxedValue, RunConfigArgs));
155
+ args[0 ] = RunConfigNode->GetValue (ctx);
156
+ auto callable = udf.Run (ctx.Builder , args.data ());
142
157
Wrap (callable);
143
158
return callable;
144
159
}
145
160
#ifndef MKQL_DISABLE_CODEGEN
146
161
void DoGenerateGetValue (const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
147
162
auto & context = ctx.Codegen .GetContext ();
148
163
164
+ const auto indexType = Type::getInt32Ty (context);
149
165
const auto valueType = Type::getInt128Ty (context);
150
166
151
167
const auto udfPtr = GetElementPtrInst::CreateInBounds (valueType, ctx.GetMutables (), {ConstantInt::get (Type::getInt32Ty (context), UdfIndex)}, " udf_ptr" , block);
@@ -168,13 +184,24 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
168
184
169
185
block = main;
170
186
171
- GetNodeValue (pointer, RunConfigNode, ctx, block);
172
- const auto conf = new LoadInst (valueType, pointer, " conf" , block);
187
+ const auto argsType = ArrayType::get (valueType, RunConfigArgs);
188
+ const auto args = new AllocaInst (argsType, 0U , " args" , block);
189
+ Value* runConfigValue;
190
+ for (ui32 i = 0 ; i < RunConfigArgs; i++) {
191
+ const auto argIndex = ConstantInt::get (indexType, i);
192
+ const auto argSlot = GetElementPtrInst::CreateInBounds (valueType, args, {argIndex}, " arg" , block);
193
+ if (i == 0 ) {
194
+ GetNodeValue (argSlot, RunConfigNode, ctx, block);
195
+ runConfigValue = new LoadInst (valueType, argSlot, " runconfig" , block);
196
+ } else {
197
+ new StoreInst (ConstantInt::get (valueType, 0U ), argSlot, block);
198
+ }
199
+ }
173
200
const auto udf = new LoadInst (valueType, udfPtr, " udf" , block);
174
201
175
- CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen , block, ctx.GetBuilder (), pointer );
202
+ CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen , block, ctx.GetBuilder (), args );
176
203
177
- ValueUnRef (RunConfigNode->GetRepresentation (), conf , ctx, block);
204
+ ValueUnRef (RunConfigNode->GetRepresentation (), runConfigValue , ctx, block);
178
205
179
206
const auto wrap = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr (&TUdfWrapper::Wrap));
180
207
const auto funType = FunctionType::get (Type::getVoidTy (context), {self->getType (), pointer->getType ()}, false );
@@ -208,6 +235,7 @@ using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolic
208
235
const TString TypeConfig;
209
236
const NUdf::TSourcePosition Pos;
210
237
IComputationNode* const RunConfigNode;
238
+ const ui32 RunConfigArgs;
211
239
const TCallableType* CallableType;
212
240
TType* const UserType;
213
241
const ui32 UdfIndex;
@@ -271,17 +299,71 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
271
299
typeConfig, flags, pos, ctx.SecureParamsProvider , &funcInfo);
272
300
273
301
MKQL_ENSURE (status.IsOk (), status.GetError ());
302
+
303
+ const auto runConfigFuncType = funcInfo.RunConfigType ;
304
+ const auto runConfigNodeType = runCfgNode.GetStaticType ();
305
+
306
+ if (!runConfigFuncType->IsSameType (*runConfigNodeType)) {
307
+ // It's only legal, when the compiled UDF declares its
308
+ // signature using run config at compilation phase, but then
309
+ // omits it in favor to function currying at execution phase.
310
+ if (!runConfigFuncType->IsVoid ()) {
311
+ TString diff = TStringBuilder ()
312
+ << " run config type mismatch, expected: "
313
+ << PrintNode ((runConfigNodeType), true )
314
+ << " , actual: "
315
+ << PrintNode (runConfigFuncType, true );
316
+ UdfTerminate ((TStringBuilder () << pos
317
+ << " UDF Function '"
318
+ << funcName
319
+ << " ' "
320
+ << TruncateTypeDiff (diff)).c_str ());
321
+ }
322
+
323
+ // If so, check the following invariants:
324
+ // * The first argument of the head function in the sequence
325
+ // of the curried functions has to be the same as the
326
+ // run config type.
327
+ // * The type of the resulting callable has to be the same
328
+ // as the function type.
329
+ const auto firstArgType = funcInfo.FunctionType ->GetArgumentType (0 );
330
+ if (!runConfigNodeType->IsSameType (*firstArgType)) {
331
+ TString diff = TStringBuilder ()
332
+ << " type mismatch, expected run config type: "
333
+ << PrintNode (runConfigNodeType, true )
334
+ << " , actual: "
335
+ << PrintNode (firstArgType, true );
336
+ UdfTerminate ((TStringBuilder () << pos
337
+ << " Udf Function '"
338
+ << funcName
339
+ << " ' "
340
+ << TruncateTypeDiff (diff)).c_str ());
341
+ }
342
+ const auto callableFuncType = funcInfo.FunctionType ->GetReturnType ();
343
+ const auto callableNodeType = callable.GetType ()->GetReturnType ();
344
+ if (!callableNodeType->IsSameType (*callableFuncType)) {
345
+ TString diff = TStringBuilder ()
346
+ << " type mismatch, expected return type: "
347
+ << PrintNode (callableNodeType, true )
348
+ << " , actual: "
349
+ << PrintNode (callableFuncType, true );
350
+ UdfTerminate ((TStringBuilder () << pos
351
+ << " Udf Function '"
352
+ << funcName
353
+ << " ' "
354
+ << TruncateTypeDiff (diff)).c_str ());
355
+ }
356
+
357
+ const auto runConfigCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
358
+ const auto runConfigArgs = funcInfo.FunctionType ->GetArgumentsCount ();
359
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType , userType);
360
+ }
274
361
MKQL_ENSURE (funcInfo.FunctionType ->IsConvertableTo (*callable.GetType ()->GetReturnType (), true ),
275
362
" Function '" << funcName << " ' type mismatch, expected return type: " << PrintNode (callable.GetType ()->GetReturnType (), true ) <<
276
363
" , actual:" << PrintNode (funcInfo.FunctionType , true ));
277
364
MKQL_ENSURE (funcInfo.Implementation , " UDF implementation is not set for function " << funcName);
278
365
279
- const auto runConfigType = funcInfo.RunConfigType ;
280
- const bool typesMatch = runConfigType->IsSameType (*runCfgNode.GetStaticType ());
281
- MKQL_ENSURE (typesMatch, " RunConfig '" << funcName << " ' type mismatch, expected: " << PrintNode (runCfgNode.GetStaticType (), true ) <<
282
- " , actual: " << PrintNode (runConfigType, true ));
283
-
284
- if (runConfigType->IsVoid ()) {
366
+ if (runConfigFuncType->IsVoid ()) {
285
367
if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName ) {
286
368
return new TUdfRunCodegeneratorNode (
287
369
ctx.Mutables , std::move (funcName), std::move (typeConfig), pos, funcInfo.FunctionType , userType,
@@ -292,7 +374,7 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont
292
374
}
293
375
294
376
const auto runCfgCompNode = LocateNode (ctx.NodeLocator , *runCfgNode.GetNode ());
295
- return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, funcInfo.FunctionType , userType);
377
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, runCfgCompNode, 1U , funcInfo.FunctionType , userType);
296
378
}
297
379
298
380
IComputationNode* WrapScriptUdf (TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -331,7 +413,7 @@ IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFacto
331
413
const auto funcTypeInfo = static_cast <TCallableType*>(callableResultType);
332
414
333
415
const auto programCompNode = LocateNode (ctx.NodeLocator , *programNode.GetNode ());
334
- return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, programCompNode, funcTypeInfo, userType);
416
+ return CreateUdfWrapper<false >(ctx, std::move (funcName), std::move (typeConfig), pos, programCompNode, 1U , funcTypeInfo, userType);
335
417
}
336
418
337
419
}
0 commit comments