Skip to content

Commit d32c80c

Browse files
authored
Implement computation node for BlockMapJoinCore (Left and Inner) with multi dict (#8501)
1 parent 03395eb commit d32c80c

File tree

3 files changed

+374
-95
lines changed

3 files changed

+374
-95
lines changed

ydb/library/yql/minikql/comp_nodes/mkql_block_map_join.cpp

Lines changed: 183 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <ydb/library/yql/minikql/mkql_node_cast.h>
99
#include <ydb/library/yql/minikql/mkql_program_builder.h>
1010

11+
#include <util/generic/serialized_enum.h>
12+
1113
namespace NKikimr {
1214
namespace NMiniKQL {
1315

@@ -184,51 +186,51 @@ using TState = TBlockJoinState<RightRequired>;
184186
{}
185187

186188
EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
187-
auto& s = GetState(state, ctx);
189+
auto& blockState = GetState(state, ctx);
188190
auto** fields = ctx.WideFields.data() + WideFieldsIndex_;
189191
const auto dict = Dict_->GetValue(ctx);
190192

191193
do {
192-
while (s.IsNotFull() && s.NextRow()) {
193-
const auto key = MakeKeysTuple(ctx, s, LeftKeyColumns_);
194+
while (blockState.IsNotFull() && blockState.NextRow()) {
195+
const auto key = MakeKeysTuple(ctx, blockState, LeftKeyColumns_);
194196
if constexpr (WithoutRight) {
195197
if (key && dict.Contains(key) == RightRequired) {
196-
s.CopyRow();
198+
blockState.CopyRow();
197199
}
198200
} else if constexpr (RightRequired) {
199201
if (NUdf::TUnboxedValue lookup; key && (lookup = dict.Lookup(key))) {
200-
s.MakeRow(lookup);
202+
blockState.MakeRow(lookup);
201203
}
202204
} else {
203-
s.MakeRow(dict.Lookup(key));
205+
blockState.MakeRow(dict.Lookup(key));
204206
}
205207
}
206-
if (!s.IsFinished()) {
208+
if (!blockState.IsFinished()) {
207209
switch (Flow_->FetchValues(ctx, fields)) {
208210
case EFetchResult::Yield:
209211
return EFetchResult::Yield;
210212
case EFetchResult::One:
211-
s.Reset();
213+
blockState.Reset();
212214
continue;
213215
case EFetchResult::Finish:
214-
s.Finish();
216+
blockState.Finish();
215217
break;
216218
}
217219
}
218220
// Leave the outer loop, if no values left in the flow.
219-
Y_DEBUG_ABORT_UNLESS(s.IsFinished());
221+
Y_DEBUG_ABORT_UNLESS(blockState.IsFinished());
220222
break;
221223
} while (true);
222224

223-
if (s.IsEmpty()) {
225+
if (blockState.IsEmpty()) {
224226
return EFetchResult::Finish;
225227
}
226-
s.MakeBlocks(ctx.HolderFactory);
227-
const auto sliceSize = s.Slice();
228+
blockState.MakeBlocks(ctx.HolderFactory);
229+
const auto sliceSize = blockState.Slice();
228230

229231
for (size_t i = 0; i < ResultJoinItems_.size(); i++) {
230232
if (const auto out = output[i]) {
231-
*out = s.Get(sliceSize, ctx.HolderFactory, i);
233+
*out = blockState.Get(sliceSize, ctx.HolderFactory, i);
232234
}
233235
}
234236

@@ -267,6 +269,154 @@ using TState = TBlockJoinState<RightRequired>;
267269
ui32 WideFieldsIndex_;
268270
};
269271

272+
template<bool RightRequired>
273+
class TBlockWideMultiMapJoinWrapper : public TPairStateWideFlowComputationNode<TBlockWideMultiMapJoinWrapper<RightRequired>>
274+
{
275+
using TBaseComputation = TPairStateWideFlowComputationNode<TBlockWideMultiMapJoinWrapper<RightRequired>>;
276+
using TState = TBlockJoinState<RightRequired>;
277+
public:
278+
TBlockWideMultiMapJoinWrapper(TComputationMutables& mutables,
279+
const TVector<TType*>&& resultJoinItems, const TVector<TType*>&& leftFlowItems,
280+
TVector<ui32>&& leftKeyColumns,
281+
IComputationWideFlowNode* flow, IComputationNode* dict)
282+
: TBaseComputation(mutables, flow, EValueRepresentation::Boxed, EValueRepresentation::Boxed)
283+
, ResultJoinItems_(std::move(resultJoinItems))
284+
, LeftFlowItems_(std::move(leftFlowItems))
285+
, LeftKeyColumns_(std::move(leftKeyColumns))
286+
, Flow_(flow)
287+
, Dict_(dict)
288+
, WideFieldsIndex_(mutables.IncrementWideFieldsIndex(LeftFlowItems_.size()))
289+
{}
290+
291+
EFetchResult DoCalculate(NUdf::TUnboxedValue& state, NUdf::TUnboxedValue& iterator, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
292+
auto& blockState = GetState(state, ctx);
293+
auto& iterState = GetIterator(iterator, ctx);
294+
auto** fields = ctx.WideFields.data() + WideFieldsIndex_;
295+
const auto dict = Dict_->GetValue(ctx);
296+
297+
do {
298+
if (iterState) {
299+
NUdf::TUnboxedValue lookupItem;
300+
// Process the remaining items from the iterator.
301+
while (blockState.IsNotFull() && iterState.Next(lookupItem)) {
302+
blockState.MakeRow(lookupItem);
303+
}
304+
}
305+
if (blockState.IsNotFull() && blockState.NextRow()) {
306+
const auto key = MakeKeysTuple(ctx, blockState, LeftKeyColumns_);
307+
// Lookup the item in the right dict. If the lookup succeeds,
308+
// reset the iterator and proceed the execution from the
309+
// beginning of the outer loop. Otherwise, the iterState is
310+
// already invalidated (i.e. finished), so the execution will
311+
// process the next tuple from the left flow.
312+
if (NUdf::TUnboxedValue lookup; key && (lookup = dict.Lookup(key))) {
313+
iterState.Reset(std::move(lookup));
314+
} else if constexpr (!RightRequired) {
315+
blockState.MakeRow(NUdf::TUnboxedValue());
316+
}
317+
continue;
318+
}
319+
if (blockState.IsNotFull() && !blockState.IsFinished()) {
320+
switch (Flow_->FetchValues(ctx, fields)) {
321+
case EFetchResult::Yield:
322+
return EFetchResult::Yield;
323+
case EFetchResult::One:
324+
blockState.Reset();
325+
continue;
326+
case EFetchResult::Finish:
327+
blockState.Finish();
328+
break;
329+
}
330+
// Leave the loop, if no values left in the flow.
331+
Y_DEBUG_ABORT_UNLESS(blockState.IsFinished());
332+
break;
333+
}
334+
break;
335+
} while(true);
336+
337+
if (blockState.IsEmpty()) {
338+
return EFetchResult::Finish;
339+
}
340+
blockState.MakeBlocks(ctx.HolderFactory);
341+
const auto sliceSize = blockState.Slice();
342+
343+
for (size_t i = 0; i < ResultJoinItems_.size(); i++) {
344+
if (const auto out = output[i]) {
345+
*out = blockState.Get(sliceSize, ctx.HolderFactory, i);
346+
}
347+
}
348+
349+
return EFetchResult::One;
350+
}
351+
352+
private:
353+
void RegisterDependencies() const final {
354+
if (const auto flow = this->FlowDependsOn(Flow_))
355+
this->DependsOn(flow, Dict_);
356+
}
357+
358+
void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
359+
state = ctx.HolderFactory.Create<TState>(ctx, LeftFlowItems_, ResultJoinItems_, ctx.WideFields.data() + WideFieldsIndex_);
360+
}
361+
362+
TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
363+
if (state.IsInvalid()) {
364+
MakeState(ctx, state);
365+
}
366+
return *static_cast<TState*>(state.AsBoxed().Get());
367+
}
368+
369+
class TIterator : public TComputationValue<TIterator> {
370+
using TBase = TComputationValue<TIterator>;
371+
NUdf::TUnboxedValue List_;
372+
NUdf::TUnboxedValue Iterator_;
373+
NUdf::TUnboxedValue Current_;
374+
375+
public:
376+
TIterator(TMemoryUsageInfo* memInfo)
377+
: TBase(memInfo)
378+
, List_(NUdf::TUnboxedValue::Invalid())
379+
, Iterator_(NUdf::TUnboxedValue::Invalid())
380+
, Current_(NUdf::TUnboxedValue::Invalid())
381+
{}
382+
383+
inline explicit operator bool() const { return !Iterator_.IsInvalid(); }
384+
void Reset(const NUdf::TUnboxedValue&& list) {
385+
List_ = std::move(list);
386+
Iterator_ = List_.GetListIterator();
387+
}
388+
bool Next(NUdf::TUnboxedValue& item) {
389+
const auto found = Iterator_.Next(Current_);
390+
item = Current_;
391+
return found;
392+
}
393+
};
394+
395+
void MakeIterator(TComputationContext& ctx, NUdf::TUnboxedValue& iterator) const {
396+
iterator = ctx.HolderFactory.Create<TIterator>();
397+
}
398+
399+
TIterator& GetIterator(NUdf::TUnboxedValue& iterator, TComputationContext& ctx) const {
400+
if (iterator.IsInvalid()) {
401+
MakeIterator(ctx, iterator);
402+
}
403+
return *static_cast<TIterator*>(iterator.AsBoxed().Get());
404+
}
405+
406+
NUdf::TUnboxedValue MakeKeysTuple(const TComputationContext& ctx, const TState& state, const TVector<ui32>& keyColumns) const {
407+
// TODO: Handle complex key.
408+
// TODO: Handle converters.
409+
return state.GetValue(ctx.HolderFactory, keyColumns.front());
410+
}
411+
412+
const TVector<TType*> ResultJoinItems_;
413+
const TVector<TType*> LeftFlowItems_;
414+
const TVector<ui32> LeftKeyColumns_;
415+
IComputationWideFlowNode* const Flow_;
416+
IComputationNode* const Dict_;
417+
ui32 WideFieldsIndex_;
418+
};
419+
270420
} // namespace
271421

272422
IComputationNode* WrapBlockMapJoinCore(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -294,9 +444,12 @@ IComputationNode* WrapBlockMapJoinCore(TCallable& callable, const TComputationNo
294444
const auto rightDictNode = callable.GetInput(1);
295445
MKQL_ENSURE(rightDictNode.GetStaticType()->IsDict(),
296446
"Expected Dict as a right join part");
297-
const auto rightDictType = AS_TYPE(TDictType, rightDictNode);
298-
MKQL_ENSURE(rightDictType->GetPayloadType()->IsVoid() ||
299-
rightDictType->GetPayloadType()->IsTuple(),
447+
const auto rightDictType = AS_TYPE(TDictType, rightDictNode)->GetPayloadType();
448+
const auto isMulti = rightDictType->IsList();
449+
const auto rightDictItemType = isMulti
450+
? AS_TYPE(TListType, rightDictType)->GetItemType()
451+
: rightDictType;
452+
MKQL_ENSURE(rightDictItemType->IsVoid() || rightDictItemType->IsTuple(),
300453
"Expected Void or Tuple as a right dict item type");
301454

302455
const auto joinKindNode = callable.GetInput(2);
@@ -319,11 +472,22 @@ IComputationNode* WrapBlockMapJoinCore(TCallable& callable, const TComputationNo
319472
const auto dict = LocateNode(ctx.NodeLocator, callable, 1);
320473

321474
switch (joinKind) {
475+
static const auto joinNames = GetEnumNames<EJoinKind>();
322476
case EJoinKind::Inner:
477+
if (isMulti) {
478+
return new TBlockWideMultiMapJoinWrapper<true>(ctx.Mutables,
479+
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
480+
static_cast<IComputationWideFlowNode*>(flow), dict);
481+
}
323482
return new TBlockWideMapJoinWrapper<false, true>(ctx.Mutables,
324483
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
325484
static_cast<IComputationWideFlowNode*>(flow), dict);
326485
case EJoinKind::Left:
486+
if (isMulti) {
487+
return new TBlockWideMultiMapJoinWrapper<false>(ctx.Mutables,
488+
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
489+
static_cast<IComputationWideFlowNode*>(flow), dict);
490+
}
327491
return new TBlockWideMapJoinWrapper<false, false>(ctx.Mutables,
328492
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
329493
static_cast<IComputationWideFlowNode*>(flow), dict);
@@ -336,7 +500,8 @@ IComputationNode* WrapBlockMapJoinCore(TCallable& callable, const TComputationNo
336500
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
337501
static_cast<IComputationWideFlowNode*>(flow), dict);
338502
default:
339-
Y_ABORT();
503+
MKQL_ENSURE(false, "BlockMapJoinCore doesn't support %s join type"
504+
<< joinNames.at(joinKind));
340505
}
341506
}
342507

0 commit comments

Comments
 (0)