Skip to content

Commit 7d0a8ef

Browse files
committed
Accept table scan as top root node in from_substrait
1 parent 5469b2e commit 7d0a8ef

File tree

2 files changed

+96
-10
lines changed

2 files changed

+96
-10
lines changed

src/from_substrait.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121

2222
#include "duckdb/parser/expression/comparison_expression.hpp"
2323

24-
#include "substrait/plan.pb.h"
25-
#include "google/protobuf/util/json_util.h"
2624
#include "duckdb/main/client_data.hpp"
25+
#include "google/protobuf/util/json_util.h"
26+
#include "substrait/plan.pb.h"
27+
28+
#include <duckdb/main/relation/table_relation.hpp>
2729

2830
namespace duckdb {
2931
const std::unordered_map<std::string, std::string> SubstraitToDuckDB::function_names_remap = {
@@ -625,21 +627,22 @@ int32_t SkipColumnNames(const LogicalType &type) {
625627
return columns_to_skip;
626628
}
627629

628-
Relation *GetProjectionRelation(Relation &relation, string &error) {
630+
Relation *GetProjectionOrTableRelation(Relation &relation, string &error) {
629631
error += RelationTypeToString(relation.type);
630632
switch (relation.type) {
633+
case RelationType::TABLE_RELATION:
631634
case RelationType::PROJECTION_RELATION:
632635
error += " -> ";
633636
return &relation;
634637
case RelationType::LIMIT_RELATION:
635638
error += " -> ";
636-
return GetProjectionRelation(*relation.Cast<LimitRelation>().child, error);
639+
return GetProjectionOrTableRelation(*relation.Cast<LimitRelation>().child, error);
637640
case RelationType::ORDER_RELATION:
638641
error += " -> ";
639-
return GetProjectionRelation(*relation.Cast<OrderRelation>().child, error);
642+
return GetProjectionOrTableRelation(*relation.Cast<OrderRelation>().child, error);
640643
case RelationType::SET_OPERATION_RELATION:
641644
error += " -> ";
642-
return GetProjectionRelation(*relation.Cast<SetOpRelation>().right, error);
645+
return GetProjectionOrTableRelation(*relation.Cast<SetOpRelation>().right, error);
643646
default:
644647
throw NotImplementedException(
645648
"Relation %s is not yet implemented as a possible root chain type of from_substrait function", error);
@@ -648,15 +651,20 @@ Relation *GetProjectionRelation(Relation &relation, string &error) {
648651

649652
shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot &sop) {
650653
vector<string> aliases;
651-
auto column_names = sop.names();
654+
const auto &column_names = sop.names();
652655
vector<unique_ptr<ParsedExpression>> expressions;
653656
int id = 1;
654657
auto child = TransformOp(sop.input());
655658
string error;
656-
auto first_projection = GetProjectionRelation(*child, error);
657-
auto &columns = first_projection->Cast<ProjectionRelation>().columns;
659+
auto first_projection_or_table = GetProjectionOrTableRelation(*child, error);
660+
vector<ColumnDefinition> *column_definitions;
661+
if (first_projection_or_table->type == RelationType::PROJECTION_RELATION) {
662+
column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
663+
} else {
664+
column_definitions = &first_projection_or_table->Cast<TableRelation>().description->columns;
665+
}
658666
int32_t i = 0;
659-
for (auto &column : columns) {
667+
for (auto &column : *column_definitions) {
660668
aliases.push_back(column_names[i++]);
661669
auto column_type = column.GetType();
662670
i += SkipColumnNames(column.GetType());

test/sql/test_direct_scan.test

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# name: test/sql/test_direct_scan.test
2+
# description: Test that a direct table scan works
3+
# group: [sql]
4+
5+
require substrait
6+
7+
statement ok
8+
PRAGMA enable_verification
9+
10+
statement ok
11+
create table users (user_id varchar, name varchar, paid_for_service bool);
12+
13+
statement ok
14+
insert into users values ('1', 'Pedro', false);
15+
16+
statement ok
17+
CALL get_substrait('FROM users')
18+
19+
query III
20+
CALL from_substrait_json('{
21+
"relations": [
22+
{
23+
"root": {
24+
"input": {
25+
"read": {
26+
"common": {
27+
"direct": {}
28+
},
29+
"baseSchema": {
30+
"names": [
31+
"user_id",
32+
"name",
33+
"paid_for_service"
34+
],
35+
"struct": {
36+
"types": [
37+
{
38+
"string": {
39+
"nullability": "NULLABILITY_NULLABLE"
40+
}
41+
},
42+
{
43+
"string": {
44+
"nullability": "NULLABILITY_NULLABLE"
45+
}
46+
},
47+
{
48+
"bool": {
49+
"nullability": "NULLABILITY_NULLABLE"
50+
}
51+
}
52+
],
53+
"nullability": "NULLABILITY_REQUIRED"
54+
}
55+
},
56+
"namedTable": {
57+
"names": [
58+
"users"
59+
]
60+
}
61+
}
62+
},
63+
"names": [
64+
"user_id",
65+
"name",
66+
"paid_for_service"
67+
]
68+
}
69+
}
70+
],
71+
"version": {
72+
"minorNumber": 52,
73+
"producer": "spark-substrait-gateway"
74+
}
75+
}
76+
')
77+
----
78+
1 Pedro false

0 commit comments

Comments
 (0)