Skip to content

Commit b47a8b8

Browse files
authored
Support PG types in arrow and clickhouse (#9335)
1 parent 29fff7e commit b47a8b8

File tree

5 files changed

+151
-46
lines changed

5 files changed

+151
-46
lines changed

ydb/core/formats/arrow/converter.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,6 @@ bool TArrowToYdbConverter::Process(const arrow::RecordBatch& batch, TString& err
297297
for (; row < rowsUnroll; row += unroll) {
298298
ui32 col = 0;
299299
for (auto& [colName, colType] : YdbSchema_) {
300-
// TODO: support pg types
301-
Y_ABORT_UNLESS(colType.GetTypeId() != NScheme::NTypeIds::Pg, "pg types are not supported");
302-
303300
auto& column = allColumns[col];
304301
bool success = SwitchYqlTypeToArrowType(colType, [&]<typename TType>(TTypeWrapper<TType> typeHolder) {
305302
Y_UNUSED(typeHolder);
@@ -347,9 +344,6 @@ bool TArrowToYdbConverter::Process(const arrow::RecordBatch& batch, TString& err
347344

348345
ui32 col = 0;
349346
for (auto& [colName, colType] : YdbSchema_) {
350-
// TODO: support pg types
351-
Y_ABORT_UNLESS(colType.GetTypeId() != NScheme::NTypeIds::Pg, "pg types are not supported");
352-
353347
auto& column = allColumns[col];
354348
auto& curCell = cells[0][col];
355349
if (column->IsNull(row)) {

ydb/core/formats/arrow/switch/switch_type.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ template <typename TFunc>
9393
case TEXTOID:
9494
return callback(TTypeWrapper<arrow::StringType>());
9595
default:
96-
break;
96+
return false;
9797
}
98-
break; // TODO: support pg types
98+
break;
9999
}
100100
return false;
101101
}

ydb/core/formats/arrow/ut/ut_arrow.cpp

Lines changed: 125 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,39 @@ using TTypeId = NScheme::TTypeId;
2121
using TTypeInfo = NScheme::TTypeInfo;
2222

2323
struct TDataRow {
24-
static const constexpr TTypeInfo Types[20] = {
25-
TTypeInfo(NTypeIds::Bool),
26-
TTypeInfo(NTypeIds::Int8),
27-
TTypeInfo(NTypeIds::Int16),
28-
TTypeInfo(NTypeIds::Int32),
29-
TTypeInfo(NTypeIds::Int64),
30-
TTypeInfo(NTypeIds::Uint8),
31-
TTypeInfo(NTypeIds::Uint16),
32-
TTypeInfo(NTypeIds::Uint32),
33-
TTypeInfo(NTypeIds::Uint64),
34-
TTypeInfo(NTypeIds::Float),
35-
TTypeInfo(NTypeIds::Double),
36-
TTypeInfo(NTypeIds::String),
37-
TTypeInfo(NTypeIds::Utf8),
38-
TTypeInfo(NTypeIds::Json),
39-
TTypeInfo(NTypeIds::Yson),
40-
TTypeInfo(NTypeIds::Date),
41-
TTypeInfo(NTypeIds::Datetime),
42-
TTypeInfo(NTypeIds::Timestamp),
43-
TTypeInfo(NTypeIds::Interval),
44-
TTypeInfo(NTypeIds::JsonDocument),
45-
// TODO: DyNumber, Decimal
46-
};
24+
static const TTypeInfo* MakeTypeInfos() {
25+
static const TTypeInfo types[27] = {
26+
TTypeInfo(NTypeIds::Bool),
27+
TTypeInfo(NTypeIds::Int8),
28+
TTypeInfo(NTypeIds::Int16),
29+
TTypeInfo(NTypeIds::Int32),
30+
TTypeInfo(NTypeIds::Int64),
31+
TTypeInfo(NTypeIds::Uint8),
32+
TTypeInfo(NTypeIds::Uint16),
33+
TTypeInfo(NTypeIds::Uint32),
34+
TTypeInfo(NTypeIds::Uint64),
35+
TTypeInfo(NTypeIds::Float),
36+
TTypeInfo(NTypeIds::Double),
37+
TTypeInfo(NTypeIds::String),
38+
TTypeInfo(NTypeIds::Utf8),
39+
TTypeInfo(NTypeIds::Json),
40+
TTypeInfo(NTypeIds::Yson),
41+
TTypeInfo(NTypeIds::Date),
42+
TTypeInfo(NTypeIds::Datetime),
43+
TTypeInfo(NTypeIds::Timestamp),
44+
TTypeInfo(NTypeIds::Interval),
45+
TTypeInfo(NTypeIds::JsonDocument),
46+
TTypeInfo(NPg::TypeDescFromPgTypeId(INT2OID)),
47+
TTypeInfo(NPg::TypeDescFromPgTypeId(INT4OID)),
48+
TTypeInfo(NPg::TypeDescFromPgTypeId(INT8OID)),
49+
TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT4OID)),
50+
TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT8OID)),
51+
TTypeInfo(NPg::TypeDescFromPgTypeId(BYTEAOID)),
52+
TTypeInfo(NPg::TypeDescFromPgTypeId(TEXTOID)),
53+
// TODO: DyNumber, Decimal
54+
};
55+
return types;
56+
}
4757

4858
bool Bool;
4959
i8 Int8;
@@ -65,6 +75,13 @@ struct TDataRow {
6575
i64 Timestamp;
6676
i64 Interval;
6777
std::string JsonDocument;
78+
i16 PgInt2;
79+
i32 PgInt4;
80+
i64 PgInt8;
81+
float PgFloat4;
82+
double PgFloat8;
83+
std::string PgBytea;
84+
std::string PgText;
6885
//ui64 Decimal[2];
6986

7087
bool operator == (const TDataRow& r) const {
@@ -87,7 +104,14 @@ struct TDataRow {
87104
(Datetime == r.Datetime) &&
88105
(Timestamp == r.Timestamp) &&
89106
(Interval == r.Interval) &&
90-
(JsonDocument == r.JsonDocument);
107+
(JsonDocument == r.JsonDocument) &&
108+
(PgInt2 == r.PgInt2) &&
109+
(PgInt4 == r.PgInt4) &&
110+
(PgInt8 == r.PgInt8) &&
111+
(PgFloat4 == r.PgFloat4) &&
112+
(PgFloat8 == r.PgFloat8) &&
113+
(PgBytea == r.PgBytea) &&
114+
(PgText == r.PgText);
91115
//(Decimal[0] == r.Decimal[0] && Decimal[1] == r.Decimal[1]);
92116
}
93117

@@ -113,6 +137,13 @@ struct TDataRow {
113137
arrow::field("ts", arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO)),
114138
arrow::field("ival", arrow::duration(arrow::TimeUnit::TimeUnit::MICRO)),
115139
arrow::field("json_doc", arrow::binary()),
140+
arrow::field("pgint2", arrow::int16()),
141+
arrow::field("pgint4", arrow::int32()),
142+
arrow::field("pgint8", arrow::int64()),
143+
arrow::field("pgfloat4", arrow::float32()),
144+
arrow::field("pgfloat8", arrow::float64()),
145+
arrow::field("pgbytea", arrow::binary()),
146+
arrow::field("pgtext", arrow::utf8()),
116147
//arrow::field("dec", arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE)),
117148
};
118149

@@ -141,13 +172,20 @@ struct TDataRow {
141172
{"ts", TTypeInfo(NTypeIds::Timestamp) },
142173
{"ival", TTypeInfo(NTypeIds::Interval) },
143174
{"json_doc", TTypeInfo(NTypeIds::JsonDocument) },
175+
{"pgint2", TTypeInfo(NPg::TypeDescFromPgTypeId(INT2OID)) },
176+
{"pgint4", TTypeInfo(NPg::TypeDescFromPgTypeId(INT4OID)) },
177+
{"pgint8", TTypeInfo(NPg::TypeDescFromPgTypeId(INT8OID)) },
178+
{"pgfloat4", TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT4OID)) },
179+
{"pgfloat8", TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT8OID)) },
180+
{"pgbytea", TTypeInfo(NPg::TypeDescFromPgTypeId(BYTEAOID)) },
181+
{"pgtext", TTypeInfo(NPg::TypeDescFromPgTypeId(TEXTOID)) },
144182
//{"dec", TTypeInfo(NTypeIds::Decimal) }
145183
};
146184
return columns;
147185
}
148186

149187
NKikimr::TDbTupleRef ToDbTupleRef() const {
150-
static TCell Cells[20];
188+
static TCell Cells[27];
151189
Cells[0] = TCell::Make<bool>(Bool);
152190
Cells[1] = TCell::Make<i8>(Int8);
153191
Cells[2] = TCell::Make<i16>(Int16);
@@ -168,9 +206,16 @@ struct TDataRow {
168206
Cells[17] = TCell::Make<i64>(Timestamp);
169207
Cells[18] = TCell::Make<i64>(Interval);
170208
Cells[19] = TCell(JsonDocument.data(), JsonDocument.size());
209+
Cells[20] = TCell::Make<i16>(PgInt2);
210+
Cells[21] = TCell::Make<i32>(PgInt4);
211+
Cells[22] = TCell::Make<i64>(PgInt8);
212+
Cells[23] = TCell::Make<float>(PgFloat4);
213+
Cells[24] = TCell::Make<double>(PgFloat8);
214+
Cells[25] = TCell(PgBytea.data(), PgBytea.size());
215+
Cells[26] = TCell(PgText.data(), PgText.size());
171216
//Cells[19] = TCell((const char *)&Decimal[0], 16);
172217

173-
return NKikimr::TDbTupleRef(Types, Cells, 20);
218+
return NKikimr::TDbTupleRef(MakeTypeInfos(), Cells, 27);
174219
}
175220

176221
TOwnedCellVec SerializedCells() const {
@@ -216,6 +261,13 @@ std::vector<TDataRow> ToVector(const std::shared_ptr<T>& table) {
216261
auto arival = std::static_pointer_cast<arrow::DurationArray>(GetColumn(*table, 18));
217262

218263
auto arjd = std::static_pointer_cast<arrow::BinaryArray>(GetColumn(*table, 19));
264+
auto arpgi2 = std::static_pointer_cast<arrow::Int16Array>(GetColumn(*table, 20));
265+
auto arpgi4 = std::static_pointer_cast<arrow::Int32Array>(GetColumn(*table, 21));
266+
auto arpgi8 = std::static_pointer_cast<arrow::Int64Array>(GetColumn(*table, 22));
267+
auto arpgf4 = std::static_pointer_cast<arrow::FloatArray>(GetColumn(*table, 23));
268+
auto arpgf8 = std::static_pointer_cast<arrow::DoubleArray>(GetColumn(*table, 24));
269+
auto arpgb = std::static_pointer_cast<arrow::BinaryArray>(GetColumn(*table, 25));
270+
auto arpgt = std::static_pointer_cast<arrow::StringArray>(GetColumn(*table, 26));
219271
//auto ardec = std::static_pointer_cast<arrow::Decimal128Array>(GetColumn(*table, 19));
220272

221273
for (int64_t i = 0; i < table->num_rows(); ++i) {
@@ -226,7 +278,9 @@ std::vector<TDataRow> ToVector(const std::shared_ptr<T>& table) {
226278
aru8->Value(i), aru16->Value(i), aru32->Value(i), aru64->Value(i),
227279
arf32->Value(i), arf64->Value(i),
228280
arstr->GetString(i), arutf->GetString(i), arj->GetString(i), ary->GetString(i),
229-
ard->Value(i), ardt->Value(i), arts->Value(i), arival->Value(i), arjd->GetString(i)
281+
ard->Value(i), ardt->Value(i), arts->Value(i), arival->Value(i), arjd->GetString(i),
282+
arpgi2->Value(i), arpgi4->Value(i), arpgi8->Value(i), arpgf4->Value(i), arpgf8->Value(i),
283+
arpgb->GetString(i), arpgt->GetString(i)
230284
//{dec[0], dec[1]}
231285
};
232286
rows.emplace_back(std::move(r));
@@ -268,6 +322,13 @@ class TDataRowTableBuilder
268322
UNIT_ASSERT(Bival.Append(row.Interval).ok());
269323

270324
UNIT_ASSERT(Bjd.Append(row.JsonDocument).ok());
325+
UNIT_ASSERT(Bpgi2.Append(row.PgInt2).ok());
326+
UNIT_ASSERT(Bpgi4.Append(row.PgInt4).ok());
327+
UNIT_ASSERT(Bpgi8.Append(row.PgInt8).ok());
328+
UNIT_ASSERT(Bpgf4.Append(row.PgFloat4).ok());
329+
UNIT_ASSERT(Bpgf8.Append(row.PgFloat8).ok());
330+
UNIT_ASSERT(Bpgb.Append(row.PgBytea).ok());
331+
UNIT_ASSERT(Bpgt.Append(row.PgText).ok());
271332
//UNIT_ASSERT(Bdec.Append((const char *)&row.Decimal).ok());
272333
}
273334

@@ -295,6 +356,13 @@ class TDataRowTableBuilder
295356
std::shared_ptr<arrow::DurationArray> arival;
296357

297358
std::shared_ptr<arrow::BinaryArray> arjd;
359+
std::shared_ptr<arrow::Int16Array> arpgi2;
360+
std::shared_ptr<arrow::Int32Array> arpgi4;
361+
std::shared_ptr<arrow::Int64Array> arpgi8;
362+
std::shared_ptr<arrow::FloatArray> arpgf4;
363+
std::shared_ptr<arrow::DoubleArray> arpgf8;
364+
std::shared_ptr<arrow::BinaryArray> arpgb;
365+
std::shared_ptr<arrow::StringArray> arpgt;
298366
//std::shared_ptr<arrow::Decimal128Array> ardec;
299367

300368
UNIT_ASSERT(Bbool.Finish(&arbool).ok());
@@ -320,6 +388,13 @@ class TDataRowTableBuilder
320388
UNIT_ASSERT(Bival.Finish(&arival).ok());
321389

322390
UNIT_ASSERT(Bjd.Finish(&arjd).ok());
391+
UNIT_ASSERT(Bpgi2.Finish(&arpgi2).ok());
392+
UNIT_ASSERT(Bpgi4.Finish(&arpgi4).ok());
393+
UNIT_ASSERT(Bpgi8.Finish(&arpgi8).ok());
394+
UNIT_ASSERT(Bpgf4.Finish(&arpgf4).ok());
395+
UNIT_ASSERT(Bpgf8.Finish(&arpgf8).ok());
396+
UNIT_ASSERT(Bpgb.Finish(&arpgb).ok());
397+
UNIT_ASSERT(Bpgt.Finish(&arpgt).ok());
323398
//UNIT_ASSERT(Bdec.Finish(&ardec).ok());
324399

325400
std::shared_ptr<arrow::Schema> schema = TDataRow::MakeArrowSchema();
@@ -329,7 +404,9 @@ class TDataRowTableBuilder
329404
aru8, aru16, aru32, aru64,
330405
arf32, arf64,
331406
arstr, arutf, arj, ary,
332-
ard, ardt, arts, arival, arjd
407+
ard, ardt, arts, arival, arjd,
408+
arpgi2, arpgi4, arpgi8, arpgf4, arpgf8,
409+
arpgb, arpgt
333410
//ardec
334411
});
335412
}
@@ -363,13 +440,21 @@ class TDataRowTableBuilder
363440
arrow::TimestampBuilder Bts;
364441
arrow::DurationBuilder Bival;
365442
arrow::BinaryBuilder Bjd;
443+
arrow::Int16Builder Bpgi2;
444+
arrow::Int32Builder Bpgi4;
445+
arrow::Int64Builder Bpgi8;
446+
arrow::FloatBuilder Bpgf4;
447+
arrow::DoubleBuilder Bpgf8;
448+
arrow::BinaryBuilder Bpgb;
449+
arrow::StringBuilder Bpgt;
366450
//arrow::Decimal128Builder Bdec;
367451
};
368452

369453
std::shared_ptr<arrow::RecordBatch> VectorToBatch(const std::vector<struct TDataRow>& rows) {
370454
TString err;
371455
NArrow::TArrowBatchBuilder batchBuilder;
372456
batchBuilder.Start(TDataRow::MakeYdbSchema(), 0, 0, err);
457+
UNIT_ASSERT_C(err.Empty(), err);
373458

374459
for (const TDataRow& row : rows) {
375460
NKikimr::TDbTupleRef key;
@@ -382,10 +467,14 @@ std::shared_ptr<arrow::RecordBatch> VectorToBatch(const std::vector<struct TData
382467

383468
std::vector<TDataRow> TestRows() {
384469
std::vector<TDataRow> rows = {
385-
{false, -1, -1, -1, -1, 1, 1, 1, 1, -1.0f, -1.0, "s1", "u1", "{\"j\":1}", "{y:1}", 0, 0, 0, 0, "{\"jd\":1}" },
386-
{false, 2, 2, 2, 2, 2, 2, 2, 2, 2.0f, 2.0, "s2", "u2", "{\"j\":2}", "{y:2}", 0, 0, 0, 0, "{\"jd\":1}" },
387-
{false, -3, -3, -3, -3, 3, 3, 3, 3, -3.0f, -3.0, "s3", "u3", "{\"j\":3}", "{y:3}", 0, 0, 0, 0, "{\"jd\":1}" },
388-
{false, -4, -4, -4, -4, 4, 4, 4, 4, 4.0f, 4.0, "s4", "u4", "{\"j\":4}", "{y:4}", 0, 0, 0, 0, "{\"jd\":1}" },
470+
{false, -1, -1, -1, -1, 1, 1, 1, 1, -1.0f, -1.0, "s1", "u1", "{\"j\":1}", "{y:1}", 0, 0, 0, 0, "{\"jd\":1}",
471+
-5, -5, -5, -5.1f, -5.1, "s5", "u5"},
472+
{false, 2, 2, 2, 2, 2, 2, 2, 2, 2.0f, 2.0, "s2", "u2", "{\"j\":2}", "{y:2}", 0, 0, 0, 0, "{\"jd\":1}",
473+
-3, -3, -3, -3.1f, -3.1, "s3", "u3"},
474+
{false, -3, -3, -3, -3, 3, 3, 3, 3, -3.0f, -3.0, "s3", "u3", "{\"j\":3}", "{y:3}", 0, 0, 0, 0, "{\"jd\":1}",
475+
-2, -2, -2, -2.1f, -2.1, "s2", "u2"},
476+
{false, -4, -4, -4, -4, 4, 4, 4, 4, 4.0f, 4.0, "s4", "u4", "{\"j\":4}", "{y:4}", 0, 0, 0, 0, "{\"jd\":1}",
477+
-7, -7, -7, -7.1f, -7.1, "s7", "u7"},
389478
};
390479
return rows;
391480
}
@@ -412,7 +501,9 @@ std::shared_ptr<arrow::Table> MakeTable1000() {
412501
i8 a = i/100;
413502
i16 b = (i%100)/10;
414503
i32 c = i%10;
415-
builder.AddRow(TDataRow{false, a, b, c, i, 1, 1, 1, 1, 1.0f, 1.0, "", "", "", "", 0, 0, 0, 0, {0,0} });
504+
builder.AddRow(
505+
TDataRow{false, a, b, c, i, 1, 1, 1, 1, 1.0f, 1.0, "", "", "", "", 0, 0, 0, 0, {0,0},
506+
0, 0, 0, 0.0f, 0.0, "", ""});
416507
}
417508

418509
auto table = builder.Finish();
@@ -575,7 +666,7 @@ Y_UNIT_TEST_SUITE(ArrowTest) {
575666
for (size_t i = 0; i < rows.size(); ++i) {
576667
UNIT_ASSERT(0 == CompareTypedCellVectors(
577668
cellRows[i].data(), rowWriter.Rows[i].data(),
578-
TDataRow::Types,
669+
TDataRow::MakeTypeInfos(),
579670
cellRows[i].size(), rowWriter.Rows[i].size()));
580671
}
581672
}

ydb/core/formats/arrow/ut/ya.make

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ PEERDIR(
1010

1111
# for NYql::NUdf alloc stuff used in binary_json
1212
ydb/library/yql/public/udf/service/exception_policy
13-
ydb/library/yql/sql/pg_dummy
13+
ydb/library/yql/sql/pg
14+
ydb/library/yql/parser/pg_wrapper
1415
)
1516

1617
ADDINCL(

ydb/core/formats/clickhouse_block.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#include <util/generic/string.h>
1313
#include <util/generic/hash.h>
1414

15+
extern "C" {
16+
#include <ydb/library/yql/parser/pg_wrapper/postgresql/src/include/catalog/pg_type_d.h>
17+
}
18+
1519
namespace NKikHouse {
1620
namespace NSerialization {
1721

@@ -480,9 +484,24 @@ class TDataTypeRegistry : public TThrRefBase {
480484
CONVERT(StepOrderId, String);
481485

482486
case NScheme::NTypeIds::Pg:
483-
// TODO: support pg types
484-
throw yexception() << "Unsupported pg type";
485-
487+
switch (NPg::PgTypeIdFromTypeDesc(type.GetPgTypeDesc())) {
488+
case INT2OID:
489+
return Get("Int16");
490+
case INT4OID:
491+
return Get("Int32");
492+
case INT8OID:
493+
return Get("Int64");
494+
case FLOAT4OID:
495+
return Get("Float");
496+
case FLOAT8OID:
497+
return Get("Double");
498+
case BYTEAOID:
499+
return Get("String");
500+
case TEXTOID:
501+
return Get("String");
502+
default:
503+
throw yexception() << "Unsupported pg type";
504+
}
486505
default:
487506
throw yexception() << "Unsupported type: " << type.GetTypeId();
488507
}

0 commit comments

Comments
 (0)