Skip to content

Commit 7ba808a

Browse files
committed
[FIX] Left join with predicate rewrite
1 parent b3b79e7 commit 7ba808a

File tree

1 file changed

+113
-72
lines changed

1 file changed

+113
-72
lines changed

yql/essentials/core/common_opt/yql_flatmap_over_join.cpp

Lines changed: 113 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,25 @@ void CountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& cou
317317
}
318318
}
319319

320+
void CollectJoinLabels(TExprNode::TPtr joinTree, THashSet<TString> &labels) {
321+
if (joinTree->IsAtom()) {
322+
labels.emplace(joinTree->Content());
323+
} else {
324+
CollectJoinLabels(joinTree->ChildPtr(1), labels);
325+
CollectJoinLabels(joinTree->ChildPtr(2), labels);
326+
}
327+
}
328+
329+
330+
void DecrementCountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap<TString, int>& counters) {
331+
if (joinTree->IsAtom()) {
332+
counters[joinTree->Content()]--;
333+
} else {
334+
DecrementCountLabelsInputUsage(joinTree->ChildPtr(1), counters);
335+
DecrementCountLabelsInputUsage(joinTree->ChildPtr(2), counters);
336+
}
337+
}
338+
320339
// returns the path to join child
321340
std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin(
322341
const TExprNode::TPtr& joinTree, const TJoinLabels& labels, ui32 inputIndex, const TExprNode::TPtr& parent = nullptr
@@ -350,6 +369,39 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> IsRightSideForLeftJoin(
350369
return {nullptr, nullptr};
351370
}
352371

372+
TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> MapLabelNamesToJoinLabels(TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& joinLabels,
373+
THashSet<TString>& labelNames) {
374+
const ui32 joinLabelSize = joinLabels.size();
375+
TVector<bool> taken(joinLabelSize, false);
376+
TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> result;
377+
for (const auto& labelName : labelNames) {
378+
for (ui32 i = 0; i < joinLabelSize; ++i) {
379+
auto& labelNamesSet = joinLabels[i].first;
380+
if (!taken[i] && labelNamesSet.count(labelName)) {
381+
taken[i] = true;
382+
result.push_back(joinLabels[i]);
383+
}
384+
}
385+
}
386+
return result;
387+
}
388+
389+
THashSet<TString> CombineLabels(const TVector<std::pair<THashSet<TString>, TExprNode::TPtr>>& labels) {
390+
THashSet<TString> combinedResult;
391+
for (const auto &[labelNames, _] : labels) {
392+
combinedResult.insert(labelNames.begin(), labelNames.end());
393+
}
394+
return combinedResult;
395+
}
396+
397+
TExprNode::TPtr CreateLabelList(const THashSet<TString>& labels, TExprContext& ctx, const TPositionHandle& position) {
398+
TExprNode::TListType newKeys;
399+
for (const auto& label : labels) {
400+
newKeys.push_back(ctx.NewAtom(position, label));
401+
}
402+
return ctx.NewList(position, std::move(newKeys));
403+
}
404+
353405
TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate,
354406
const TSet<TStringBuf>& usedFields, TExprNode::TPtr args, const TJoinLabels& labels,
355407
ui32 inputIndex, const TMap<TStringBuf, TVector<TStringBuf>>& renameMap, bool ordered, bool skipNulls, TExprContext& ctx,
@@ -391,26 +443,47 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
391443
}
392444

393445
THashMap<TString, TExprNode::TPtr> equiJoinLabels;
446+
TVector<std::pair<THashSet<TString>, TExprNode::TPtr>> joinLabels;
394447
for (size_t i = 0; i < equiJoin->ChildrenSize() - 2; i++) {
395448
auto label = equiJoin->Child(i);
396-
equiJoinLabels.emplace(label->Child(1)->Content(), label->ChildPtr(0));
449+
THashSet<TString> labelsName;
450+
if (auto value = TMaybeNode<TCoAtom>(label->Child(1))) {
451+
labelsName.emplace(value.Cast().Value());
452+
equiJoinLabels.emplace(value.Cast().Value(), label->ChildPtr(0));
453+
} else if (auto tuple = TMaybeNode<TCoAtomList>(label->Child(1))) {
454+
for (const auto& value : tuple.Cast()) {
455+
labelsName.emplace(value.Value());
456+
equiJoinLabels.emplace(value.Value(), label->ChildPtr(0));
457+
}
458+
}
459+
joinLabels.push_back({labelsName, label->ChildPtr(0)});
397460
}
398461

399462
THashMap<TString, int> joinLabelCounters;
400463
CountLabelsInputUsage(joinTree, joinLabelCounters);
401464

402465
auto [leftJoinTree, parentJoinPtr] = IsRightSideForLeftJoin(joinTree, labels, inputIndex);
403466
YQL_ENSURE(leftJoinTree);
404-
joinLabelCounters[leftJoinTree->Child(1)->Content()]--;
405-
joinLabelCounters[leftJoinTree->Child(2)->Content()]--;
467+
DecrementCountLabelsInputUsage(leftJoinTree, joinLabelCounters);
406468

407469
auto leftJoinSettings = equiJoin->ChildPtr(equiJoin->ChildrenSize() - 1);
408470

409471
auto innerJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "Inner"));
410472
auto leftOnlyJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "LeftOnly"));
411473

412-
THashMap<TString, int> leftSideJoinLabels;
413-
CountLabelsInputUsage(leftJoinTree->Child(1), leftSideJoinLabels);
474+
THashSet<TString> leftLabelsNoRightChild;
475+
CollectJoinLabels(leftJoinTree->Child(1), leftLabelsNoRightChild);
476+
auto leftJoinLabelsNoRightChild = MapLabelNamesToJoinLabels(joinLabels, leftLabelsNoRightChild);
477+
478+
THashSet<TString> leftLabelsFull;
479+
CollectJoinLabels(leftJoinTree, leftLabelsFull);
480+
auto leftJoinLabelsFull = MapLabelNamesToJoinLabels(joinLabels, leftLabelsFull);
481+
Cerr << "LABELS " << Endl;
482+
for (const auto & label : leftJoinLabelsFull) {
483+
for (auto &e : label.first) {
484+
Cerr << e << Endl;
485+
}
486+
}
414487

415488
YQL_ENSURE(leftJoinTree->Child(2)->IsAtom());
416489
auto rightSideInput = equiJoinLabels.at(leftJoinTree->Child(2)->Content());
@@ -436,11 +509,19 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
436509
auto innerJoin = ctx.Builder(pos)
437510
.Callable("EquiJoin")
438511
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
439-
for (const auto& [labelName, _] : leftSideJoinLabels) {
440-
parent.List(i++)
441-
.Add(0, equiJoinLabels.at(labelName))
442-
.Atom(1, labelName)
443-
.Seal();
512+
for (const auto& [labelNames, labelKeys] : leftJoinLabelsNoRightChild) {
513+
if (labelNames.size() == 1) {
514+
parent.List(i++)
515+
.Add(0, labelKeys)
516+
.Atom(1, *labelNames.begin())
517+
.Seal();
518+
} else {
519+
auto labelList = CreateLabelList(labelNames, ctx, pos);
520+
parent.List(i++)
521+
.Add(0, labelKeys)
522+
.Add(1, labelList)
523+
.Seal();
524+
}
444525
}
445526
return parent;
446527
})
@@ -459,11 +540,19 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
459540
auto leftOnlyJoin = ctx.Builder(pos)
460541
.Callable("EquiJoin")
461542
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
462-
for (const auto& [labelName, _] : leftSideJoinLabels) {
463-
parent.List(i++)
464-
.Add(0, equiJoinLabels.at(labelName))
465-
.Atom(1, labelName)
466-
.Seal();
543+
for (const auto& [labelNames, labelKeys] : leftJoinLabelsNoRightChild) {
544+
if (labelNames.size() == 1) {
545+
parent.List(i++)
546+
.Add(0, labelKeys)
547+
.Atom(1, *labelNames.begin())
548+
.Seal();
549+
} else {
550+
auto labelList = CreateLabelList(labelNames, ctx, pos);
551+
parent.List(i++)
552+
.Add(0, labelKeys)
553+
.Add(1, labelList)
554+
.Seal();
555+
}
467556
}
468557
return parent;
469558
})
@@ -495,25 +584,6 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
495584
return unionAll;
496585
}
497586

498-
THashSet <TString> joinColumns;
499-
for (const auto& [labelName, _] : leftSideJoinLabels) {
500-
auto tableName = labels.FindInputIndex(labelName);
501-
YQL_ENSURE(tableName);
502-
for (auto column : labels.Inputs[*tableName].EnumerateAllColumns()) {
503-
joinColumns.emplace(std::move(column));
504-
}
505-
}
506-
auto rightSideTableName = labels.FindInputIndex(innerJoinTree->Child(2)->Content());
507-
YQL_ENSURE(rightSideTableName);
508-
for (auto column : labels.Inputs[*rightSideTableName].EnumerateAllColumns()) {
509-
joinColumns.emplace(std::move(column));
510-
}
511-
512-
auto newJoinLabel = ctx.Builder(pos)
513-
.Atom("__yql_right_side_pushdown_input_label")
514-
.Build();
515-
516-
517587
TExprNode::TPtr remJoinKeys;
518588
bool changedLeftSide = false;
519589
if (leftJoinTree == parentJoinPtr->ChildPtr(1)) {
@@ -523,36 +593,14 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
523593
remJoinKeys = parentJoinPtr->ChildPtr(4);
524594
}
525595

526-
TExprNode::TListType newKeys;
527-
newKeys.reserve(remJoinKeys->ChildrenSize());
528-
529-
for (ui32 i = 0; i < remJoinKeys->ChildrenSize(); i += 2) {
530-
auto table = remJoinKeys->ChildPtr(i);
531-
auto column = remJoinKeys->ChildPtr(i + 1);
532-
533-
YQL_ENSURE(table->IsAtom());
534-
YQL_ENSURE(column->IsAtom());
535-
536-
auto fcn = FullColumnName(table->Content(), column->Content());
537-
538-
if (joinColumns.contains(fcn)) {
539-
newKeys.push_back(newJoinLabel);
540-
newKeys.push_back(ctx.NewAtom(column->Pos(), fcn));
541-
} else {
542-
newKeys.push_back(table);
543-
newKeys.push_back(column);
544-
}
545-
}
546-
547-
auto newKeysList = ctx.NewList(remJoinKeys->Pos(), std::move(newKeys));
548-
596+
auto parentJoinLabel = remJoinKeys->ChildPtr(0);
549597
auto newParentJoin = ctx.Builder(joinTree->Pos())
550598
.List()
551599
.Add(0, parentJoinPtr->ChildPtr(0))
552-
.Add(1, changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(1))
553-
.Add(2, !changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(2))
554-
.Add(3, changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(3))
555-
.Add(4, !changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(4))
600+
.Add(1, changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(1))
601+
.Add(2, !changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(2))
602+
.Add(3, parentJoinPtr->ChildPtr(3))
603+
.Add(4, parentJoinPtr->ChildPtr(4))
556604
.Add(5, parentJoinPtr->ChildPtr(5))
557605
.Seal()
558606
.Build();
@@ -568,19 +616,12 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
568616
}
569617
return parent;
570618
})
571-
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
572-
for (const auto& column : joinColumns) {
573-
parent.List(i++)
574-
.Atom(0, "rename")
575-
.Atom(1, FullColumnName("__yql_right_side_pushdown_input_label", column))
576-
.Atom(2, column)
577-
.Seal();
578-
}
579-
return parent;
580-
})
581619
.Seal()
582620
.Build();
583621

622+
auto combinedLabelList = CombineLabels(leftJoinLabelsFull);
623+
auto combinedJoinLabels = CreateLabelList(combinedLabelList, ctx, pos);
624+
584625
i = 0;
585626
auto newEquiJoin = ctx.Builder(pos)
586627
.Callable("EquiJoin")
@@ -598,7 +639,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx
598639
})
599640
.List(i++)
600641
.Add(0, unionAll)
601-
.Add(1, newJoinLabel)
642+
.Add(1, combinedJoinLabels)
602643
.Seal()
603644
.Add(i++, newJoinTree)
604645
.Add(i++, newJoinSettings)

0 commit comments

Comments
 (0)