diff --git a/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp b/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp index 52107e4b3114..8ff1883b4cf0 100644 --- a/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp +++ b/yql/essentials/core/common_opt/yql_flatmap_over_join.cpp @@ -317,6 +317,25 @@ void CountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap& cou } } +void CollectJoinLabels(TExprNode::TPtr joinTree, THashSet &labels) { + if (joinTree->IsAtom()) { + labels.emplace(joinTree->Content()); + } else { + CollectJoinLabels(joinTree->ChildPtr(1), labels); + CollectJoinLabels(joinTree->ChildPtr(2), labels); + } +} + + +void DecrementCountLabelsInputUsage(TExprNode::TPtr joinTree, THashMap& counters) { + if (joinTree->IsAtom()) { + counters[joinTree->Content()]--; + } else { + DecrementCountLabelsInputUsage(joinTree->ChildPtr(1), counters); + DecrementCountLabelsInputUsage(joinTree->ChildPtr(2), counters); + } +} + // returns the path to join child std::pair IsRightSideForLeftJoin( const TExprNode::TPtr& joinTree, const TJoinLabels& labels, ui32 inputIndex, const TExprNode::TPtr& parent = nullptr @@ -350,6 +369,39 @@ std::pair IsRightSideForLeftJoin( return {nullptr, nullptr}; } +TVector, TExprNode::TPtr>> MapLabelNamesToJoinLabels(TVector, TExprNode::TPtr>>& joinLabels, + THashSet& labelNames) { + const ui32 joinLabelSize = joinLabels.size(); + TVector taken(joinLabelSize, false); + TVector, TExprNode::TPtr>> result; + for (const auto& labelName : labelNames) { + for (ui32 i = 0; i < joinLabelSize; ++i) { + auto& labelNamesSet = joinLabels[i].first; + if (!taken[i] && labelNamesSet.count(labelName)) { + taken[i] = true; + result.push_back(joinLabels[i]); + } + } + } + return result; +} + +THashSet CombineLabels(const TVector, TExprNode::TPtr>>& labels) { + THashSet combinedResult; + for (const auto &[labelNames, _] : labels) { + combinedResult.insert(labelNames.begin(), labelNames.end()); + } + return combinedResult; +} + +TExprNode::TPtr CreateLabelList(const THashSet& labels, TExprContext& ctx, const TPositionHandle& position) { + TExprNode::TListType newKeys; + for (const auto& label : labels) { + newKeys.push_back(ctx.NewAtom(position, label)); + } + return ctx.NewList(position, std::move(newKeys)); +} + TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TExprNode::TPtr predicate, const TSet& usedFields, TExprNode::TPtr args, const TJoinLabels& labels, ui32 inputIndex, const TMap>& renameMap, bool ordered, bool skipNulls, TExprContext& ctx, @@ -391,9 +443,20 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx } THashMap equiJoinLabels; + TVector, TExprNode::TPtr>> joinLabels; for (size_t i = 0; i < equiJoin->ChildrenSize() - 2; i++) { auto label = equiJoin->Child(i); - equiJoinLabels.emplace(label->Child(1)->Content(), label->ChildPtr(0)); + THashSet labelsName; + if (auto value = TMaybeNode(label->Child(1))) { + labelsName.emplace(value.Cast().Value()); + equiJoinLabels.emplace(value.Cast().Value(), label->ChildPtr(0)); + } else if (auto tuple = TMaybeNode(label->Child(1))) { + for (const auto& value : tuple.Cast()) { + labelsName.emplace(value.Value()); + equiJoinLabels.emplace(value.Value(), label->ChildPtr(0)); + } + } + joinLabels.push_back({labelsName, label->ChildPtr(0)}); } THashMap joinLabelCounters; @@ -401,16 +464,20 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx auto [leftJoinTree, parentJoinPtr] = IsRightSideForLeftJoin(joinTree, labels, inputIndex); YQL_ENSURE(leftJoinTree); - joinLabelCounters[leftJoinTree->Child(1)->Content()]--; - joinLabelCounters[leftJoinTree->Child(2)->Content()]--; + DecrementCountLabelsInputUsage(leftJoinTree, joinLabelCounters); auto leftJoinSettings = equiJoin->ChildPtr(equiJoin->ChildrenSize() - 1); auto innerJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "Inner")); auto leftOnlyJoinTree = ctx.ChangeChild(*leftJoinTree, 0, ctx.NewAtom(leftJoinTree->Pos(), "LeftOnly")); - THashMap leftSideJoinLabels; - CountLabelsInputUsage(leftJoinTree->Child(1), leftSideJoinLabels); + THashSet leftLabelsNoRightChild; + CollectJoinLabels(leftJoinTree->Child(1), leftLabelsNoRightChild); + auto leftJoinLabelsNoRightChild = MapLabelNamesToJoinLabels(joinLabels, leftLabelsNoRightChild); + + THashSet leftLabelsFull; + CollectJoinLabels(leftJoinTree, leftLabelsFull); + auto leftJoinLabelsFull = MapLabelNamesToJoinLabels(joinLabels, leftLabelsFull); YQL_ENSURE(leftJoinTree->Child(2)->IsAtom()); auto rightSideInput = equiJoinLabels.at(leftJoinTree->Child(2)->Content()); @@ -436,11 +503,19 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx auto innerJoin = ctx.Builder(pos) .Callable("EquiJoin") .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (const auto& [labelName, _] : leftSideJoinLabels) { - parent.List(i++) - .Add(0, equiJoinLabels.at(labelName)) - .Atom(1, labelName) - .Seal(); + for (const auto& [labelNames, labelKeys] : leftJoinLabelsNoRightChild) { + if (labelNames.size() == 1) { + parent.List(i++) + .Add(0, labelKeys) + .Atom(1, *labelNames.begin()) + .Seal(); + } else { + auto labelList = CreateLabelList(labelNames, ctx, pos); + parent.List(i++) + .Add(0, labelKeys) + .Add(1, labelList) + .Seal(); + } } return parent; }) @@ -459,11 +534,19 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx auto leftOnlyJoin = ctx.Builder(pos) .Callable("EquiJoin") .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (const auto& [labelName, _] : leftSideJoinLabels) { - parent.List(i++) - .Add(0, equiJoinLabels.at(labelName)) - .Atom(1, labelName) - .Seal(); + for (const auto& [labelNames, labelKeys] : leftJoinLabelsNoRightChild) { + if (labelNames.size() == 1) { + parent.List(i++) + .Add(0, labelKeys) + .Atom(1, *labelNames.begin()) + .Seal(); + } else { + auto labelList = CreateLabelList(labelNames, ctx, pos); + parent.List(i++) + .Add(0, labelKeys) + .Add(1, labelList) + .Seal(); + } } return parent; }) @@ -495,25 +578,6 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx return unionAll; } - THashSet joinColumns; - for (const auto& [labelName, _] : leftSideJoinLabels) { - auto tableName = labels.FindInputIndex(labelName); - YQL_ENSURE(tableName); - for (auto column : labels.Inputs[*tableName].EnumerateAllColumns()) { - joinColumns.emplace(std::move(column)); - } - } - auto rightSideTableName = labels.FindInputIndex(innerJoinTree->Child(2)->Content()); - YQL_ENSURE(rightSideTableName); - for (auto column : labels.Inputs[*rightSideTableName].EnumerateAllColumns()) { - joinColumns.emplace(std::move(column)); - } - - auto newJoinLabel = ctx.Builder(pos) - .Atom("__yql_right_side_pushdown_input_label") - .Build(); - - TExprNode::TPtr remJoinKeys; bool changedLeftSide = false; if (leftJoinTree == parentJoinPtr->ChildPtr(1)) { @@ -523,36 +587,14 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx remJoinKeys = parentJoinPtr->ChildPtr(4); } - TExprNode::TListType newKeys; - newKeys.reserve(remJoinKeys->ChildrenSize()); - - for (ui32 i = 0; i < remJoinKeys->ChildrenSize(); i += 2) { - auto table = remJoinKeys->ChildPtr(i); - auto column = remJoinKeys->ChildPtr(i + 1); - - YQL_ENSURE(table->IsAtom()); - YQL_ENSURE(column->IsAtom()); - - auto fcn = FullColumnName(table->Content(), column->Content()); - - if (joinColumns.contains(fcn)) { - newKeys.push_back(newJoinLabel); - newKeys.push_back(ctx.NewAtom(column->Pos(), fcn)); - } else { - newKeys.push_back(table); - newKeys.push_back(column); - } - } - - auto newKeysList = ctx.NewList(remJoinKeys->Pos(), std::move(newKeys)); - + auto parentJoinLabel = remJoinKeys->ChildPtr(0); auto newParentJoin = ctx.Builder(joinTree->Pos()) .List() .Add(0, parentJoinPtr->ChildPtr(0)) - .Add(1, changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(1)) - .Add(2, !changedLeftSide ? newJoinLabel : parentJoinPtr->ChildPtr(2)) - .Add(3, changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(3)) - .Add(4, !changedLeftSide ? newKeysList : parentJoinPtr->ChildPtr(4)) + .Add(1, changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(1)) + .Add(2, !changedLeftSide ? parentJoinLabel : parentJoinPtr->ChildPtr(2)) + .Add(3, parentJoinPtr->ChildPtr(3)) + .Add(4, parentJoinPtr->ChildPtr(4)) .Add(5, parentJoinPtr->ChildPtr(5)) .Seal() .Build(); @@ -568,19 +610,12 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx } return parent; }) - .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { - for (const auto& column : joinColumns) { - parent.List(i++) - .Atom(0, "rename") - .Atom(1, FullColumnName("__yql_right_side_pushdown_input_label", column)) - .Atom(2, column) - .Seal(); - } - return parent; - }) .Seal() .Build(); + auto combinedLabelList = CombineLabels(leftJoinLabelsFull); + auto combinedJoinLabels = CreateLabelList(combinedLabelList, ctx, pos); + i = 0; auto newEquiJoin = ctx.Builder(pos) .Callable("EquiJoin") @@ -598,7 +633,7 @@ TExprNode::TPtr FilterPushdownOverJoinOptionalSide(TExprNode::TPtr equiJoin, TEx }) .List(i++) .Add(0, unionAll) - .Add(1, newJoinLabel) + .Add(1, combinedJoinLabels) .Seal() .Add(i++, newJoinTree) .Add(i++, newJoinSettings)