Skip to content

Commit 637353f

Browse files
vitalifkunga
authored andcommitted
Refactor FillVectorIndex state machine (#18579) (#19128)
1 parent 7d6e252 commit 637353f

File tree

3 files changed

+112
-121
lines changed

3 files changed

+112
-121
lines changed

ydb/core/tx/schemeshard/schemeshard_build_index__progress.cpp

Lines changed: 102 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,6 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
764764
buildInfo.Sample.Rows, buildInfo.KMeans.Parent, buildInfo.KMeans.Child);
765765

766766
TActivationContext::AsActorContext().MakeFor(Self->SelfId()).Register(actor);
767-
buildInfo.Sample.State = TIndexBuildInfo::TSample::EState::Upload;
768767

769768
LOG_N("TTxBuildProgress: TUploadSampleK: " << buildInfo);
770769
}
@@ -807,6 +806,10 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
807806
}
808807
}
809808

809+
bool NoShardsAdded(TIndexBuildInfo& buildInfo) {
810+
return buildInfo.DoneShards.empty() && buildInfo.InProgressShards.empty() && buildInfo.ToUploadShards.empty();
811+
}
812+
810813
void AddAllShards(TIndexBuildInfo& buildInfo) {
811814
ToTabletSend.clear();
812815
Self->IndexBuildPipes.CloseAll(BuildId, Self->ActorContext());
@@ -816,10 +819,37 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
816819
}
817820
}
818821

822+
void AddGlobalShardsForCurrentParent(TIndexBuildInfo& buildInfo) {
823+
Y_ENSURE(NoShardsAdded(buildInfo));
824+
if (buildInfo.KMeans.Parent == 0) {
825+
AddAllShards(buildInfo);
826+
return;
827+
}
828+
auto it = buildInfo.Cluster2Shards.lower_bound(buildInfo.KMeans.Parent);
829+
Y_ENSURE(it != buildInfo.Cluster2Shards.end());
830+
if (it->second.Shards.size() > 1) {
831+
for (const auto& idx : it->second.Shards) {
832+
const auto& status = buildInfo.Shards.at(idx);
833+
AddShard(buildInfo, idx, status);
834+
}
835+
}
836+
}
837+
838+
void AddLocalClusters(TIndexBuildInfo& buildInfo) {
839+
Y_ENSURE(NoShardsAdded(buildInfo));
840+
for (const auto& [to, state] : buildInfo.Cluster2Shards) {
841+
if (state.Shards.size() == 1) {
842+
const auto* status = buildInfo.Shards.FindPtr(state.Shards[0]);
843+
Y_ENSURE(status);
844+
AddShard(buildInfo, state.Shards[0], *status);
845+
}
846+
}
847+
}
848+
819849
bool FillSecondaryIndex(TIndexBuildInfo& buildInfo) {
820850
LOG_D("FillSecondaryIndex Start");
821851

822-
if (buildInfo.DoneShards.empty() && buildInfo.ToUploadShards.empty() && buildInfo.InProgressShards.empty()) {
852+
if (NoShardsAdded(buildInfo)) {
823853
AddAllShards(buildInfo);
824854
}
825855
auto done = SendToShards(buildInfo, [&](TShardIdx shardIdx) { SendBuildSecondaryIndexRequest(shardIdx, buildInfo); }) &&
@@ -833,65 +863,21 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
833863
}
834864

835865
bool FillPrefixKMeans(TIndexBuildInfo& buildInfo) {
836-
if (buildInfo.DoneShards.empty() && buildInfo.ToUploadShards.empty() && buildInfo.InProgressShards.empty()) {
866+
if (NoShardsAdded(buildInfo)) {
837867
AddAllShards(buildInfo);
838868
}
839869
return SendToShards(buildInfo, [&](TShardIdx shardIdx) { SendPrefixKMeansRequest(shardIdx, buildInfo); }) &&
840870
buildInfo.DoneShards.size() == buildInfo.Shards.size();
841871
}
842872

843873
bool FillLocalKMeans(TIndexBuildInfo& buildInfo) {
844-
if (buildInfo.DoneShards.empty() && buildInfo.ToUploadShards.empty() && buildInfo.InProgressShards.empty()) {
874+
if (NoShardsAdded(buildInfo)) {
845875
AddAllShards(buildInfo);
846876
}
847877
return SendToShards(buildInfo, [&](TShardIdx shardIdx) { SendKMeansLocalRequest(shardIdx, buildInfo); }) &&
848878
buildInfo.DoneShards.size() == buildInfo.Shards.size();
849879
}
850880

851-
bool InitSingleKMeans(TIndexBuildInfo& buildInfo) {
852-
if (!buildInfo.DoneShards.empty() || !buildInfo.InProgressShards.empty() || !buildInfo.ToUploadShards.empty()) {
853-
return false;
854-
}
855-
if (buildInfo.KMeans.State == TIndexBuildInfo::TKMeans::MultiLocal) {
856-
InitMultiKMeans(buildInfo);
857-
return false;
858-
}
859-
if (buildInfo.KMeans.Parent == 0) {
860-
AddAllShards(buildInfo);
861-
} else {
862-
auto it = buildInfo.Cluster2Shards.lower_bound(buildInfo.KMeans.Parent);
863-
Y_ENSURE(it != buildInfo.Cluster2Shards.end());
864-
if (it->second.Local == InvalidShardIdx) {
865-
for (const auto& idx : it->second.Global) {
866-
const auto& status = buildInfo.Shards.at(idx);
867-
AddShard(buildInfo, idx, status);
868-
}
869-
}
870-
}
871-
if (buildInfo.DoneShards.size() + buildInfo.ToUploadShards.size() <= 1) {
872-
buildInfo.KMeans.State = TIndexBuildInfo::TKMeans::Local;
873-
}
874-
return true;
875-
}
876-
877-
bool InitMultiKMeans(TIndexBuildInfo& buildInfo) {
878-
if (buildInfo.Cluster2Shards.empty()) {
879-
return false;
880-
}
881-
Y_ENSURE(buildInfo.KMeans.Parent != 0);
882-
for (const auto& [to, state] : buildInfo.Cluster2Shards) {
883-
if (const auto& [from, local, global] = state; local != InvalidShardIdx) {
884-
if (const auto* status = buildInfo.Shards.FindPtr(local)) {
885-
AddShard(buildInfo, local, *status);
886-
}
887-
}
888-
}
889-
buildInfo.KMeans.State = TIndexBuildInfo::TKMeans::MultiLocal;
890-
buildInfo.Cluster2Shards.clear();
891-
Y_ENSURE(buildInfo.InProgressShards.empty());
892-
return !buildInfo.ToUploadShards.empty();
893-
}
894-
895881
bool SendKMeansSample(TIndexBuildInfo& buildInfo) {
896882
if (buildInfo.Sample.MaxProbability == 0) {
897883
buildInfo.ToUploadShards.clear();
@@ -910,22 +896,6 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
910896
return SendToShards(buildInfo, [&](TShardIdx shardIdx) { SendKMeansLocalRequest(shardIdx, buildInfo); });
911897
}
912898

913-
bool SendVectorIndex(TIndexBuildInfo& buildInfo) {
914-
switch (buildInfo.KMeans.State) {
915-
case TIndexBuildInfo::TKMeans::Sample:
916-
return SendKMeansSample(buildInfo);
917-
// TODO(mbkkt)
918-
// case TIndexBuildInfo::TKMeans::Recompute:
919-
// return SendKMeansRecompute(buildInfo);
920-
case TIndexBuildInfo::TKMeans::Reshuffle:
921-
return SendKMeansReshuffle(buildInfo);
922-
case TIndexBuildInfo::TKMeans::Local:
923-
case TIndexBuildInfo::TKMeans::MultiLocal:
924-
return SendKMeansLocal(buildInfo);
925-
}
926-
return true;
927-
}
928-
929899
void ClearDoneShards(TTransactionContext& txc, TIndexBuildInfo& buildInfo) {
930900
if (buildInfo.DoneShards.empty()) {
931901
return;
@@ -966,6 +936,7 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
966936
// it's approximate but upper bound, so it's ok
967937
buildInfo.KMeans.TableSize = std::max<ui64>(1, buildInfo.Processed.GetUploadRows());
968938
buildInfo.KMeans.PrefixIndexDone(doneShards);
939+
buildInfo.KMeans.State = TIndexBuildInfo::TKMeans::MultiLocal;
969940
LOG_D("FillPrefixedVectorIndex PrefixIndexDone " << buildInfo.DebugString());
970941

971942
PersistKMeansState(txc, buildInfo);
@@ -1006,60 +977,93 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
1006977
}
1007978

1008979
bool FillVectorIndex(TTransactionContext& txc, TIndexBuildInfo& buildInfo) {
1009-
// FIXME: Very non-intuitive state machine, rework it by adding an explicit vector index fill state
1010980
LOG_D("FillVectorIndex Start " << buildInfo.DebugString());
1011981

1012-
if (buildInfo.Sample.State == TIndexBuildInfo::TSample::EState::Upload) {
1013-
return false;
1014-
}
1015-
if (InitSingleKMeans(buildInfo)) {
1016-
LOG_D("FillVectorIndex SingleKMeans " << buildInfo.DebugString());
1017-
}
1018-
if (!SendVectorIndex(buildInfo)) {
1019-
return false;
1020-
}
1021-
1022-
if (buildInfo.KMeans.State == TIndexBuildInfo::TKMeans::Sample &&
1023-
!buildInfo.Sample.Rows.empty()) {
1024-
if (buildInfo.Sample.State == TIndexBuildInfo::TSample::EState::Collect) {
1025-
LOG_D("FillVectorIndex SendUploadSampleKRequest " << buildInfo.DebugString());
1026-
SendUploadSampleKRequest(buildInfo);
1027-
return false;
982+
// (Sample -> Reshuffle)* -> MultiLocal -> NextLevel
983+
if (buildInfo.KMeans.State == TIndexBuildInfo::TKMeans::Sample) {
984+
return FillVectorIndexSamples(txc, buildInfo);
985+
} else if (buildInfo.KMeans.State == TIndexBuildInfo::TKMeans::Reshuffle) {
986+
if (NoShardsAdded(buildInfo)) {
987+
AddGlobalShardsForCurrentParent(buildInfo);
1028988
}
1029-
}
1030-
1031-
LOG_D("FillVectorIndex DoneLevel " << buildInfo.DebugString());
1032-
ClearDoneShards(txc, buildInfo);
1033-
1034-
if (!buildInfo.Sample.Rows.empty()) {
1035-
if (buildInfo.KMeans.State == TIndexBuildInfo::TKMeans::Sample) {
1036-
buildInfo.KMeans.State = TIndexBuildInfo::TKMeans::Reshuffle;
1037-
LOG_D("FillVectorIndex NextState " << buildInfo.DebugString());
1038-
PersistKMeansState(txc, buildInfo);
1039-
Progress(BuildId);
989+
if (!SendKMeansReshuffle(buildInfo)) {
1040990
return false;
1041991
}
992+
ClearDoneShards(txc, buildInfo);
1042993
buildInfo.Sample.Clear();
1043994
NIceDb::TNiceDb db{txc.DB};
1044995
Self->PersistBuildIndexSampleForget(db, buildInfo);
1045-
LOG_D("FillVectorIndex DoneState " << buildInfo.DebugString());
996+
return FillVectorIndexNextParent(txc, buildInfo);
997+
} else if (buildInfo.KMeans.State == TIndexBuildInfo::TKMeans::MultiLocal) {
998+
if (!SendKMeansLocal(buildInfo)) {
999+
return false;
1000+
}
1001+
ClearDoneShards(txc, buildInfo);
1002+
return FillVectorIndexNextParent(txc, buildInfo);
10461003
}
1004+
Y_ENSURE(false);
1005+
}
10471006

1048-
if (buildInfo.KMeans.NextParent()) {
1049-
LOG_D("FillVectorIndex NextParent " << buildInfo.DebugString());
1007+
bool FillVectorIndexSamples(TTransactionContext& txc, TIndexBuildInfo& buildInfo) {
1008+
if (buildInfo.Sample.State == TIndexBuildInfo::TSample::EState::Collect) {
1009+
if (NoShardsAdded(buildInfo)) {
1010+
AddGlobalShardsForCurrentParent(buildInfo);
1011+
if (!buildInfo.DoneShards.size() && !buildInfo.ToUploadShards.size()) {
1012+
// No "global" shards to handle - parent only has 1 shard,
1013+
// it will be handled during the MultiLocal phase
1014+
return FillVectorIndexNextParent(txc, buildInfo);
1015+
}
1016+
// Otherwise, we collect samples
1017+
LOG_D("FillVectorIndex Samples " << buildInfo.DebugString());
1018+
}
1019+
if (!SendKMeansSample(buildInfo)) {
1020+
return false;
1021+
}
1022+
ClearDoneShards(txc, buildInfo);
1023+
if (buildInfo.Sample.Rows.empty()) {
1024+
// No samples => no data for this cluster
1025+
return FillVectorIndexNextParent(txc, buildInfo);
1026+
}
1027+
LOG_D("FillVectorIndex SendUploadSampleKRequest " << buildInfo.DebugString());
1028+
SendUploadSampleKRequest(buildInfo);
1029+
buildInfo.Sample.State = TIndexBuildInfo::TSample::EState::Upload;
1030+
return false;
1031+
} else if (buildInfo.Sample.State == TIndexBuildInfo::TSample::EState::Upload) {
1032+
// Just wait until samples are uploaded (saved)
1033+
return false;
1034+
} else if (buildInfo.Sample.State == TIndexBuildInfo::TSample::EState::Done) {
1035+
buildInfo.KMeans.State = TIndexBuildInfo::TKMeans::Reshuffle;
1036+
LOG_D("FillVectorIndex NextState " << buildInfo.DebugString());
10501037
PersistKMeansState(txc, buildInfo);
10511038
Progress(BuildId);
10521039
return false;
10531040
}
1041+
Y_ENSURE(false);
1042+
}
10541043

1055-
if (InitMultiKMeans(buildInfo)) {
1056-
LOG_D("FillVectorIndex MultiKMeans " << buildInfo.DebugString());
1044+
bool FillVectorIndexNextParent(TTransactionContext& txc, TIndexBuildInfo& buildInfo) {
1045+
if (buildInfo.KMeans.NextParent()) {
1046+
buildInfo.KMeans.State = TIndexBuildInfo::TKMeans::Sample;
1047+
LOG_D("FillVectorIndex NextParent " << buildInfo.DebugString());
10571048
PersistKMeansState(txc, buildInfo);
10581049
Progress(BuildId);
10591050
return false;
10601051
}
10611052

1053+
if (!buildInfo.Cluster2Shards.empty()) {
1054+
AddLocalClusters(buildInfo);
1055+
buildInfo.Cluster2Shards.clear();
1056+
if (!buildInfo.ToUploadShards.empty()) {
1057+
LOG_D("FillVectorIndex MultiKMeans " << buildInfo.DebugString());
1058+
buildInfo.KMeans.State = TIndexBuildInfo::TKMeans::MultiLocal;
1059+
PersistKMeansState(txc, buildInfo);
1060+
Progress(BuildId);
1061+
return false;
1062+
}
1063+
}
1064+
10621065
if (buildInfo.KMeans.NextLevel()) {
1066+
buildInfo.KMeans.State = TIndexBuildInfo::TKMeans::Sample;
10631067
LOG_D("FillVectorIndex NextLevel " << buildInfo.DebugString());
10641068
PersistKMeansState(txc, buildInfo);
10651069
NIceDb::TNiceDb db{txc.DB};
@@ -1070,6 +1074,7 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
10701074
Progress(BuildId);
10711075
return false;
10721076
}
1077+
10731078
LOG_D("FillVectorIndex Done " << buildInfo.DebugString());
10741079
return true;
10751080
}

ydb/core/tx/schemeshard/schemeshard_info_types.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,14 +2225,14 @@ void TIndexBuildInfo::AddParent(const TSerializedTableRange& range, TShardIdx sh
22252225
const auto [parentFrom, parentTo] = KMeans.RangeToBorders(range);
22262226
// TODO(mbkkt) We can make it more granular
22272227

2228-
// if new range is not intersect with other ranges, it's local
2228+
// the new range does not intersect with other ranges, just add it with 1 shard
22292229
auto itFrom = Cluster2Shards.lower_bound(parentFrom);
22302230
if (itFrom == Cluster2Shards.end() || parentTo < itFrom->second.From) {
2231-
Cluster2Shards.emplace_hint(itFrom, parentTo, TClusterShards{.From = parentFrom, .Local = shard});
2231+
Cluster2Shards.emplace_hint(itFrom, parentTo, TClusterShards{.From = parentFrom, .Shards = {shard}});
22322232
return;
22332233
}
22342234

2235-
// otherwise, this range is global and we need to merge all intersecting ranges
2235+
// otherwise, this range has multiple shards and we need to merge all intersecting ranges
22362236
auto itTo = parentTo < itFrom->first ? itFrom : Cluster2Shards.lower_bound(parentTo);
22372237
if (itTo == Cluster2Shards.end()) {
22382238
itTo = Cluster2Shards.rbegin().base();
@@ -2244,25 +2244,16 @@ void TIndexBuildInfo::AddParent(const TSerializedTableRange& range, TShardIdx sh
22442244
itTo = Cluster2Shards.insert(Cluster2Shards.end(), std::move(node));
22452245
itFrom = needsToReplaceFrom ? itTo : itFrom;
22462246
}
2247-
auto& [toFrom, toLocal, toGlobal] = itTo->second;
2247+
auto& [toFrom, toShards] = itTo->second;
22482248

22492249
toFrom = std::min(toFrom, parentFrom);
2250-
if (toLocal != InvalidShardIdx) {
2251-
toGlobal.emplace_back(toLocal);
2252-
toLocal = InvalidShardIdx;
2253-
}
2254-
toGlobal.emplace_back(shard);
2250+
toShards.emplace_back(shard);
22552251

22562252
while (itFrom != itTo) {
2257-
const auto& [fromFrom, fromLocal, fromGlobal] = itFrom->second;
2253+
const auto& [fromFrom, fromShards] = itFrom->second;
22582254
toFrom = std::min(toFrom, fromFrom);
2259-
if (fromLocal != InvalidShardIdx) {
2260-
Y_ASSERT(fromGlobal.empty());
2261-
toGlobal.emplace_back(fromLocal);
2262-
} else {
2263-
Y_ASSERT(!fromGlobal.empty());
2264-
toGlobal.insert(toGlobal.end(), fromGlobal.begin(), fromGlobal.end());
2265-
}
2255+
Y_ASSERT(!fromShards.empty());
2256+
toShards.insert(toShards.end(), fromShards.begin(), fromShards.end());
22662257
itFrom = Cluster2Shards.erase(itFrom);
22672258
}
22682259
}

ydb/core/tx/schemeshard/schemeshard_info_types.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3091,7 +3091,6 @@ struct TIndexBuildInfo: public TSimpleRefCount<TIndexBuildInfo> {
30913091
Sample = 0,
30923092
// Recompute,
30933093
Reshuffle,
3094-
Local,
30953094
MultiLocal,
30963095
};
30973096
ui32 Level = 1;
@@ -3141,7 +3140,6 @@ struct TIndexBuildInfo: public TSimpleRefCount<TIndexBuildInfo> {
31413140
if (!NeedsAnotherParent()) {
31423141
return false;
31433142
}
3144-
State = Sample;
31453143
++Parent;
31463144
Child += K;
31473145
return true;
@@ -3151,14 +3149,12 @@ struct TIndexBuildInfo: public TSimpleRefCount<TIndexBuildInfo> {
31513149
if (!NeedsAnotherLevel()) {
31523150
return false;
31533151
}
3154-
State = Sample;
31553152
NextLevel(ChildCount());
31563153
return true;
31573154
}
31583155

31593156
void PrefixIndexDone(ui64 shards) {
31603157
Y_ENSURE(NeedsAnotherLevel());
3161-
State = MultiLocal;
31623158
// There's two worst cases, but in both one shard contains TableSize rows
31633159
// 1. all rows have unique prefix (*), in such case we need 1 id for each row (parent, id in prefix table)
31643160
// 2. all unique prefixes have size K, so we have TableSize/K parents + TableSize childs
@@ -3442,10 +3438,9 @@ struct TIndexBuildInfo: public TSimpleRefCount<TIndexBuildInfo> {
34423438

34433439
struct TClusterShards {
34443440
NTableIndex::TClusterId From = std::numeric_limits<NTableIndex::TClusterId>::max();
3445-
TShardIdx Local = InvalidShardIdx;
3446-
std::vector<TShardIdx> Global;
3441+
std::vector<TShardIdx> Shards;
34473442
};
3448-
TMap<NTableIndex::TClusterId, TClusterShards> Cluster2Shards;
3443+
TMap<NTableIndex::TClusterId, TClusterShards> Cluster2Shards; // To => { From, Shards }
34493444

34503445
void AddParent(const TSerializedTableRange& range, TShardIdx shard);
34513446

0 commit comments

Comments
 (0)