Skip to content

Commit 5a2282b

Browse files
Refactored join conditions in CBO (#10366)
1 parent 2baa7c5 commit 5a2282b

19 files changed

+260
-243
lines changed

ydb/core/kqp/opt/logical/kqp_opt_cbo.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ TMaybeNode<TKqlKeyInc> GetRightTableKeyPrefix(const TKqlKeyRange& range) {
3636
/**
3737
* KQP specific rule to check if a LookupJoin is applicable
3838
*/
39-
bool IsLookupJoinApplicableDetailed(const std::shared_ptr<NYql::TRelOptimizerNode>& node, const TVector<TString>& joinColumns, const TKqpProviderContext& ctx) {
39+
bool IsLookupJoinApplicableDetailed(const std::shared_ptr<NYql::TRelOptimizerNode>& node, const TVector<TJoinColumn>& joinColumns, const TKqpProviderContext& ctx) {
4040

4141
auto rel = std::static_pointer_cast<TKqpRelOptimizerNode>(node);
4242
auto expr = TExprBase(rel->Node);
@@ -45,7 +45,7 @@ bool IsLookupJoinApplicableDetailed(const std::shared_ptr<NYql::TRelOptimizerNod
4545
return false;
4646
}
4747

48-
if (find_if(joinColumns.begin(), joinColumns.end(), [&] (const TString& s) { return node->Stats->KeyColumns->Data[0] == s;}) != joinColumns.end()) {
48+
if (std::find_if(joinColumns.begin(), joinColumns.end(), [&] (const TJoinColumn& c) { return node->Stats->KeyColumns->Data[0] == c.AttributeName;}) != joinColumns.end()) {
4949
return true;
5050
}
5151

@@ -97,8 +97,8 @@ bool IsLookupJoinApplicableDetailed(const std::shared_ptr<NYql::TRelOptimizerNod
9797
return false;
9898
}
9999

100-
if (prefixSize < node->Stats->KeyColumns->Data.size() && (find_if(joinColumns.begin(), joinColumns.end(), [&] (const TString& s) {
101-
return node->Stats->KeyColumns->Data[prefixSize] == s;
100+
if (prefixSize < node->Stats->KeyColumns->Data.size() && (std::find_if(joinColumns.begin(), joinColumns.end(), [&] (const TJoinColumn& c) {
101+
return node->Stats->KeyColumns->Data[prefixSize] == c.AttributeName;
102102
}) == joinColumns.end())){
103103
return false;
104104
}
@@ -108,12 +108,11 @@ bool IsLookupJoinApplicableDetailed(const std::shared_ptr<NYql::TRelOptimizerNod
108108

109109
bool IsLookupJoinApplicable(std::shared_ptr<IBaseOptimizerNode> left,
110110
std::shared_ptr<IBaseOptimizerNode> right,
111-
const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions,
112-
const TVector<TString>& leftJoinKeys,
113-
const TVector<TString>& rightJoinKeys,
111+
const TVector<TJoinColumn>& leftJoinKeys,
112+
const TVector<TJoinColumn>& rightJoinKeys,
114113
TKqpProviderContext& ctx
115114
) {
116-
Y_UNUSED(left, joinConditions, leftJoinKeys);
115+
Y_UNUSED(left, leftJoinKeys);
117116

118117
if (!(right->Stats->StorageType == EStorageType::RowStorage)) {
119118
return false;
@@ -130,7 +129,7 @@ bool IsLookupJoinApplicable(std::shared_ptr<IBaseOptimizerNode> left,
130129
}
131130

132131
for (auto rightCol : rightJoinKeys) {
133-
if (std::find(rightStats->KeyColumns->Data.begin(), rightStats->KeyColumns->Data.end(), rightCol) == rightStats->KeyColumns->Data.end()) {
132+
if (find(rightStats->KeyColumns->Data.begin(), rightStats->KeyColumns->Data.end(), rightCol.AttributeName) == rightStats->KeyColumns->Data.end()) {
134133
return false;
135134
}
136135
}
@@ -142,18 +141,17 @@ bool IsLookupJoinApplicable(std::shared_ptr<IBaseOptimizerNode> left,
142141

143142
bool TKqpProviderContext::IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
144143
const std::shared_ptr<IBaseOptimizerNode>& right,
145-
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
146-
const TVector<TString>& leftJoinKeys,
147-
const TVector<TString>& rightJoinKeys,
144+
const TVector<TJoinColumn>& leftJoinKeys,
145+
const TVector<TJoinColumn>& rightJoinKeys,
148146
EJoinAlgoType joinAlgo,
149-
EJoinKind joinKind) {
147+
EJoinKind joinKind) {
150148

151149
switch( joinAlgo ) {
152150
case EJoinAlgoType::LookupJoin:
153151
if ((OptLevel != 3) && (left->Stats->Nrows > 1000)) {
154152
return false;
155153
}
156-
return IsLookupJoinApplicable(left, right, joinConditions, leftJoinKeys, rightJoinKeys, *this);
154+
return IsLookupJoinApplicable(left, right, leftJoinKeys, rightJoinKeys, *this);
157155

158156
case EJoinAlgoType::LookupJoinReverse:
159157
if (joinKind != EJoinKind::LeftSemi) {
@@ -162,7 +160,7 @@ bool TKqpProviderContext::IsJoinApplicable(const std::shared_ptr<IBaseOptimizerN
162160
if ((OptLevel != 3) && (right->Stats->Nrows > 1000)) {
163161
return false;
164162
}
165-
return IsLookupJoinApplicable(right, left, joinConditions, rightJoinKeys, leftJoinKeys, *this);
163+
return IsLookupJoinApplicable(right, left, rightJoinKeys, leftJoinKeys, *this);
166164

167165
case EJoinAlgoType::MapJoin:
168166
return joinKind != EJoinKind::OuterJoin && joinKind != EJoinKind::Exclusion && right->Stats->ByteSize < 1e6;

ydb/core/kqp/opt/logical/kqp_opt_cbo.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ struct TKqpProviderContext : public NYql::TBaseProviderContext {
2525

2626
virtual bool IsJoinApplicable(const std::shared_ptr<NYql::IBaseOptimizerNode>& left,
2727
const std::shared_ptr<NYql::IBaseOptimizerNode>& right,
28-
const std::set<std::pair<NYql::NDq::TJoinColumn, NYql::NDq::TJoinColumn>>& joinConditions,
29-
const TVector<TString>& leftJoinKeys, const TVector<TString>& rightJoinKeys,
28+
const TVector<NYql::NDq::TJoinColumn>& leftJoinKeys, const TVector<NYql::NDq::TJoinColumn>& rightJoinKeys,
3029
NYql::EJoinAlgoType joinAlgo, NYql::EJoinKind joinKind) override;
3130

3231
virtual double ComputeJoinCost(const NYql::TOptimizerStatistics& leftStats, const NYql::TOptimizerStatistics& rightStats, const double outputRows, const double outputByteSize, NYql::EJoinAlgoType joinAlgo) const override;

ydb/library/yql/core/cbo/cbo_optimizer_new.cpp

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ void TRelOptimizerNode::Print(std::stringstream& stream, int ntabs) {
7777
TJoinOptimizerNode::TJoinOptimizerNode(
7878
const std::shared_ptr<IBaseOptimizerNode>& left,
7979
const std::shared_ptr<IBaseOptimizerNode>& right,
80-
const std::set<std::pair<TJoinColumn, TJoinColumn>>& joinConditions,
80+
TVector<TJoinColumn> leftKeys,
81+
TVector<TJoinColumn> rightKeys,
8182
const EJoinKind joinType,
8283
const EJoinAlgoType joinAlgo,
8384
bool leftAny,
@@ -86,18 +87,14 @@ TJoinOptimizerNode::TJoinOptimizerNode(
8687
) : IBaseOptimizerNode(JoinNodeType)
8788
, LeftArg(left)
8889
, RightArg(right)
89-
, JoinConditions(joinConditions)
90+
, LeftJoinKeys(leftKeys)
91+
, RightJoinKeys(rightKeys)
9092
, JoinType(joinType)
9193
, JoinAlgo(joinAlgo)
9294
, LeftAny(leftAny)
9395
, RightAny(rightAny)
9496
, IsReorderable(!nonReorderable)
95-
{
96-
for (const auto& [l,r] : joinConditions ) {
97-
LeftJoinKeys.push_back(l.AttributeName);
98-
RightJoinKeys.push_back(r.AttributeName);
99-
}
100-
}
97+
{}
10198

10299
TVector<TString> TJoinOptimizerNode::Labels() {
103100
auto res = LeftArg->Labels();
@@ -120,10 +117,10 @@ void TJoinOptimizerNode::Print(std::stringstream& stream, int ntabs) {
120117
}
121118
stream << ") ";
122119

123-
for (auto c : JoinConditions){
124-
stream << c.first.RelName << "." << c.first.AttributeName
125-
<< "=" << c.second.RelName << "."
126-
<< c.second.AttributeName << ",";
120+
for (size_t i=0; i<LeftJoinKeys.size(); i++){
121+
stream << LeftJoinKeys[i].RelName << "." << LeftJoinKeys[i].AttributeName
122+
<< "=" << RightJoinKeys[i].RelName << "."
123+
<< RightJoinKeys[i].AttributeName << ",";
127124
}
128125
stream << "\n";
129126

@@ -138,13 +135,14 @@ void TJoinOptimizerNode::Print(std::stringstream& stream, int ntabs) {
138135
RightArg->Print(stream, ntabs+1);
139136
}
140137

141-
bool IsPKJoin(const TOptimizerStatistics& stats, const TVector<TString>& joinKeys) {
138+
bool IsPKJoin(const TOptimizerStatistics& stats, const TVector<TJoinColumn>& joinKeys) {
142139
if (!stats.KeyColumns) {
143140
return false;
144141
}
145142

146143
for(size_t i = 0; i < stats.KeyColumns->Data.size(); i++){
147-
if (std::find(joinKeys.begin(), joinKeys.end(), stats.KeyColumns->Data[i]) == joinKeys.end()) {
144+
if (std::find_if(joinKeys.begin(), joinKeys.end(),
145+
[&] (const TJoinColumn& c) { return c.AttributeName == stats.KeyColumns->Data[i];}) == joinKeys.end()) {
148146
return false;
149147
}
150148
}
@@ -153,15 +151,13 @@ bool IsPKJoin(const TOptimizerStatistics& stats, const TVector<TString>& joinKey
153151

154152
bool TBaseProviderContext::IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
155153
const std::shared_ptr<IBaseOptimizerNode>& right,
156-
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
157-
const TVector<TString>& leftJoinKeys,
158-
const TVector<TString>& rightJoinKeys,
154+
const TVector<TJoinColumn>& leftJoinKeys,
155+
const TVector<TJoinColumn>& rightJoinKeys,
159156
EJoinAlgoType joinAlgo,
160157
EJoinKind joinKind) {
161158

162159
Y_UNUSED(left);
163160
Y_UNUSED(right);
164-
Y_UNUSED(joinConditions);
165161
Y_UNUSED(leftJoinKeys);
166162
Y_UNUSED(rightJoinKeys);
167163
Y_UNUSED(joinKind);
@@ -182,30 +178,12 @@ double TBaseProviderContext::ComputeJoinCost(const TOptimizerStatistics& leftSta
182178
*
183179
* The build is on the right side, so we make the build side a bit more expensive than the probe
184180
*/
185-
TOptimizerStatistics TBaseProviderContext::ComputeJoinStats(
186-
const TOptimizerStatistics& leftStats,
187-
const TOptimizerStatistics& rightStats,
188-
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
189-
EJoinAlgoType joinAlgo,
190-
EJoinKind joinKind,
191-
TCardinalityHints::TCardinalityHint* maybeHint) const
192-
{
193-
TVector<TString> leftJoinKeys;
194-
TVector<TString> rightJoinKeys;
195-
196-
for (auto c : joinConditions) {
197-
leftJoinKeys.emplace_back(c.first.AttributeName);
198-
rightJoinKeys.emplace_back(c.second.AttributeName);
199-
}
200-
201-
return ComputeJoinStats(leftStats, rightStats, leftJoinKeys, rightJoinKeys, joinAlgo, joinKind, maybeHint);
202-
}
203181

204182
TOptimizerStatistics TBaseProviderContext::ComputeJoinStats(
205183
const TOptimizerStatistics& leftStats,
206184
const TOptimizerStatistics& rightStats,
207-
const TVector<TString>& leftJoinKeys,
208-
const TVector<TString>& rightJoinKeys,
185+
const TVector<TJoinColumn>& leftJoinKeys,
186+
const TVector<TJoinColumn>& rightJoinKeys,
209187
EJoinAlgoType joinAlgo,
210188
EJoinKind joinKind,
211189
TCardinalityHints::TCardinalityHint* maybeHint) const
@@ -265,9 +243,9 @@ TOptimizerStatistics TBaseProviderContext::ComputeJoinStats(
265243
std::optional<double> lhsUniqueVals;
266244
std::optional<double> rhsUniqueVals;
267245
if (leftStats.ColumnStatistics && rightStats.ColumnStatistics && !leftJoinKeys.empty() && !rightJoinKeys.empty()) {
268-
auto lhs = leftJoinKeys[0];
246+
auto lhs = leftJoinKeys[0].AttributeName;
269247
lhsUniqueVals = leftStats.ColumnStatistics->Data[lhs].NumUniqueVals;
270-
auto rhs = rightJoinKeys[0];
248+
auto rhs = rightJoinKeys[0].AttributeName;
271249
rightStats.ColumnStatistics->Data[rhs];
272250
rhsUniqueVals = leftStats.ColumnStatistics->Data[lhs].NumUniqueVals;
273251
}

ydb/library/yql/core/cbo/cbo_optimizer_new.h

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -201,27 +201,18 @@ struct IProviderContext {
201201
virtual TOptimizerStatistics ComputeJoinStats(
202202
const TOptimizerStatistics& leftStats,
203203
const TOptimizerStatistics& rightStats,
204-
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
205-
EJoinAlgoType joinAlgo,
206-
EJoinKind joinKind,
207-
TCardinalityHints::TCardinalityHint* maybeHint = nullptr) const = 0;
208-
209-
virtual TOptimizerStatistics ComputeJoinStats(
210-
const TOptimizerStatistics& leftStats,
211-
const TOptimizerStatistics& rightStats,
212-
const TVector<TString>& leftJoinKeys,
213-
const TVector<TString>& rightJoinKeys,
204+
const TVector<NDq::TJoinColumn>& leftJoinKeys,
205+
const TVector<NDq::TJoinColumn>& rightJoinKeys,
214206
EJoinAlgoType joinAlgo,
215207
EJoinKind joinKind,
216208
TCardinalityHints::TCardinalityHint* maybeHint = nullptr) const = 0;
217209

218210
virtual bool IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
219211
const std::shared_ptr<IBaseOptimizerNode>& right,
220-
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
221-
const TVector<TString>& leftJoinKeys,
222-
const TVector<TString>& rightJoinKeys,
212+
const TVector<NDq::TJoinColumn>& leftJoinKeys,
213+
const TVector<NDq::TJoinColumn>& rightJoinKeys,
223214
EJoinAlgoType joinAlgo,
224-
EJoinKind joinKind) = 0;
215+
EJoinKind joinKin) = 0;
225216
};
226217

227218
/**
@@ -233,27 +224,19 @@ struct TBaseProviderContext : public IProviderContext {
233224

234225
double ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, const double outputRows, const double outputByteSize, EJoinAlgoType joinAlgo) const override;
235226

236-
bool IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
237-
const std::shared_ptr<IBaseOptimizerNode>& right,
238-
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
239-
const TVector<TString>& leftJoinKeys,
240-
const TVector<TString>& rightJoinKeys,
227+
bool IsJoinApplicable(
228+
const std::shared_ptr<IBaseOptimizerNode>& leftStats,
229+
const std::shared_ptr<IBaseOptimizerNode>& rightStats,
230+
const TVector<NDq::TJoinColumn>& leftJoinKeys,
231+
const TVector<NDq::TJoinColumn>& rightJoinKeys,
241232
EJoinAlgoType joinAlgo,
242233
EJoinKind joinKind) override;
243234

244235
virtual TOptimizerStatistics ComputeJoinStats(
245236
const TOptimizerStatistics& leftStats,
246237
const TOptimizerStatistics& rightStats,
247-
const TVector<TString>& leftJoinKeys,
248-
const TVector<TString>& rightJoinKeys,
249-
EJoinAlgoType joinAlgo,
250-
EJoinKind joinKind,
251-
TCardinalityHints::TCardinalityHint* maybeHint = nullptr) const override;
252-
253-
virtual TOptimizerStatistics ComputeJoinStats(
254-
const TOptimizerStatistics& leftStats,
255-
const TOptimizerStatistics& rightStats,
256-
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
238+
const TVector<NDq::TJoinColumn>& leftJoinKeys,
239+
const TVector<NDq::TJoinColumn>& rightJoinKeys,
257240
EJoinAlgoType joinAlgo,
258241
EJoinKind joinKind,
259242
TCardinalityHints::TCardinalityHint* maybeHint = nullptr) const override;
@@ -290,9 +273,8 @@ struct TRelOptimizerNode : public IBaseOptimizerNode {
290273
struct TJoinOptimizerNode : public IBaseOptimizerNode {
291274
std::shared_ptr<IBaseOptimizerNode> LeftArg;
292275
std::shared_ptr<IBaseOptimizerNode> RightArg;
293-
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>> JoinConditions;
294-
TVector<TString> LeftJoinKeys;
295-
TVector<TString> RightJoinKeys;
276+
TVector<NDq::TJoinColumn> LeftJoinKeys;
277+
TVector<NDq::TJoinColumn> RightJoinKeys;
296278
EJoinKind JoinType;
297279
EJoinAlgoType JoinAlgo;
298280
/////////////////// 'ANY' flag means leaving only one row from the join side.
@@ -303,7 +285,8 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode {
303285

304286
TJoinOptimizerNode(const std::shared_ptr<IBaseOptimizerNode>& left,
305287
const std::shared_ptr<IBaseOptimizerNode>& right,
306-
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
288+
TVector<NDq::TJoinColumn> leftKeys,
289+
TVector<NDq::TJoinColumn> rightKeys,
307290
const EJoinKind joinType,
308291
const EJoinAlgoType joinAlgo,
309292
bool leftAny,

ydb/library/yql/core/yql_cost_function.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,14 @@ namespace NDq {
3838
struct TJoinColumn {
3939
TString RelName;
4040
TString AttributeName;
41+
TString AttributeNameWithAliases;
42+
ui32 EquivalenceClass = 0;
43+
bool IsConstant = false;
4144

42-
TJoinColumn(TString relName, TString attributeName) : RelName(relName),
43-
AttributeName(std::move(attributeName)) {}
45+
TJoinColumn(TString relName, TString attributeName) :
46+
RelName(relName),
47+
AttributeName(attributeName),
48+
AttributeNameWithAliases(attributeName) {}
4449

4550
bool operator == (const TJoinColumn& other) const {
4651
return RelName == other.RelName && AttributeName == other.AttributeName;

ydb/library/yql/dq/opt/dq_cbo_ut.cpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,14 @@ Y_UNIT_TEST(JoinSearch2Rels) {
4545
auto rel2 = std::make_shared<TRelOptimizerNode>("b",
4646
std::make_shared<TOptimizerStatistics>(BaseTable, 1000000, 1, 0, 9000009));
4747

48-
std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>> joinConditions;
49-
joinConditions.insert({
50-
NDq::TJoinColumn("a", "1"),
51-
NDq::TJoinColumn("b", "1")
52-
});
48+
TVector<NDq::TJoinColumn> leftKeys = {NDq::TJoinColumn("a", "1")};
49+
TVector<NDq::TJoinColumn> rightKeys ={NDq::TJoinColumn("b", "1")};
50+
5351
auto op = std::make_shared<TJoinOptimizerNode>(
5452
std::static_pointer_cast<IBaseOptimizerNode>(rel1),
5553
std::static_pointer_cast<IBaseOptimizerNode>(rel2),
56-
joinConditions,
54+
leftKeys,
55+
rightKeys,
5756
InnerJoin,
5857
EJoinAlgoType::GraceJoin,
5958
true,
@@ -86,30 +85,28 @@ Y_UNIT_TEST(JoinSearch3Rels) {
8685
auto rel3 = std::make_shared<TRelOptimizerNode>("c",
8786
std::make_shared<TOptimizerStatistics>(BaseTable, 10000, 1, 0, 9009));
8887

89-
std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>> joinConditions;
90-
joinConditions.insert({
91-
NDq::TJoinColumn("a", "1"),
92-
NDq::TJoinColumn("b", "1")
93-
});
88+
TVector<NDq::TJoinColumn> leftKeys = {NDq::TJoinColumn("a", "1")};
89+
TVector<NDq::TJoinColumn> rightKeys ={NDq::TJoinColumn("b", "1")};
90+
9491
auto op1 = std::make_shared<TJoinOptimizerNode>(
9592
std::static_pointer_cast<IBaseOptimizerNode>(rel1),
9693
std::static_pointer_cast<IBaseOptimizerNode>(rel2),
97-
joinConditions,
94+
leftKeys,
95+
rightKeys,
9896
InnerJoin,
9997
EJoinAlgoType::GraceJoin,
10098
false,
10199
false
102100
);
103101

104-
joinConditions.insert({
105-
NDq::TJoinColumn("a", "1"),
106-
NDq::TJoinColumn("c", "1")
107-
});
102+
leftKeys.push_back(NDq::TJoinColumn("a", "1"));
103+
rightKeys.push_back(NDq::TJoinColumn("c", "1"));
108104

109105
auto op2 = std::make_shared<TJoinOptimizerNode>(
110106
std::static_pointer_cast<IBaseOptimizerNode>(op1),
111107
std::static_pointer_cast<IBaseOptimizerNode>(rel3),
112-
joinConditions,
108+
leftKeys,
109+
rightKeys,
113110
InnerJoin,
114111
EJoinAlgoType::GraceJoin,
115112
true,

0 commit comments

Comments
 (0)