Skip to content

Commit 4b61f51

Browse files
add support for nested types
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
1 parent ec9f872 commit 4b61f51

11 files changed

+592
-44
lines changed

CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ add_library(${EXTENSION_NAME} STATIC ${EXTENSION_SOURCES})
107107

108108
set(PARAMETERS "-warnings")
109109
build_loadable_extension(${TARGET_NAME} ${PARAMETERS} ${EXTENSION_SOURCES})
110-
IF (DEFINED ENV{SKIP_SUBSTRAIT_C_TESTS})
111-
message(STATUS "Skipping substrait c tests")
112-
ELSE()
113-
add_subdirectory("test/c")
114-
ENDIF()
110+
if(DEFINED ENV{SKIP_SUBSTRAIT_C_TESTS})
111+
message(STATUS "Skipping substrait c tests")
112+
else()
113+
add_subdirectory("test/c")
114+
endif()
115115

116116
install(
117117
TARGETS ${EXTENSION_NAME}

src/from_substrait.cpp

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ Value TransformLiteralToValue(const substrait::Expression_Literal &literal) {
174174
case substrait::Expression_Literal::LiteralTypeCase::kVarChar:
175175
return {literal.var_char().value()};
176176
default:
177-
throw NotImplementedException("literals of this type are not implemented: %s",
178-
substrait::Expression_Literal::GetDescriptor()->FindFieldByNumber(literal.literal_type_case())->name());
177+
throw NotImplementedException(
178+
"literals of this type are not implemented: %s",
179+
substrait::Expression_Literal::GetDescriptor()->FindFieldByNumber(literal.literal_type_case())->name());
179180
}
180181
}
181182

@@ -305,6 +306,8 @@ LogicalType SubstraitToDuckDB::SubstraitToDuckType(const substrait::Type &s_type
305306
switch (s_type.kind_case()) {
306307
case substrait::Type::KindCase::kBool:
307308
return {LogicalTypeId::BOOLEAN};
309+
case substrait::Type::KindCase::kI8:
310+
return {LogicalTypeId::TINYINT};
308311
case substrait::Type::KindCase::kI16:
309312
return {LogicalTypeId::SMALLINT};
310313
case substrait::Type::KindCase::kI32:
@@ -317,14 +320,50 @@ LogicalType SubstraitToDuckDB::SubstraitToDuckType(const substrait::Type &s_type
317320
}
318321
case substrait::Type::KindCase::kDate:
319322
return {LogicalTypeId::DATE};
323+
case substrait::Type::KindCase::kTime:
324+
return {LogicalTypeId::TIME};
320325
case substrait::Type::KindCase::kVarchar:
321326
case substrait::Type::KindCase::kString:
322327
return {LogicalTypeId::VARCHAR};
328+
case substrait::Type::KindCase::kBinary:
329+
return {LogicalTypeId::BLOB};
330+
case substrait::Type::KindCase::kFp32:
331+
return {LogicalTypeId::FLOAT};
323332
case substrait::Type::KindCase::kFp64:
324333
return {LogicalTypeId::DOUBLE};
334+
case substrait::Type::KindCase::kTimestamp:
335+
return {LogicalTypeId::TIMESTAMP};
336+
case substrait::Type::KindCase::kList: {
337+
auto &s_list_type = s_type.list();
338+
auto element_type = SubstraitToDuckType(s_list_type.type());
339+
return LogicalType::LIST(element_type);
340+
}
341+
case substrait::Type::KindCase::kMap: {
342+
auto &s_map_type = s_type.map();
343+
auto key_type = SubstraitToDuckType(s_map_type.key());
344+
auto value_type = SubstraitToDuckType(s_map_type.value());
345+
return LogicalType::MAP(key_type, value_type);
346+
}
347+
case substrait::Type::KindCase::kStruct: {
348+
auto &s_struct_type = s_type.struct_();
349+
child_list_t<LogicalType> children;
350+
351+
for (idx_t i = 0; i < s_struct_type.types_size(); i++) {
352+
auto field_name = "f" + std::to_string(i);
353+
auto field_type = SubstraitToDuckType(s_struct_type.types(i));
354+
children.push_back(make_pair(field_name, field_type));
355+
}
356+
357+
return LogicalType::STRUCT(children);
358+
}
359+
case substrait::Type::KindCase::kUuid:
360+
return {LogicalTypeId::UUID};
361+
case substrait::Type::KindCase::kIntervalDay:
362+
case substrait::Type::KindCase::kIntervalYear:
363+
return {LogicalTypeId::INTERVAL};
325364
default:
326365
throw NotImplementedException("Substrait type not yet supported: %s",
327-
substrait::Type::GetDescriptor()->FindFieldByNumber(s_type.kind_case())->name());
366+
substrait::Type::GetDescriptor()->FindFieldByNumber(s_type.kind_case())->name());
328367
}
329368
}
330369

@@ -378,9 +417,10 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait:
378417
} else if (nested_expression.has_map()) {
379418
auto &map_expression = nested_expression.map();
380419
vector<unique_ptr<ParsedExpression>> children;
381-
auto key_value = map_expression.key_values();
382-
children.emplace_back(TransformExpr(key_value[0].key()));
383-
children.emplace_back(TransformExpr(key_value[0].value()));
420+
for (auto &key_value_pair : map_expression.key_values()) {
421+
children.emplace_back(TransformExpr(key_value_pair.key()));
422+
children.emplace_back(TransformExpr(key_value_pair.value()));
423+
}
384424
return make_uniq<FunctionExpression>("map", std::move(children));
385425

386426
} else {
@@ -410,15 +450,15 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::E
410450
return TransformNested(sexpr, iterator);
411451
case substrait::Expression::RexTypeCase::kSubquery:
412452
default:
413-
throw NotImplementedException("Unsupported expression type %s",
414-
substrait::Expression::GetDescriptor()->FindFieldByNumber(sexpr.rex_type_case())->name());
453+
throw NotImplementedException(
454+
"Unsupported expression type %s",
455+
substrait::Expression::GetDescriptor()->FindFieldByNumber(sexpr.rex_type_case())->name());
415456
}
416457
}
417458

418459
string SubstraitToDuckDB::FindFunction(uint64_t id) {
419460
if (functions_map.find(id) == functions_map.end()) {
420-
throw NotImplementedException("Could not find aggregate function %s",
421-
to_string(id));
461+
throw NotImplementedException("Could not find aggregate function %s", to_string(id));
422462
}
423463
return functions_map[id];
424464
}
@@ -446,8 +486,9 @@ OrderByNode SubstraitToDuckDB::TransformOrder(const substrait::SortField &sordf)
446486
dnullorder = OrderByNullType::NULLS_LAST;
447487
break;
448488
default:
449-
throw NotImplementedException("Unsupported ordering %s",
450-
substrait::SortField::GetDescriptor()->FindFieldByNumber(sordf.direction())->name());
489+
throw NotImplementedException(
490+
"Unsupported ordering %s",
491+
substrait::SortField::GetDescriptor()->FindFieldByNumber(sordf.direction())->name());
451492
}
452493

453494
return {dordertype, dnullorder, TransformExpr(sordf.expr())};
@@ -478,7 +519,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformJoinOp(const substrait::Rel &so
478519
break;
479520
default:
480521
throw NotImplementedException("Unsupported join type: %s",
481-
substrait::JoinRel::GetDescriptor()->FindFieldByNumber(sjoin.type())->name());
522+
substrait::JoinRel::GetDescriptor()->FindFieldByNumber(sjoin.type())->name());
482523
}
483524
unique_ptr<ParsedExpression> join_condition = TransformExpr(sjoin.expression());
484525
return make_shared_ptr<JoinRelation>(TransformOp(sjoin.left())->Alias("left"),
@@ -506,8 +547,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &
506547
return make_shared_ptr<FilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
507548
}
508549

509-
const substrait::RelCommon* GetCommon(const substrait::Rel &sop) {
510-
const substrait::RelCommon * common;
550+
const substrait::RelCommon *GetCommon(const substrait::Rel &sop) {
551+
const substrait::RelCommon *common;
511552
switch (sop.rel_type_case()) {
512553
case substrait::Rel::RelTypeCase::kRead:
513554
return &sop.read().common();
@@ -550,12 +591,12 @@ const substrait::RelCommon* GetCommon(const substrait::Rel &sop) {
550591
case substrait::Rel::RelTypeCase::kDdl:
551592
default:
552593
throw NotImplementedException("Unsupported relation type %s",
553-
substrait::Rel::GetDescriptor()->FindFieldByNumber(sop.rel_type_case())->name());
594+
substrait::Rel::GetDescriptor()->FindFieldByNumber(sop.rel_type_case())->name());
554595
}
555596
}
556597

557-
const google::protobuf::RepeatedField<int32_t>& GetOutputMapping(const substrait::Rel &sop) {
558-
const substrait::RelCommon* common = GetCommon(sop);
598+
const google::protobuf::RepeatedField<int32_t> &GetOutputMapping(const substrait::Rel &sop) {
599+
const substrait::RelCommon *common = GetCommon(sop);
559600
if (!common->has_emit()) {
560601
static google::protobuf::RepeatedField<int32_t> empty_mapping;
561602
return empty_mapping;
@@ -757,15 +798,15 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
757798
}
758799
parameters.push_back(Value::UBIGINT(snapshot_id));
759800
} else if (sget.iceberg_table().direct().has_snapshot_timestamp()) {
760-
parameters.push_back( Value::TIMESTAMP(timestamp_t(sget.iceberg_table().direct().snapshot_timestamp())));
801+
parameters.push_back(Value::TIMESTAMP(timestamp_t(sget.iceberg_table().direct().snapshot_timestamp())));
761802
}
762803
shared_ptr<TableFunctionRelation> scan_rel;
763804
if (acquire_lock) {
764805
scan_rel = make_shared_ptr<TableFunctionRelation>(context, "iceberg_scan", parameters,
765-
std::move(named_parameters));
806+
std::move(named_parameters));
766807
} else {
767808
scan_rel = make_shared_ptr<TableFunctionRelation>(context_wrapper, "iceberg_scan", parameters,
768-
std::move(named_parameters));
809+
std::move(named_parameters));
769810
}
770811
auto rel = static_cast<Relation *>(scan_rel.get());
771812
scan = rel->Alias(name);
@@ -810,7 +851,8 @@ shared_ptr<Relation> SubstraitToDuckDB::GetValueRelationWithSingleBoolColumn() {
810851
return scan;
811852
}
812853

813-
shared_ptr<Relation> SubstraitToDuckDB::GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows) {
854+
shared_ptr<Relation> SubstraitToDuckDB::GetValuesExpression(
855+
const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows) {
814856
vector<vector<unique_ptr<ParsedExpression>>> expressions;
815857
for (auto &row : expression_rows) {
816858
vector<unique_ptr<ParsedExpression>> expression_row;
@@ -852,7 +894,7 @@ static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop)
852894
}
853895
default: {
854896
throw NotImplementedException("SetOperationType transform not implemented for SetRel_SetOp type %s",
855-
substrait::SetRel::GetDescriptor()->FindFieldByNumber(setop)->name());
897+
substrait::SetRel::GetDescriptor()->FindFieldByNumber(setop)->name());
856898
}
857899
}
858900
}
@@ -895,28 +937,30 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s
895937
}
896938
auto input = TransformOp(swrite.input());
897939
switch (swrite.op()) {
898-
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
899-
return input->CreateRel(schema_name, table_name);
940+
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
941+
return input->CreateRel(schema_name, table_name);
900942
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT:
901943
return input->InsertRel(schema_name, table_name);
902944
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
903945
switch (input->type) {
904946
case RelationType::PROJECTION_RELATION: {
905947
auto project = std::move(input.get()->Cast<ProjectionRelation>());
906948
auto filter = std::move(project.child->Cast<FilterRelation>());
907-
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name, table_name);
949+
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name,
950+
table_name);
908951
}
909952
case RelationType::FILTER_RELATION: {
910953
auto filter = std::move(input.get()->Cast<FilterRelation>());
911-
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name, table_name);
954+
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name,
955+
table_name);
912956
}
913957
default:
914958
throw NotImplementedException("Unsupported relation type for delete operation");
915959
}
916960
}
917961
default:
918962
throw NotImplementedException("Unsupported write operation %s",
919-
substrait::WriteRel::GetDescriptor()->FindFieldByNumber(swrite.op())->name());
963+
substrait::WriteRel::GetDescriptor()->FindFieldByNumber(swrite.op())->name());
920964
}
921965
}
922966

@@ -945,7 +989,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
945989
return TransformWriteOp(sop);
946990
default:
947991
throw NotImplementedException("Unsupported relation type %s",
948-
substrait::Rel::GetDescriptor()->FindFieldByNumber(sop.rel_type_case())->name());
992+
substrait::Rel::GetDescriptor()->FindFieldByNumber(sop.rel_type_case())->name());
949993
}
950994
}
951995

src/include/from_substrait.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ class SubstraitToDuckDB {
7575
shared_ptr<Relation> TransformAggregateOp(const substrait::Rel &sop);
7676
shared_ptr<Relation> TransformReadOp(const substrait::Rel &sop);
7777
shared_ptr<Relation> GetValueRelationWithSingleBoolColumn();
78-
shared_ptr<Relation> GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows);
78+
shared_ptr<Relation>
79+
GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows);
7980
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop,
8081
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
8182
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,

src/substrait_extension.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ struct FromSubstraitFunctionData : public TableFunctionData {
289289
};
290290

291291
static unique_ptr<FunctionData> SubstraitBind(ClientContext &context, TableFunctionBindInput &input,
292-
vector<LogicalType> &return_types, vector<string> &names, bool is_json) {
292+
vector<LogicalType> &return_types, vector<string> &names, bool is_json) {
293293
auto result = make_uniq<FromSubstraitFunctionData>();
294294
result->conn = make_uniq<Connection>(*context.db);
295295
if (input.inputs[0].IsNull()) {
@@ -306,12 +306,12 @@ static unique_ptr<FunctionData> SubstraitBind(ClientContext &context, TableFunct
306306
}
307307

308308
static unique_ptr<FunctionData> FromSubstraitBind(ClientContext &context, TableFunctionBindInput &input,
309-
vector<LogicalType> &return_types, vector<string> &names) {
309+
vector<LogicalType> &return_types, vector<string> &names) {
310310
return SubstraitBind(context, input, return_types, names, false);
311311
}
312312

313313
static unique_ptr<FunctionData> FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input,
314-
vector<LogicalType> &return_types, vector<string> &names) {
314+
vector<LogicalType> &return_types, vector<string> &names) {
315315
return SubstraitBind(context, input, return_types, names, true);
316316
}
317317

@@ -366,7 +366,8 @@ void InitializeFromSubstraitJSON(const Connection &con) {
366366
auto &catalog = Catalog::GetSystemCatalog(*con.context);
367367
// create the from_substrait table function that allows us to get a query
368368
// result from a substrait plan
369-
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, FromSubFunction, FromSubstraitBindJSON);
369+
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, FromSubFunction,
370+
FromSubstraitBindJSON);
370371
from_sub_func_json.bind_replace = FromSubstraitBindReplaceJSON;
371372
CreateTableFunctionInfo from_sub_info_json(from_sub_func_json);
372373
catalog.CreateTableFunction(*con.context, from_sub_info_json);

src/to_substrait.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,8 @@ substrait::RelCommon *DuckDBToSubstrait::CreateOutputMapping(vector<int32_t> vec
879879
return rel_common;
880880
}
881881

882-
bool DuckDBToSubstrait::IsPassthroughProjection(LogicalProjection &dproj, idx_t child_column_count, bool &needs_output_mapping) {
882+
bool DuckDBToSubstrait::IsPassthroughProjection(LogicalProjection &dproj, idx_t child_column_count,
883+
bool &needs_output_mapping) {
883884
// check if the projection is just pass through of input columns with no reordering
884885
needs_output_mapping = true;
885886
auto isPassThrough = true;
@@ -1028,12 +1029,12 @@ substrait::Rel *DuckDBToSubstrait::TransformOrderBy(LogicalOperator &dop) {
10281029
return res;
10291030
}
10301031

1031-
void PrintRelAsJson(substrait::Rel * rel) {
1032+
void PrintRelAsJson(substrait::Rel *rel) {
10321033
static int i;
10331034
std::string json_output;
10341035
google::protobuf::util::JsonPrintOptions options;
1035-
options.add_whitespace = false; // Pretty-print with indentation
1036-
options.always_print_primitive_fields = true; // Print even if default values
1036+
options.add_whitespace = false; // Pretty-print with indentation
1037+
options.always_print_primitive_fields = true; // Print even if default values
10371038

10381039
auto status = google::protobuf::util::MessageToJsonString(*rel, &json_output, options);
10391040
if (!status.ok()) {
@@ -1325,6 +1326,37 @@ substrait::Type DuckDBToSubstrait::DuckToSubstraitType(const LogicalType &type,
13251326
s_type.set_allocated_struct_(struct_type);
13261327
return s_type;
13271328
}
1329+
case LogicalTypeId::MAP: {
1330+
auto map_type = new substrait::Type_Map;
1331+
map_type->set_nullability(type_nullability);
1332+
1333+
auto key_type = MapType::KeyType(type);
1334+
auto value_type = MapType::ValueType(type);
1335+
1336+
auto key = new substrait::Type();
1337+
*key = DuckToSubstraitType(key_type, column_statistics, not_null);
1338+
map_type->set_allocated_key(key);
1339+
1340+
auto value = new substrait::Type();
1341+
*value = DuckToSubstraitType(value_type, column_statistics, not_null);
1342+
map_type->set_allocated_value(value);
1343+
1344+
s_type.set_allocated_map(map_type);
1345+
return s_type;
1346+
}
1347+
case LogicalTypeId::LIST: {
1348+
auto list_type = new substrait::Type_List;
1349+
list_type->set_nullability(type_nullability);
1350+
1351+
auto child_type = ListType::GetChildType(type);
1352+
1353+
auto element_type = new substrait::Type();
1354+
*element_type = DuckToSubstraitType(child_type, column_statistics, not_null);
1355+
list_type->set_allocated_type(element_type);
1356+
1357+
s_type.set_allocated_list(list_type);
1358+
return s_type;
1359+
}
13281360
default:
13291361
throw NotImplementedException("Logical Type " + type.ToString() +
13301362
" not implemented as Substrait Schema Result.");

0 commit comments

Comments
 (0)