Skip to content

Commit 55922a3

Browse files
authored
Merge pull request #104 from pdet/nested_expressions
Nested expressions
2 parents 7f02b87 + 71037c6 commit 55922a3

File tree

4 files changed

+116
-2
lines changed

4 files changed

+116
-2
lines changed

src/from_substrait.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,37 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformInExpr(const substrait:
325325
return make_uniq<OperatorExpression>(ExpressionType::COMPARE_IN, std::move(values));
326326
}
327327

328+
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait::Expression &sexpr) {
329+
auto &nested_expression = sexpr.nested();
330+
if (nested_expression.has_struct_()) {
331+
auto &struct_expression = nested_expression.struct_();
332+
vector<unique_ptr<ParsedExpression>> children;
333+
for (auto& child: struct_expression.fields()) {
334+
children.emplace_back(TransformExpr(child));
335+
}
336+
return make_uniq<FunctionExpression>("row", std::move(children));
337+
} else if (nested_expression.has_list()) {
338+
auto &list_expression = nested_expression.list();
339+
vector<unique_ptr<ParsedExpression>> children;
340+
for (auto& child: list_expression.values()) {
341+
children.emplace_back(TransformExpr(child));
342+
}
343+
return make_uniq<FunctionExpression>("list_value", std::move(children));
344+
345+
} else if (nested_expression.has_map()) {
346+
auto &map_expression = nested_expression.map();
347+
vector<unique_ptr<ParsedExpression>> children;
348+
auto key_value = map_expression.key_values();
349+
children.emplace_back(TransformExpr(key_value[0].key()));
350+
children.emplace_back(TransformExpr(key_value[0].value()));
351+
return make_uniq<FunctionExpression>("map", std::move(children));
352+
353+
} else{
354+
throw NotImplementedException("Substrait nested expression is not yet implemented.");
355+
}
356+
357+
}
358+
328359
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr) {
329360
switch (sexpr.rex_type_case()) {
330361
case substrait::Expression::RexTypeCase::kLiteral:
@@ -339,6 +370,8 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::E
339370
return TransformCastExpr(sexpr);
340371
case substrait::Expression::RexTypeCase::kSingularOrList:
341372
return TransformInExpr(sexpr);
373+
case substrait::Expression::RexTypeCase::kNested:
374+
return TransformNested(sexpr);
342375
case substrait::Expression::RexTypeCase::kSubquery:
343376
default:
344377
throw InternalException("Unsupported expression type " + to_string(sexpr.rex_type_case()));

src/include/from_substrait.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class SubstraitToDuckDB {
3737
unique_ptr<ParsedExpression> TransformIfThenExpr(const substrait::Expression &sexpr);
3838
unique_ptr<ParsedExpression> TransformCastExpr(const substrait::Expression &sexpr);
3939
unique_ptr<ParsedExpression> TransformInExpr(const substrait::Expression &sexpr);
40+
unique_ptr<ParsedExpression> TransformNested(const substrait::Expression &sexpr);
4041

4142
static void VerifyCorrectExtractSubfield(const string &subfield);
4243
static string RemapFunctionName(const string &function_name);
@@ -57,5 +58,6 @@ class SubstraitToDuckDB {
5758
//! names
5859
static const unordered_map<std::string, std::string> function_names_remap;
5960
static const case_insensitive_set_t valid_extract_subfields;
61+
vector<ParsedExpression*> struct_expressions;
6062
};
6163
} // namespace duckdb

src/to_substrait.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,17 +314,48 @@ bool DuckDBToSubstrait::IsExtractFunction(const string &function_name) {
314314
void DuckDBToSubstrait::TransformFunctionExpression(Expression &dexpr, substrait::Expression &sexpr,
315315
uint64_t col_offset) {
316316
auto &dfun = dexpr.Cast<BoundFunctionExpression>();
317-
auto sfun = sexpr.mutable_scalar_function();
317+
318318

319319
auto function_name = dfun.function.name;
320+
321+
if (function_name == "row") {
322+
auto nested_expression = sexpr.mutable_nested();
323+
auto struct_expression = nested_expression->mutable_struct_();
324+
for (auto& child: dfun.children) {
325+
auto child_expression = struct_expression->add_fields();
326+
TransformExpr(*child, *child_expression);
327+
}
328+
return;
329+
}
330+
if (function_name == "list_value" || function_name == "list_pack") {
331+
auto nested_expression = sexpr.mutable_nested();
332+
auto list_expression = nested_expression->mutable_list();
333+
for (auto& child: dfun.children) {
334+
auto child_value = list_expression->add_values();
335+
TransformExpr(*child, *child_value);
336+
}
337+
return;
338+
}
339+
if (function_name == "map") {
340+
auto nested_expression = sexpr.mutable_nested();
341+
auto map_expression = nested_expression->mutable_map();
342+
D_ASSERT(dfun.children.size() == 2);
343+
auto child_value = map_expression->add_key_values();
344+
auto key = child_value->mutable_key();
345+
auto value = child_value->mutable_value();
346+
TransformExpr(*dfun.children[0], *key);
347+
TransformExpr(*dfun.children[1], *value);
348+
return;
349+
}
350+
auto sfun = sexpr.mutable_scalar_function();
320351
if (IsExtractFunction(function_name)) {
321352
// Change the name to 'extract', and add an Enum argument containing the subfield
322353
auto subfield = function_name;
323354
function_name = "extract";
324355
auto enum_arg = sfun->add_arguments();
325356
*enum_arg->mutable_enum_() = subfield;
326357
}
327-
vector<::substrait::Type> args_types;
358+
vector<substrait::Type> args_types;
328359
for (auto &darg : dfun.children) {
329360
auto sarg = sfun->add_arguments();
330361
TransformExpr(*darg, *sarg->mutable_value(), col_offset);

test/sql/test_nested_expressions.test

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# name: test/sql/test_nested_expressions.test
2+
# description: Test nested expressions
3+
# group: [sql]
4+
5+
require substrait
6+
7+
statement ok
8+
PRAGMA enable_verification
9+
10+
statement ok
11+
create table t as select range as a from range(10)
12+
13+
# Test struct creation
14+
statement ok
15+
CALL get_substrait('SELECT row(a,a,10) from t;')
16+
17+
# Test list creation
18+
statement ok
19+
CALL get_substrait('SELECT [a,a,10] from t;')
20+
21+
# Test map creation
22+
statement ok
23+
CALL get_substrait('SELECT MAP {a: a} from t;')
24+
25+
# Test nested-> nested
26+
statement ok
27+
CALL get_substrait('SELECT MAP {[a,a,10]: [a,a,10]} from t;')
28+
29+
statement ok
30+
CALL get_substrait('SELECT MAP {[row(a,a,10),row(a,a,10),row(a,a,10)]: [row(a,a,10),row(a,a,10),row(a,a,10)]} from t;')
31+
32+
statement ok
33+
CALL get_substrait('SELECT row([a,a,10],[a,a,10],[a,a,10]) from t;')
34+
35+
statement ok
36+
CALL get_substrait('SELECT row([MAP {a: a},MAP {a: a},MAP {a: a}],[MAP {a: a},MAP {a: a},MAP {a: a}],[MAP {a: a},MAP {a: a},MAP {a: a}]) from t;')
37+
38+
statement ok
39+
CALL get_substrait('SELECT MAP {[a,a,10]: [a,a,10]} from t;')
40+
41+
statement ok
42+
CALL get_substrait('SELECT MAP {[row(a,a,10),row(a,a,10),row(a,a,10)]: [row(a,a,10),row(a,a,10),row(a,a,10)]} from t;')
43+
44+
statement ok
45+
CALL get_substrait('SELECT row(row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10))) from t;')
46+
47+
statement ok
48+
CALL get_substrait('SELECT [[[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]]] from t;')

0 commit comments

Comments
 (0)