Skip to content

Left join with predicate rewrite fix #16187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 107 additions & 72 deletions yql/essentials/core/common_opt/yql_flatmap_over_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,25 @@ void CountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& cou
}
}

void CollectJoinLabels(TExprNode::TPtr joinTree, THashSet<TString> &labels) {
if (joinTree->IsAtom()) {
labels.emplace(joinTree->Content());
} else {
CollectJoinLabels(joinTree->ChildPtr(1), labels);
CollectJoinLabels(joinTree->ChildPtr(2), labels);
}
}


void DecrementCountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& counters) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Директорию YQL через Аркадию надо править

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ага, пока тестируем. У них тестов нет на эту фичу в Аркадии

if (joinTree->IsAtom()) {
counters[joinTree->Content()]--;
} else {
DecrementCountLabelsInputUsage(joinTree->ChildPtr(1), counters);
DecrementCountLabelsInputUsage(joinTree->ChildPtr(2), counters);
}
}

// returns the path to join child
std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin(
const TExprNode::TPtr& joinTree, const TJoinLabels& labels, ui32 inputIndex, const TExprNode::TPtr& parent = nullptr
Expand Down Expand Up @@ -350,6 +369,39 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin(
return {nullptr, nullptr};
}

TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> MapLabelNamesToJoinLabels(TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& joinLabels,
THashSet<TString>& labelNames) {
const ui32 joinLabelSize = joinLabels.size();
TVector<bool> taken(joinLabelSize, false);
TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> result;
for (const auto& labelName : labelNames) {
for (ui32 i = 0; i < joinLabelSize; ++i) {
auto& labelNamesSet = joinLabels[i].first;
if (!taken[i] && labelNamesSet.count(labelName)) {
taken[i] = true;
result.push_back(joinLabels[i]);
}
}
}
return result;
}

THashSet<TString> CombineLabels(const TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& labels) {
THashSet<TString> combinedResult;
for (const auto &[labelNames, _] : labels) {
combinedResult.insert(labelNames.begin(), labelNames.end());
}
return combinedResult;
}

TExprNode::TPtr CreateLabelList(const THashSet<TString>& labels, TExprContext& ctx, const TPositionHandle& position) {
TExprNode::TListType newKeys;
for (const auto& label : labels) {
newKeys.push_back(ctx.NewAtom(position, label));
}
return ctx.NewList(position, std::move(newKeys));
}

TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate,
const TSet<TStringBuf>& usedFields, TExprNode::TPtr args, const TJoinLabels& labels,
ui32 inputIndex, const TMap<TStringBuf, TVector<TStringBuf>>& renameMap, bool ordered, bool skipNulls, TExprContext& ctx,
Expand Down Expand Up @@ -391,26 +443,41 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
}

THashMap<TString, TExprNode::TPtr> equiJoinLabels;
TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> joinLabels;
for (size_t i = 0; i < equiJoin->ChildrenSize() - 2; i++) {
auto label = equiJoin->Child(i);
equiJoinLabels.emplace(label->Child(1)->Content(), label->ChildPtr(0));
THashSet<TString> labelsName;
if (auto value = TMaybeNode<TCoAtom>(label->Child(1))) {
labelsName.emplace(value.Cast().Value());
equiJoinLabels.emplace(value.Cast().Value(), label->ChildPtr(0));
} else if (auto tuple = TMaybeNode<TCoAtomList>(label->Child(1))) {
for (const auto& value : tuple.Cast()) {
labelsName.emplace(value.Value());
equiJoinLabels.emplace(value.Value(), label->ChildPtr(0));
}
}
joinLabels.push_back({labelsName, label->ChildPtr(0)});
}

THashMap<TString, int> joinLabelCounters;
CountLabelsInputUsage(joinTree, joinLabelCounters);

auto [leftJoinTree, parentJoinPtr] = IsRightSideForLeftJoin(joinTree, labels, inputIndex);
YQL_ENSURE(leftJoinTree);
joinLabelCounters[leftJoinTree->Child(1)->Content()]--;
joinLabelCounters[leftJoinTree->Child(2)->Content()]--;
DecrementCountLabelsInputUsage(leftJoinTree, joinLabelCounters);

auto leftJoinSettings = equiJoin->ChildPtr(equiJoin->ChildrenSize() - 1);

auto innerJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "Inner"));
auto leftOnlyJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "LeftOnly"));

THashMap<TString, int> leftSideJoinLabels;
CountLabelsInputUsage(leftJoinTree->Child(1), leftSideJoinLabels);
THashSet<TString> leftLabelsNoRightChild;
CollectJoinLabels(leftJoinTree->Child(1), leftLabelsNoRightChild);
auto leftJoinLabelsNoRightChild = MapLabelNamesToJoinLabels(joinLabels, leftLabelsNoRightChild);

THashSet<TString> leftLabelsFull;
CollectJoinLabels(leftJoinTree, leftLabelsFull);
auto leftJoinLabelsFull = MapLabelNamesToJoinLabels(joinLabels, leftLabelsFull);

YQL_ENSURE(leftJoinTree->Child(2)->IsAtom());
auto rightSideInput = equiJoinLabels.at(leftJoinTree->Child(2)->Content());
Expand All @@ -436,11 +503,19 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
auto innerJoin = ctx.Builder(pos)
.Callable("EquiJoin")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
for (const auto& [labelName, _] : leftSideJoinLabels) {
parent.List(i++)
.Add(0, equiJoinLabels.at(labelName))
.Atom(1, labelName)
.Seal();
for (const auto& [labelNames, labelKeys] : leftJoinLabelsNoRightChild) {
if (labelNames.size() == 1) {
parent.List(i++)
.Add(0, labelKeys)
.Atom(1, *labelNames.begin())
.Seal();
} else {
auto labelList = CreateLabelList(labelNames, ctx, pos);
parent.List(i++)
.Add(0, labelKeys)
.Add(1, labelList)
.Seal();
}
}
return parent;
})
Expand All @@ -459,11 +534,19 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
auto leftOnlyJoin = ctx.Builder(pos)
.Callable("EquiJoin")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
for (const auto& [labelName, _] : leftSideJoinLabels) {
parent.List(i++)
.Add(0, equiJoinLabels.at(labelName))
.Atom(1, labelName)
.Seal();
for (const auto& [labelNames, labelKeys] : leftJoinLabelsNoRightChild) {
if (labelNames.size() == 1) {
parent.List(i++)
.Add(0, labelKeys)
.Atom(1, *labelNames.begin())
.Seal();
} else {
auto labelList = CreateLabelList(labelNames, ctx, pos);
parent.List(i++)
.Add(0, labelKeys)
.Add(1, labelList)
.Seal();
}
}
return parent;
})
Expand Down Expand Up @@ -495,25 +578,6 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
return unionAll;
}

THashSet <TString> joinColumns;
for (const auto& [labelName, _] : leftSideJoinLabels) {
auto tableName = labels.FindInputIndex(labelName);
YQL_ENSURE(tableName);
for (auto column : labels.Inputs[*tableName].EnumerateAllColumns()) {
joinColumns.emplace(std::move(column));
}
}
auto rightSideTableName = labels.FindInputIndex(innerJoinTree->Child(2)->Content());
YQL_ENSURE(rightSideTableName);
for (auto column : labels.Inputs[*rightSideTableName].EnumerateAllColumns()) {
joinColumns.emplace(std::move(column));
}

auto newJoinLabel = ctx.Builder(pos)
.Atom("__yql_right_side_pushdown_input_label")
.Build();


TExprNode::TPtr remJoinKeys;
bool changedLeftSide = false;
if (leftJoinTree == parentJoinPtr->ChildPtr(1)) {
Expand All @@ -523,36 +587,14 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
remJoinKeys = parentJoinPtr->ChildPtr(4);
}

TExprNode::TListType newKeys;
newKeys.reserve(remJoinKeys->ChildrenSize());

for (ui32 i = 0; i < remJoinKeys->ChildrenSize(); i += 2) {
auto table = remJoinKeys->ChildPtr(i);
auto column = remJoinKeys->ChildPtr(i + 1);

YQL_ENSURE(table->IsAtom());
YQL_ENSURE(column->IsAtom());

auto fcn = FullColumnName(table->Content(), column->Content());

if (joinColumns.contains(fcn)) {
newKeys.push_back(newJoinLabel);
newKeys.push_back(ctx.NewAtom(column->Pos(), fcn));
} else {
newKeys.push_back(table);
newKeys.push_back(column);
}
}

auto newKeysList = ctx.NewList(remJoinKeys->Pos(), std::move(newKeys));

auto parentJoinLabel = remJoinKeys->ChildPtr(0);
auto newParentJoin = ctx.Builder(joinTree->Pos())
.List()
.Add(0, parentJoinPtr->ChildPtr(0))
.Add(1, changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(1))
.Add(2, !changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(2))
.Add(3, changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(3))
.Add(4, !changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(4))
.Add(1, changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(1))
.Add(2, !changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(2))
.Add(3, parentJoinPtr->ChildPtr(3))
.Add(4, parentJoinPtr->ChildPtr(4))
.Add(5, parentJoinPtr->ChildPtr(5))
.Seal()
.Build();
Expand All @@ -568,19 +610,12 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
}
return parent;
})
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
for (const auto& column : joinColumns) {
parent.List(i++)
.Atom(0, "rename")
.Atom(1, FullColumnName("__yql_right_side_pushdown_input_label", column))
.Atom(2, column)
.Seal();
}
return parent;
})
.Seal()
.Build();

auto combinedLabelList = CombineLabels(leftJoinLabelsFull);
auto combinedJoinLabels = CreateLabelList(combinedLabelList, ctx, pos);

i = 0;
auto newEquiJoin = ctx.Builder(pos)
.Callable("EquiJoin")
Expand All @@ -598,7 +633,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
})
.List(i++)
.Add(0, unionAll)
.Add(1, newJoinLabel)
.Add(1, combinedJoinLabels)
.Seal()
.Add(i++, newJoinTree)
.Add(i++, newJoinSettings)
Expand Down
Loading