Skip to content

Commit ca1a884

Browse files
committed
[FIX] Left join with predicate rewrite
1 parent b3b79e7 commit ca1a884

File tree

1 file changed

+107
-72
lines changed

1 file changed

+107
-72
lines changed

yql/essentials/core/common_opt/yql_flatmap_over_join.cpp

Lines changed: 107 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,25 @@ void CountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& cou
317317
}
318318
}
319319

320+
void CollectJoinLabels(TExprNode::TPtr joinTree, THashSet<TString> &labels) {
321+
if (joinTree->IsAtom()) {
322+
labels.emplace(joinTree->Content());
323+
} else {
324+
CollectJoinLabels(joinTree->ChildPtr(1), labels);
325+
CollectJoinLabels(joinTree->ChildPtr(2), labels);
326+
}
327+
}
328+
329+
330+
void DecrementCountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& counters) {
331+
if (joinTree->IsAtom()) {
332+
counters[joinTree->Content()]--;
333+
} else {
334+
DecrementCountLabelsInputUsage(joinTree->ChildPtr(1), counters);
335+
DecrementCountLabelsInputUsage(joinTree->ChildPtr(2), counters);
336+
}
337+
}
338+
320339
// returns the path to join child
321340
std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin(
322341
const TExprNode::TPtr& joinTree, const TJoinLabels& labels, ui32 inputIndex, const TExprNode::TPtr& parent = nullptr
@@ -350,6 +369,39 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin(
350369
return {nullptr, nullptr};
351370
}
352371

372+
TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> MapLabelNamesToJoinLabels(TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& joinLabels,
373+
THashSet<TString>& labelNames) {
374+
const ui32 joinLabelSize = joinLabels.size();
375+
TVector<bool> taken(joinLabelSize, false);
376+
TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> result;
377+
for (const auto& labelName : labelNames) {
378+
for (ui32 i = 0; i < joinLabelSize; ++i) {
379+
auto& labelNamesSet = joinLabels[i].first;
380+
if (!taken[i] && labelNamesSet.count(labelName)) {
381+
taken[i] = true;
382+
result.push_back(joinLabels[i]);
383+
}
384+
}
385+
}
386+
return result;
387+
}
388+
389+
THashSet<TString> CombineLabels(const TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& labels) {
390+
THashSet<TString> combinedResult;
391+
for (const auto &[labelNames, _] : labels) {
392+
combinedResult.insert(labelNames.begin(), labelNames.end());
393+
}
394+
return combinedResult;
395+
}
396+
397+
TExprNode::TPtr CreateLabelList(const THashSet<TString>& labels, TExprContext& ctx, const TPositionHandle& position) {
398+
TExprNode::TListType newKeys;
399+
for (const auto& label : labels) {
400+
newKeys.push_back(ctx.NewAtom(position, label));
401+
}
402+
return ctx.NewList(position, std::move(newKeys));
403+
}
404+
353405
TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate,
354406
const TSet<TStringBuf>& usedFields, TExprNode::TPtr args, const TJoinLabels& labels,
355407
ui32 inputIndex, const TMap<TStringBuf, TVector<TStringBuf>>& renameMap, bool ordered, bool skipNulls, TExprContext& ctx,
@@ -391,26 +443,41 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
391443
}
392444

393445
THashMap<TString, TExprNode::TPtr> equiJoinLabels;
446+
TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> joinLabels;
394447
for (size_t i = 0; i < equiJoin->ChildrenSize() - 2; i++) {
395448
auto label = equiJoin->Child(i);
396-
equiJoinLabels.emplace(label->Child(1)->Content(), label->ChildPtr(0));
449+
THashSet<TString> labelsName;
450+
if (auto value = TMaybeNode<TCoAtom>(label->Child(1))) {
451+
labelsName.emplace(value.Cast().Value());
452+
equiJoinLabels.emplace(value.Cast().Value(), label->ChildPtr(0));
453+
} else if (auto tuple = TMaybeNode<TCoAtomList>(label->Child(1))) {
454+
for (const auto& value : tuple.Cast()) {
455+
labelsName.emplace(value.Value());
456+
equiJoinLabels.emplace(value.Value(), label->ChildPtr(0));
457+
}
458+
}
459+
joinLabels.push_back({labelsName, label->ChildPtr(0)});
397460
}
398461

399462
THashMap<TString, int> joinLabelCounters;
400463
CountLabelsInputUsage(joinTree, joinLabelCounters);
401464

402465
auto [leftJoinTree, parentJoinPtr] = IsRightSideForLeftJoin(joinTree, labels, inputIndex);
403466
YQL_ENSURE(leftJoinTree);
404-
joinLabelCounters[leftJoinTree->Child(1)->Content()]--;
405-
joinLabelCounters[leftJoinTree->Child(2)->Content()]--;
467+
DecrementCountLabelsInputUsage(leftJoinTree, joinLabelCounters);
406468

407469
auto leftJoinSettings = equiJoin->ChildPtr(equiJoin->ChildrenSize() - 1);
408470

409471
auto innerJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "Inner"));
410472
auto leftOnlyJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "LeftOnly"));
411473

412-
THashMap<TString, int> leftSideJoinLabels;
413-
CountLabelsInputUsage(leftJoinTree->Child(1), leftSideJoinLabels);
474+
THashSet<TString> leftLabelsNoRightChild;
475+
CollectJoinLabels(leftJoinTree->Child(1), leftLabelsNoRightChild);
476+
auto leftJoinLabelsNoRightChild = MapLabelNamesToJoinLabels(joinLabels, leftLabelsNoRightChild);
477+
478+
THashSet<TString> leftLabelsFull;
479+
CollectJoinLabels(leftJoinTree, leftLabelsFull);
480+
auto leftJoinLabelsFull = MapLabelNamesToJoinLabels(joinLabels, leftLabelsFull);
414481

415482
YQL_ENSURE(leftJoinTree->Child(2)->IsAtom());
416483
auto rightSideInput = equiJoinLabels.at(leftJoinTree->Child(2)->Content());
@@ -436,11 +503,19 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
436503
auto innerJoin = ctx.Builder(pos)
437504
.Callable("EquiJoin")
438505
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
439-
for (const auto& [labelName, _] : leftSideJoinLabels) {
440-
parent.List(i++)
441-
.Add(0, equiJoinLabels.at(labelName))
442-
.Atom(1, labelName)
443-
.Seal();
506+
for (const auto& [labelNames, labelKeys] : leftJoinLabelsNoRightChild) {
507+
if (labelNames.size() == 1) {
508+
parent.List(i++)
509+
.Add(0, labelKeys)
510+
.Atom(1, *labelNames.begin())
511+
.Seal();
512+
} else {
513+
auto labelList = CreateLabelList(labelNames, ctx, pos);
514+
parent.List(i++)
515+
.Add(0, labelKeys)
516+
.Add(1, labelList)
517+
.Seal();
518+
}
444519
}
445520
return parent;
446521
})
@@ -459,11 +534,19 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
459534
auto leftOnlyJoin = ctx.Builder(pos)
460535
.Callable("EquiJoin")
461536
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
462-
for (const auto& [labelName, _] : leftSideJoinLabels) {
463-
parent.List(i++)
464-
.Add(0, equiJoinLabels.at(labelName))
465-
.Atom(1, labelName)
466-
.Seal();
537+
for (const auto& [labelNames, labelKeys] : leftJoinLabelsNoRightChild) {
538+
if (labelNames.size() == 1) {
539+
parent.List(i++)
540+
.Add(0, labelKeys)
541+
.Atom(1, *labelNames.begin())
542+
.Seal();
543+
} else {
544+
auto labelList = CreateLabelList(labelNames, ctx, pos);
545+
parent.List(i++)
546+
.Add(0, labelKeys)
547+
.Add(1, labelList)
548+
.Seal();
549+
}
467550
}
468551
return parent;
469552
})
@@ -495,25 +578,6 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
495578
return unionAll;
496579
}
497580

498-
THashSet <TString> joinColumns;
499-
for (const auto& [labelName, _] : leftSideJoinLabels) {
500-
auto tableName = labels.FindInputIndex(labelName);
501-
YQL_ENSURE(tableName);
502-
for (auto column : labels.Inputs[*tableName].EnumerateAllColumns()) {
503-
joinColumns.emplace(std::move(column));
504-
}
505-
}
506-
auto rightSideTableName = labels.FindInputIndex(innerJoinTree->Child(2)->Content());
507-
YQL_ENSURE(rightSideTableName);
508-
for (auto column : labels.Inputs[*rightSideTableName].EnumerateAllColumns()) {
509-
joinColumns.emplace(std::move(column));
510-
}
511-
512-
auto newJoinLabel = ctx.Builder(pos)
513-
.Atom("__yql_right_side_pushdown_input_label")
514-
.Build();
515-
516-
517581
TExprNode::TPtr remJoinKeys;
518582
bool changedLeftSide = false;
519583
if (leftJoinTree == parentJoinPtr->ChildPtr(1)) {
@@ -523,36 +587,14 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
523587
remJoinKeys = parentJoinPtr->ChildPtr(4);
524588
}
525589

526-
TExprNode::TListType newKeys;
527-
newKeys.reserve(remJoinKeys->ChildrenSize());
528-
529-
for (ui32 i = 0; i < remJoinKeys->ChildrenSize(); i += 2) {
530-
auto table = remJoinKeys->ChildPtr(i);
531-
auto column = remJoinKeys->ChildPtr(i + 1);
532-
533-
YQL_ENSURE(table->IsAtom());
534-
YQL_ENSURE(column->IsAtom());
535-
536-
auto fcn = FullColumnName(table->Content(), column->Content());
537-
538-
if (joinColumns.contains(fcn)) {
539-
newKeys.push_back(newJoinLabel);
540-
newKeys.push_back(ctx.NewAtom(column->Pos(), fcn));
541-
} else {
542-
newKeys.push_back(table);
543-
newKeys.push_back(column);
544-
}
545-
}
546-
547-
auto newKeysList = ctx.NewList(remJoinKeys->Pos(), std::move(newKeys));
548-
590+
auto parentJoinLabel = remJoinKeys->ChildPtr(0);
549591
auto newParentJoin = ctx.Builder(joinTree->Pos())
550592
.List()
551593
.Add(0, parentJoinPtr->ChildPtr(0))
552-
.Add(1, changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(1))
553-
.Add(2, !changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(2))
554-
.Add(3, changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(3))
555-
.Add(4, !changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(4))
594+
.Add(1, changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(1))
595+
.Add(2, !changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(2))
596+
.Add(3, parentJoinPtr->ChildPtr(3))
597+
.Add(4, parentJoinPtr->ChildPtr(4))
556598
.Add(5, parentJoinPtr->ChildPtr(5))
557599
.Seal()
558600
.Build();
@@ -568,19 +610,12 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
568610
}
569611
return parent;
570612
})
571-
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
572-
for (const auto& column : joinColumns) {
573-
parent.List(i++)
574-
.Atom(0, "rename")
575-
.Atom(1, FullColumnName("__yql_right_side_pushdown_input_label", column))
576-
.Atom(2, column)
577-
.Seal();
578-
}
579-
return parent;
580-
})
581613
.Seal()
582614
.Build();
583615

616+
auto combinedLabelList = CombineLabels(leftJoinLabelsFull);
617+
auto combinedJoinLabels = CreateLabelList(combinedLabelList, ctx, pos);
618+
584619
i = 0;
585620
auto newEquiJoin = ctx.Builder(pos)
586621
.Callable("EquiJoin")
@@ -598,7 +633,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
598633
})
599634
.List(i++)
600635
.Add(0, unionAll)
601-
.Add(1, newJoinLabel)
636+
.Add(1, combinedJoinLabels)
602637
.Seal()
603638
.Add(i++, newJoinTree)
604639
.Add(i++, newJoinSettings)

0 commit comments

Comments
 (0)