|
1 | 1 | #include "mkql_builtins_impl.h" // Y_IGNORE
|
| 2 | +#include <ydb/library/yql/minikql/mkql_node_builder.h> // UnpackOptionalData |
2 | 3 |
|
3 | 4 | namespace NKikimr {
|
4 | 5 | namespace NMiniKQL {
|
@@ -70,6 +71,7 @@ arrow::compute::InputType GetPrimitiveInputArrowType(NUdf::EDataSlot slot) {
|
70 | 71 | case NUdf::EDataSlot::Timestamp64: return GetPrimitiveInputArrowType<i64>();
|
71 | 72 | case NUdf::EDataSlot::TzTimestamp64: return GetPrimitiveInputArrowType<i64>(true);
|
72 | 73 | case NUdf::EDataSlot::Interval64: return GetPrimitiveInputArrowType<i64>();
|
| 74 | + case NUdf::EDataSlot::Decimal: return GetPrimitiveInputArrowType<NYql::NDecimal::TInt128>(); |
73 | 75 | default:
|
74 | 76 | ythrow yexception() << "Unexpected data slot: " << slot;
|
75 | 77 | }
|
@@ -104,6 +106,7 @@ arrow::compute::OutputType GetPrimitiveOutputArrowType(NUdf::EDataSlot slot) {
|
104 | 106 | case NUdf::EDataSlot::Timestamp64: return GetPrimitiveOutputArrowType<i64>();
|
105 | 107 | case NUdf::EDataSlot::TzTimestamp64: return GetPrimitiveOutputArrowType<i64>(true);
|
106 | 108 | case NUdf::EDataSlot::Interval64: return GetPrimitiveOutputArrowType<i64>();
|
| 109 | + case NUdf::EDataSlot::Decimal: return GetPrimitiveOutputArrowType<NYql::NDecimal::TInt128>(); |
107 | 110 | default:
|
108 | 111 | ythrow yexception() << "Unexpected data slot: " << slot;
|
109 | 112 | }
|
@@ -260,6 +263,68 @@ const arrow::compute::ScalarKernel& TPlainKernel::GetArrowKernel() const {
|
260 | 263 | return *ArrowKernel;
|
261 | 264 | }
|
262 | 265 |
|
| 266 | +std::shared_ptr<arrow::compute::ScalarKernel> TPlainKernel::MakeArrowKernel(const TVector<TType*>&, TType*) const { |
| 267 | + ythrow yexception() << "Unsupported kernel"; |
| 268 | +} |
| 269 | + |
| 270 | +bool TPlainKernel::IsPolymorphic() const { |
| 271 | + return false; |
| 272 | +} |
| 273 | + |
| 274 | +TDecimalKernel::TDecimalKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, |
| 275 | + NUdf::TDataTypeId returnType, TStatelessArrayKernelExec exec, |
| 276 | + TKernel::ENullMode nullMode) |
| 277 | + : TKernel(family, argTypes, returnType, nullMode) |
| 278 | + , Exec(exec) |
| 279 | +{ |
| 280 | +} |
| 281 | + |
| 282 | +const arrow::compute::ScalarKernel& TDecimalKernel::GetArrowKernel() const { |
| 283 | + ythrow yexception() << "Unsupported kernel"; |
| 284 | +} |
| 285 | + |
| 286 | +std::shared_ptr<arrow::compute::ScalarKernel> TDecimalKernel::MakeArrowKernel(const TVector<TType*>& argTypes, TType* resultType) const { |
| 287 | + MKQL_ENSURE(argTypes.size() == 2, "Require 2 arguments"); |
| 288 | + MKQL_ENSURE(argTypes[0]->GetKind() == TType::EKind::Block, "Require block"); |
| 289 | + MKQL_ENSURE(argTypes[1]->GetKind() == TType::EKind::Block, "Require block"); |
| 290 | + MKQL_ENSURE(resultType->GetKind() == TType::EKind::Block, "Require block"); |
| 291 | + |
| 292 | + bool isOptional = false; |
| 293 | + auto dataType1 = UnpackOptionalData(static_cast<TBlockType*>(argTypes[0])->GetItemType(), isOptional); |
| 294 | + auto dataType2 = UnpackOptionalData(static_cast<TBlockType*>(argTypes[1])->GetItemType(), isOptional); |
| 295 | + auto dataResultType = UnpackOptionalData(static_cast<TBlockType*>(resultType)->GetItemType(), isOptional); |
| 296 | + |
| 297 | + MKQL_ENSURE(*dataType1->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal"); |
| 298 | + MKQL_ENSURE(*dataType2->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal"); |
| 299 | + MKQL_ENSURE(*dataResultType->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal"); |
| 300 | + |
| 301 | + auto decimalType1 = static_cast<TDataDecimalType*>(dataType1); |
| 302 | + auto decimalType2 = static_cast<TDataDecimalType*>(dataType2); |
| 303 | + auto decimalResultType = static_cast<TDataDecimalType*>(dataResultType); |
| 304 | + |
| 305 | + MKQL_ENSURE(decimalType1->GetParams() == decimalType2->GetParams(), "Require same precision/scale"); |
| 306 | + MKQL_ENSURE(decimalType1->GetParams() == decimalResultType->GetParams(), "Require same precision/scale"); |
| 307 | + |
| 308 | + ui8 precision = decimalType1->GetParams().first; |
| 309 | + MKQL_ENSURE(precision >= 1&& precision <= 35, TStringBuilder() << "Wrong precision: " << (int)precision); |
| 310 | + |
| 311 | + auto k = std::make_shared<arrow::compute::ScalarKernel>(std::vector<arrow::compute::InputType>{ |
| 312 | + GetPrimitiveInputArrowType(NUdf::EDataSlot::Decimal), GetPrimitiveInputArrowType(NUdf::EDataSlot::Decimal) |
| 313 | + }, GetPrimitiveOutputArrowType(NUdf::EDataSlot::Decimal), Exec); |
| 314 | + k->null_handling = arrow::compute::NullHandling::INTERSECTION; |
| 315 | + k->init = [precision](arrow::compute::KernelContext*, const arrow::compute::KernelInitArgs&) { |
| 316 | + auto state = std::make_unique<TDecimalKernel::TKernelState>(); |
| 317 | + state->Precision = precision; |
| 318 | + return arrow::Result(std::move(state)); |
| 319 | + }; |
| 320 | + |
| 321 | + return k; |
| 322 | +} |
| 323 | + |
| 324 | +bool TDecimalKernel::IsPolymorphic() const { |
| 325 | + return true; |
| 326 | +} |
| 327 | + |
263 | 328 | void AddUnaryKernelImpl(TKernelFamilyBase& owner, NUdf::EDataSlot arg1, NUdf::EDataSlot res,
|
264 | 329 | TStatelessArrayKernelExec exec, TKernel::ENullMode nullMode) {
|
265 | 330 | auto type1 = NUdf::GetDataTypeInfo(arg1).TypeId;
|
@@ -600,5 +665,100 @@ arrow::Status ExecBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,
|
600 | 665 | }
|
601 | 666 | }
|
602 | 667 |
|
| 668 | +arrow::Status ExecDecimalArrayScalarOptImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, |
| 669 | + TUntypedBinaryArrayOptFuncPtr func) { |
| 670 | + MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); |
| 671 | + const auto& arg1 = batch.values[0]; |
| 672 | + const auto& arg2 = batch.values[1]; |
| 673 | + auto& resArr = *res->array(); |
| 674 | + if (arg2.scalar()->is_valid) { |
| 675 | + const auto& arr1 = *arg1.array(); |
| 676 | + const auto val1Ptr = arr1.buffers[1]->data(); |
| 677 | + auto length = arr1.length; |
| 678 | + const auto nullCount1 = arr1.GetNullCount(); |
| 679 | + const auto valid1 = (nullCount1 == 0) ? nullptr : arr1.GetValues<uint8_t>(0); |
| 680 | + const auto val2Ptr = GetStringScalarValue(*arg2.scalar()); |
| 681 | + auto resPtr = resArr.buffers[1]->mutable_data(); |
| 682 | + auto resValid = res->array()->GetMutableValues<uint8_t>(0); |
| 683 | + func(val1Ptr, valid1, val2Ptr.data(), nullptr, resPtr, resValid, length, arr1.offset, 0); |
| 684 | + } else { |
| 685 | + GetBitmap(resArr, 0).SetBitsTo(false); |
| 686 | + } |
| 687 | + |
| 688 | + return arrow::Status::OK(); |
| 689 | +} |
| 690 | + |
| 691 | +arrow::Status ExecDecimalScalarArrayOptImpl(const arrow::compute::ExecBatch& batch, arrow::Datum* res, |
| 692 | + TUntypedBinaryArrayOptFuncPtr func) { |
| 693 | + MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); |
| 694 | + const auto& arg1 = batch.values[0]; |
| 695 | + const auto& arg2 = batch.values[1]; |
| 696 | + auto& resArr = *res->array(); |
| 697 | + if (arg1.scalar()->is_valid) { |
| 698 | + const auto val1Ptr = GetStringScalarValue(*arg1.scalar()); |
| 699 | + const auto& arr2 = *arg2.array(); |
| 700 | + auto length = arr2.length; |
| 701 | + const auto val2Ptr = arr2.buffers[1]->data(); |
| 702 | + const auto nullCount2 = arr2.GetNullCount(); |
| 703 | + const auto valid2 = (nullCount2 == 0) ? nullptr : arr2.GetValues<uint8_t>(0); |
| 704 | + auto resPtr = resArr.buffers[1]->mutable_data(); |
| 705 | + auto resValid = res->array()->GetMutableValues<uint8_t>(0); |
| 706 | + func(val1Ptr.data(), nullptr, val2Ptr, valid2, resPtr, resValid, length, 0, arr2.offset); |
| 707 | + } else { |
| 708 | + GetBitmap(resArr, 0).SetBitsTo(false); |
| 709 | + } |
| 710 | + |
| 711 | + return arrow::Status::OK(); |
| 712 | +} |
| 713 | + |
| 714 | +arrow::Status ExecDecimalScalarScalarOptImpl(arrow::compute::KernelContext* kernelCtx, |
| 715 | + const arrow::compute::ExecBatch& batch, arrow::Datum* res, |
| 716 | + TPrimitiveDataTypeGetter typeGetter, TUntypedBinaryScalarOptFuncPtr func) { |
| 717 | + MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); |
| 718 | + const auto& arg1 = batch.values[0]; |
| 719 | + const auto& arg2 = batch.values[1]; |
| 720 | + if (!arg1.scalar()->is_valid || !arg2.scalar()->is_valid) { |
| 721 | + *res = arrow::MakeNullScalar(typeGetter()); |
| 722 | + } else { |
| 723 | + const auto val1Ptr = GetStringScalarValue(*arg1.scalar()); |
| 724 | + const auto val2Ptr = GetStringScalarValue(*arg2.scalar()); |
| 725 | + std::shared_ptr<arrow::Buffer> buffer(ARROW_RESULT(arrow::AllocateBuffer(16, kernelCtx->memory_pool()))); |
| 726 | + auto resDatum = arrow::Datum(std::make_shared<TPrimitiveDataType<NYql::NDecimal::TInt128>::TScalarResult>(buffer)); |
| 727 | + if (!func(val1Ptr.data(), val2Ptr.data(), buffer->mutable_data())) { |
| 728 | + *res = arrow::MakeNullScalar(typeGetter()); |
| 729 | + } else { |
| 730 | + *res = resDatum.scalar(); |
| 731 | + } |
| 732 | + } |
| 733 | + |
| 734 | + return arrow::Status::OK(); |
| 735 | +} |
| 736 | + |
| 737 | +arrow::Status ExecDecimalBinaryOptImpl(arrow::compute::KernelContext* kernelCtx, |
| 738 | + const arrow::compute::ExecBatch& batch, arrow::Datum* res, |
| 739 | + TPrimitiveDataTypeGetter typeGetter, |
| 740 | + size_t outputSizeOf, |
| 741 | + TUntypedBinaryScalarOptFuncPtr scalarScalarFunc, |
| 742 | + TUntypedBinaryArrayOptFuncPtr scalarArrayFunc, |
| 743 | + TUntypedBinaryArrayOptFuncPtr arrayScalarFunc, |
| 744 | + TUntypedBinaryArrayOptFuncPtr arrayArrayFunc) { |
| 745 | + MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); |
| 746 | + const auto& arg1 = batch.values[0]; |
| 747 | + const auto& arg2 = batch.values[1]; |
| 748 | + if (arg1.is_scalar()) { |
| 749 | + if (arg2.is_scalar()) { |
| 750 | + return ExecDecimalScalarScalarOptImpl(kernelCtx, batch, res, typeGetter, scalarScalarFunc); |
| 751 | + } else { |
| 752 | + return ExecDecimalScalarArrayOptImpl(batch, res, scalarArrayFunc); |
| 753 | + } |
| 754 | + } else { |
| 755 | + if (arg2.is_scalar()) { |
| 756 | + return ExecDecimalArrayScalarOptImpl(batch, res, arrayScalarFunc); |
| 757 | + } else { |
| 758 | + return ExecArrayArrayOptImpl(kernelCtx, batch, res, arrayArrayFunc, outputSizeOf, typeGetter, false, false, EPropagateTz::None); |
| 759 | + } |
| 760 | + } |
| 761 | +} |
| 762 | + |
603 | 763 | } // namespace NMiniKQL
|
604 | 764 | } // namespace NKikimr
|
0 commit comments