Skip to content

Commit f35aa93

Browse files
authored
Merge pull request #107 from pdet/substrait_regressions
Make projection finding function optional
2 parents 800be49 + 62316a6 commit f35aa93

File tree

4 files changed

+2404
-36
lines changed

4 files changed

+2404
-36
lines changed

src/from_substrait.cpp

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -331,31 +331,30 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait:
331331
auto &nested_expression = sexpr.nested();
332332
if (nested_expression.has_struct_()) {
333333
auto &struct_expression = nested_expression.struct_();
334-
vector<unique_ptr<ParsedExpression>> children;
335-
for (auto& child: struct_expression.fields()) {
334+
vector<unique_ptr<ParsedExpression>> children;
335+
for (auto &child : struct_expression.fields()) {
336336
children.emplace_back(TransformExpr(child));
337337
}
338338
return make_uniq<FunctionExpression>("row", std::move(children));
339339
} else if (nested_expression.has_list()) {
340340
auto &list_expression = nested_expression.list();
341-
vector<unique_ptr<ParsedExpression>> children;
342-
for (auto& child: list_expression.values()) {
341+
vector<unique_ptr<ParsedExpression>> children;
342+
for (auto &child : list_expression.values()) {
343343
children.emplace_back(TransformExpr(child));
344344
}
345345
return make_uniq<FunctionExpression>("list_value", std::move(children));
346346

347347
} else if (nested_expression.has_map()) {
348348
auto &map_expression = nested_expression.map();
349-
vector<unique_ptr<ParsedExpression>> children;
349+
vector<unique_ptr<ParsedExpression>> children;
350350
auto key_value = map_expression.key_values();
351351
children.emplace_back(TransformExpr(key_value[0].key()));
352352
children.emplace_back(TransformExpr(key_value[0].value()));
353353
return make_uniq<FunctionExpression>("map", std::move(children));
354354

355-
} else{
355+
} else {
356356
throw NotImplementedException("Substrait nested expression is not yet implemented.");
357357
}
358-
359358
}
360359

361360
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr) {
@@ -663,25 +662,18 @@ int32_t SkipColumnNames(const LogicalType &type) {
663662
return columns_to_skip;
664663
}
665664

666-
Relation *GetProjectionOrTableRelation(Relation &relation, string &error) {
667-
error += RelationTypeToString(relation.type);
665+
Relation *GetProjection(Relation &relation) {
668666
switch (relation.type) {
669-
case RelationType::TABLE_RELATION:
670667
case RelationType::PROJECTION_RELATION:
671-
error += " -> ";
672668
return &relation;
673669
case RelationType::LIMIT_RELATION:
674-
error += " -> ";
675-
return GetProjectionOrTableRelation(*relation.Cast<LimitRelation>().child, error);
670+
return GetProjection(*relation.Cast<LimitRelation>().child);
676671
case RelationType::ORDER_RELATION:
677-
error += " -> ";
678-
return GetProjectionOrTableRelation(*relation.Cast<OrderRelation>().child, error);
672+
return GetProjection(*relation.Cast<OrderRelation>().child);
679673
case RelationType::SET_OPERATION_RELATION:
680-
error += " -> ";
681-
return GetProjectionOrTableRelation(*relation.Cast<SetOpRelation>().right, error);
674+
return GetProjection(*relation.Cast<SetOpRelation>().right);
682675
default:
683-
throw NotImplementedException(
684-
"Relation %s is not yet implemented as a possible root chain type of from_substrait function", error);
676+
return nullptr;
685677
}
686678
}
687679

@@ -691,21 +683,23 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
691683
vector<unique_ptr<ParsedExpression>> expressions;
692684
int id = 1;
693685
auto child = TransformOp(sop.input());
694-
string error;
695-
auto first_projection_or_table = GetProjectionOrTableRelation(*child, error);
696-
vector<ColumnDefinition> *column_definitions;
697-
if (first_projection_or_table->type == RelationType::PROJECTION_RELATION) {
698-
column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
686+
auto first_projection_or_table = GetProjection(*child);
687+
if (first_projection_or_table) {
688+
vector<ColumnDefinition> *column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
689+
int32_t i = 0;
690+
for (auto &column : *column_definitions) {
691+
aliases.push_back(column_names[i++]);
692+
auto column_type = column.GetType();
693+
i += SkipColumnNames(column.GetType());
694+
expressions.push_back(make_uniq<PositionalReferenceExpression>(id++));
695+
}
699696
} else {
700-
column_definitions = &first_projection_or_table->Cast<TableRelation>().description->columns;
701-
}
702-
int32_t i = 0;
703-
for (auto &column : *column_definitions) {
704-
aliases.push_back(column_names[i++]);
705-
auto column_type = column.GetType();
706-
i += SkipColumnNames(column.GetType());
707-
expressions.push_back(make_uniq<PositionalReferenceExpression>(id++));
697+
for (auto &column_name : column_names) {
698+
aliases.push_back(column_name);
699+
expressions.push_back(make_uniq<PositionalReferenceExpression>(id++));
700+
}
708701
}
702+
709703
return make_shared_ptr<ProjectionRelation>(child, std::move(expressions), aliases);
710704
}
711705

src/include/from_substrait.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,6 @@ class SubstraitToDuckDB {
5858
//! names
5959
static const unordered_map<std::string, std::string> function_names_remap;
6060
static const case_insensitive_set_t valid_extract_subfields;
61-
vector<ParsedExpression*> struct_expressions;
61+
vector<ParsedExpression *> struct_expressions;
6262
};
6363
} // namespace duckdb

src/to_substrait.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,13 +315,12 @@ void DuckDBToSubstrait::TransformFunctionExpression(Expression &dexpr, substrait
315315
uint64_t col_offset) {
316316
auto &dfun = dexpr.Cast<BoundFunctionExpression>();
317317

318-
319318
auto function_name = dfun.function.name;
320319

321320
if (function_name == "row") {
322321
auto nested_expression = sexpr.mutable_nested();
323322
auto struct_expression = nested_expression->mutable_struct_();
324-
for (auto& child: dfun.children) {
323+
for (auto &child : dfun.children) {
325324
auto child_expression = struct_expression->add_fields();
326325
TransformExpr(*child, *child_expression);
327326
}
@@ -330,7 +329,7 @@ void DuckDBToSubstrait::TransformFunctionExpression(Expression &dexpr, substrait
330329
if (function_name == "list_value" || function_name == "list_pack") {
331330
auto nested_expression = sexpr.mutable_nested();
332331
auto list_expression = nested_expression->mutable_list();
333-
for (auto& child: dfun.children) {
332+
for (auto &child : dfun.children) {
334333
auto child_value = list_expression->add_values();
335334
TransformExpr(*child, *child_value);
336335
}

0 commit comments

Comments
 (0)