Skip to content

Commit dd2a9ce

Browse files
authored
Refactor tests for BlockMapJoinCore computation node (#8129)
1 parent c5366d7 commit dd2a9ce

File tree

1 file changed

+204
-171
lines changed

1 file changed

+204
-171
lines changed

ydb/library/yql/minikql/comp_nodes/ut/mkql_block_map_join_ut.cpp

Lines changed: 204 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -13,195 +13,228 @@ namespace NKikimr {
1313
namespace NMiniKQL {
1414

1515
namespace {
16-
TMap<const TString, ui64> NameToIndex(const TStructType* structType) {
17-
TMap<const TString, ui64> map;
18-
for (size_t i = 0; i < structType->GetMembersCount(); i++) {
19-
const TString name(structType->GetMemberName(i));
20-
map[name] = i;
21-
}
22-
return map;
23-
}
2416

25-
TVector<TString> GeneratePayload(size_t level) {
26-
constexpr size_t alphaSize = 'Z' - 'A' + 1;
27-
if (level == 1) {
28-
TVector<TString> alphabet(alphaSize);
29-
std::iota(alphabet.begin(), alphabet.end(), 'A');
30-
return alphabet;
31-
}
32-
const auto subPayload = GeneratePayload(level - 1);
33-
TVector<TString> payload;
34-
payload.reserve(alphaSize * subPayload.size());
35-
for (char ch = 'A'; ch <= 'Z'; ch++) {
36-
for (const auto& tail : subPayload) {
37-
payload.emplace_back(ch + tail);
38-
}
39-
}
40-
return payload;
41-
}
17+
using TKSV = std::tuple<ui64, ui64, TStringBuf>;
18+
using TArrays = std::array<std::shared_ptr<arrow::ArrayData>, std::tuple_size_v<TKSV>>;
4219

43-
constexpr size_t payloadSize = 2;
44-
static const TVector<TString> twoLetterPayloads = GeneratePayload(payloadSize);
45-
46-
template <typename T, bool isOptional = false>
47-
const TRuntimeNode MakeSimpleKey(
48-
TProgramBuilder& pgmBuilder,
49-
T value,
50-
bool isEmpty = false
51-
) {
52-
if constexpr (!isOptional) {
53-
return pgmBuilder.NewDataLiteral<T>(value);
54-
}
55-
const auto keyType = pgmBuilder.NewDataType(NUdf::TDataType<T>::Id, true);
56-
if (isEmpty) {
57-
return pgmBuilder.NewEmptyOptional(keyType);
20+
TVector<TString> GenerateValues(size_t level) {
21+
constexpr size_t alphaSize = 'Z' - 'A' + 1;
22+
if (level == 1) {
23+
TVector<TString> alphabet(alphaSize);
24+
std::iota(alphabet.begin(), alphabet.end(), 'A');
25+
return alphabet;
26+
}
27+
const auto subValues = GenerateValues(level - 1);
28+
TVector<TString> values;
29+
values.reserve(alphaSize * subValues.size());
30+
for (char ch = 'A'; ch <= 'Z'; ch++) {
31+
for (const auto& tail : subValues) {
32+
values.emplace_back(ch + tail);
5833
}
59-
return pgmBuilder.NewOptional(pgmBuilder.NewDataLiteral<T>(value));
6034
}
35+
return values;
36+
}
6137

62-
template <typename TKey>
63-
const TRuntimeNode MakeSet(
64-
TProgramBuilder& pgmBuilder,
65-
const TVector<TKey>& keyValues
66-
) {
67-
const auto keyType = pgmBuilder.NewDataType(NUdf::TDataType<TKey>::Id);
68-
69-
TRuntimeNode::TList keyListItems;
70-
std::transform(keyValues.cbegin(), keyValues.cend(),
71-
std::back_inserter(keyListItems), [&pgmBuilder](const auto key) {
72-
return pgmBuilder.NewDataLiteral<TKey>(key);
73-
});
74-
75-
const auto keyList = pgmBuilder.NewList(keyType, keyListItems);
76-
return pgmBuilder.ToHashedDict(keyList, false,
77-
[&](TRuntimeNode item) {
78-
return item;
79-
}, [&](TRuntimeNode) {
80-
return pgmBuilder.NewVoid();
81-
});
38+
template <typename T, bool isOptional = false>
39+
const TRuntimeNode MakeSimpleKey(TProgramBuilder& pgmBuilder, T value, bool isEmpty = false) {
40+
if constexpr (!isOptional) {
41+
return pgmBuilder.NewDataLiteral<T>(value);
42+
}
43+
const auto keyType = pgmBuilder.NewDataType(NUdf::TDataType<T>::Id, true);
44+
if (isEmpty) {
45+
return pgmBuilder.NewEmptyOptional(keyType);
8246
}
47+
return pgmBuilder.NewOptional(pgmBuilder.NewDataLiteral<T>(value));
48+
}
49+
50+
template <typename TKey>
51+
const TRuntimeNode MakeSet(TProgramBuilder& pgmBuilder, const TSet<TKey>& keyValues) {
52+
const auto keyType = pgmBuilder.NewDataType(NUdf::TDataType<TKey>::Id);
8353

84-
void DoTestBlockJoinOnUint64(EJoinKind joinKind, size_t blockSize, size_t testSize) {
85-
TSetup<false> setup;
86-
TProgramBuilder& pb = *setup.PgmBuilder;
87-
88-
const TVector<ui64> dictKeys = {1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144};
89-
const auto dict = MakeSet(pb, dictKeys);
90-
91-
const auto ui64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id);
92-
const auto strType = pb.NewDataType(NUdf::TDataType<char*>::Id);
93-
const auto ui64BlockType = pb.NewBlockType(ui64Type, TBlockType::EShape::Many);
94-
const auto strBlockType = pb.NewBlockType(strType, TBlockType::EShape::Many);
95-
const auto blockLenType = pb.NewBlockType(ui64Type, TBlockType::EShape::Scalar);
96-
const auto structType = pb.NewStructType({
97-
{"key", ui64BlockType},
98-
{"subkey", ui64BlockType},
99-
{"payload", strBlockType},
100-
{"_yql_block_length", blockLenType}
54+
TRuntimeNode::TList keyListItems;
55+
std::transform(keyValues.cbegin(), keyValues.cend(),
56+
std::back_inserter(keyListItems), [&pgmBuilder](const auto key) {
57+
return pgmBuilder.NewDataLiteral<TKey>(key);
58+
});
59+
60+
const auto keyList = pgmBuilder.NewList(keyType, keyListItems);
61+
return pgmBuilder.ToHashedDict(keyList, false,
62+
[&](TRuntimeNode item) {
63+
return item;
64+
}, [&](TRuntimeNode) {
65+
return pgmBuilder.NewVoid();
66+
});
67+
}
68+
69+
TArrays KSVToArrays(const TVector<TKSV>& ksvVector, size_t current,
70+
size_t blockSize, arrow::MemoryPool* memoryPool
71+
) {
72+
TArrays arrays;
73+
arrow::UInt64Builder keysBuilder(memoryPool);
74+
arrow::UInt64Builder subkeysBuilder(memoryPool);
75+
arrow::BinaryBuilder valuesBuilder(memoryPool);
76+
ARROW_OK(keysBuilder.Reserve(blockSize));
77+
ARROW_OK(subkeysBuilder.Reserve(blockSize));
78+
ARROW_OK(valuesBuilder.Reserve(blockSize));
79+
for (size_t i = 0; i < blockSize; i++) {
80+
keysBuilder.UnsafeAppend(std::get<0>(ksvVector[current + i]));
81+
subkeysBuilder.UnsafeAppend(std::get<1>(ksvVector[current + i]));
82+
const TStringBuf string(std::get<2>(ksvVector[current + i]));
83+
ARROW_OK(valuesBuilder.Append(string.data(), string.size()));
84+
}
85+
ARROW_OK(keysBuilder.FinishInternal(&arrays[0]));
86+
ARROW_OK(subkeysBuilder.FinishInternal(&arrays[1]));
87+
ARROW_OK(valuesBuilder.FinishInternal(&arrays[2]));
88+
return arrays;
89+
}
90+
91+
TVector<TKSV> ArraysToKSV(const TArrays& arrays, const int64_t blockSize) {
92+
TVector<TKSV> ksvVector;
93+
for (size_t i = 0; i < std::tuple_size_v<TKSV>; i++) {
94+
Y_ENSURE(arrays[i]->length == blockSize,
95+
"Array size differs from the given block size");
96+
Y_ENSURE(arrays[i]->GetNullCount() == 0,
97+
"Null values conversion is not supported");
98+
Y_ENSURE(arrays[i]->buffers.size() == 2 + (i > 1),
99+
"Array layout doesn't respect the schema");
100+
}
101+
const ui64* keyBuffer = arrays[0]->GetValuesSafe<ui64>(1);
102+
const ui64* subkeyBuffer = arrays[1]->GetValuesSafe<ui64>(1);
103+
const int32_t* offsets = arrays[2]->GetValuesSafe<int32_t>(1);
104+
const char* valuesBuffer = arrays[2]->GetValuesSafe<char>(2, 0);
105+
for (auto i = 0; i < blockSize; i++) {
106+
const TStringBuf value(valuesBuffer + offsets[i], offsets[i + 1] - offsets[i]);
107+
ksvVector.push_back(std::make_tuple(keyBuffer[i], subkeyBuffer[i], value));
108+
}
109+
return ksvVector;
110+
}
111+
112+
const TRuntimeNode BuildBlockJoin(TProgramBuilder& pgmBuilder, EJoinKind joinKind,
113+
TVector<ui32> keyColumns, TRuntimeNode& leftArg, TType* leftTuple,
114+
const TRuntimeNode& dictNode
115+
) {
116+
const auto tupleType = AS_TYPE(TTupleType, leftTuple);
117+
const auto listTupleType = pgmBuilder.NewListType(leftTuple);
118+
leftArg = pgmBuilder.Arg(listTupleType);
119+
120+
const auto leftWideFlow = pgmBuilder.ExpandMap(pgmBuilder.ToFlow(leftArg),
121+
[&](TRuntimeNode tupleNode) -> TRuntimeNode::TList {
122+
TRuntimeNode::TList wide;
123+
wide.reserve(tupleType->GetElementsCount());
124+
for (size_t i = 0; i < tupleType->GetElementsCount(); i++) {
125+
wide.emplace_back(pgmBuilder.Nth(tupleNode, i));
126+
}
127+
return wide;
101128
});
102-
const auto fields = NameToIndex(AS_TYPE(TStructType, structType));
103-
const auto listStructType = pb.NewListType(structType);
104-
105-
const auto leftArg = pb.Arg(listStructType);
106-
107-
const auto leftWideFlow = pb.ExpandMap(pb.ToFlow(leftArg),
108-
[&](TRuntimeNode item) -> TRuntimeNode::TList {
109-
return {
110-
pb.Member(item, "key"),
111-
pb.Member(item, "subkey"),
112-
pb.Member(item, "payload"),
113-
pb.Member(item, "_yql_block_length")
114-
};
115-
});
116-
117-
const auto joinNode = pb.BlockMapJoinCore(leftWideFlow, dict, joinKind, {0});
118-
119-
const auto rootNode = pb.Collect(pb.NarrowMap(joinNode,
120-
[&](TRuntimeNode::TList items) -> TRuntimeNode {
121-
return pb.NewStruct(structType, {
122-
{"key", items[0]},
123-
{"subkey", items[1]},
124-
{"payload", items[2]},
125-
{"_yql_block_length", items[3]}
126-
});
127-
}));
128-
129-
const auto graph = setup.BuildGraph(rootNode, {leftArg.GetNode()});
130-
const auto& leftBlocks = graph->GetEntryPoint(0, true);
131-
const auto& holderFactory = graph->GetHolderFactory();
132-
auto& ctx = graph->GetContext();
133-
134-
TVector<ui64> keys(testSize);
135-
TVector<ui64> subkeys;
136-
std::iota(keys.begin(), keys.end(), 1);
137-
std::transform(keys.cbegin(), keys.cend(), std::back_inserter(subkeys),
138-
[](const auto& value) { return value * 1001; });
139-
140-
TVector<const char*> payloads;
141-
std::transform(keys.cbegin(), keys.cend(), std::back_inserter(payloads),
142-
[](const auto& value) { return twoLetterPayloads[value].c_str(); });
143-
144-
size_t current = 0;
145-
TDefaultListRepresentation leftListValues;
146-
while (current < testSize) {
147-
arrow::UInt64Builder keysBuilder(&ctx.ArrowMemoryPool);
148-
arrow::UInt64Builder subkeysBuilder(&ctx.ArrowMemoryPool);
149-
arrow::BinaryBuilder payloadsBuilder(&ctx.ArrowMemoryPool);
150-
ARROW_OK(keysBuilder.Reserve(blockSize));
151-
ARROW_OK(subkeysBuilder.Reserve(blockSize));
152-
ARROW_OK(payloadsBuilder.Reserve(blockSize));
153-
for (size_t i = 0; i < blockSize; i++, current++) {
154-
keysBuilder.UnsafeAppend(keys[current]);
155-
subkeysBuilder.UnsafeAppend(subkeys[current]);
156-
ARROW_OK(payloadsBuilder.Append(payloads[current], payloadSize));
129+
130+
const auto joinNode = pgmBuilder.BlockMapJoinCore(leftWideFlow, dictNode, joinKind, keyColumns);
131+
132+
const auto rootNode = pgmBuilder.Collect(pgmBuilder.NarrowMap(joinNode,
133+
[&](TRuntimeNode::TList items) -> TRuntimeNode {
134+
TVector<TRuntimeNode> tupleElements;
135+
tupleElements.reserve(tupleType->GetElementsCount());
136+
for (size_t i = 0; i < tupleType->GetElementsCount(); i++) {
137+
tupleElements.emplace_back(items[i]);
157138
}
158-
std::shared_ptr<arrow::ArrayData> keysData;
159-
ARROW_OK(keysBuilder.FinishInternal(&keysData));
160-
std::shared_ptr<arrow::ArrayData> subkeysData;
161-
ARROW_OK(subkeysBuilder.FinishInternal(&subkeysData));
162-
std::shared_ptr<arrow::ArrayData> payloadsData;
163-
ARROW_OK(payloadsBuilder.FinishInternal(&payloadsData));
164-
165-
NUdf::TUnboxedValue* items = nullptr;
166-
const auto structObj = holderFactory.CreateDirectArrayHolder(fields.size(), items);
167-
items[fields.at("key")] = holderFactory.CreateArrowBlock(keysData);
168-
items[fields.at("subkey")] = holderFactory.CreateArrowBlock(subkeysData);
169-
items[fields.at("payload")] = holderFactory.CreateArrowBlock(payloadsData);
170-
items[fields.at("_yql_block_length")] = MakeBlockCount(holderFactory, blockSize);
171-
leftListValues = leftListValues.Append(std::move(structObj));
139+
return pgmBuilder.NewTuple(tupleElements);
140+
}));
141+
142+
return rootNode;
143+
}
144+
145+
TVector<TKSV> DoTestBlockJoinOnUint64(EJoinKind joinKind, TVector<TKSV> values,
146+
TSet<ui64> set, size_t blockSize
147+
) {
148+
TSetup<false> setup;
149+
TProgramBuilder& pb = *setup.PgmBuilder;
150+
151+
const auto dict = MakeSet(pb, set);
152+
153+
const auto ui64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id);
154+
const auto strType = pb.NewDataType(NUdf::EDataSlot::String);
155+
const auto ui64BlockType = pb.NewBlockType(ui64Type, TBlockType::EShape::Many);
156+
const auto strBlockType = pb.NewBlockType(strType, TBlockType::EShape::Many);
157+
const auto blockLenType = pb.NewBlockType(ui64Type, TBlockType::EShape::Scalar);
158+
const auto ksvType = pb.NewTupleType({
159+
ui64BlockType, ui64BlockType, strBlockType, blockLenType
160+
});
161+
// Mind the last block length column.
162+
const auto ksvWidth = AS_TYPE(TTupleType, ksvType)->GetElementsCount() - 1;
163+
164+
TRuntimeNode leftArg;
165+
const auto rootNode = BuildBlockJoin(pb, joinKind, {0}, leftArg, ksvType, dict);
166+
167+
const auto graph = setup.BuildGraph(rootNode, {leftArg.GetNode()});
168+
const auto& leftBlocks = graph->GetEntryPoint(0, true);
169+
const auto& holderFactory = graph->GetHolderFactory();
170+
auto& ctx = graph->GetContext();
171+
172+
const size_t testSize = values.size();
173+
size_t current = 0;
174+
TDefaultListRepresentation leftListValues;
175+
while (current < testSize) {
176+
const auto arrays = KSVToArrays(values, current, blockSize, &ctx.ArrowMemoryPool);
177+
current += blockSize;
178+
179+
NUdf::TUnboxedValue* items = nullptr;
180+
const auto tuple = holderFactory.CreateDirectArrayHolder(ksvWidth + 1, items);
181+
for (size_t i = 0; i < ksvWidth; i++) {
182+
items[i] = holderFactory.CreateArrowBlock(arrays[i]);
172183
}
173-
leftBlocks->SetValue(ctx, holderFactory.CreateDirectListHolder(std::move(leftListValues)));
174-
const auto joinIterator = graph->GetValue().GetListIterator();
175-
176-
NUdf::TUnboxedValue item;
177-
TVector<NUdf::TUnboxedValue> joinResult;
178-
while (joinIterator.Next(item)) {
179-
joinResult.push_back(item);
184+
items[ksvWidth] = MakeBlockCount(holderFactory, blockSize);
185+
leftListValues = leftListValues.Append(std::move(tuple));
186+
}
187+
leftBlocks->SetValue(ctx, holderFactory.CreateDirectListHolder(std::move(leftListValues)));
188+
const auto joinIterator = graph->GetValue().GetListIterator();
189+
190+
TVector<TKSV> resultKSV;
191+
TArrays arrays;
192+
NUdf::TUnboxedValue value;
193+
while (joinIterator.Next(value)) {
194+
for (size_t i = 0; i < ksvWidth; i++) {
195+
const auto arrayValue = value.GetElement(i);
196+
const auto arrayDatum = TArrowBlock::From(arrayValue).GetDatum();
197+
UNIT_ASSERT(arrayDatum.is_array());
198+
arrays[i] = arrayDatum.array();
180199
}
181-
182-
UNIT_ASSERT_VALUES_EQUAL(joinResult.size(), 1);
183-
const auto blocks = joinResult.front();
184-
const auto blockLengthValue = blocks.GetElement(fields.at("_yql_block_length"));
200+
const auto blockLengthValue = value.GetElement(ksvWidth);
185201
const auto blockLengthDatum = TArrowBlock::From(blockLengthValue).GetDatum();
186-
UNIT_ASSERT(blockLengthDatum.is_scalar());
202+
Y_ENSURE(blockLengthDatum.is_scalar());
187203
const auto blockLength = blockLengthDatum.scalar_as<arrow::UInt64Scalar>().value;
188-
const auto dictSize = std::count_if(dictKeys.cbegin(), dictKeys.cend(),
189-
[testSize](ui64 key) { return key < testSize; });
190-
const auto expectedLength = joinKind == EJoinKind::LeftSemi ? dictSize
191-
: joinKind == EJoinKind::LeftOnly ? testSize - dictSize
192-
: -1;
193-
UNIT_ASSERT_VALUES_EQUAL(expectedLength, blockLength);
204+
const auto blockKSV = ArraysToKSV(arrays, blockLength);
205+
resultKSV.insert(resultKSV.end(), blockKSV.cbegin(), blockKSV.cend());
206+
}
207+
std::sort(resultKSV.begin(), resultKSV.end());
208+
return resultKSV;
209+
}
210+
211+
void TestBlockJoinOnUint64(EJoinKind joinKind) {
212+
constexpr size_t testSize = 1 << 14;
213+
constexpr size_t valueSize = 3;
214+
static const TVector<TString> threeLetterValues = GenerateValues(valueSize);
215+
static const TSet<ui64> fib = {1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144,
216+
233, 377, 610, 987, 1597, 2584, 4181, 6765, 10946, 17711};
217+
218+
TVector<TKSV> testKSV;
219+
for (size_t k = 0; k < testSize; k++) {
220+
testKSV.push_back(std::make_tuple(k, k * 1001, threeLetterValues[k]));
194221
}
222+
TVector<TKSV> expectedKSV;
223+
std::copy_if(testKSV.cbegin(), testKSV.cend(), std::back_inserter(expectedKSV),
224+
[&joinKind](const auto& ksv) {
225+
const auto contains = fib.contains(std::get<0>(ksv));
226+
return joinKind == EJoinKind::LeftSemi ? contains : !contains;
227+
});
195228

196-
void TestBlockJoinOnUint64(EJoinKind joinKind) {
197-
const size_t testSize = 512;
198-
for (size_t blockSize = 8; blockSize <= testSize; blockSize <<= 2) {
199-
DoTestBlockJoinOnUint64(joinKind, blockSize, testSize);
200-
}
229+
for (size_t blockSize = 8; blockSize <= testSize; blockSize <<= 1) {
230+
const auto gotKSV = DoTestBlockJoinOnUint64(joinKind, testKSV, fib, blockSize);
231+
UNIT_ASSERT_EQUAL(expectedKSV, gotKSV);
201232
}
233+
}
234+
202235
} // namespace
203236

204-
Y_UNIT_TEST_SUITE(TMiniKQLBlockMapJoinTest) {
237+
Y_UNIT_TEST_SUITE(TMiniKQLBlockMapJoinBasicTest) {
205238
Y_UNIT_TEST(TestLeftSemiOnUint64) {
206239
TestBlockJoinOnUint64(EJoinKind::LeftSemi);
207240
}

0 commit comments

Comments
 (0)