Skip to content

Commit e1f5aa7

Browse files
authored
Merge pull request #101 from pdet/struct
Implementing Struct Types, adding their names in a DF, to schema, and root names.
2 parents 213b9cd + 47108e3 commit e1f5aa7

File tree

9 files changed

+327
-186
lines changed

9 files changed

+327
-186
lines changed

data/bug-17/test_table.parquet

888 Bytes
Binary file not shown.

src/custom_extensions.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace duckdb {
66

77
// FIXME: This cannot be the best way of getting string names of the types
8-
string TransformTypes(const ::substrait::Type &type) {
8+
string TransformTypes(const substrait::Type &type) {
99
auto str = type.DebugString();
1010
string str_type;
1111
for (auto &c : str) {
@@ -109,22 +109,22 @@ string SubstraitCustomFunction::GetName() {
109109
return function_signature;
110110
}
111111

112-
string SubstraitFunctionExtensions::GetExtensionURI() {
112+
string SubstraitFunctionExtensions::GetExtensionURI() const {
113113
if (IsNative()) {
114114
return "";
115115
}
116116
return "https://github.com/substrait-io/substrait/blob/main/extensions/" + extension_path;
117117
}
118118

119-
bool SubstraitFunctionExtensions::IsNative() {
119+
bool SubstraitFunctionExtensions::IsNative() const {
120120
return extension_path == "native";
121121
}
122122

123123
SubstraitCustomFunctions::SubstraitCustomFunctions() {
124124
Initialize();
125125
};
126126

127-
vector<string> SubstraitCustomFunctions::GetTypes(const vector<::substrait::Type> &types) const {
127+
vector<string> SubstraitCustomFunctions::GetTypes(const vector<substrait::Type> &types) {
128128
vector<string> transformed_types;
129129
for (auto &type : types) {
130130
transformed_types.emplace_back(TransformTypes(type));

src/from_substrait.cpp

Lines changed: 76 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ const case_insensitive_set_t SubstraitToDuckDB::valid_extract_subfields = {
3636
"year", "month", "day", "decade", "century", "millenium",
3737
"quarter", "microsecond", "milliseconds", "second", "minute", "hour"};
3838

39-
std::string SubstraitToDuckDB::RemapFunctionName(std::string &function_name) {
39+
string SubstraitToDuckDB::RemapFunctionName(const string &function_name) {
4040
// Lets first drop any extension id
4141
string name;
4242
for (auto &c : function_name) {
@@ -52,7 +52,7 @@ std::string SubstraitToDuckDB::RemapFunctionName(std::string &function_name) {
5252
return name;
5353
}
5454

55-
std::string SubstraitToDuckDB::RemoveExtension(std::string &function_name) {
55+
string SubstraitToDuckDB::RemoveExtension(const string &function_name) {
5656
// Lets first drop any extension id
5757
string name;
5858
for (auto &c : function_name) {
@@ -97,10 +97,10 @@ Value TransformLiteralToValue(const substrait::Expression_Literal &literal) {
9797
return {literal.string()};
9898
case substrait::Expression_Literal::LiteralTypeCase::kDecimal: {
9999
const auto &substrait_decimal = literal.decimal();
100-
auto raw_value = (uint64_t *)substrait_decimal.value().c_str();
100+
auto raw_value = reinterpret_cast<const uint64_t *>(substrait_decimal.value().c_str());
101101
hugeint_t substrait_value {};
102102
substrait_value.lower = raw_value[0];
103-
substrait_value.upper = raw_value[1];
103+
substrait_value.upper = static_cast<int64_t>(raw_value[1]);
104104
Value val = Value::HUGEINT(substrait_value);
105105
auto decimal_type = LogicalType::DECIMAL(substrait_decimal.precision(), substrait_decimal.scale());
106106
// cast to correct value
@@ -123,7 +123,7 @@ Value TransformLiteralToValue(const substrait::Expression_Literal &literal) {
123123
return Value(literal.boolean());
124124
}
125125
case substrait::Expression_Literal::LiteralTypeCase::kI8:
126-
return Value::TINYINT(literal.i8());
126+
return Value::TINYINT(static_cast<int8_t>(literal.i8()));
127127
case substrait::Expression_Literal::LiteralTypeCase::kI32:
128128
return Value::INTEGER(literal.i32());
129129
case substrait::Expression_Literal::LiteralTypeCase::kI64:
@@ -278,27 +278,29 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformIfThenExpr(const substr
278278
return std::move(dcase);
279279
}
280280

281-
LogicalType SubstraitToDuckDB::SubstraitToDuckType(const ::substrait::Type &s_type) {
282-
283-
if (s_type.has_bool_()) {
284-
return LogicalType(LogicalTypeId::BOOLEAN);
285-
} else if (s_type.has_i16()) {
286-
return LogicalType(LogicalTypeId::SMALLINT);
287-
} else if (s_type.has_i32()) {
288-
return LogicalType(LogicalTypeId::INTEGER);
289-
} else if (s_type.has_decimal()) {
281+
LogicalType SubstraitToDuckDB::SubstraitToDuckType(const substrait::Type &s_type) {
282+
switch (s_type.kind_case()) {
283+
case substrait::Type::KindCase::kBool:
284+
return {LogicalTypeId::BOOLEAN};
285+
case substrait::Type::KindCase::kI16:
286+
return {LogicalTypeId::SMALLINT};
287+
case substrait::Type::KindCase::kI32:
288+
return {LogicalTypeId::INTEGER};
289+
case substrait::Type::KindCase::kI64:
290+
return {LogicalTypeId::BIGINT};
291+
case substrait::Type::KindCase::kDecimal: {
290292
auto &s_decimal_type = s_type.decimal();
291293
return LogicalType::DECIMAL(s_decimal_type.precision(), s_decimal_type.scale());
292-
} else if (s_type.has_i64()) {
293-
return LogicalType(LogicalTypeId::BIGINT);
294-
} else if (s_type.has_date()) {
295-
return LogicalType(LogicalTypeId::DATE);
296-
} else if (s_type.has_varchar() || s_type.has_string()) {
297-
return LogicalType(LogicalTypeId::VARCHAR);
298-
} else if (s_type.has_fp64()) {
299-
return LogicalType(LogicalTypeId::DOUBLE);
300-
} else {
301-
throw InternalException("Substrait type not yet supported");
294+
}
295+
case substrait::Type::KindCase::kDate:
296+
return {LogicalTypeId::DATE};
297+
case substrait::Type::KindCase::kVarchar:
298+
case substrait::Type::KindCase::kString:
299+
return {LogicalTypeId::VARCHAR};
300+
case substrait::Type::KindCase::kFp64:
301+
return {LogicalTypeId::DOUBLE};
302+
default:
303+
throw NotImplementedException("Substrait type not yet supported");
302304
}
303305
}
304306

@@ -315,7 +317,7 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformInExpr(const substrait:
315317
vector<unique_ptr<ParsedExpression>> values;
316318
values.emplace_back(TransformExpr(substrait_in.value()));
317319

318-
for (idx_t i = 0; i < (idx_t)substrait_in.options_size(); i++) {
320+
for (int32_t i = 0; i < substrait_in.options_size(); i++) {
319321
values.emplace_back(TransformExpr(substrait_in.options(i)));
320322
}
321323

@@ -416,9 +418,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformCrossProductOp(const substrait:
416418

417419
shared_ptr<Relation> SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop) {
418420
auto &slimit = sop.fetch();
419-
idx_t limit, offset;
420-
limit = slimit.count() == -1 ? NumericLimits<idx_t>::Maximum() : slimit.count();
421-
offset = slimit.offset();
421+
idx_t limit = slimit.count() == -1 ? NumericLimits<idx_t>::Maximum() : slimit.count();
422+
idx_t offset = slimit.offset();
422423
return make_shared_ptr<LimitRelation>(TransformOp(slimit.input()), limit, offset);
423424
}
424425

@@ -607,16 +608,60 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) {
607608
}
608609
}
609610

611+
void SkipColumnNamesRecurse(int32_t &columns_to_skip, const LogicalType &type) {
612+
if (type.id() == LogicalTypeId::STRUCT) {
613+
idx_t struct_size = StructType::GetChildCount(type);
614+
columns_to_skip += static_cast<int32_t>(struct_size);
615+
for (auto &struct_type : StructType::GetChildTypes(type)) {
616+
SkipColumnNamesRecurse(columns_to_skip, struct_type.second);
617+
}
618+
}
619+
}
620+
621+
int32_t SkipColumnNames(const LogicalType &type) {
622+
int32_t columns_to_skip = 0;
623+
SkipColumnNamesRecurse(columns_to_skip, type);
624+
return columns_to_skip;
625+
}
626+
627+
Relation *GetProjectionRelation(Relation &relation, string &error) {
628+
error += RelationTypeToString(relation.type);
629+
switch (relation.type) {
630+
case RelationType::PROJECTION_RELATION:
631+
error += " -> ";
632+
return &relation;
633+
case RelationType::LIMIT_RELATION:
634+
error += " -> ";
635+
return GetProjectionRelation(*relation.Cast<LimitRelation>().child, error);
636+
case RelationType::ORDER_RELATION:
637+
error += " -> ";
638+
return GetProjectionRelation(*relation.Cast<OrderRelation>().child, error);
639+
case RelationType::SET_OPERATION_RELATION:
640+
error += " -> ";
641+
return GetProjectionRelation(*relation.Cast<SetOpRelation>().right, error);
642+
default:
643+
throw NotImplementedException(
644+
"Relation %s is not yet implemented as a possible root chain type of from_substrait function", error);
645+
}
646+
}
647+
610648
shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot &sop) {
611649
vector<string> aliases;
612650
auto column_names = sop.names();
613651
vector<unique_ptr<ParsedExpression>> expressions;
614652
int id = 1;
615-
for (auto &column_name : column_names) {
616-
aliases.push_back(column_name);
653+
auto child = TransformOp(sop.input());
654+
string error;
655+
auto first_projection = GetProjectionRelation(*child, error);
656+
auto &columns = first_projection->Cast<ProjectionRelation>().columns;
657+
int32_t i = 0;
658+
for (auto &column : columns) {
659+
aliases.push_back(column_names[i++]);
660+
auto column_type = column.GetType();
661+
i += SkipColumnNames(column.GetType());
617662
expressions.push_back(make_uniq<PositionalReferenceExpression>(id++));
618663
}
619-
return make_shared_ptr<ProjectionRelation>(TransformOp(sop.input()), std::move(expressions), aliases);
664+
return make_shared_ptr<ProjectionRelation>(child, std::move(expressions), aliases);
620665
}
621666

622667
shared_ptr<Relation> SubstraitToDuckDB::TransformPlan() {

src/include/custom_extensions/custom_extensions.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
namespace duckdb {
1616

1717
struct SubstraitCustomFunction {
18-
public:
1918
SubstraitCustomFunction(string name_p, vector<string> arg_types_p)
2019
: name(std::move(name_p)), arg_types(std::move(arg_types_p)) {};
2120

@@ -34,8 +33,8 @@ class SubstraitFunctionExtensions {
3433
: function(std::move(function_p)), extension_path(std::move(extension_path_p)) {};
3534
SubstraitFunctionExtensions() = default;
3635

37-
string GetExtensionURI();
38-
bool IsNative();
36+
string GetExtensionURI() const;
37+
bool IsNative() const;
3938

4039
SubstraitCustomFunction function;
4140
string extension_path;
@@ -66,8 +65,8 @@ struct HashSubstraitFunctionsName {
6665
class SubstraitCustomFunctions {
6766
public:
6867
SubstraitCustomFunctions();
69-
SubstraitFunctionExtensions Get(const string &name, const vector<::substrait::Type> &types) const;
70-
vector<string> GetTypes(const vector<::substrait::Type> &types) const;
68+
SubstraitFunctionExtensions Get(const string &name, const vector<substrait::Type> &types) const;
69+
static vector<string> GetTypes(const vector<substrait::Type> &types);
7170
void Initialize();
7271

7372
private:

src/include/from_substrait.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,17 @@ class SubstraitToDuckDB {
3131

3232
//! Transform Substrait Expressions to DuckDB Expressions
3333
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr);
34-
unique_ptr<ParsedExpression> TransformLiteralExpr(const substrait::Expression &sexpr);
35-
unique_ptr<ParsedExpression> TransformSelectionExpr(const substrait::Expression &sexpr);
34+
static unique_ptr<ParsedExpression> TransformLiteralExpr(const substrait::Expression &sexpr);
35+
static unique_ptr<ParsedExpression> TransformSelectionExpr(const substrait::Expression &sexpr);
3636
unique_ptr<ParsedExpression> TransformScalarFunctionExpr(const substrait::Expression &sexpr);
3737
unique_ptr<ParsedExpression> TransformIfThenExpr(const substrait::Expression &sexpr);
3838
unique_ptr<ParsedExpression> TransformCastExpr(const substrait::Expression &sexpr);
3939
unique_ptr<ParsedExpression> TransformInExpr(const substrait::Expression &sexpr);
4040

41-
void VerifyCorrectExtractSubfield(const string &subfield);
42-
std::string RemapFunctionName(std::string &function_name);
43-
std::string RemoveExtension(std::string &function_name);
44-
LogicalType SubstraitToDuckType(const ::substrait::Type &s_type);
41+
static void VerifyCorrectExtractSubfield(const string &subfield);
42+
static string RemapFunctionName(const string &function_name);
43+
static string RemoveExtension(const string &function_name);
44+
static LogicalType SubstraitToDuckType(const substrait::Type &s_type);
4545
//! Looks up for aggregation function in functions_map
4646
string FindFunction(uint64_t id);
4747

src/include/to_substrait.hpp

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,19 @@ class DuckDBToSubstrait {
2525
plan.Clear();
2626
}
2727
//! Serializes the substrait plan to a string
28-
string SerializeToString();
29-
string SerializeToJson();
28+
string SerializeToString() const;
29+
string SerializeToJson() const;
3030

3131
private:
3232
//! Transform DuckDB Plan to Substrait Plan
3333
void TransformPlan(LogicalOperator &dop);
3434
//! Registers a function
3535
uint64_t RegisterFunction(const std::string &name, vector<::substrait::Type> &args_types);
3636
//! Creates a reference to a table column
37-
void CreateFieldRef(substrait::Expression *expr, uint64_t col_idx);
37+
static void CreateFieldRef(substrait::Expression *expr, uint64_t col_idx);
38+
//! In case of struct types we might we do DFS to get all names
39+
static vector<string> DepthFirstNames(const LogicalType &type);
40+
static void DepthFirstNamesRecurse(vector<string> &names, const LogicalType &type);
3841

3942
//! Transforms Relation Root
4043
substrait::RelRoot *TransformRootOp(LogicalOperator &dop);
@@ -54,36 +57,36 @@ class DuckDBToSubstrait {
5457
substrait::Rel *TransformDistinct(LogicalOperator &dop);
5558
substrait::Rel *TransformExcept(LogicalOperator &dop);
5659
substrait::Rel *TransformIntersect(LogicalOperator &dop);
57-
substrait::Rel *TransformDummyScan();
60+
static substrait::Rel *TransformDummyScan();
5861
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
5962
//! To Substrait;
60-
void TransformTableScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget);
63+
void TransformTableScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget) const;
6164
void TransformParquetScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget, BindInfo &bind_info,
62-
FunctionData &bind_data);
65+
const FunctionData &bind_data) const;
6366

6467
//! Methods to transform DuckDBConstants to Substrait Expressions
65-
void TransformConstant(Value &dval, substrait::Expression &sexpr);
66-
void TransformInteger(Value &dval, substrait::Expression &sexpr);
67-
void TransformDouble(Value &dval, substrait::Expression &sexpr);
68-
void TransformBigInt(Value &dval, substrait::Expression &sexpr);
69-
void TransformDate(Value &dval, substrait::Expression &sexpr);
70-
void TransformVarchar(Value &dval, substrait::Expression &sexpr);
71-
void TransformBoolean(Value &dval, substrait::Expression &sexpr);
72-
void TransformDecimal(Value &dval, substrait::Expression &sexpr);
73-
void TransformHugeInt(Value &dval, substrait::Expression &sexpr);
74-
void TransformSmallInt(Value &dval, substrait::Expression &sexpr);
75-
void TransformFloat(Value &dval, substrait::Expression &sexpr);
76-
void TransformTime(Value &dval, substrait::Expression &sexpr);
77-
void TransformInterval(Value &dval, substrait::Expression &sexpr);
78-
void TransformTimestamp(Value &dval, substrait::Expression &sexpr);
79-
void TransformEnum(Value &dval, substrait::Expression &sexpr);
68+
static void TransformConstant(const Value &dval, substrait::Expression &sexpr);
69+
static void TransformInteger(const Value &dval, substrait::Expression &sexpr);
70+
static void TransformDouble(const Value &dval, substrait::Expression &sexpr);
71+
static void TransformBigInt(const Value &dval, substrait::Expression &sexpr);
72+
static void TransformDate(const Value &dval, substrait::Expression &sexpr);
73+
static void TransformVarchar(const Value &dval, substrait::Expression &sexpr);
74+
static void TransformBoolean(const Value &dval, substrait::Expression &sexpr);
75+
static void TransformDecimal(const Value &dval, substrait::Expression &sexpr);
76+
static void TransformHugeInt(const Value &dval, substrait::Expression &sexpr);
77+
static void TransformSmallInt(const Value &dval, substrait::Expression &sexpr);
78+
static void TransformFloat(const Value &dval, substrait::Expression &sexpr);
79+
static void TransformTime(const Value &dval, substrait::Expression &sexpr);
80+
static void TransformInterval(const Value &dval, substrait::Expression &sexpr);
81+
static void TransformTimestamp(const Value &dval, substrait::Expression &sexpr);
82+
static void TransformEnum(const Value &dval, substrait::Expression &sexpr);
8083

8184
//! Methods to transform a DuckDB Expression to a Substrait Expression
8285
void TransformExpr(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset = 0);
83-
void TransformBoundRefExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
86+
static void TransformBoundRefExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
8487
void TransformCastExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
8588
void TransformFunctionExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
86-
void TransformConstantExpression(Expression &dexpr, substrait::Expression &sexpr);
89+
static void TransformConstantExpression(Expression &dexpr, substrait::Expression &sexpr);
8790
void TransformComparisonExpression(Expression &dexpr, substrait::Expression &sexpr);
8891
void TransformConjunctionExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
8992
void TransformNotNullExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
@@ -93,31 +96,32 @@ class DuckDBToSubstrait {
9396
void TransformInExpression(Expression &dexpr, substrait::Expression &sexpr);
9497

9598
//! Transforms a DuckDB Logical Type into a Substrait Type
96-
::substrait::Type DuckToSubstraitType(const LogicalType &type, BaseStatistics *column_statistics = nullptr,
97-
bool not_null = false);
99+
static substrait::Type DuckToSubstraitType(const LogicalType &type, BaseStatistics *column_statistics = nullptr,
100+
bool not_null = false);
98101

99102
//! Methods to transform DuckDB Filters to Substrait Expression
100103
substrait::Expression *TransformFilter(uint64_t col_idx, LogicalType &column_type, TableFilter &dfilter,
101104
LogicalType &return_type);
102-
substrait::Expression *TransformIsNotNullFilter(uint64_t col_idx, LogicalType &column_type, TableFilter &dfilter,
103-
LogicalType &return_type);
105+
substrait::Expression *TransformIsNotNullFilter(uint64_t col_idx, const LogicalType &column_type,
106+
TableFilter &dfilter, const LogicalType &return_type);
104107
substrait::Expression *TransformConjuctionAndFilter(uint64_t col_idx, LogicalType &column_type,
105108
TableFilter &dfilter, LogicalType &return_type);
106-
substrait::Expression *TransformConstantComparisonFilter(uint64_t col_idx, LogicalType &column_type,
107-
TableFilter &dfilter, LogicalType &return_type);
109+
substrait::Expression *TransformConstantComparisonFilter(uint64_t col_idx, const LogicalType &column_type,
110+
TableFilter &dfilter, const LogicalType &return_type);
108111

109112
//! Transforms DuckDB Join Conditions to Substrait Expression
110-
substrait::Expression *TransformJoinCond(JoinCondition &dcond, uint64_t left_ncol);
113+
substrait::Expression *TransformJoinCond(const JoinCondition &dcond, uint64_t left_ncol);
111114
//! Transforms DuckDB Sort Order to Substrait Sort Order
112-
void TransformOrder(BoundOrderByNode &dordf, substrait::SortField &sordf);
115+
void TransformOrder(const BoundOrderByNode &dordf, substrait::SortField &sordf);
113116

114-
void AllocateFunctionArgument(substrait::Expression_ScalarFunction *scalar_fun, substrait::Expression *value);
117+
static void AllocateFunctionArgument(substrait::Expression_ScalarFunction *scalar_fun,
118+
substrait::Expression *value);
115119
static std::string &RemapFunctionName(std::string &function_name);
116-
bool IsExtractFunction(const string &function_name) const;
120+
static bool IsExtractFunction(const string &function_name);
117121

118122
//! Creates a Conjunction
119123
template <typename T, typename FUNC>
120-
substrait::Expression *CreateConjunction(T &source, FUNC f) {
124+
substrait::Expression *CreateConjunction(T &source, const FUNC f) {
121125
substrait::Expression *res = nullptr;
122126
for (auto &ele : source) {
123127
auto child_expression = f(ele);

0 commit comments

Comments
 (0)