Skip to content

Commit 846c491

Browse files
committed
Override bind function of relations to bypass context lock
1 parent 922c7f2 commit 846c491

File tree

5 files changed

+167
-43
lines changed

5 files changed

+167
-43
lines changed

src/from_substrait.cpp

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,6 @@
22

33
#include "duckdb/common/types/value.hpp"
44
#include "duckdb/parser/expression/list.hpp"
5-
#include "duckdb/main/relation/join_relation.hpp"
6-
#include "duckdb/main/relation/cross_product_relation.hpp"
7-
8-
#include "duckdb/main/relation/limit_relation.hpp"
9-
#include "duckdb/main/relation/projection_relation.hpp"
10-
#include "duckdb/main/relation/setop_relation.hpp"
11-
#include "duckdb/main/relation/aggregate_relation.hpp"
12-
#include "duckdb/main/relation/filter_relation.hpp"
13-
#include "duckdb/main/relation/order_relation.hpp"
145
#include "duckdb/main/connection.hpp"
156
#include "duckdb/parser/parser.hpp"
167
#include "duckdb/common/exception.hpp"
@@ -25,12 +16,7 @@
2516
#include "google/protobuf/util/json_util.h"
2617
#include "substrait/plan.pb.h"
2718

28-
#include "duckdb/main/relation/table_relation.hpp"
29-
30-
#include "duckdb/main/relation/table_function_relation.hpp"
31-
#include "duckdb/main/relation/view_relation.hpp"
32-
#include "duckdb/main/relation/value_relation.hpp"
33-
#include "duckdb/main/relation.hpp"
19+
#include "substrait_relations.hpp"
3420
#include "duckdb/common/helper.hpp"
3521
#include "duckdb/main/table_description.hpp"
3622
#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp"
@@ -75,11 +61,8 @@ string SubstraitToDuckDB::RemoveExtension(const string &function_name) {
7561
return name;
7662
}
7763

78-
void do_nothing(ClientContext*) {}
7964

80-
SubstraitToDuckDB::SubstraitToDuckDB(ClientContext &context_p, const string &serialized, bool json) {
81-
shared_ptr<ClientContext> c_ptr(&context_p, do_nothing);
82-
context = std::move(c_ptr);
65+
SubstraitToDuckDB::SubstraitToDuckDB(shared_ptr<ClientContext> &context_p, const string &serialized, bool json):context(context_p) {
8366
if (!json) {
8467
if (!plan.ParseFromString(serialized)) {
8568
throw std::runtime_error("Was not possible to convert binary into Substrait plan");
@@ -454,28 +437,28 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformJoinOp(const substrait::Rel &so
454437
throw InternalException("Unsupported join type");
455438
}
456439
unique_ptr<ParsedExpression> join_condition = TransformExpr(sjoin.expression());
457-
return make_shared_ptr<JoinRelation>(TransformOp(sjoin.left())->Alias("left"),
440+
return make_shared_ptr<SubstraitJoinRelation>(TransformOp(sjoin.left())->Alias("left"),
458441
TransformOp(sjoin.right())->Alias("right"), std::move(join_condition),
459442
djointype);
460443
}
461444

462445
shared_ptr<Relation> SubstraitToDuckDB::TransformCrossProductOp(const substrait::Rel &sop) {
463446
auto &sub_cross = sop.cross();
464447

465-
return make_shared_ptr<CrossProductRelation>(TransformOp(sub_cross.left())->Alias("left"),
448+
return make_shared_ptr<SubstraitCrossProductRelation>(TransformOp(sub_cross.left())->Alias("left"),
466449
TransformOp(sub_cross.right())->Alias("right"));
467450
}
468451

469452
shared_ptr<Relation> SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop) {
470453
auto &slimit = sop.fetch();
471454
idx_t limit = slimit.count() == -1 ? NumericLimits<idx_t>::Maximum() : slimit.count();
472455
idx_t offset = slimit.offset();
473-
return make_shared_ptr<LimitRelation>(TransformOp(slimit.input()), limit, offset);
456+
return make_shared_ptr<SubstraitLimitRelation>(TransformOp(slimit.input()), limit, offset);
474457
}
475458

476459
shared_ptr<Relation> SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &sop) {
477460
auto &sfilter = sop.filter();
478-
return make_shared_ptr<FilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
461+
return make_shared_ptr<SubstraitFilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
479462
}
480463

481464
shared_ptr<Relation> SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop) {
@@ -488,7 +471,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformProjectOp(const substrait::Rel
488471
for (size_t i = 0; i < expressions.size(); i++) {
489472
mock_aliases.push_back("expr_" + to_string(i));
490473
}
491-
return make_shared_ptr<ProjectionRelation>(TransformOp(sop.project().input()), std::move(expressions),
474+
return make_shared_ptr<SubstraitProjectionRelation>(TransformOp(sop.project().input()), std::move(expressions),
492475
std::move(mock_aliases));
493476
}
494477

@@ -520,7 +503,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformAggregateOp(const substrait::Re
520503
nullptr, nullptr, is_distinct));
521504
}
522505

523-
return make_shared_ptr<AggregateRelation>(TransformOp(sop.aggregate().input()), std::move(expressions),
506+
return make_shared_ptr<SubstraitAggregateRelation>(TransformOp(sop.aggregate().input()), std::move(expressions),
524507
std::move(groups));
525508
}
526509
unique_ptr<TableDescription> TableInfo(ClientContext& context, const string &schema_name, const string &table_name) {
@@ -552,9 +535,9 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
552535
if (!table_info) {
553536
throw CatalogException("Table '%s' does not exist!", table_name);
554537
}
555-
return make_shared_ptr<TableRelation>(context, std::move(table_info));
538+
return make_shared_ptr<SubstraitTableRelation>(context, std::move(table_info));
556539
} catch (...) {
557-
scan = make_shared_ptr<ViewRelation>(context, DEFAULT_SCHEMA, table_name);
540+
scan = make_shared_ptr<SubstraitViewRelation>(context, DEFAULT_SCHEMA, table_name);
558541
}
559542
} else if (sget.has_local_files()) {
560543
vector<Value> parquet_files;
@@ -575,7 +558,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
575558
}
576559
string name = "parquet_" + StringUtil::GenerateRandomName();
577560
named_parameter_map_t named_parameters({{"binary_as_string", Value::BOOLEAN(false)}});
578-
// auto scan_rel = make_shared_ptr<TableFunctionRelation>(context, "parquet_scan", {Value::LIST(parquet_files)}, named_parameters);
561+
// auto scan_rel = make_shared_ptr<SubstraitTableFunctionRelation>(context, "parquet_scan", {Value::LIST(parquet_files)}, named_parameters);
579562
// auto rel = static_cast<Relation*>(scan_rel.get());
580563
// scan = rel->Alias(name);
581564
} else if (sget.has_virtual_table()) {
@@ -591,13 +574,13 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
591574
expression_rows.emplace_back(expression_row);
592575
}
593576
vector<string> column_names;
594-
scan = make_shared_ptr<ValueRelation>(context, expression_rows, column_names, "values");
577+
scan = make_shared_ptr<SubstraitValueRelation>(context, expression_rows, column_names, "values");
595578
} else {
596579
throw NotImplementedException("Unsupported type of read operator for substrait");
597580
}
598581

599582
if (sget.has_filter()) {
600-
scan = make_shared_ptr<FilterRelation>(std::move(scan), TransformExpr(sget.filter()));
583+
scan = make_shared_ptr<SubstraitFilterRelation>(std::move(scan), TransformExpr(sget.filter()));
601584
}
602585

603586
if (sget.has_projection()) {
@@ -610,7 +593,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
610593
// TODO make sure nothing else is in there
611594
expressions.push_back(make_uniq<PositionalReferenceExpression>(sproj.field() + 1));
612595
}
613-
scan = make_shared_ptr<ProjectionRelation>(std::move(scan), std::move(expressions), std::move(aliases));
596+
scan = make_shared_ptr<SubstraitProjectionRelation>(std::move(scan), std::move(expressions), std::move(aliases));
614597
}
615598

616599
return scan;
@@ -621,7 +604,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &so
621604
for (auto &sordf : sop.sort().sorts()) {
622605
order_nodes.push_back(TransformOrder(sordf));
623606
}
624-
return make_shared_ptr<OrderRelation>(TransformOp(sop.sort().input()), std::move(order_nodes));
607+
return make_shared_ptr<SubstraitOrderRelation>(TransformOp(sop.sort().input()), std::move(order_nodes));
625608
}
626609

627610
static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) {
@@ -655,7 +638,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop
655638
auto lhs = TransformOp(inputs[0]);
656639
auto rhs = TransformOp(inputs[1]);
657640

658-
return make_shared_ptr<SetOpRelation>(std::move(lhs), std::move(rhs), type);
641+
return make_shared_ptr<SubstraitSetOpRelation>(std::move(lhs), std::move(rhs), type);
659642
}
660643

661644
shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) {
@@ -704,11 +687,11 @@ Relation *GetProjection(Relation &relation) {
704687
case RelationType::PROJECTION_RELATION:
705688
return &relation;
706689
case RelationType::LIMIT_RELATION:
707-
return GetProjection(*relation.Cast<LimitRelation>().child);
690+
return GetProjection(*relation.Cast<SubstraitLimitRelation>().child);
708691
case RelationType::ORDER_RELATION:
709-
return GetProjection(*relation.Cast<OrderRelation>().child);
692+
return GetProjection(*relation.Cast<SubstraitOrderRelation>().child);
710693
case RelationType::SET_OPERATION_RELATION:
711-
return GetProjection(*relation.Cast<SetOpRelation>().right);
694+
return GetProjection(*relation.Cast<SubstraitSetOpRelation>().right);
712695
default:
713696
return nullptr;
714697
}
@@ -722,7 +705,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
722705
auto child = TransformOp(sop.input());
723706
auto first_projection_or_table = GetProjection(*child);
724707
if (first_projection_or_table) {
725-
vector<ColumnDefinition> *column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
708+
vector<ColumnDefinition> *column_definitions = &first_projection_or_table->Cast<SubstraitProjectionRelation>().columns;
726709
int32_t i = 0;
727710
for (auto &column : *column_definitions) {
728711
aliases.push_back(column_names[i++]);
@@ -737,7 +720,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
737720
}
738721
}
739722

740-
return make_shared_ptr<ProjectionRelation>(child, std::move(expressions), aliases);
723+
return make_shared_ptr<SubstraitProjectionRelation>(child, std::move(expressions), aliases);
741724
}
742725

743726
shared_ptr<Relation> SubstraitToDuckDB::TransformPlan() {

src/include/from_substrait.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
//===----------------------------------------------------------------------===//
2+
// DuckDB
3+
//
4+
// from_substrait.hpp
5+
//
6+
//
7+
//===----------------------------------------------------------------------===//
8+
19
#pragma once
210

311
#include <string>
@@ -10,7 +18,7 @@ namespace duckdb {
1018

1119
class SubstraitToDuckDB {
1220
public:
13-
SubstraitToDuckDB(ClientContext &context_p, const string &serialized, bool json = false);
21+
SubstraitToDuckDB(shared_ptr<ClientContext> &context_p, const string &serialized, bool json = false);
1422
//! Transforms Substrait Plan to DuckDB Relation
1523
shared_ptr<Relation> TransformPlan();
1624

src/include/substrait_relations.hpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
//===----------------------------------------------------------------------===//
2+
// DuckDB
3+
//
4+
// substrait_relations
5+
//
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "duckdb/main/relation/table_function_relation.hpp"
10+
#include "duckdb/main/relation/table_relation.hpp"
11+
#include "duckdb/main/relation/value_relation.hpp"
12+
#include "duckdb/main/relation/view_relation.hpp"
13+
#include "duckdb/main/relation/limit_relation.hpp"
14+
#include "duckdb/main/relation/projection_relation.hpp"
15+
#include "duckdb/main/relation/setop_relation.hpp"
16+
#include "duckdb/main/relation/aggregate_relation.hpp"
17+
#include "duckdb/main/relation/filter_relation.hpp"
18+
#include "duckdb/main/relation/order_relation.hpp"
19+
#include "duckdb/main/relation/join_relation.hpp"
20+
#include "duckdb/main/relation/cross_product_relation.hpp"
21+
#include "duckdb/main/relation.hpp"
22+
23+
namespace duckdb {
24+
25+
class SubstraitJoinRelation : public JoinRelation {
26+
using JoinRelation::JoinRelation;
27+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
28+
context.GetContext()->InternalTryBindRelation(*this, columns);
29+
}
30+
};
31+
32+
class SubstraitCrossProductRelation : public CrossProductRelation {
33+
using CrossProductRelation::CrossProductRelation;
34+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
35+
context.GetContext()->InternalTryBindRelation(*this, columns);
36+
}
37+
};
38+
39+
class SubstraitLimitRelation : public LimitRelation {
40+
using LimitRelation::LimitRelation;
41+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
42+
context.GetContext()->InternalTryBindRelation(*this, columns);
43+
}
44+
};
45+
46+
47+
class SubstraitFilterRelation : public FilterRelation {
48+
using FilterRelation::FilterRelation;
49+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
50+
context.GetContext()->InternalTryBindRelation(*this, columns);
51+
}
52+
};
53+
54+
55+
class SubstraitProjectionRelation : public ProjectionRelation {
56+
using ProjectionRelation::ProjectionRelation;
57+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
58+
context.GetContext()->InternalTryBindRelation(*this, columns);
59+
}
60+
};
61+
62+
63+
class SubstraitAggregateRelation : public AggregateRelation {
64+
using AggregateRelation::AggregateRelation;
65+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
66+
context.GetContext()->InternalTryBindRelation(*this, columns);
67+
}
68+
};
69+
70+
71+
class SubstraitTableRelation : public TableRelation {
72+
using TableRelation::TableRelation;
73+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
74+
context.GetContext()->InternalTryBindRelation(*this, columns);
75+
}
76+
};
77+
78+
79+
class SubstraitViewRelation : public ViewRelation {
80+
using ViewRelation::ViewRelation;
81+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
82+
context.GetContext()->InternalTryBindRelation(*this, columns);
83+
}
84+
};
85+
86+
87+
class SubstraitTableFunctionRelation : public TableFunctionRelation {
88+
using TableFunctionRelation::TableFunctionRelation;
89+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
90+
context.GetContext()->InternalTryBindRelation(*this, columns);
91+
}
92+
};
93+
94+
95+
class SubstraitValueRelation : public ValueRelation {
96+
using ValueRelation::ValueRelation;
97+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
98+
context.GetContext()->InternalTryBindRelation(*this, columns);
99+
}
100+
};
101+
102+
103+
class SubstraitOrderRelation : public OrderRelation {
104+
using OrderRelation::OrderRelation;
105+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
106+
context.GetContext()->InternalTryBindRelation(*this, columns);
107+
}
108+
};
109+
110+
111+
class SubstraitSetOpRelation : public SetOpRelation {
112+
using SetOpRelation::SetOpRelation;
113+
void TryBindRelation(vector<ColumnDefinition> &columns) override {
114+
context.GetContext()->InternalTryBindRelation(*this, columns);
115+
}
116+
};
117+
118+
}

src/include/to_substrait.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
//===----------------------------------------------------------------------===//
2+
// DuckDB
3+
//
4+
// to_substrait.hpp
5+
//
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
110
#pragma once
211

312
#include "custom_extensions/custom_extensions.hpp"

src/substrait_extension.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222

2323
namespace duckdb {
2424

25+
void do_nothing(ClientContext*) {}
26+
27+
2528
struct ToSubstraitFunctionData : public TableFunctionData {
2629
ToSubstraitFunctionData() = default;
2730
string query;
@@ -144,7 +147,7 @@ static unique_ptr<FunctionData> ToJsonBind(ClientContext &context, TableFunction
144147
return InitToSubstraitFunctionData(context.config, input);
145148
}
146149

147-
shared_ptr<Relation> SubstraitPlanToDuckDBRel(ClientContext &context, const string &serialized, bool json = false) {
150+
shared_ptr<Relation> SubstraitPlanToDuckDBRel(shared_ptr<ClientContext> &context, const string &serialized, bool json = false) {
148151
SubstraitToDuckDB transformer_s2d(context, serialized, json);
149152
return transformer_s2d.TransformPlan();
150153
}
@@ -154,8 +157,8 @@ static void VerifySubstraitRoundtrip(unique_ptr<LogicalOperator> &query_plan, Cl
154157
// We round-trip the generated json and verify if the result is the same
155158
auto con = Connection(*context.db);
156159
auto actual_result = con.Query(data.query);
157-
158-
auto sub_relation = SubstraitPlanToDuckDBRel(context, serialized, is_json);
160+
shared_ptr<ClientContext> c_ptr(&context, do_nothing);
161+
auto sub_relation = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json);
159162
auto substrait_result = sub_relation->Execute();
160163
substrait_result->names = actual_result->names;
161164
unique_ptr<MaterializedQueryResult> substrait_materialized;
@@ -255,6 +258,7 @@ static void ToJsonFunction(ClientContext &context, TableFunctionInput &data_p, D
255258

256259
struct FromSubstraitFunctionData : public TableFunctionData {
257260
FromSubstraitFunctionData() = default;
261+
shared_ptr<ClientContext> context;
258262
shared_ptr<Relation> plan;
259263
unique_ptr<QueryResult> res;
260264
};
@@ -266,7 +270,9 @@ static unique_ptr<FunctionData> SubstraitBind(ClientContext &context, TableFunct
266270
throw BinderException("from_substrait cannot be called with a NULL parameter");
267271
}
268272
string serialized = input.inputs[0].GetValueUnsafe<string>();
269-
result->plan = SubstraitPlanToDuckDBRel(context, serialized, is_json);
273+
shared_ptr<ClientContext> c_ptr(&context, do_nothing);
274+
result->context = move(c_ptr);
275+
result->plan = SubstraitPlanToDuckDBRel(result->context, serialized, is_json);
270276
for (auto &column : result->plan->Columns()) {
271277
return_types.emplace_back(column.Type());
272278
names.emplace_back(column.Name());

0 commit comments

Comments
 (0)