Skip to content

Commit ae71f52

Browse files
add support for nested types (#160)
1 parent b77690d commit ae71f52

11 files changed

+1333
-42
lines changed

CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ add_library(${EXTENSION_NAME} STATIC ${EXTENSION_SOURCES})
107107

108108
set(PARAMETERS "-warnings")
109109
build_loadable_extension(${TARGET_NAME} ${PARAMETERS} ${EXTENSION_SOURCES})
110-
IF (DEFINED ENV{SKIP_SUBSTRAIT_C_TESTS})
111-
message(STATUS "Skipping substrait c tests")
112-
ELSE()
113-
add_subdirectory("test/c")
114-
ENDIF()
110+
if(DEFINED ENV{SKIP_SUBSTRAIT_C_TESTS})
111+
message(STATUS "Skipping substrait c tests")
112+
else()
113+
add_subdirectory("test/c")
114+
endif()
115115

116116
install(
117117
TARGETS ${EXTENSION_NAME}

src/from_substrait.cpp

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ Value TransformLiteralToValue(const substrait::Expression_Literal &literal) {
174174
case substrait::Expression_Literal::LiteralTypeCase::kVarChar:
175175
return {literal.var_char().value()};
176176
default:
177-
throw NotImplementedException("literals of this type are not implemented: %s",
178-
substrait::Expression_Literal::GetDescriptor()->FindFieldByNumber(literal.literal_type_case())->name());
177+
throw NotImplementedException(
178+
"literals of this type are not implemented: %s",
179+
substrait::Expression_Literal::GetDescriptor()->FindFieldByNumber(literal.literal_type_case())->name());
179180
}
180181
}
181182

@@ -305,6 +306,8 @@ LogicalType SubstraitToDuckDB::SubstraitToDuckType(const substrait::Type &s_type
305306
switch (s_type.kind_case()) {
306307
case substrait::Type::KindCase::kBool:
307308
return {LogicalTypeId::BOOLEAN};
309+
case substrait::Type::KindCase::kI8:
310+
return {LogicalTypeId::TINYINT};
308311
case substrait::Type::KindCase::kI16:
309312
return {LogicalTypeId::SMALLINT};
310313
case substrait::Type::KindCase::kI32:
@@ -317,14 +320,49 @@ LogicalType SubstraitToDuckDB::SubstraitToDuckType(const substrait::Type &s_type
317320
}
318321
case substrait::Type::KindCase::kDate:
319322
return {LogicalTypeId::DATE};
323+
case substrait::Type::KindCase::kTime:
324+
return {LogicalTypeId::TIME};
320325
case substrait::Type::KindCase::kVarchar:
321326
case substrait::Type::KindCase::kString:
322327
return {LogicalTypeId::VARCHAR};
328+
case substrait::Type::KindCase::kBinary:
329+
return {LogicalTypeId::BLOB};
330+
case substrait::Type::KindCase::kFp32:
331+
return {LogicalTypeId::FLOAT};
323332
case substrait::Type::KindCase::kFp64:
324333
return {LogicalTypeId::DOUBLE};
334+
case substrait::Type::KindCase::kTimestamp:
335+
return {LogicalTypeId::TIMESTAMP};
336+
case substrait::Type::KindCase::kList: {
337+
auto &s_list_type = s_type.list();
338+
auto element_type = SubstraitToDuckType(s_list_type.type());
339+
return LogicalType::LIST(element_type);
340+
}
341+
case substrait::Type::KindCase::kMap: {
342+
auto &s_map_type = s_type.map();
343+
auto key_type = SubstraitToDuckType(s_map_type.key());
344+
auto value_type = SubstraitToDuckType(s_map_type.value());
345+
return LogicalType::MAP(key_type, value_type);
346+
}
347+
case substrait::Type::KindCase::kStruct: {
348+
auto &s_struct_type = s_type.struct_();
349+
child_list_t<LogicalType> children;
350+
351+
for (idx_t i = 0; i < s_struct_type.types_size(); i++) {
352+
auto field_name = "f" + std::to_string(i);
353+
auto field_type = SubstraitToDuckType(s_struct_type.types(i));
354+
children.push_back(make_pair(field_name, field_type));
355+
}
356+
357+
return LogicalType::STRUCT(children);
358+
}
359+
case substrait::Type::KindCase::kUuid:
360+
return {LogicalTypeId::UUID};
361+
case substrait::Type::KindCase::kIntervalDay:
362+
return {LogicalTypeId::INTERVAL};
325363
default:
326364
throw NotImplementedException("Substrait type not yet supported: %s",
327-
substrait::Type::GetDescriptor()->FindFieldByNumber(s_type.kind_case())->name());
365+
substrait::Type::GetDescriptor()->FindFieldByNumber(s_type.kind_case())->name());
328366
}
329367
}
330368

@@ -378,9 +416,10 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait:
378416
} else if (nested_expression.has_map()) {
379417
auto &map_expression = nested_expression.map();
380418
vector<unique_ptr<ParsedExpression>> children;
381-
auto key_value = map_expression.key_values();
382-
children.emplace_back(TransformExpr(key_value[0].key()));
383-
children.emplace_back(TransformExpr(key_value[0].value()));
419+
for (auto &key_value_pair : map_expression.key_values()) {
420+
children.emplace_back(TransformExpr(key_value_pair.key()));
421+
children.emplace_back(TransformExpr(key_value_pair.value()));
422+
}
384423
return make_uniq<FunctionExpression>("map", std::move(children));
385424

386425
} else {
@@ -410,15 +449,15 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::E
410449
return TransformNested(sexpr, iterator);
411450
case substrait::Expression::RexTypeCase::kSubquery:
412451
default:
413-
throw NotImplementedException("Unsupported expression type %s",
414-
substrait::Expression::GetDescriptor()->FindFieldByNumber(sexpr.rex_type_case())->name());
452+
throw NotImplementedException(
453+
"Unsupported expression type %s",
454+
substrait::Expression::GetDescriptor()->FindFieldByNumber(sexpr.rex_type_case())->name());
415455
}
416456
}
417457

418458
string SubstraitToDuckDB::FindFunction(uint64_t id) {
419459
if (functions_map.find(id) == functions_map.end()) {
420-
throw NotImplementedException("Could not find aggregate function %s",
421-
to_string(id));
460+
throw NotImplementedException("Could not find aggregate function %s", to_string(id));
422461
}
423462
return functions_map[id];
424463
}
@@ -446,8 +485,9 @@ OrderByNode SubstraitToDuckDB::TransformOrder(const substrait::SortField &sordf)
446485
dnullorder = OrderByNullType::NULLS_LAST;
447486
break;
448487
default:
449-
throw NotImplementedException("Unsupported ordering %s",
450-
substrait::SortField::GetDescriptor()->FindFieldByNumber(sordf.direction())->name());
488+
throw NotImplementedException(
489+
"Unsupported ordering %s",
490+
substrait::SortField::GetDescriptor()->FindFieldByNumber(sordf.direction())->name());
451491
}
452492

453493
return {dordertype, dnullorder, TransformExpr(sordf.expr())};
@@ -478,7 +518,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformJoinOp(const substrait::Rel &so
478518
break;
479519
default:
480520
throw NotImplementedException("Unsupported join type: %s",
481-
substrait::JoinRel::GetDescriptor()->FindFieldByNumber(sjoin.type())->name());
521+
substrait::JoinRel::GetDescriptor()->FindFieldByNumber(sjoin.type())->name());
482522
}
483523
unique_ptr<ParsedExpression> join_condition = TransformExpr(sjoin.expression());
484524
return make_shared_ptr<JoinRelation>(TransformOp(sjoin.left())->Alias("left"),
@@ -506,8 +546,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &
506546
return make_shared_ptr<FilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
507547
}
508548

509-
const substrait::RelCommon* GetCommon(const substrait::Rel &sop) {
510-
const substrait::RelCommon * common;
549+
const substrait::RelCommon *GetCommon(const substrait::Rel &sop) {
550+
const substrait::RelCommon *common;
511551
switch (sop.rel_type_case()) {
512552
case substrait::Rel::RelTypeCase::kRead:
513553
return &sop.read().common();
@@ -550,12 +590,12 @@ const substrait::RelCommon* GetCommon(const substrait::Rel &sop) {
550590
case substrait::Rel::RelTypeCase::kDdl:
551591
default:
552592
throw NotImplementedException("Unsupported relation type %s",
553-
substrait::Rel::GetDescriptor()->FindFieldByNumber(sop.rel_type_case())->name());
593+
substrait::Rel::GetDescriptor()->FindFieldByNumber(sop.rel_type_case())->name());
554594
}
555595
}
556596

557-
const google::protobuf::RepeatedField<int32_t>& GetOutputMapping(const substrait::Rel &sop) {
558-
const substrait::RelCommon* common = GetCommon(sop);
597+
const google::protobuf::RepeatedField<int32_t> &GetOutputMapping(const substrait::Rel &sop) {
598+
const substrait::RelCommon *common = GetCommon(sop);
559599
if (!common->has_emit()) {
560600
static google::protobuf::RepeatedField<int32_t> empty_mapping;
561601
return empty_mapping;
@@ -757,15 +797,15 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
757797
}
758798
parameters.push_back(Value::UBIGINT(snapshot_id));
759799
} else if (sget.iceberg_table().direct().has_snapshot_timestamp()) {
760-
parameters.push_back( Value::TIMESTAMP(timestamp_t(sget.iceberg_table().direct().snapshot_timestamp())));
800+
parameters.push_back(Value::TIMESTAMP(timestamp_t(sget.iceberg_table().direct().snapshot_timestamp())));
761801
}
762802
shared_ptr<TableFunctionRelation> scan_rel;
763803
if (acquire_lock) {
764804
scan_rel = make_shared_ptr<TableFunctionRelation>(context, "iceberg_scan", parameters,
765-
std::move(named_parameters));
805+
std::move(named_parameters));
766806
} else {
767807
scan_rel = make_shared_ptr<TableFunctionRelation>(context_wrapper, "iceberg_scan", parameters,
768-
std::move(named_parameters));
808+
std::move(named_parameters));
769809
}
770810
auto rel = static_cast<Relation *>(scan_rel.get());
771811
scan = rel->Alias(name);
@@ -810,7 +850,8 @@ shared_ptr<Relation> SubstraitToDuckDB::GetValueRelationWithSingleBoolColumn() {
810850
return scan;
811851
}
812852

813-
shared_ptr<Relation> SubstraitToDuckDB::GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows) {
853+
shared_ptr<Relation> SubstraitToDuckDB::GetValuesExpression(
854+
const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows) {
814855
vector<vector<unique_ptr<ParsedExpression>>> expressions;
815856
for (auto &row : expression_rows) {
816857
vector<unique_ptr<ParsedExpression>> expression_row;
@@ -852,7 +893,7 @@ static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop)
852893
}
853894
default: {
854895
throw NotImplementedException("SetOperationType transform not implemented for SetRel_SetOp type %s",
855-
substrait::SetRel::GetDescriptor()->FindFieldByNumber(setop)->name());
896+
substrait::SetRel::GetDescriptor()->FindFieldByNumber(setop)->name());
856897
}
857898
}
858899
}
@@ -895,8 +936,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s
895936
}
896937
auto input = TransformOp(swrite.input());
897938
switch (swrite.op()) {
898-
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
899-
return input->CreateRel(schema_name, table_name);
939+
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
940+
return input->CreateRel(schema_name, table_name);
900941
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT:
901942
return input->InsertRel(schema_name, table_name);
902943
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
@@ -916,7 +957,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s
916957
}
917958
default:
918959
throw NotImplementedException("Unsupported write operation %s",
919-
substrait::WriteRel::GetDescriptor()->FindFieldByNumber(swrite.op())->name());
960+
substrait::WriteRel::GetDescriptor()->FindFieldByNumber(swrite.op())->name());
920961
}
921962
}
922963

@@ -945,7 +986,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
945986
return TransformWriteOp(sop);
946987
default:
947988
throw NotImplementedException("Unsupported relation type %s",
948-
substrait::Rel::GetDescriptor()->FindFieldByNumber(sop.rel_type_case())->name());
989+
substrait::Rel::GetDescriptor()->FindFieldByNumber(sop.rel_type_case())->name());
949990
}
950991
}
951992

src/include/from_substrait.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ class SubstraitToDuckDB {
7575
shared_ptr<Relation> TransformAggregateOp(const substrait::Rel &sop);
7676
shared_ptr<Relation> TransformReadOp(const substrait::Rel &sop);
7777
shared_ptr<Relation> GetValueRelationWithSingleBoolColumn();
78-
shared_ptr<Relation> GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows);
78+
shared_ptr<Relation>
79+
GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows);
7980
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop,
8081
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
8182
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,

src/substrait_extension.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ struct FromSubstraitFunctionData : public TableFunctionData {
289289
};
290290

291291
static unique_ptr<FunctionData> SubstraitBind(ClientContext &context, TableFunctionBindInput &input,
292-
vector<LogicalType> &return_types, vector<string> &names, bool is_json) {
292+
vector<LogicalType> &return_types, vector<string> &names, bool is_json) {
293293
auto result = make_uniq<FromSubstraitFunctionData>();
294294
result->conn = make_uniq<Connection>(*context.db);
295295
if (input.inputs[0].IsNull()) {
@@ -306,12 +306,12 @@ static unique_ptr<FunctionData> SubstraitBind(ClientContext &context, TableFunct
306306
}
307307

308308
static unique_ptr<FunctionData> FromSubstraitBind(ClientContext &context, TableFunctionBindInput &input,
309-
vector<LogicalType> &return_types, vector<string> &names) {
309+
vector<LogicalType> &return_types, vector<string> &names) {
310310
return SubstraitBind(context, input, return_types, names, false);
311311
}
312312

313313
static unique_ptr<FunctionData> FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input,
314-
vector<LogicalType> &return_types, vector<string> &names) {
314+
vector<LogicalType> &return_types, vector<string> &names) {
315315
return SubstraitBind(context, input, return_types, names, true);
316316
}
317317

@@ -366,7 +366,8 @@ void InitializeFromSubstraitJSON(const Connection &con) {
366366
auto &catalog = Catalog::GetSystemCatalog(*con.context);
367367
// create the from_substrait table function that allows us to get a query
368368
// result from a substrait plan
369-
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, FromSubFunction, FromSubstraitBindJSON);
369+
TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, FromSubFunction,
370+
FromSubstraitBindJSON);
370371
from_sub_func_json.bind_replace = FromSubstraitBindReplaceJSON;
371372
CreateTableFunctionInfo from_sub_info_json(from_sub_func_json);
372373
catalog.CreateTableFunction(*con.context, from_sub_info_json);

src/to_substrait.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,8 @@ substrait::RelCommon *DuckDBToSubstrait::CreateOutputMapping(vector<int32_t> vec
879879
return rel_common;
880880
}
881881

882-
bool DuckDBToSubstrait::IsPassthroughProjection(LogicalProjection &dproj, idx_t child_column_count, bool &needs_output_mapping) {
882+
bool DuckDBToSubstrait::IsPassthroughProjection(LogicalProjection &dproj, idx_t child_column_count,
883+
bool &needs_output_mapping) {
883884
// check if the projection is just pass through of input columns with no reordering
884885
needs_output_mapping = true;
885886
auto isPassThrough = true;
@@ -1028,12 +1029,12 @@ substrait::Rel *DuckDBToSubstrait::TransformOrderBy(LogicalOperator &dop) {
10281029
return res;
10291030
}
10301031

1031-
void PrintRelAsJson(substrait::Rel * rel) {
1032+
void PrintRelAsJson(substrait::Rel *rel) {
10321033
static int i;
10331034
std::string json_output;
10341035
google::protobuf::util::JsonPrintOptions options;
1035-
options.add_whitespace = false; // Pretty-print with indentation
1036-
options.always_print_primitive_fields = true; // Print even if default values
1036+
options.add_whitespace = false; // Pretty-print with indentation
1037+
options.always_print_primitive_fields = true; // Print even if default values
10371038

10381039
auto status = google::protobuf::util::MessageToJsonString(*rel, &json_output, options);
10391040
if (!status.ok()) {
@@ -1325,6 +1326,37 @@ substrait::Type DuckDBToSubstrait::DuckToSubstraitType(const LogicalType &type,
13251326
s_type.set_allocated_struct_(struct_type);
13261327
return s_type;
13271328
}
1329+
case LogicalTypeId::MAP: {
1330+
auto map_type = new substrait::Type_Map;
1331+
map_type->set_nullability(type_nullability);
1332+
1333+
auto key_type = MapType::KeyType(type);
1334+
auto value_type = MapType::ValueType(type);
1335+
1336+
auto key = new substrait::Type();
1337+
*key = DuckToSubstraitType(key_type, column_statistics, not_null);
1338+
map_type->set_allocated_key(key);
1339+
1340+
auto value = new substrait::Type();
1341+
*value = DuckToSubstraitType(value_type, column_statistics, not_null);
1342+
map_type->set_allocated_value(value);
1343+
1344+
s_type.set_allocated_map(map_type);
1345+
return s_type;
1346+
}
1347+
case LogicalTypeId::LIST: {
1348+
auto list_type = new substrait::Type_List;
1349+
list_type->set_nullability(type_nullability);
1350+
1351+
auto child_type = ListType::GetChildType(type);
1352+
1353+
auto element_type = new substrait::Type();
1354+
*element_type = DuckToSubstraitType(child_type, column_statistics, not_null);
1355+
list_type->set_allocated_type(element_type);
1356+
1357+
s_type.set_allocated_list(list_type);
1358+
return s_type;
1359+
}
13281360
default:
13291361
throw NotImplementedException("Logical Type " + type.ToString() +
13301362
" not implemented as Substrait Schema Result.");

0 commit comments

Comments
 (0)