Skip to content

Commit c73194d

Browse files
zlobobernadya02
authored andcommitted
Support Decimal128/Decimal256 in Arrow parser
No description --- b24f71e64f22e615ebb32f33fb2cfc5c88198c1a Pull Request resolved: ytsaurus/ytsaurus#769 Co-authored-by: nadya02 <nadya02@yandex-team.com>
1 parent feda194 commit c73194d

File tree

3 files changed

+127
-9
lines changed

3 files changed

+127
-9
lines changed

yt/yt/library/decimal/decimal.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,21 @@ TStringBuf TDecimal::WriteBinary128(int precision, TValue128 value, char* buffer
535535
return TStringBuf{buffer, sizeof(TValue128)};
536536
}
537537

538+
TStringBuf TDecimal::WriteBinaryVariadic(int precision, TValue128 value, char* buffer, size_t bufferLength)
539+
{
540+
const size_t resultLength = GetValueBinarySize(precision);
541+
switch (resultLength) {
542+
case 4:
543+
return WriteBinary32(precision, static_cast<i32>(value.Low), buffer, bufferLength);
544+
case 8:
545+
return WriteBinary64(precision, static_cast<i64>(value.Low), buffer, bufferLength);
546+
case 16:
547+
return WriteBinary128(precision, value, buffer, bufferLength);
548+
default:
549+
THROW_ERROR_EXCEPTION("Invalid precision %v", precision);
550+
}
551+
}
552+
538553
template <typename T>
539554
Y_FORCE_INLINE void CheckBufferLength(int precision, size_t bufferLength)
540555
{

yt/yt/library/decimal/decimal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class TDecimal
5050
static TStringBuf WriteBinary64(int precision, i64 value, char* buffer, size_t bufferLength);
5151
static TStringBuf WriteBinary128(int precision, TValue128 value, char* buffer, size_t bufferLength);
5252

53+
// Writes either 32-bit, 64-bit or 128-bit binary value depending on precision, provided a TValue128.
54+
static TStringBuf WriteBinaryVariadic(int precision, TValue128 value, char* buffer, size_t bufferLength);
55+
5356
static i32 ParseBinary32(int precision, TStringBuf buffer);
5457
static i64 ParseBinary64(int precision, TStringBuf buffer);
5558
static TValue128 ParseBinary128(int precision, TStringBuf buffer);

yt/yt/library/formats/arrow_parser.cpp

Lines changed: 109 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
#include <yt/yt/client/formats/parser.h>
99

10+
#include <yt/yt/library/decimal/decimal.h>
11+
1012
#include <library/cpp/yt/memory/chunked_output_stream.h>
1113

1214
#include <util/stream/buffer.h>
@@ -19,10 +21,13 @@
1921

2022
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/api.h>
2123

24+
#include <contrib/libs/apache/arrow/cpp/src/arrow/util/decimal.h>
25+
2226
namespace NYT::NFormats {
2327

2428
using namespace NTableClient;
2529
using TUnversionedRowValues = std::vector<NTableClient::TUnversionedValue>;
30+
using namespace NDecimal;
2631

2732
namespace {
2833

@@ -31,7 +36,7 @@ namespace {
3136
void ThrowOnError(const arrow::Status& status)
3237
{
3338
if (!status.ok()) {
34-
THROW_ERROR_EXCEPTION("Arrow error occurred: %Qv", status.message());
39+
THROW_ERROR_EXCEPTION("Arrow error [%v]: %Qv", status.CodeAsString(), status.message());
3540
}
3641
}
3742

@@ -158,6 +163,31 @@ class TArraySimpleVisitor
158163
return ParseNull();
159164
}
160165

166+
// Decimal types. For now, YT natively supports only Decimal128 with scale up to 35.
167+
// Thus, we represent short enough decimals as native YT decimals, and wider decimals as
168+
// their decimal string representation; but the latter is subject to change whenever we
169+
// get the native support for Decimal128 with scale up to 38 or Decimal256 with scale up to 76.
170+
arrow::Status Visit(const arrow::Decimal128Type& type) override
171+
{
172+
constexpr int MaximumYTDecimalPrecision = 35;
173+
if (type.precision() <= MaximumYTDecimalPrecision) {
174+
return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) {
175+
return MakeDecimalBinaryValue(value, columnId, type.precision());
176+
});
177+
} else {
178+
return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) {
179+
return MakeDecimalTextValue<arrow::Decimal128>(value, columnId, type.scale());
180+
});
181+
}
182+
}
183+
184+
arrow::Status Visit(const arrow::Decimal256Type& type) override
185+
{
186+
return ParseStringLikeArray<arrow::Decimal256Array>([&] (const TStringBuf& value, i64 columnId) {
187+
return MakeDecimalTextValue<arrow::Decimal256>(value, columnId, type.scale());
188+
});
189+
}
190+
161191
private:
162192
const i64 ColumnId_;
163193

@@ -209,7 +239,7 @@ class TArraySimpleVisitor
209239
}
210240

211241
template <typename ArrayType>
212-
arrow::Status ParseStringLikeArray()
242+
arrow::Status ParseStringLikeArray(auto makeUnversionedValueFunc)
213243
{
214244
auto array = std::static_pointer_cast<ArrayType>(Array_);
215245
for (int rowIndex = 0; rowIndex < array->length(); ++rowIndex) {
@@ -225,12 +255,23 @@ class TArraySimpleVisitor
225255
BufferForStringLikeValues_->Advance(element.size());
226256
auto value = TStringBuf(buffer, element.size());
227257

228-
(*RowValues_)[rowIndex] = MakeUnversionedStringValue(value, ColumnId_);
258+
(*RowValues_)[rowIndex] = makeUnversionedValueFunc(value, ColumnId_);
229259
}
230260
}
231261
return arrow::Status::OK();
232262
}
233263

264+
template <typename ArrayType>
265+
arrow::Status ParseStringLikeArray()
266+
{
267+
// Note that MakeUnversionedStringValue actually has third argument in its signature,
268+
// which leads to a "too few arguments" in the point of its invocation if we try to pass
269+
// it directly to ParseStringLikeArray.
270+
return ParseStringLikeArray<ArrayType>([] (const TStringBuf& value, i64 columnId) {
271+
return MakeUnversionedStringValue(value, columnId);
272+
});
273+
}
274+
234275
arrow::Status ParseBoolean()
235276
{
236277
auto array = std::static_pointer_cast<arrow::BooleanArray>(Array_);
@@ -252,6 +293,34 @@ class TArraySimpleVisitor
252293
}
253294
return arrow::Status::OK();
254295
}
296+
297+
TUnversionedValue MakeDecimalBinaryValue(const TStringBuf& value, i64 columnId, int precision)
298+
{
299+
// NB: arrow wire representation of Decimal128 is little-endian and (obviously) 128 bit,
300+
// while YT in-memory representation of Decimal is big-endian, variadic-length of either 32 bit, 64 bit or 128 bit,
301+
// and MSB-flipped to ensure lexical sorting order.
302+
TDecimal::TValue128 value128;
303+
YT_VERIFY(value.size() == sizeof(value128));
304+
std::memcpy(&value128, value.data(), value.size());
305+
306+
const auto maxByteCount = sizeof(value128);
307+
char* buffer = BufferForStringLikeValues_->Preallocate(maxByteCount);
308+
auto decimalBinary = TDecimal::WriteBinaryVariadic(precision, value128, buffer, maxByteCount);
309+
BufferForStringLikeValues_->Advance(decimalBinary.size());
310+
311+
return MakeUnversionedStringValue(decimalBinary, columnId);
312+
}
313+
314+
template <class TArrowDecimalType>
315+
TUnversionedValue MakeDecimalTextValue(const TStringBuf& value, i64 columnId, int scale)
316+
{
317+
TArrowDecimalType decimal(reinterpret_cast<const uint8_t*>(value.data()));
318+
auto string = decimal.ToString(scale);
319+
char* buffer = BufferForStringLikeValues_->Preallocate(string.size());
320+
std::memcpy(buffer, string.data(), string.size());
321+
BufferForStringLikeValues_->Advance(string.size());
322+
return MakeUnversionedStringValue(TStringBuf(buffer, string.size()), columnId);
323+
}
255324
};
256325

257326
////////////////////////////////////////////////////////////////////////////////
@@ -552,12 +621,14 @@ class TArrayCompositeVisitor
552621
////////////////////////////////////////////////////////////////////////////////
553622

554623
void CheckArrowType(
624+
auto ytTypeOrMetatype,
555625
const std::shared_ptr<arrow::DataType>& arrowType,
556626
std::initializer_list<arrow::Type::type> allowedTypes)
557627
{
558628
if (std::find(allowedTypes.begin(), allowedTypes.end(), arrowType->id()) == allowedTypes.end()) {
559-
THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv",
560-
arrowType->name());
629+
THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv for YT type or metatype %Qlv",
630+
arrowType->name(),
631+
ytTypeOrMetatype);
561632
}
562633
}
563634

@@ -573,6 +644,7 @@ void CheckMatchingArrowTypes(
573644

574645
case ESimpleLogicalValueType::Interval:
575646
CheckArrowType(
647+
columnType,
576648
column->type(),
577649
{
578650
arrow::Type::INT8,
@@ -597,6 +669,7 @@ void CheckMatchingArrowTypes(
597669
case ESimpleLogicalValueType::Datetime:
598670
case ESimpleLogicalValueType::Timestamp:
599671
CheckArrowType(
672+
columnType,
600673
column->type(),
601674
{
602675
arrow::Type::UINT8,
@@ -611,20 +684,24 @@ void CheckMatchingArrowTypes(
611684
case ESimpleLogicalValueType::Json:
612685
case ESimpleLogicalValueType::Utf8:
613686
CheckArrowType(
687+
columnType,
614688
column->type(),
615689
{
616690
arrow::Type::STRING,
617691
arrow::Type::BINARY,
618692
arrow::Type::LARGE_STRING,
619693
arrow::Type::LARGE_BINARY,
620694
arrow::Type::FIXED_SIZE_BINARY,
621-
arrow::Type::DICTIONARY
695+
arrow::Type::DICTIONARY,
696+
arrow::Type::DECIMAL128,
697+
arrow::Type::DECIMAL256,
622698
});
623699
break;
624700

625701
case ESimpleLogicalValueType::Float:
626702
case ESimpleLogicalValueType::Double:
627703
CheckArrowType(
704+
columnType,
628705
column->type(),
629706
{
630707
arrow::Type::HALF_FLOAT,
@@ -636,12 +713,14 @@ void CheckMatchingArrowTypes(
636713

637714
case ESimpleLogicalValueType::Boolean:
638715
CheckArrowType(
716+
columnType,
639717
column->type(),
640718
{arrow::Type::BOOL, arrow::Type::DICTIONARY});
641719
break;
642720

643721
case ESimpleLogicalValueType::Any:
644722
CheckArrowType(
723+
columnType,
645724
column->type(),
646725
{
647726
arrow::Type::INT8,
@@ -679,6 +758,7 @@ void CheckMatchingArrowTypes(
679758
case ESimpleLogicalValueType::Null:
680759
case ESimpleLogicalValueType::Void:
681760
CheckArrowType(
761+
columnType,
682762
column->type(),
683763
{
684764
arrow::Type::NA,
@@ -688,6 +768,7 @@ void CheckMatchingArrowTypes(
688768

689769
case ESimpleLogicalValueType::Uuid:
690770
CheckArrowType(
771+
columnType,
691772
column->type(),
692773
{
693774
arrow::Type::STRING,
@@ -749,9 +830,10 @@ void PrepareArrayForComplexType(
749830
int columnIndex,
750831
int columnId)
751832
{
752-
switch (denullifiedLogicalType->GetMetatype()) {
833+
switch (auto metatype = denullifiedLogicalType->GetMetatype()) {
753834
case ELogicalMetatype::List:
754835
CheckArrowType(
836+
metatype,
755837
column->type(),
756838
{
757839
arrow::Type::LIST,
@@ -761,6 +843,7 @@ void PrepareArrayForComplexType(
761843

762844
case ELogicalMetatype::Dict:
763845
CheckArrowType(
846+
metatype,
764847
column->type(),
765848
{
766849
arrow::Type::MAP,
@@ -770,32 +853,49 @@ void PrepareArrayForComplexType(
770853

771854
case ELogicalMetatype::Struct:
772855
CheckArrowType(
856+
metatype,
773857
column->type(),
774858
{
775859
arrow::Type::STRUCT,
776860
arrow::Type::BINARY
777861
});
778862
break;
863+
779864
case ELogicalMetatype::Decimal:
865+
CheckArrowType(
866+
metatype,
867+
column->type(),
868+
{
869+
arrow::Type::DECIMAL128,
870+
arrow::Type::DECIMAL256
871+
});
872+
break;
873+
780874
case ELogicalMetatype::Optional:
781875
case ELogicalMetatype::Tuple:
782876
case ELogicalMetatype::VariantTuple:
783877
case ELogicalMetatype::VariantStruct:
784-
CheckArrowType(column->type(), {arrow::Type::BINARY});
878+
CheckArrowType(metatype, column->type(), {arrow::Type::BINARY});
785879
break;
786880

787881
default:
788882
THROW_ERROR_EXCEPTION("Unexpected arrow type in complex type %Qv", column->type()->name());
789883
}
790884

791-
if (column->type()->id() == arrow::Type::BINARY) {
885+
if (column->type()->id() == arrow::Type::BINARY ||
886+
column->type()->id() == arrow::Type::DECIMAL128 ||
887+
column->type()->id() == arrow::Type::DECIMAL256)
888+
{
792889
TUnversionedRowValues stringValues(rowsValues[columnIndex].size());
793890
TArraySimpleVisitor visitor(columnId, column, bufferForStringLikeValues, &stringValues);
794891
ThrowOnError(column->type()->Accept(&visitor));
795892
for (int offset = 0; offset < std::ssize(rowsValues[columnIndex]); offset++) {
796893
if (column->IsNull(offset)) {
797894
rowsValues[columnIndex][offset] = MakeUnversionedNullValue(columnId);
895+
} else if (column->type()->id() == arrow::Type::DECIMAL128 || column->type()->id() == arrow::Type::DECIMAL256) {
896+
rowsValues[columnIndex][offset] = MakeUnversionedStringValue(stringValues[offset].AsStringBuf(), columnId);
798897
} else {
898+
// TODO(max): is it even correct? Binary is not necessarily a correct YSON...
799899
rowsValues[columnIndex][offset] = MakeUnversionedCompositeValue(stringValues[offset].AsStringBuf(), columnId);
800900
}
801901
}

0 commit comments

Comments
 (0)