Skip to content

Commit 7edec7d

Browse files
committed
Use vector index recompute in scheme shard (#19154) (#19854)
1 parent 1c2dbea commit 7edec7d

File tree

17 files changed

+517
-110
lines changed

17 files changed

+517
-110
lines changed

ydb/core/base/kmeans_clusters.cpp

Lines changed: 77 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -125,43 +125,35 @@ class TClusters: public IClusters {
125125
using TSum = TMetric::TSum;
126126
using TEmbedding = TVector<TSum>;
127127

128-
ui32 InitK = 0;
129-
ui32 K = 0;
130128
const ui32 Dimensions = 0;
129+
const ui32 MaxRounds = 0;
130+
const ui8 TypeByte = 0;
131131

132132
TVector<TString> Clusters;
133133
TVector<ui64> ClusterSizes;
134-
135-
struct TAggregatedCluster {
136-
TEmbedding Cluster;
137-
ui64 Size = 0;
138-
};
139-
TVector<TAggregatedCluster> AggregatedClusters;
134+
TVector<TEmbedding> NextClusters;
135+
TVector<ui64> NextClusterSizes;
140136

141137
ui32 Round = 0;
142-
ui32 MaxRounds = 0;
143138

144139
public:
145-
TClusters(ui32 dimensions)
140+
TClusters(ui32 dimensions, ui32 maxRounds, ui8 typeByte)
146141
: Dimensions(dimensions)
142+
, MaxRounds(maxRounds)
143+
, TypeByte(typeByte)
147144
{
148145
}
149146

150-
void Init(ui32 k, ui32 maxRounds) override {
151-
InitK = k;
152-
K = k;
153-
MaxRounds = maxRounds;
154-
}
155-
156-
ui32 GetK() const override {
157-
return K;
147+
void SetRound(ui32 round) override {
148+
Round = round;
158149
}
159150

160151
TString Debug() const override {
161-
if (!MaxRounds) {
162-
return TStringBuilder() << "K: " << K;
152+
auto sb = TStringBuilder() << "K: " << Clusters.size();
153+
if (MaxRounds) {
154+
sb << " Round: " << Round << " / " << MaxRounds;
163155
}
164-
return TStringBuilder() << "K: " << K << " Round: " << Round << " / " << MaxRounds;
156+
return sb;
165157
}
166158

167159
const TVector<TString>& GetClusters() const override {
@@ -172,11 +164,19 @@ class TClusters: public IClusters {
172164
return ClusterSizes;
173165
}
174166

167+
const TVector<ui64>& GetNextClusterSizes() const override {
168+
return NextClusterSizes;
169+
}
170+
171+
virtual void SetClusterSize(ui32 num, ui64 size) override {
172+
ClusterSizes.at(num) = size;
173+
}
174+
175175
void Clear() override {
176-
K = InitK;
177176
Clusters.clear();
178177
ClusterSizes.clear();
179-
AggregatedClusters.clear();
178+
NextClusterSizes.clear();
179+
NextClusters.clear();
180180
Round = 0;
181181
}
182182

@@ -189,40 +189,37 @@ class TClusters: public IClusters {
189189
return false;
190190
}
191191
}
192-
Clusters = newClusters;
193-
K = newClusters.size();
194-
return true;
195-
}
196-
197-
void InitAggregatedClusters() override {
198-
AggregatedClusters.resize(K);
199-
ClusterSizes.resize(K, 0);
200-
for (auto& aggregate : AggregatedClusters) {
201-
aggregate.Cluster.resize(Dimensions, 0);
192+
Clusters = std::move(newClusters);
193+
ClusterSizes.clear();
194+
ClusterSizes.resize(Clusters.size());
195+
NextClusterSizes.clear();
196+
NextClusterSizes.resize(Clusters.size());
197+
NextClusters.clear();
198+
NextClusters.resize(Clusters.size());
199+
for (auto& aggregate : NextClusters) {
200+
aggregate.resize(Dimensions, 0);
202201
}
203-
Round = 1;
202+
return true;
204203
}
205204

206205
bool RecomputeClusters() override {
207-
Y_ENSURE(K >= 1);
208206
ui64 vectorCount = 0;
209207
ui64 reassignedCount = 0;
210-
for (size_t i = 0; auto& aggregate : AggregatedClusters) {
211-
vectorCount += aggregate.Size;
208+
for (size_t i = 0; auto& aggregate : NextClusters) {
209+
auto newSize = NextClusterSizes[i];
210+
vectorCount += newSize;
212211

213-
auto& clusterSize = ClusterSizes[i];
214-
reassignedCount += clusterSize < aggregate.Size ? aggregate.Size - clusterSize : 0;
215-
clusterSize = aggregate.Size;
212+
auto clusterSize = ClusterSizes[i];
213+
reassignedCount += clusterSize < newSize ? newSize - clusterSize : 0;
216214

217-
if (aggregate.Size != 0) {
218-
this->Fill(Clusters[i], aggregate.Cluster.data(), aggregate.Size);
219-
Y_ENSURE(aggregate.Size == 0);
215+
if (newSize != 0) {
216+
this->Fill(Clusters[i], aggregate.data(), newSize);
220217
}
221218
++i;
222219
}
223-
Y_ENSURE(vectorCount >= K);
220+
224221
Y_ENSURE(reassignedCount <= vectorCount);
225-
if (K == 1) {
222+
if (Clusters.size() == 1) {
226223
return true;
227224
}
228225

@@ -232,7 +229,6 @@ class TClusters: public IClusters {
232229
last = changes < MinVectorsNeedsReassigned;
233230
}
234231
if (!last) {
235-
++Round;
236232
return false;
237233
}
238234
return true;
@@ -251,6 +247,25 @@ class TClusters: public IClusters {
251247
Clusters.erase(Clusters.begin() + w, Clusters.end());
252248
}
253249

250+
bool NextRound() override {
251+
bool isLast = RecomputeClusters();
252+
ClusterSizes = std::move(NextClusterSizes);
253+
RemoveEmptyClusters();
254+
if (isLast) {
255+
NextClusters.clear();
256+
return true;
257+
}
258+
++Round;
259+
NextClusterSizes.clear();
260+
NextClusterSizes.resize(Clusters.size());
261+
NextClusters.clear();
262+
NextClusters.resize(Clusters.size());
263+
for (auto& aggregate : NextClusters) {
264+
aggregate.resize(Dimensions, 0);
265+
}
266+
return false;
267+
}
268+
254269
std::optional<ui32> FindCluster(TArrayRef<const TCell> row, ui32 embeddingPos) override {
255270
Y_ENSURE(embeddingPos < row.size());
256271
const auto embedding = row.at(embeddingPos).AsRef();
@@ -271,16 +286,17 @@ class TClusters: public IClusters {
271286
return closest;
272287
}
273288

274-
void AggregateToCluster(ui32 pos, const char* embedding) override {
275-
auto& aggregate = AggregatedClusters[pos];
276-
auto* coords = aggregate.Cluster.data();
277-
for (auto coord : this->GetCoords(embedding)) {
278-
*coords++ += coord;
289+
void AggregateToCluster(ui32 pos, const TArrayRef<const char>& embedding, ui64 weight) override {
290+
auto& aggregate = NextClusters.at(pos);
291+
auto* coords = aggregate.data();
292+
Y_ENSURE(IsExpectedSize(embedding));
293+
for (auto coord : this->GetCoords(embedding.data())) {
294+
*coords++ += (TSum)coord * weight;
279295
}
280-
++aggregate.Size;
296+
NextClusterSizes.at(pos) += weight;
281297
}
282298

283-
bool IsExpectedSize(TArrayRef<const char> data) override {
299+
bool IsExpectedSize(const TArrayRef<const char>& data) override {
284300
return data.size() == 1 + sizeof(TCoord) * Dimensions;
285301
}
286302

@@ -295,36 +311,37 @@ class TClusters: public IClusters {
295311

296312
void Fill(TString& d, TSum* embedding, ui64& c) {
297313
Y_ENSURE(c > 0);
298-
const auto count = static_cast<TSum>(std::exchange(c, 0));
314+
const auto count = static_cast<TSum>(c);
299315
auto data = GetData(d.MutRef().data());
300316
for (auto& coord : data) {
301317
coord = *embedding / count;
302-
*embedding++ = 0;
318+
embedding++;
303319
}
304320
}
305321
};
306322

307-
std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings& settings, TString& error) {
323+
std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings& settings, ui32 maxRounds, TString& error) {
308324
if (settings.vector_dimension() < 1) {
309325
error = "Dimension of vector should be at least one";
310326
return nullptr;
311327
}
312328

329+
const ui8 typeVal = (ui8)settings.vector_type();
313330
const ui32 dim = settings.vector_dimension();
314331

315332
auto handleMetric = [&]<typename T>() -> std::unique_ptr<IClusters> {
316333
switch (settings.metric()) {
317334
case Ydb::Table::VectorIndexSettings::SIMILARITY_INNER_PRODUCT:
318-
return std::make_unique<TClusters<TMaxInnerProductSimilarity<T>>>(dim);
335+
return std::make_unique<TClusters<TMaxInnerProductSimilarity<T>>>(dim, maxRounds, typeVal);
319336
case Ydb::Table::VectorIndexSettings::SIMILARITY_COSINE:
320337
case Ydb::Table::VectorIndexSettings::DISTANCE_COSINE:
321338
// We don't need to have separate implementation for distance,
322339
// because clusters will be same as for similarity
323-
return std::make_unique<TClusters<TCosineSimilarity<T>>>(dim);
340+
return std::make_unique<TClusters<TCosineSimilarity<T>>>(dim, maxRounds, typeVal);
324341
case Ydb::Table::VectorIndexSettings::DISTANCE_MANHATTAN:
325-
return std::make_unique<TClusters<TL1Distance<T>>>(dim);
342+
return std::make_unique<TClusters<TL1Distance<T>>>(dim, maxRounds, typeVal);
326343
case Ydb::Table::VectorIndexSettings::DISTANCE_EUCLIDEAN:
327-
return std::make_unique<TClusters<TL2Distance<T>>>(dim);
344+
return std::make_unique<TClusters<TL2Distance<T>>>(dim, maxRounds, typeVal);
328345
default:
329346
error = "Wrong similarity";
330347
break;

ydb/core/base/kmeans_clusters.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,35 @@ class IClusters {
1010
public:
1111
virtual ~IClusters() = default;
1212

13-
virtual void Init(ui32 k, ui32 maxRounds) = 0;
14-
15-
virtual ui32 GetK() const = 0;
13+
virtual void SetRound(ui32 round) = 0;
1614

1715
virtual TString Debug() const = 0;
1816

1917
virtual const TVector<TString>& GetClusters() const = 0;
2018

2119
virtual const TVector<ui64>& GetClusterSizes() const = 0;
2220

21+
virtual const TVector<ui64>& GetNextClusterSizes() const = 0;
22+
23+
virtual void SetClusterSize(ui32 num, ui64 size) = 0;
24+
2325
virtual void Clear() = 0;
2426

2527
virtual bool SetClusters(TVector<TString> && newClusters) = 0;
2628

27-
virtual void InitAggregatedClusters() = 0;
28-
2929
virtual bool RecomputeClusters() = 0;
3030

31+
virtual bool NextRound() = 0;
32+
3133
virtual void RemoveEmptyClusters() = 0;
3234

3335
virtual std::optional<ui32> FindCluster(TArrayRef<const TCell> row, ui32 embeddingPos) = 0;
3436

35-
virtual void AggregateToCluster(ui32 pos, const char* embedding) = 0;
37+
virtual void AggregateToCluster(ui32 pos, const TArrayRef<const char>& embedding, ui64 weight = 1) = 0;
3638

37-
virtual bool IsExpectedSize(TArrayRef<const char> data) = 0;
39+
virtual bool IsExpectedSize(const TArrayRef<const char>& data) = 0;
3840
};
3941

40-
std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings& settings, TString& error);
42+
std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings& settings, ui32 maxRounds, TString& error);
4143

4244
}

ydb/core/base/table_vector_index.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,6 @@ inline constexpr const char* BuildSuffix1 = "1build";
2121
// Prefix table
2222
inline constexpr const char* PrefixTable = "indexImplPrefixTable";
2323

24+
inline constexpr const int DefaultKMeansRounds = 3;
25+
2426
}

ydb/core/tx/datashard/build_index/local_kmeans.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ class TLocalKMeansScan: public TActor<TLocalKMeansScan>, public NTable::IScan {
132132
, Clusters(std::move(clusters))
133133
{
134134
LOG_I("Create " << Debug());
135-
Clusters->Init(request.GetK(), request.GetNeedsRounds());
136135

137136
const auto& embedding = request.GetEmbeddingColumn();
138137
const auto& data = request.GetDataColumns();
@@ -231,7 +230,7 @@ class TLocalKMeansScan: public TActor<TLocalKMeansScan>, public NTable::IScan {
231230
if (PrefixColumns && !Prefix) {
232231
Prefix = TSerializedCellVec{key.subspan(0, PrefixColumns)};
233232
auto newParent = key.at(0).template AsValue<ui64>();
234-
Child += (newParent - Parent) * Clusters->GetK();
233+
Child += (newParent - Parent) * K;
235234
Parent = newParent;
236235
}
237236

@@ -380,13 +379,11 @@ class TLocalKMeansScan: public TActor<TLocalKMeansScan>, public NTable::IScan {
380379
}
381380
bool ok = Clusters->SetClusters(std::move(rows));
382381
Y_ENSURE(ok);
383-
Clusters->InitAggregatedClusters();
384382
return false; // do KMEANS
385383
}
386384

387385
if (State == EState::KMEANS) {
388-
if (Clusters->RecomputeClusters()) {
389-
Clusters->RemoveEmptyClusters();
386+
if (Clusters->NextRound()) {
390387
FormLevelRows();
391388
State = UploadState;
392389
return false; // do UPLOAD_*
@@ -444,7 +441,7 @@ class TLocalKMeansScan: public TActor<TLocalKMeansScan>, public NTable::IScan {
444441
void FeedKMeans(TArrayRef<const TCell> row) noexcept
445442
{
446443
if (auto pos = Clusters->FindCluster(row, EmbeddingPos); pos) {
447-
Clusters->AggregateToCluster(*pos, row.at(EmbeddingPos).Data());
444+
Clusters->AggregateToCluster(*pos, row.at(EmbeddingPos).AsRef());
448445
}
449446
}
450447

@@ -661,7 +658,7 @@ void TDataShard::HandleSafe(TEvDataShard::TEvLocalKMeansRequest::TPtr& ev, const
661658

662659
// 3. Validating vector index settings
663660
TString error;
664-
auto clusters = NKikimr::NKMeans::CreateClusters(request.GetSettings(), error);
661+
auto clusters = NKikimr::NKMeans::CreateClusters(request.GetSettings(), request.GetNeedsRounds(), error);
665662
if (!clusters) {
666663
badRequest(error);
667664
auto sent = trySendBadRequest();

ydb/core/tx/datashard/build_index/prefix_kmeans.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ class TPrefixKMeansScan: public TActor<TPrefixKMeansScan>, public NTable::IScan
134134
, Clusters(std::move(clusters))
135135
{
136136
LOG_I("Create " << Debug());
137-
Clusters->Init(request.GetK(), request.GetNeedsRounds());
138137

139138
const auto& embedding = request.GetEmbeddingColumn();
140139
TVector<TString> data{request.GetDataColumns().begin(), request.GetDataColumns().end()};
@@ -355,7 +354,7 @@ class TPrefixKMeansScan: public TActor<TPrefixKMeansScan>, public NTable::IScan
355354
}
356355

357356
void StartNewPrefix() {
358-
Parent = Child + Clusters->GetK();
357+
Parent = Child + K;
359358
Child = Parent + 1;
360359
State = EState::SAMPLE;
361360
Lead.To(Prefix.GetCells(), NTable::ESeek::Upper); // seek to (prefix, inf)
@@ -426,13 +425,11 @@ class TPrefixKMeansScan: public TActor<TPrefixKMeansScan>, public NTable::IScan
426425
}
427426
bool ok = Clusters->SetClusters(std::move(rows));
428427
Y_ENSURE(ok);
429-
Clusters->InitAggregatedClusters();
430428
return false; // do KMEANS
431429
}
432430

433431
if (State == EState::KMEANS) {
434-
if (Clusters->RecomputeClusters()) {
435-
Clusters->RemoveEmptyClusters();
432+
if (Clusters->NextRound()) {
436433
FormLevelRows();
437434
State = UploadState;
438435
return false; // do UPLOAD_*
@@ -484,7 +481,7 @@ class TPrefixKMeansScan: public TActor<TPrefixKMeansScan>, public NTable::IScan
484481
void FeedKMeans(TArrayRef<const TCell> row)
485482
{
486483
if (auto pos = Clusters->FindCluster(row, EmbeddingPos); pos) {
487-
Clusters->AggregateToCluster(*pos, row.at(EmbeddingPos).Data());
484+
Clusters->AggregateToCluster(*pos, row.at(EmbeddingPos).AsRef());
488485
}
489486
}
490487

@@ -650,7 +647,7 @@ void TDataShard::HandleSafe(TEvDataShard::TEvPrefixKMeansRequest::TPtr& ev, cons
650647

651648
// 3. Validating vector index settings
652649
TString error;
653-
auto clusters = NKikimr::NKMeans::CreateClusters(request.GetSettings(), error);
650+
auto clusters = NKikimr::NKMeans::CreateClusters(request.GetSettings(), request.GetNeedsRounds(), error);
654651
if (!clusters) {
655652
badRequest(error);
656653
auto sent = trySendBadRequest();

0 commit comments

Comments
 (0)