Skip to content

Commit 918a12d

Browse files
authored
Merge pull request #115 from pdet/structs
Structs - Fix
2 parents bc9f4b3 + d07d339 commit 918a12d

File tree

3 files changed

+237
-26
lines changed

3 files changed

+237
-26
lines changed

src/from_substrait.cpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,25 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformInExpr(const substrait:
336336
return make_uniq<OperatorExpression>(ExpressionType::COMPARE_IN, std::move(values));
337337
}
338338

339-
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait::Expression &sexpr) {
339+
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait::Expression &sexpr,
340+
RootNameIterator *iterator) {
340341
auto &nested_expression = sexpr.nested();
341342
if (nested_expression.has_struct_()) {
342343
auto &struct_expression = nested_expression.struct_();
343344
vector<unique_ptr<ParsedExpression>> children;
344345
for (auto &child : struct_expression.fields()) {
345346
children.emplace_back(TransformExpr(child));
346347
}
347-
return make_uniq<FunctionExpression>("row", std::move(children));
348+
if (iterator && !iterator->Finished() && iterator->Unique(children.size())) {
349+
for (auto &child : children) {
350+
child->alias = iterator->GetCurrentName();
351+
iterator->Next();
352+
}
353+
return make_uniq<FunctionExpression>("struct_pack", std::move(children));
354+
} else {
355+
return make_uniq<FunctionExpression>("row", std::move(children));
356+
}
357+
348358
} else if (nested_expression.has_list()) {
349359
auto &list_expression = nested_expression.list();
350360
vector<unique_ptr<ParsedExpression>> children;
@@ -366,7 +376,11 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait:
366376
}
367377
}
368378

369-
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr) {
379+
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr,
380+
RootNameIterator *iterator) {
381+
if (iterator) {
382+
iterator->Next();
383+
}
370384
switch (sexpr.rex_type_case()) {
371385
case substrait::Expression::RexTypeCase::kLiteral:
372386
return TransformLiteralExpr(sexpr);
@@ -381,7 +395,7 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::E
381395
case substrait::Expression::RexTypeCase::kSingularOrList:
382396
return TransformInExpr(sexpr);
383397
case substrait::Expression::RexTypeCase::kNested:
384-
return TransformNested(sexpr);
398+
return TransformNested(sexpr, iterator);
385399
case substrait::Expression::RexTypeCase::kSubquery:
386400
default:
387401
throw InternalException("Unsupported expression type " + to_string(sexpr.rex_type_case()));
@@ -463,22 +477,27 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformCrossProductOp(const substrait:
463477
TransformOp(sub_cross.right())->Alias("right"));
464478
}
465479

466-
shared_ptr<Relation> SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop) {
480+
shared_ptr<Relation> SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop,
481+
const google::protobuf::RepeatedPtrField<std::string> *names) {
467482
auto &slimit = sop.fetch();
468483
idx_t limit = slimit.count() == -1 ? NumericLimits<idx_t>::Maximum() : slimit.count();
469484
idx_t offset = slimit.offset();
470-
return make_shared_ptr<LimitRelation>(TransformOp(slimit.input()), limit, offset);
485+
return make_shared_ptr<LimitRelation>(TransformOp(slimit.input(), names), limit, offset);
471486
}
472487

473488
shared_ptr<Relation> SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &sop) {
474489
auto &sfilter = sop.filter();
475490
return make_shared_ptr<FilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
476491
}
477492

478-
shared_ptr<Relation> SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop) {
493+
shared_ptr<Relation>
494+
SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop,
495+
const google::protobuf::RepeatedPtrField<std::string> *names) {
479496
vector<unique_ptr<ParsedExpression>> expressions;
497+
RootNameIterator iterator(names);
498+
480499
for (auto &sexpr : sop.project().expressions()) {
481-
expressions.push_back(TransformExpr(sexpr));
500+
expressions.push_back(TransformExpr(sexpr, &iterator));
482501
}
483502

484503
vector<string> mock_aliases;
@@ -635,12 +654,13 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
635654
return scan;
636655
}
637656

638-
shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop) {
657+
shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop,
658+
const google::protobuf::RepeatedPtrField<std::string> *names) {
639659
vector<OrderByNode> order_nodes;
640660
for (auto &sordf : sop.sort().sorts()) {
641661
order_nodes.push_back(TransformOrder(sordf));
642662
}
643-
return make_shared_ptr<OrderRelation>(TransformOp(sop.sort().input()), std::move(order_nodes));
663+
return make_shared_ptr<OrderRelation>(TransformOp(sop.sort().input(), names), std::move(order_nodes));
644664
}
645665

646666
static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) {
@@ -660,7 +680,8 @@ static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop)
660680
}
661681
}
662682

663-
shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop) {
683+
shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop,
684+
const google::protobuf::RepeatedPtrField<std::string> *names) {
664685
D_ASSERT(sop.has_set());
665686
auto &set = sop.set();
666687
auto set_op_type = set.op();
@@ -672,31 +693,32 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop
672693
throw NotImplementedException("The amount of inputs (%d) is not supported for this set operation", input_count);
673694
}
674695
auto lhs = TransformOp(inputs[0]);
675-
auto rhs = TransformOp(inputs[1]);
696+
auto rhs = TransformOp(inputs[1], names);
676697

677698
return make_shared_ptr<SetOpRelation>(std::move(lhs), std::move(rhs), type);
678699
}
679700

680-
shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) {
701+
shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
702+
const google::protobuf::RepeatedPtrField<std::string> *names) {
681703
switch (sop.rel_type_case()) {
682704
case substrait::Rel::RelTypeCase::kJoin:
683705
return TransformJoinOp(sop);
684706
case substrait::Rel::RelTypeCase::kCross:
685707
return TransformCrossProductOp(sop);
686708
case substrait::Rel::RelTypeCase::kFetch:
687-
return TransformFetchOp(sop);
709+
return TransformFetchOp(sop, names);
688710
case substrait::Rel::RelTypeCase::kFilter:
689711
return TransformFilterOp(sop);
690712
case substrait::Rel::RelTypeCase::kProject:
691-
return TransformProjectOp(sop);
713+
return TransformProjectOp(sop, names);
692714
case substrait::Rel::RelTypeCase::kAggregate:
693715
return TransformAggregateOp(sop);
694716
case substrait::Rel::RelTypeCase::kRead:
695717
return TransformReadOp(sop);
696718
case substrait::Rel::RelTypeCase::kSort:
697-
return TransformSortOp(sop);
719+
return TransformSortOp(sop, names);
698720
case substrait::Rel::RelTypeCase::kSet:
699-
return TransformSetOp(sop);
721+
return TransformSetOp(sop, names);
700722
default:
701723
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
702724
}
@@ -738,7 +760,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
738760
const auto &column_names = sop.names();
739761
vector<unique_ptr<ParsedExpression>> expressions;
740762
int id = 1;
741-
auto child = TransformOp(sop.input());
763+
auto child = TransformOp(sop.input(), &column_names);
742764
auto first_projection_or_table = GetProjection(*child);
743765
if (first_projection_or_table) {
744766
vector<ColumnDefinition> *column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;

src/include/from_substrait.hpp

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,42 @@
1616

1717
namespace duckdb {
1818

19+
struct RootNameIterator {
20+
explicit RootNameIterator(const google::protobuf::RepeatedPtrField<std::string> *names) : names(names) {};
21+
string GetCurrentName() const {
22+
if (!names) {
23+
return "";
24+
}
25+
if (iterator >= names->size()) {
26+
throw InvalidInputException("Trying to access invalid root name at struct creation");
27+
}
28+
return (*names)[iterator];
29+
}
30+
void Next() {
31+
++iterator;
32+
}
33+
bool Unique(idx_t count) const {
34+
idx_t pos = iterator;
35+
set<string> values;
36+
for (idx_t i = 0; i < count; i++) {
37+
if (values.find((*names)[pos]) != values.end()) {
38+
return false;
39+
}
40+
values.insert((*names)[pos]);
41+
pos++;
42+
}
43+
return true;
44+
}
45+
bool Finished() const {
46+
if (!names) {
47+
return true;
48+
}
49+
return iterator >= names->size();
50+
}
51+
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr;
52+
int iterator = 0;
53+
};
54+
1955
class SubstraitToDuckDB {
2056
public:
2157
SubstraitToDuckDB(shared_ptr<ClientContext> &context_p, const string &serialized, bool json = false,
@@ -27,26 +63,33 @@ class SubstraitToDuckDB {
2763
//! Transforms Substrait Plan Root To a DuckDB Relation
2864
shared_ptr<Relation> TransformRootOp(const substrait::RelRoot &sop);
2965
//! Transform Substrait Operations to DuckDB Relations
30-
shared_ptr<Relation> TransformOp(const substrait::Rel &sop);
66+
shared_ptr<Relation> TransformOp(const substrait::Rel &sop,
67+
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
3168
shared_ptr<Relation> TransformJoinOp(const substrait::Rel &sop);
3269
shared_ptr<Relation> TransformCrossProductOp(const substrait::Rel &sop);
33-
shared_ptr<Relation> TransformFetchOp(const substrait::Rel &sop);
70+
shared_ptr<Relation> TransformFetchOp(const substrait::Rel &sop,
71+
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
3472
shared_ptr<Relation> TransformFilterOp(const substrait::Rel &sop);
35-
shared_ptr<Relation> TransformProjectOp(const substrait::Rel &sop);
73+
shared_ptr<Relation> TransformProjectOp(const substrait::Rel &sop,
74+
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
3675
shared_ptr<Relation> TransformAggregateOp(const substrait::Rel &sop);
3776
shared_ptr<Relation> TransformReadOp(const substrait::Rel &sop);
38-
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop);
39-
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop);
77+
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop,
78+
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
79+
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
80+
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
4081

4182
//! Transform Substrait Expressions to DuckDB Expressions
42-
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr);
83+
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr,
84+
RootNameIterator *iterator = nullptr);
4385
static unique_ptr<ParsedExpression> TransformLiteralExpr(const substrait::Expression &sexpr);
4486
static unique_ptr<ParsedExpression> TransformSelectionExpr(const substrait::Expression &sexpr);
4587
unique_ptr<ParsedExpression> TransformScalarFunctionExpr(const substrait::Expression &sexpr);
4688
unique_ptr<ParsedExpression> TransformIfThenExpr(const substrait::Expression &sexpr);
4789
unique_ptr<ParsedExpression> TransformCastExpr(const substrait::Expression &sexpr);
4890
unique_ptr<ParsedExpression> TransformInExpr(const substrait::Expression &sexpr);
49-
unique_ptr<ParsedExpression> TransformNested(const substrait::Expression &sexpr);
91+
unique_ptr<ParsedExpression> TransformNested(const substrait::Expression &sexpr,
92+
RootNameIterator *iterator = nullptr);
5093

5194
static void VerifyCorrectExtractSubfield(const string &subfield);
5295
static string RemapFunctionName(const string &function_name);

test/sql/test_nested_expressions.test

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,150 @@ statement ok
4545
CALL get_substrait('SELECT row(row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10))) from t;')
4646

4747
statement ok
48-
CALL get_substrait('SELECT [[[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]]] from t;')
48+
CALL get_substrait('SELECT [[[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]]] from t;')
49+
50+
require tpch
51+
52+
statement ok
53+
CALL dbgen(sf=0.01)
54+
55+
query I
56+
CALL from_substrait_json('
57+
{
58+
"relations": [
59+
{
60+
"root": {
61+
"input": {
62+
"fetch": {
63+
"common": {
64+
"direct": {}
65+
},
66+
"input": {
67+
"project": {
68+
"common": {
69+
"emit": {
70+
"outputMapping": 8
71+
}
72+
},
73+
"input": {
74+
"read": {
75+
"common": {
76+
"direct": {}
77+
},
78+
"baseSchema": {
79+
"names": [
80+
"c_custkey",
81+
"c_name",
82+
"c_address",
83+
"c_nationkey",
84+
"c_phone",
85+
"c_acctbal",
86+
"c_mktsegment",
87+
"c_comment"
88+
],
89+
"struct": {
90+
"types": [
91+
{
92+
"i64": {
93+
"nullability": "NULLABILITY_NULLABLE"
94+
}
95+
},
96+
{
97+
"string": {
98+
"nullability": "NULLABILITY_NULLABLE"
99+
}
100+
},
101+
{
102+
"string": {
103+
"nullability": "NULLABILITY_NULLABLE"
104+
}
105+
},
106+
{
107+
"i32": {
108+
"nullability": "NULLABILITY_NULLABLE"
109+
}
110+
},
111+
{
112+
"string": {
113+
"nullability": "NULLABILITY_NULLABLE"
114+
}
115+
},
116+
{
117+
"decimal": {
118+
"scale": 2,
119+
"precision": 15,
120+
"nullability": "NULLABILITY_NULLABLE"
121+
}
122+
},
123+
{
124+
"string": {
125+
"nullability": "NULLABILITY_NULLABLE"
126+
}
127+
},
128+
{
129+
"string": {
130+
"nullability": "NULLABILITY_NULLABLE"
131+
}
132+
}
133+
],
134+
"nullability": "NULLABILITY_REQUIRED"
135+
}
136+
},
137+
"namedTable": {
138+
"names": [
139+
"customer"
140+
]
141+
}
142+
}
143+
},
144+
"expressions": [
145+
{
146+
"nested": {
147+
"struct": {
148+
"fields": [
149+
{
150+
"selection": {
151+
"directReference": {
152+
"structField": {}
153+
},
154+
"rootReference": {}
155+
}
156+
},
157+
{
158+
"selection": {
159+
"directReference": {
160+
"structField": {
161+
"field": 1
162+
}
163+
},
164+
"rootReference": {}
165+
}
166+
}
167+
]
168+
}
169+
}
170+
}
171+
]
172+
}
173+
},
174+
"count": 3
175+
}
176+
},
177+
"names": [
178+
"test_struct",
179+
"custid",
180+
"custname"
181+
]
182+
}
183+
}
184+
],
185+
"version": {
186+
"minorNumber": 52,
187+
"producer": "spark-substrait-gateway"
188+
}
189+
}
190+
')
191+
----
192+
{'custid': 1, 'custname': Customer#000000001}
193+
{'custid': 2, 'custname': Customer#000000002}
194+
{'custid': 3, 'custname': Customer#000000003}

0 commit comments

Comments
 (0)