7
7
8
8
#include < yt/yt/client/formats/parser.h>
9
9
10
+ #include < yt/yt/library/decimal/decimal.h>
11
+
10
12
#include < library/cpp/yt/memory/chunked_output_stream.h>
11
13
12
14
#include < util/stream/buffer.h>
19
21
20
22
#include < contrib/libs/apache/arrow/cpp/src/arrow/ipc/api.h>
21
23
24
+ #include < contrib/libs/apache/arrow/cpp/src/arrow/util/decimal.h>
25
+
22
26
namespace NYT ::NFormats {
23
27
24
28
using namespace NTableClient ;
25
29
using TUnversionedRowValues = std::vector<NTableClient::TUnversionedValue>;
30
+ using namespace NDecimal ;
26
31
27
32
namespace {
28
33
@@ -31,7 +36,7 @@ namespace {
31
36
void ThrowOnError (const arrow::Status& status)
32
37
{
33
38
if (!status.ok ()) {
34
- THROW_ERROR_EXCEPTION (" Arrow error occurred : %Qv" , status.message ());
39
+ THROW_ERROR_EXCEPTION (" Arrow error [%v] : %Qv" , status. CodeAsString () , status.message ());
35
40
}
36
41
}
37
42
@@ -158,6 +163,31 @@ class TArraySimpleVisitor
158
163
return ParseNull ();
159
164
}
160
165
166
+ // Decimal types. For now, YT natively supports only Decimal128 with scale up to 35.
167
+ // Thus, we represent short enough decimals as native YT decimals, and wider decimals as
168
+ // their decimal string representation; but the latter is subject to change whenever we
169
+ // get the native support for Decimal128 with scale up to 38 or Decimal256 with scale up to 76.
170
+ arrow::Status Visit (const arrow::Decimal128Type& type) override
171
+ {
172
+ constexpr int MaximumYTDecimalPrecision = 35 ;
173
+ if (type.precision () <= MaximumYTDecimalPrecision) {
174
+ return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) {
175
+ return MakeDecimalBinaryValue (value, columnId, type.precision ());
176
+ });
177
+ } else {
178
+ return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) {
179
+ return MakeDecimalTextValue<arrow::Decimal128>(value, columnId, type.scale ());
180
+ });
181
+ }
182
+ }
183
+
184
+ arrow::Status Visit (const arrow::Decimal256Type& type) override
185
+ {
186
+ return ParseStringLikeArray<arrow::Decimal256Array>([&] (const TStringBuf& value, i64 columnId) {
187
+ return MakeDecimalTextValue<arrow::Decimal256>(value, columnId, type.scale ());
188
+ });
189
+ }
190
+
161
191
private:
162
192
const i64 ColumnId_;
163
193
@@ -209,7 +239,7 @@ class TArraySimpleVisitor
209
239
}
210
240
211
241
template <typename ArrayType>
212
- arrow::Status ParseStringLikeArray ()
242
+ arrow::Status ParseStringLikeArray (auto makeUnversionedValueFunc )
213
243
{
214
244
auto array = std::static_pointer_cast<ArrayType>(Array_);
215
245
for (int rowIndex = 0 ; rowIndex < array->length (); ++rowIndex) {
@@ -225,12 +255,23 @@ class TArraySimpleVisitor
225
255
BufferForStringLikeValues_->Advance (element.size ());
226
256
auto value = TStringBuf (buffer, element.size ());
227
257
228
- (*RowValues_)[rowIndex] = MakeUnversionedStringValue (value, ColumnId_);
258
+ (*RowValues_)[rowIndex] = makeUnversionedValueFunc (value, ColumnId_);
229
259
}
230
260
}
231
261
return arrow::Status::OK ();
232
262
}
233
263
264
+ template <typename ArrayType>
265
+ arrow::Status ParseStringLikeArray ()
266
+ {
267
+ // Note that MakeUnversionedStringValue actually has third argument in its signature,
268
+ // which leads to a "too few arguments" in the point of its invocation if we try to pass
269
+ // it directly to ParseStringLikeArray.
270
+ return ParseStringLikeArray<ArrayType>([] (const TStringBuf& value, i64 columnId) {
271
+ return MakeUnversionedStringValue (value, columnId);
272
+ });
273
+ }
274
+
234
275
arrow::Status ParseBoolean ()
235
276
{
236
277
auto array = std::static_pointer_cast<arrow::BooleanArray>(Array_);
@@ -252,6 +293,34 @@ class TArraySimpleVisitor
252
293
}
253
294
return arrow::Status::OK ();
254
295
}
296
+
297
+ TUnversionedValue MakeDecimalBinaryValue (const TStringBuf& value, i64 columnId, int precision)
298
+ {
299
+ // NB: arrow wire representation of Decimal128 is little-endian and (obviously) 128 bit,
300
+ // while YT in-memory representation of Decimal is big-endian, variadic-length of either 32 bit, 64 bit or 128 bit,
301
+ // and MSB-flipped to ensure lexical sorting order.
302
+ TDecimal::TValue128 value128;
303
+ YT_VERIFY (value.size () == sizeof (value128));
304
+ std::memcpy (&value128, value.data (), value.size ());
305
+
306
+ const auto maxByteCount = sizeof (value128);
307
+ char * buffer = BufferForStringLikeValues_->Preallocate (maxByteCount);
308
+ auto decimalBinary = TDecimal::WriteBinaryVariadic (precision, value128, buffer, maxByteCount);
309
+ BufferForStringLikeValues_->Advance (decimalBinary.size ());
310
+
311
+ return MakeUnversionedStringValue (decimalBinary, columnId);
312
+ }
313
+
314
+ template <class TArrowDecimalType >
315
+ TUnversionedValue MakeDecimalTextValue (const TStringBuf& value, i64 columnId, int scale)
316
+ {
317
+ TArrowDecimalType decimal (reinterpret_cast <const uint8_t *>(value.data ()));
318
+ auto string = decimal.ToString (scale);
319
+ char * buffer = BufferForStringLikeValues_->Preallocate (string.size ());
320
+ std::memcpy (buffer, string.data (), string.size ());
321
+ BufferForStringLikeValues_->Advance (string.size ());
322
+ return MakeUnversionedStringValue (TStringBuf (buffer, string.size ()), columnId);
323
+ }
255
324
};
256
325
257
326
// //////////////////////////////////////////////////////////////////////////////
@@ -552,12 +621,14 @@ class TArrayCompositeVisitor
552
621
// //////////////////////////////////////////////////////////////////////////////
553
622
554
623
void CheckArrowType (
624
+ auto ytTypeOrMetatype,
555
625
const std::shared_ptr<arrow::DataType>& arrowType,
556
626
std::initializer_list<arrow::Type::type> allowedTypes)
557
627
{
558
628
if (std::find (allowedTypes.begin (), allowedTypes.end (), arrowType->id ()) == allowedTypes.end ()) {
559
- THROW_ERROR_EXCEPTION (" Unexpected arrow type %Qv" ,
560
- arrowType->name ());
629
+ THROW_ERROR_EXCEPTION (" Unexpected arrow type %Qv for YT type or metatype %Qlv" ,
630
+ arrowType->name (),
631
+ ytTypeOrMetatype);
561
632
}
562
633
}
563
634
@@ -573,6 +644,7 @@ void CheckMatchingArrowTypes(
573
644
574
645
case ESimpleLogicalValueType::Interval:
575
646
CheckArrowType (
647
+ columnType,
576
648
column->type (),
577
649
{
578
650
arrow::Type::INT8,
@@ -597,6 +669,7 @@ void CheckMatchingArrowTypes(
597
669
case ESimpleLogicalValueType::Datetime:
598
670
case ESimpleLogicalValueType::Timestamp:
599
671
CheckArrowType (
672
+ columnType,
600
673
column->type (),
601
674
{
602
675
arrow::Type::UINT8,
@@ -611,20 +684,24 @@ void CheckMatchingArrowTypes(
611
684
case ESimpleLogicalValueType::Json:
612
685
case ESimpleLogicalValueType::Utf8:
613
686
CheckArrowType (
687
+ columnType,
614
688
column->type (),
615
689
{
616
690
arrow::Type::STRING,
617
691
arrow::Type::BINARY,
618
692
arrow::Type::LARGE_STRING,
619
693
arrow::Type::LARGE_BINARY,
620
694
arrow::Type::FIXED_SIZE_BINARY,
621
- arrow::Type::DICTIONARY
695
+ arrow::Type::DICTIONARY,
696
+ arrow::Type::DECIMAL128,
697
+ arrow::Type::DECIMAL256,
622
698
});
623
699
break ;
624
700
625
701
case ESimpleLogicalValueType::Float:
626
702
case ESimpleLogicalValueType::Double:
627
703
CheckArrowType (
704
+ columnType,
628
705
column->type (),
629
706
{
630
707
arrow::Type::HALF_FLOAT,
@@ -636,12 +713,14 @@ void CheckMatchingArrowTypes(
636
713
637
714
case ESimpleLogicalValueType::Boolean:
638
715
CheckArrowType (
716
+ columnType,
639
717
column->type (),
640
718
{arrow::Type::BOOL, arrow::Type::DICTIONARY});
641
719
break ;
642
720
643
721
case ESimpleLogicalValueType::Any:
644
722
CheckArrowType (
723
+ columnType,
645
724
column->type (),
646
725
{
647
726
arrow::Type::INT8,
@@ -679,6 +758,7 @@ void CheckMatchingArrowTypes(
679
758
case ESimpleLogicalValueType::Null:
680
759
case ESimpleLogicalValueType::Void:
681
760
CheckArrowType (
761
+ columnType,
682
762
column->type (),
683
763
{
684
764
arrow::Type::NA,
@@ -688,6 +768,7 @@ void CheckMatchingArrowTypes(
688
768
689
769
case ESimpleLogicalValueType::Uuid:
690
770
CheckArrowType (
771
+ columnType,
691
772
column->type (),
692
773
{
693
774
arrow::Type::STRING,
@@ -749,9 +830,10 @@ void PrepareArrayForComplexType(
749
830
int columnIndex,
750
831
int columnId)
751
832
{
752
- switch (denullifiedLogicalType->GetMetatype ()) {
833
+ switch (auto metatype = denullifiedLogicalType->GetMetatype ()) {
753
834
case ELogicalMetatype::List:
754
835
CheckArrowType (
836
+ metatype,
755
837
column->type (),
756
838
{
757
839
arrow::Type::LIST,
@@ -761,6 +843,7 @@ void PrepareArrayForComplexType(
761
843
762
844
case ELogicalMetatype::Dict:
763
845
CheckArrowType (
846
+ metatype,
764
847
column->type (),
765
848
{
766
849
arrow::Type::MAP,
@@ -770,32 +853,49 @@ void PrepareArrayForComplexType(
770
853
771
854
case ELogicalMetatype::Struct:
772
855
CheckArrowType (
856
+ metatype,
773
857
column->type (),
774
858
{
775
859
arrow::Type::STRUCT,
776
860
arrow::Type::BINARY
777
861
});
778
862
break ;
863
+
779
864
case ELogicalMetatype::Decimal:
865
+ CheckArrowType (
866
+ metatype,
867
+ column->type (),
868
+ {
869
+ arrow::Type::DECIMAL128,
870
+ arrow::Type::DECIMAL256
871
+ });
872
+ break ;
873
+
780
874
case ELogicalMetatype::Optional:
781
875
case ELogicalMetatype::Tuple:
782
876
case ELogicalMetatype::VariantTuple:
783
877
case ELogicalMetatype::VariantStruct:
784
- CheckArrowType (column->type (), {arrow::Type::BINARY});
878
+ CheckArrowType (metatype, column->type (), {arrow::Type::BINARY});
785
879
break ;
786
880
787
881
default :
788
882
THROW_ERROR_EXCEPTION (" Unexpected arrow type in complex type %Qv" , column->type ()->name ());
789
883
}
790
884
791
- if (column->type ()->id () == arrow::Type::BINARY) {
885
+ if (column->type ()->id () == arrow::Type::BINARY ||
886
+ column->type ()->id () == arrow::Type::DECIMAL128 ||
887
+ column->type ()->id () == arrow::Type::DECIMAL256)
888
+ {
792
889
TUnversionedRowValues stringValues (rowsValues[columnIndex].size ());
793
890
TArraySimpleVisitor visitor (columnId, column, bufferForStringLikeValues, &stringValues);
794
891
ThrowOnError (column->type ()->Accept (&visitor));
795
892
for (int offset = 0 ; offset < std::ssize (rowsValues[columnIndex]); offset++) {
796
893
if (column->IsNull (offset)) {
797
894
rowsValues[columnIndex][offset] = MakeUnversionedNullValue (columnId);
895
+ } else if (column->type ()->id () == arrow::Type::DECIMAL128 || column->type ()->id () == arrow::Type::DECIMAL256) {
896
+ rowsValues[columnIndex][offset] = MakeUnversionedStringValue (stringValues[offset].AsStringBuf (), columnId);
798
897
} else {
898
+ // TODO(max): is it even correct? Binary is not necessarily a correct YSON...
799
899
rowsValues[columnIndex][offset] = MakeUnversionedCompositeValue (stringValues[offset].AsStringBuf (), columnId);
800
900
}
801
901
}
0 commit comments