Skip to content

Commit 2f2468c

Browse files
authored
Fix DQ input union/merge values with zero input channels (#7515)
1 parent 8ee38ab commit 2f2468c

File tree

3 files changed

+35
-36
lines changed

3 files changed

+35
-36
lines changed

ydb/library/yql/dq/runtime/dq_input_producer.cpp

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ template<bool IsWide>
2424
class TDqInputUnionStreamValue : public TComputationValue<TDqInputUnionStreamValue<IsWide>> {
2525
using TBase = TComputationValue<TDqInputUnionStreamValue<IsWide>>;
2626
public:
27-
TDqInputUnionStreamValue(TMemoryUsageInfo* memInfo, TVector<IDqInput::TPtr>&& inputs, TDqMeteringStats::TInputStatsMeter stats)
27+
TDqInputUnionStreamValue(TMemoryUsageInfo* memInfo, const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs, TDqMeteringStats::TInputStatsMeter stats)
2828
: TBase(memInfo)
2929
, Inputs(std::move(inputs))
3030
, Alive(Inputs.size())
31-
, Batch(Inputs.empty() ? nullptr : Inputs.front()->GetInputType())
31+
, Batch(type)
3232
, Stats(stats)
3333
{}
3434

@@ -114,13 +114,15 @@ template<bool IsWide>
114114
class TDqInputMergeStreamValue : public TComputationValue<TDqInputMergeStreamValue<IsWide>> {
115115
using TBase = TComputationValue<TDqInputMergeStreamValue<IsWide>>;
116116
public:
117-
TDqInputMergeStreamValue(TMemoryUsageInfo* memInfo, TVector<IDqInput::TPtr>&& inputs,
117+
TDqInputMergeStreamValue(TMemoryUsageInfo* memInfo, const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
118118
TVector<TSortColumnInfo>&& sortCols, TDqMeteringStats::TInputStatsMeter stats)
119119
: TBase(memInfo)
120120
, Inputs(std::move(inputs))
121+
, Width(type->IsMulti() ? static_cast<const NMiniKQL::TMultiType*>(type)->GetElementsCount() : TMaybe<ui32>())
121122
, SortCols(std::move(sortCols))
122123
, Stats(stats)
123124
{
125+
YQL_ENSURE(!IsWide ^ Width.Defined());
124126
CurrentBuffers.reserve(Inputs.size());
125127
CurrentItemIndexes.reserve(Inputs.size());
126128
for (ui32 idx = 0; idx < Inputs.size(); ++idx) {
@@ -216,7 +218,7 @@ class TDqInputMergeStreamValue : public TComputationValue<TDqInputMergeStreamVal
216218
return status;
217219
}
218220

219-
YQL_ENSURE(!Inputs.empty() && *Inputs.front()->GetInputWidth() == width);
221+
YQL_ENSURE(*Width == width);
220222
CopyResult(result, width);
221223
if (Stats) {
222224
Stats.Add(result, width);
@@ -300,6 +302,7 @@ class TDqInputMergeStreamValue : public TComputationValue<TDqInputMergeStreamVal
300302

301303
private:
302304
TVector<IDqInput::TPtr> Inputs;
305+
const TMaybe<ui32> Width;
303306
TVector<TSortColumnInfo> SortCols;
304307
TVector<TUnboxedValueBatch> CurrentBuffers;
305308
TVector<TUnboxedValuesIterator<IsWide>> CurrentItemIndexes;
@@ -308,20 +311,6 @@ class TDqInputMergeStreamValue : public TComputationValue<TDqInputMergeStreamVal
308311
TDqMeteringStats::TInputStatsMeter Stats;
309312
};
310313

311-
bool IsWideInputs(const TVector<IDqInput::TPtr>& inputs) {
312-
NKikimr::NMiniKQL::TType* type = nullptr;
313-
bool isWide = false;
314-
for (auto& input : inputs) {
315-
if (!type) {
316-
type = input->GetInputType();
317-
isWide = input->GetInputWidth().Defined();
318-
} else {
319-
YQL_ENSURE(type->IsSameType(*input->GetInputType()));
320-
}
321-
}
322-
return isWide;
323-
}
324-
325314
TVector<NKikimr::NMiniKQL::TType*> ExtractBlockItemTypes(const NKikimr::NMiniKQL::TType* type) {
326315
TVector<NKikimr::NMiniKQL::TType*> result;
327316

@@ -390,18 +379,17 @@ TVector<IBlockItemComparator::TPtr> MakeComparators(const TVector<TSortColumnInf
390379
class TDqInputMergeBlockStreamValue : public TComputationValue<TDqInputMergeBlockStreamValue> {
391380
using TBase = TComputationValue<TDqInputMergeBlockStreamValue>;
392381
public:
393-
TDqInputMergeBlockStreamValue(TMemoryUsageInfo* memInfo, TVector<IDqInput::TPtr>&& inputs,
382+
TDqInputMergeBlockStreamValue(TMemoryUsageInfo* memInfo, const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
394383
TVector<TSortColumnInfo>&& sortCols, const NKikimr::NMiniKQL::THolderFactory& factory, TDqMeteringStats::TInputStatsMeter stats)
395384
: TBase(memInfo)
396385
, SortCols_(std::move(sortCols))
397-
, ItemTypes_(ExtractBlockItemTypes(inputs.front()->GetInputType()))
386+
, ItemTypes_(ExtractBlockItemTypes(type))
398387
, MaxOutputBlockLen_(CalcMaxBlockLength(ItemTypes_.begin(), ItemTypes_.end(), TTypeInfoHelper()))
399388
, Comparators_(MakeComparators(SortCols_, ItemTypes_))
400389
, Builders_(MakeBuilders(MaxOutputBlockLen_, ItemTypes_))
401390
, Factory_(factory)
402391
, Stats_(stats)
403392
{
404-
YQL_ENSURE(!inputs.empty());
405393
YQL_ENSURE(MaxOutputBlockLen_ > 0);
406394
InputData_.reserve(inputs.size());
407395
for (auto& input : inputs) {
@@ -697,6 +685,15 @@ class TDqInputMergeBlockStreamValue : public TComputationValue<TDqInputMergeBloc
697685
bool IsFinished_ = false;
698686
};
699687

688+
void ValidateInputTypes(const NKikimr::NMiniKQL::TType* type, const TVector<IDqInput::TPtr>& inputs) {
689+
YQL_ENSURE(type);
690+
for (size_t i = 0; i < inputs.size(); ++i) {
691+
auto inputType = inputs[i]->GetInputType();
692+
YQL_ENSURE(inputType);
693+
YQL_ENSURE(type->IsSameType(*inputType), "Unexpected type for input #" << i << ": expected " << *type << ", got " << *inputType);
694+
}
695+
}
696+
700697
} // namespace
701698

702699
void TDqMeteringStats::TInputStatsMeter::Add(const NKikimr::NUdf::TUnboxedValue& val) {
@@ -737,31 +734,33 @@ void TDqMeteringStats::TInputStatsMeter::Add(const NKikimr::NUdf::TUnboxedValue*
737734
}
738735
}
739736

740-
NUdf::TUnboxedValue CreateInputUnionValue(TVector<IDqInput::TPtr>&& inputs,
737+
NUdf::TUnboxedValue CreateInputUnionValue(const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
741738
const NMiniKQL::THolderFactory& factory, TDqMeteringStats::TInputStatsMeter stats)
742739
{
743-
if (IsWideInputs(inputs)) {
744-
return factory.Create<TDqInputUnionStreamValue<true>>(std::move(inputs), stats);
740+
ValidateInputTypes(type, inputs);
741+
if (type->IsMulti()) {
742+
return factory.Create<TDqInputUnionStreamValue<true>>(type, std::move(inputs), stats);
745743
}
746-
return factory.Create<TDqInputUnionStreamValue<false>>(std::move(inputs), stats);
744+
return factory.Create<TDqInputUnionStreamValue<false>>(type, std::move(inputs), stats);
747745
}
748746

749-
NKikimr::NUdf::TUnboxedValue CreateInputMergeValue(TVector<IDqInput::TPtr>&& inputs,
747+
NKikimr::NUdf::TUnboxedValue CreateInputMergeValue(const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
750748
TVector<TSortColumnInfo>&& sortCols, const NKikimr::NMiniKQL::THolderFactory& factory, TDqMeteringStats::TInputStatsMeter stats)
751749
{
750+
ValidateInputTypes(type, inputs);
752751
YQL_ENSURE(!inputs.empty());
753-
if (IsWideInputs(inputs)) {
752+
if (type->IsMulti()) {
754753
if (AnyOf(sortCols, [](const auto& sortCol) { return sortCol.IsBlockOrScalar(); })) {
755754
// we can ignore scalar columns, since all they have exactly the same value in all inputs
756755
EraseIf(sortCols, [](const auto& sortCol) { return *sortCol.IsScalar; });
757756
if (sortCols.empty()) {
758-
return factory.Create<TDqInputUnionStreamValue<true>>(std::move(inputs), stats);
757+
return factory.Create<TDqInputUnionStreamValue<true>>(type, std::move(inputs), stats);
759758
}
760-
return factory.Create<TDqInputMergeBlockStreamValue>(std::move(inputs), std::move(sortCols), factory, stats);
759+
return factory.Create<TDqInputMergeBlockStreamValue>(type, std::move(inputs), std::move(sortCols), factory, stats);
761760
}
762-
return factory.Create<TDqInputMergeStreamValue<true>>(std::move(inputs), std::move(sortCols), stats);
761+
return factory.Create<TDqInputMergeStreamValue<true>>(type, std::move(inputs), std::move(sortCols), stats);
763762
}
764-
return factory.Create<TDqInputMergeStreamValue<false>>(std::move(inputs), std::move(sortCols), stats);
763+
return factory.Create<TDqInputMergeStreamValue<false>>(type, std::move(inputs), std::move(sortCols), stats);
765764
}
766765

767766
} // namespace NYql::NDq

ydb/library/yql/dq/runtime/dq_input_producer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ struct TDqMeteringStats {
2828
}
2929
};
3030

31-
NKikimr::NUdf::TUnboxedValue CreateInputUnionValue(TVector<IDqInput::TPtr>&& inputs,
31+
NKikimr::NUdf::TUnboxedValue CreateInputUnionValue(const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
3232
const NKikimr::NMiniKQL::THolderFactory& holderFactory, TDqMeteringStats::TInputStatsMeter = {});
3333

34-
NKikimr::NUdf::TUnboxedValue CreateInputMergeValue(TVector<IDqInput::TPtr>&& inputs,
34+
NKikimr::NUdf::TUnboxedValue CreateInputMergeValue(const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
3535
TVector<TSortColumnInfo>&& sortCols, const NKikimr::NMiniKQL::THolderFactory& factory,
3636
TDqMeteringStats::TInputStatsMeter = {});
3737

ydb/library/yql/dq/runtime/dq_tasks_runner.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,14 @@ NUdf::TUnboxedValue DqBuildInputValue(const NDqProto::TTaskInput& inputDesc, con
146146
Y_ABORT_UNLESS(inputs.size() == 1);
147147
[[fallthrough]];
148148
case NYql::NDqProto::TTaskInput::kUnionAll:
149-
return CreateInputUnionValue(std::move(inputs), holderFactory, stats);
149+
return CreateInputUnionValue(type, std::move(inputs), holderFactory, stats);
150150
case NYql::NDqProto::TTaskInput::kMerge: {
151151
const auto& protoSortCols = inputDesc.GetMerge().GetSortColumns();
152152
TVector<TSortColumnInfo> sortColsInfo;
153153
GetSortColumnsInfo(type, protoSortCols, sortColsInfo);
154154
YQL_ENSURE(!sortColsInfo.empty());
155155

156-
return CreateInputMergeValue(std::move(inputs), std::move(sortColsInfo), holderFactory, stats);
156+
return CreateInputMergeValue(type, std::move(inputs), std::move(sortColsInfo), holderFactory, stats);
157157
}
158158
default:
159159
YQL_ENSURE(false, "Unknown input type: " << (ui32) inputDesc.GetTypeCase());
@@ -576,7 +576,7 @@ class TDqTaskRunner : public IDqTaskRunner {
576576
inputs.clear();
577577
inputs.emplace_back(transform->TransformOutput);
578578
entryNode->SetValue(AllocatedHolder->ProgramParsed.CompGraph->GetContext(),
579-
CreateInputUnionValue(std::move(inputs), holderFactory,
579+
CreateInputUnionValue(transform->TransformOutput->GetInputType(), std::move(inputs), holderFactory,
580580
{&inputStats, transform->TransformOutputType}));
581581
} else {
582582
entryNode->SetValue(AllocatedHolder->ProgramParsed.CompGraph->GetContext(),

0 commit comments

Comments
 (0)