Skip to content

Commit c8a32ef

Browse files
authored
[YQ-3621] support AFTER MATCH SKIP PAST LAST ROW (#10597)
1 parent 4b74b39 commit c8a32ef

File tree

21 files changed

+199
-206
lines changed

21 files changed

+199
-206
lines changed

ydb/library/yql/core/sql_types/match_recognize.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88

99
namespace NYql::NMatchRecognize {
1010

11+
enum class EAfterMatchSkipTo {
12+
NextRow,
13+
PastLastRow,
14+
ToFirst,
15+
ToLast,
16+
To
17+
};
18+
19+
struct TAfterMatchSkipTo {
20+
EAfterMatchSkipTo To;
21+
TString Var;
22+
23+
[[nodiscard]] bool operator==(const TAfterMatchSkipTo&) const noexcept = default;
24+
};
25+
1126
constexpr size_t MaxPatternNesting = 20; //Limit recursion for patterns
1227
constexpr size_t MaxPermutedItems = 6;
1328

ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp

Lines changed: 27 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -39,131 +39,13 @@ struct TMatchRecognizeProcessorParameters {
3939
TMeasureInputColumnOrder MeasureInputColumnOrder;
4040
TComputationNodePtrVector Measures;
4141
TOutputColumnOrder OutputColumnOrder;
42-
};
43-
44-
class TBackTrackingMatchRecognize {
45-
using TPartitionList = TSimpleList;
46-
using TRange = TPartitionList::TRange;
47-
using TMatchedVars = TMatchedVars<TRange>;
48-
public:
49-
//TODO(YQL-16486): create a tree for backtracking(replace var names with indexes)
50-
51-
struct TPatternConfiguration {
52-
void Save(TMrOutputSerializer& /*serializer*/) const {
53-
}
54-
55-
void Load(TMrInputSerializer& /*serializer*/) {
56-
}
57-
58-
friend bool operator==(const TPatternConfiguration&, const TPatternConfiguration&) {
59-
return true;
60-
}
61-
};
62-
63-
struct TPatternConfigurationBuilder {
64-
using TPatternConfigurationPtr = std::shared_ptr<TPatternConfiguration>;
65-
static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
66-
Y_UNUSED(pattern);
67-
Y_UNUSED(varNameToIndex);
68-
return std::make_shared<TPatternConfiguration>();
69-
}
70-
};
71-
72-
TBackTrackingMatchRecognize(
73-
NUdf::TUnboxedValue&& partitionKey,
74-
const TMatchRecognizeProcessorParameters& parameters,
75-
const TPatternConfigurationBuilder::TPatternConfigurationPtr pattern,
76-
const TContainerCacheOnContext& cache
77-
)
78-
: PartitionKey(std::move(partitionKey))
79-
, Parameters(parameters)
80-
, Cache(cache)
81-
, CurMatchedVars(parameters.Defines.size())
82-
, MatchNumber(0)
83-
{
84-
//TODO(YQL-16486)
85-
Y_UNUSED(pattern);
86-
}
87-
88-
bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
89-
Y_UNUSED(ctx);
90-
Rows.Append(std::move(row));
91-
return false;
92-
}
93-
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
94-
if (Matches.empty())
95-
return NUdf::TUnboxedValue{};
96-
Parameters.MatchedVarsArg->SetValue(ctx, ToValue(ctx.HolderFactory, std::move(Matches.front())));
97-
Matches.pop_front();
98-
Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
99-
Parameters.InputDataArg->GetValue(ctx),
100-
Parameters.MeasureInputColumnOrder,
101-
Parameters.MatchedVarsArg->GetValue(ctx),
102-
Parameters.VarNames,
103-
++MatchNumber
104-
));
105-
NUdf::TUnboxedValue *itemsPtr = nullptr;
106-
const auto result = Cache.NewArray(ctx, Parameters.OutputColumnOrder.size(), itemsPtr);
107-
for (auto const& c: Parameters.OutputColumnOrder) {
108-
switch(c.first) {
109-
case EOutputColumnSource::Measure:
110-
*itemsPtr++ = Parameters.Measures[c.second]->GetValue(ctx);
111-
break;
112-
case EOutputColumnSource::PartitionKey:
113-
*itemsPtr++ = PartitionKey.GetElement(c.second);
114-
break;
115-
}
116-
}
117-
return result;
118-
}
119-
bool ProcessEndOfData(TComputationContext& ctx) {
120-
//Assume, that data moved to IComputationExternalNode node, will not be modified or released
121-
//till the end of the current function
122-
auto rowsSize = Rows.Size();
123-
Parameters.InputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TListValue<TPartitionList>>(Rows));
124-
for (size_t i = 0; i != rowsSize; ++i) {
125-
Parameters.CurrentRowIndexArg->SetValue(ctx, NUdf::TUnboxedValuePod(static_cast<ui64>(i)));
126-
for (size_t v = 0; v != Parameters.Defines.size(); ++v) {
127-
const auto &d = Parameters.Defines[v]->GetValue(ctx);
128-
if (d && d.GetOptionalValue().Get<bool>()) {
129-
Extend(CurMatchedVars[v], TRange{i});
130-
}
131-
}
132-
//for the sake of dummy usage assume non-overlapped matches at every 5th row of any partition
133-
if (i % 5 == 0) {
134-
TMatchedVars temp;
135-
temp.swap(CurMatchedVars);
136-
Matches.emplace_back(std::move(temp));
137-
CurMatchedVars.resize(Parameters.Defines.size());
138-
}
139-
}
140-
return not Matches.empty();
141-
}
142-
143-
void Save(TOutputSerializer& /*serializer*/) const {
144-
// Not used in not streaming mode.
145-
}
146-
147-
void Load(TMrInputSerializer& /*serializer*/) {
148-
// Not used in not streaming mode.
149-
}
150-
151-
private:
152-
const NUdf::TUnboxedValue PartitionKey;
153-
const TMatchRecognizeProcessorParameters& Parameters;
154-
const TContainerCacheOnContext& Cache;
155-
TSimpleList Rows;
156-
TMatchedVars CurMatchedVars;
157-
std::deque<TMatchedVars, TMKQLAllocator<TMatchedVars>> Matches;
158-
ui64 MatchNumber;
42+
TAfterMatchSkipTo SkipTo;
15943
};
16044

16145
class TStreamingMatchRecognize {
16246
using TPartitionList = TSparseList;
16347
using TRange = TPartitionList::TRange;
16448
public:
165-
using TPatternConfiguration = TNfaTransitionGraph;
166-
using TPatternConfigurationBuilder = TNfaTransitionGraphBuilder;
16749
TStreamingMatchRecognize(
16850
NUdf::TUnboxedValue&& partitionKey,
16951
const TMatchRecognizeProcessorParameters& parameters,
@@ -213,6 +95,9 @@ class TStreamingMatchRecognize {
21395
break;
21496
}
21597
}
98+
if (EAfterMatchSkipTo::PastLastRow == Parameters.SkipTo.To) {
99+
Nfa.Clear();
100+
}
216101
return result;
217102
}
218103
bool ProcessEndOfData(TComputationContext& ctx) {
@@ -243,11 +128,9 @@ class TStreamingMatchRecognize {
243128
ui64 MatchNumber = 0;
244129
};
245130

246-
template <typename Algo>
247131
class TStateForNonInterleavedPartitions
248-
: public TComputationValue<TStateForNonInterleavedPartitions<Algo>>
132+
: public TComputationValue<TStateForNonInterleavedPartitions>
249133
{
250-
using TRowPatternConfigurationBuilder = typename Algo::TPatternConfigurationBuilder;
251134
public:
252135
TStateForNonInterleavedPartitions(
253136
TMemoryUsageInfo* memInfo,
@@ -265,7 +148,7 @@ class TStateForNonInterleavedPartitions
265148
, PartitionKey(partitionKey)
266149
, PartitionKeyPacker(true, partitionKeyType)
267150
, Parameters(parameters)
268-
, RowPatternConfiguration(TRowPatternConfigurationBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
151+
, RowPatternConfiguration(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
269152
, Cache(cache)
270153
, Terminating(false)
271154
, SerializerContext(ctx, rowType, rowPacker)
@@ -301,7 +184,7 @@ class TStateForNonInterleavedPartitions
301184
bool validPartitionHandler = in.Read<bool>();
302185
if (validPartitionHandler) {
303186
NUdf::TUnboxedValue key = PartitionKeyPacker.Unpack(CurPartitionPackedKey, SerializerContext.Ctx.HolderFactory);
304-
PartitionHandler.reset(new Algo(
187+
PartitionHandler.reset(new TStreamingMatchRecognize(
305188
std::move(key),
306189
Parameters,
307190
RowPatternConfiguration,
@@ -313,7 +196,7 @@ class TStateForNonInterleavedPartitions
313196
if (validDelayedRow) {
314197
in(DelayedRow);
315198
}
316-
auto restoredRowPatternConfiguration = std::make_shared<typename Algo::TPatternConfiguration>();
199+
auto restoredRowPatternConfiguration = std::make_shared<TNfaTransitionGraph>();
317200
restoredRowPatternConfiguration->Load(in);
318201
MKQL_ENSURE(*restoredRowPatternConfiguration == *RowPatternConfiguration, "Restored and current RowPatternConfiguration is different");
319202
MKQL_ENSURE(in.Empty(), "State is corrupted");
@@ -367,12 +250,11 @@ class TStateForNonInterleavedPartitions
367250
InputRowArg->SetValue(ctx, NUdf::TUnboxedValue(temp));
368251
auto partitionKey = PartitionKey->GetValue(ctx);
369252
CurPartitionPackedKey = PartitionKeyPacker.Pack(partitionKey);
370-
PartitionHandler.reset(new Algo(
253+
PartitionHandler.reset(new TStreamingMatchRecognize(
371254
std::move(partitionKey),
372255
Parameters,
373256
RowPatternConfiguration,
374-
Cache
375-
));
257+
Cache));
376258
PartitionHandler->ProcessInputRow(std::move(temp), ctx);
377259
}
378260
if (Terminating) {
@@ -382,12 +264,12 @@ class TStateForNonInterleavedPartitions
382264
}
383265
private:
384266
TString CurPartitionPackedKey;
385-
std::unique_ptr<Algo> PartitionHandler;
267+
std::unique_ptr<TStreamingMatchRecognize> PartitionHandler;
386268
IComputationExternalNode* InputRowArg;
387269
IComputationNode* PartitionKey;
388270
TValuePackerGeneric<false> PartitionKeyPacker;
389271
const TMatchRecognizeProcessorParameters& Parameters;
390-
const typename TRowPatternConfigurationBuilder::TPatternConfigurationPtr RowPatternConfiguration;
272+
const TNfaTransitionGraph::TPtr RowPatternConfiguration;
391273
const TContainerCacheOnContext& Cache;
392274
NUdf::TUnboxedValue DelayedRow;
393275
bool Terminating;
@@ -768,6 +650,11 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
768650
defines.push_back(callable.GetInput(inputIndex++));
769651
}
770652
const auto& streamingMode = callable.GetInput(inputIndex++);
653+
NYql::NMatchRecognize::TAfterMatchSkipTo skipTo = {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""};
654+
if (inputIndex + 2 <= callable.GetInputsCount()) {
655+
skipTo.To = static_cast<EAfterMatchSkipTo>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>());
656+
skipTo.Var = AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().AsStringRef();
657+
}
771658
MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count");
772659

773660
const auto& [vars, varsLookup] = ConvertListOfStrings(varNames);
@@ -788,6 +675,7 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
788675
)
789676
, ConvertVectorOfCallables(measures, ctx)
790677
, GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes)
678+
, skipTo
791679
};
792680
if (AS_VALUE(TDataLiteral, streamingMode)->AsValue().Get<bool>()) {
793681
return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(ctx.Mutables
@@ -800,28 +688,15 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
800688
, rowType
801689
);
802690
} else {
803-
const bool useNfaForTables = true; //TODO(YQL-16486) get this flag from an optimizer
804-
if (useNfaForTables) {
805-
return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions<TStreamingMatchRecognize>>(ctx.Mutables
806-
, GetValueRepresentation(inputFlow.GetStaticType())
807-
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
808-
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
809-
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
810-
, partitionKeySelector.GetStaticType()
811-
, std::move(parameters)
812-
, rowType
813-
);
814-
} else {
815-
return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions<TBackTrackingMatchRecognize>>(ctx.Mutables
816-
, GetValueRepresentation(inputFlow.GetStaticType())
817-
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
818-
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
819-
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
820-
, partitionKeySelector.GetStaticType()
821-
, std::move(parameters)
822-
, rowType
823-
);
824-
}
691+
return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(ctx.Mutables
692+
, GetValueRepresentation(inputFlow.GetStaticType())
693+
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
694+
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
695+
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
696+
, partitionKeySelector.GetStaticType()
697+
, std::move(parameters)
698+
, rowType
699+
);
825700
}
826701
}
827702

ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,7 @@ class TNfaTransitionGraphBuilder {
283283
return {input, output};
284284
}
285285
public:
286-
using TPatternConfigurationPtr = TNfaTransitionGraph::TPtr;
287-
static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
286+
static TNfaTransitionGraph::TPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
288287
auto result = std::make_shared<TNfaTransitionGraph>();
289288
TNfaTransitionGraphBuilder builder(result);
290289
auto item = builder.BuildTerms(pattern, varNameToIndex);
@@ -455,6 +454,10 @@ class TNfa {
455454
serializer.Read(EpsilonTransitionsLastRow);
456455
}
457456

457+
void Clear() {
458+
ActiveStates.clear();
459+
}
460+
458461
private:
459462
//TODO (zverevgeny): Consider to change to std::vector for the sake of perf
460463
using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;

ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ namespace NKikimr {
115115
{NYql::NMatchRecognize::TRowPatternFactor{"A", 3, 3, false, false, false}}
116116
},
117117
getDefines,
118-
streamingMode);
118+
streamingMode,
119+
{NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""}
120+
);
119121

120122
auto graph = setup.BuildGraph(pgmReturn);
121123
return graph;

ydb/library/yql/minikql/mkql_program_builder.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5973,7 +5973,8 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
59735973
const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
59745974
const NYql::NMatchRecognize::TRowPattern& pattern,
59755975
const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
5976-
bool streamingMode
5976+
bool streamingMode,
5977+
const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
59775978
) {
59785979
MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion);
59795980

@@ -6127,6 +6128,10 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
61276128
callableBuilder.Add(d);
61286129
}
61296130
callableBuilder.Add(NewDataLiteral(streamingMode));
6131+
if (RuntimeVersion >= 52U) {
6132+
callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To)));
6133+
callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var));
6134+
}
61306135
return TRuntimeNode(callableBuilder.Build(), false);
61316136
}
61326137

ydb/library/yql/minikql/mkql_program_builder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,8 @@ class TProgramBuilder : public TTypeBuilder {
713713
const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
714714
const NYql::NMatchRecognize::TRowPattern& pattern,
715715
const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
716-
bool streamingMode
716+
bool streamingMode,
717+
const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
717718
);
718719

719720
TRuntimeNode TimeOrderRecover(

ydb/library/yql/minikql/mkql_runtime_version.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace NMiniKQL {
2424
// 1. Bump this version every time incompatible runtime nodes are introduced.
2525
// 2. Make sure you provide runtime node generation for previous runtime versions.
2626
#ifndef MKQL_RUNTIME_VERSION
27-
#define MKQL_RUNTIME_VERSION 51U
27+
#define MKQL_RUNTIME_VERSION 52U
2828
#endif
2929

3030
// History:

0 commit comments

Comments
 (0)