Skip to content

Commit 3e299f1

Browse files
authored
Take the return type of BlockMapJoinCore computation node as an argument (#9529)
1 parent 2e0e43b commit 3e299f1

File tree

3 files changed

+77
-43
lines changed

3 files changed

+77
-43
lines changed

ydb/library/yql/minikql/comp_nodes/ut/mkql_block_map_join_ut.cpp

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,42 @@ const TRuntimeNode MakeDict(TProgramBuilder& pgmBuilder,
5555
});
5656
}
5757

58+
// XXX: Copy-pasted from program builder sources. Adjusted on demand.
59+
const std::vector<TType*> ValidateBlockStreamType(const TType* streamType) {
60+
const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType));
61+
Y_ENSURE(wideComponents.size() > 0, "Expected at least one column");
62+
std::vector<TType*> items;
63+
items.reserve(wideComponents.size());
64+
// XXX: Declare these variables outside the loop body to use for the last
65+
// item (i.e. block length column) in the assertions below.
66+
bool isScalar;
67+
TType* itemType;
68+
for (const auto& wideComponent : wideComponents) {
69+
auto blockType = AS_TYPE(TBlockType, wideComponent);
70+
isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
71+
itemType = blockType->GetItemType();
72+
items.push_back(blockType);
73+
}
74+
75+
Y_ENSURE(isScalar, "Last column should be scalar");
76+
Y_ENSURE(AS_TYPE(TDataType, itemType)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
77+
return items;
78+
}
79+
80+
bool IsOptionalOrNull(const TType* type) {
81+
return type->IsOptional() || type->IsNull() || type->IsPg();
82+
}
83+
5884
const TRuntimeNode BuildBlockJoin(TProgramBuilder& pgmBuilder, EJoinKind joinKind,
5985
const TVector<ui32>& leftKeyColumns, const TVector<ui32>& leftKeyDrops,
6086
TRuntimeNode& leftArg, TType* leftTuple, const TRuntimeNode& dictNode
6187
) {
88+
// 1. Make left argument node.
6289
const auto tupleType = AS_TYPE(TTupleType, leftTuple);
6390
const auto listTupleType = pgmBuilder.NewListType(leftTuple);
6491
leftArg = pgmBuilder.Arg(listTupleType);
6592

93+
// 2. Make left wide stream node.
6694
const auto leftWideStream = pgmBuilder.FromFlow(pgmBuilder.ExpandMap(pgmBuilder.ToFlow(leftArg),
6795
[&](TRuntimeNode tupleNode) -> TRuntimeNode::TList {
6896
TRuntimeNode::TList wide;
@@ -73,8 +101,53 @@ const TRuntimeNode BuildBlockJoin(TProgramBuilder& pgmBuilder, EJoinKind joinKin
73101
return wide;
74102
}));
75103

104+
// 3. Calculate the resulting join type.
105+
const auto leftStreamItems = ValidateBlockStreamType(leftWideStream.GetStaticType());
106+
const THashSet<ui32> leftKeyDropsSet(leftKeyDrops.cbegin(), leftKeyDrops.cend());
107+
TVector<TType*> returnJoinItems;
108+
for (size_t i = 0; i < leftStreamItems.size(); i++) {
109+
if (leftKeyDropsSet.contains(i)) {
110+
continue;
111+
}
112+
returnJoinItems.push_back(leftStreamItems[i]);
113+
}
114+
115+
const auto payloadType = AS_TYPE(TDictType, dictNode.GetStaticType())->GetPayloadType();
116+
const auto payloadItemType = payloadType->IsList()
117+
? AS_TYPE(TListType, payloadType)->GetItemType()
118+
: payloadType;
119+
if (joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left) {
120+
// XXX: This is the contract ensured by the expression compiler and
121+
// optimizers to ease the processing of the dict payload in wide context.
122+
Y_ENSURE(payloadItemType->IsTuple(), "Dict payload has to be a Tuple");
123+
const auto payloadItems = AS_TYPE(TTupleType, payloadItemType)->GetElements();
124+
TVector<TType*> dictBlockItems;
125+
dictBlockItems.reserve(payloadItems.size());
126+
for (const auto& payloadItem : payloadItems) {
127+
MKQL_ENSURE(!payloadItem->IsBlock(), "Dict payload item has to be non-block");
128+
const auto itemType = joinKind == EJoinKind::Inner ? payloadItem
129+
: IsOptionalOrNull(payloadItem) ? payloadItem
130+
: pgmBuilder.NewOptionalType(payloadItem);
131+
dictBlockItems.emplace_back(pgmBuilder.NewBlockType(itemType, TBlockType::EShape::Many));
132+
}
133+
// Block length column has to be the last column in wide block stream item,
134+
// so all contents of the dict payload should be appended to the resulting
135+
// wide type before the block size column.
136+
const auto blockLenPos = std::prev(returnJoinItems.end());
137+
returnJoinItems.insert(blockLenPos, dictBlockItems.cbegin(), dictBlockItems.cend());
138+
} else {
139+
// XXX: This is the contract ensured by the expression compiler and
140+
// optimizers for join types that don't require the right (i.e. dict) part.
141+
Y_ENSURE(payloadItemType->IsVoid(), "Dict payload has to be Void");
142+
}
143+
TType* returnJoinType = pgmBuilder.NewStreamType(pgmBuilder.NewMultiType(returnJoinItems));
144+
145+
// 4. Build BlockMapJoinCore node.
76146
const auto joinNode = pgmBuilder.BlockMapJoinCore(leftWideStream, dictNode, joinKind,
77-
leftKeyColumns, leftKeyDrops);
147+
leftKeyColumns, leftKeyDrops,
148+
returnJoinType);
149+
150+
// 5. Build the root node with list of tuples.
78151
const auto joinItems = GetWideComponents(AS_TYPE(TStreamType, joinNode.GetStaticType()));
79152
const auto resultType = AS_TYPE(TTupleType, pgmBuilder.NewTupleType(joinItems));
80153

ydb/library/yql/minikql/mkql_program_builder.cpp

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5894,7 +5894,7 @@ TRuntimeNode TProgramBuilder::ScalarApply(const TArrayRef<const TRuntimeNode>& a
58945894

58955895
TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode stream, TRuntimeNode dict,
58965896
EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns,
5897-
const TArrayRef<const ui32>& leftKeyDrops
5897+
const TArrayRef<const ui32>& leftKeyDrops, TType* returnType
58985898
) {
58995899
if constexpr (RuntimeVersion < 51U) {
59005900
THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
@@ -5923,46 +5923,7 @@ TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode stream, TRuntimeNode
59235923
return NewDataLiteral(idx);
59245924
});
59255925

5926-
const auto leftStreamItems = ValidateBlockStreamType(stream.GetStaticType(), false);
5927-
const THashSet<ui32> leftKeyDropsSet(leftKeyDrops.cbegin(), leftKeyDrops.cend());
5928-
TVector<TType*> returnJoinItems;
5929-
for (size_t i = 0; i < leftStreamItems.size(); i++) {
5930-
if (leftKeyDropsSet.contains(i)) {
5931-
continue;
5932-
}
5933-
returnJoinItems.push_back(leftStreamItems[i]);
5934-
}
5935-
5936-
const auto payloadType = AS_TYPE(TDictType, dict.GetStaticType())->GetPayloadType();
5937-
const auto payloadItemType = payloadType->IsList()
5938-
? AS_TYPE(TListType, payloadType)->GetItemType()
5939-
: payloadType;
5940-
if (joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left) {
5941-
// XXX: This is the contract ensured by the expression compiler and
5942-
// optimizers to ease the processing of the dict payload in wide context.
5943-
MKQL_ENSURE(payloadItemType->IsTuple(), "Dict payload has to be a Tuple");
5944-
const auto payloadItems = AS_TYPE(TTupleType, payloadItemType)->GetElements();
5945-
TVector<TType*> dictBlockItems;
5946-
dictBlockItems.reserve(payloadItems.size());
5947-
for (const auto& payloadItem : payloadItems) {
5948-
MKQL_ENSURE(!payloadItem->IsBlock(), "Dict payload item has to be non-block");
5949-
const auto itemType = joinKind == EJoinKind::Inner ? payloadItem
5950-
: NewOptionalType(payloadItem);
5951-
dictBlockItems.emplace_back(NewBlockType(itemType, TBlockType::EShape::Many));
5952-
}
5953-
// Block length column has to be the last column in wide block stream item,
5954-
// so all contents of the dict payload should be appended to the resulting
5955-
// wide type before the block size column.
5956-
const auto blockLenPos = std::prev(returnJoinItems.end());
5957-
returnJoinItems.insert(blockLenPos, dictBlockItems.cbegin(), dictBlockItems.cend());
5958-
} else {
5959-
// XXX: This is the contract ensured by the expression compiler and
5960-
// optimizers for join types that don't require the right (i.e. dict) part.
5961-
MKQL_ENSURE(payloadItemType->IsVoid(), "Dict payload has to be Void");
5962-
}
5963-
TType* returnJoinType = NewStreamType(NewMultiType(returnJoinItems));
5964-
5965-
TCallableBuilder callableBuilder(Env, __func__, returnJoinType);
5926+
TCallableBuilder callableBuilder(Env, __func__, returnType);
59665927
callableBuilder.Add(stream);
59675928
callableBuilder.Add(dict);
59685929
callableBuilder.Add(NewDataLiteral((ui32)joinKind));

ydb/library/yql/minikql/mkql_program_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class TProgramBuilder : public TTypeBuilder {
257257
const TArrayRef<const TRuntimeNode>& args, TType* returnType);
258258
TRuntimeNode BlockMapJoinCore(TRuntimeNode flow, TRuntimeNode dict,
259259
EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns,
260-
const TArrayRef<const ui32>& leftKeyDrops = {});
260+
const TArrayRef<const ui32>& leftKeyDrops, TType* returnType);
261261

262262
//-- logical functions
263263
TRuntimeNode BlockNot(TRuntimeNode data);

0 commit comments

Comments
 (0)