Skip to content

Commit a184bbe

Browse files
committed
Implement RecomputeKMeans scan at the data shard side (#19154) (#19894)
1 parent e079321 commit a184bbe

19 files changed

+903
-72
lines changed

ydb/core/base/kmeans_clusters.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ class TClusters: public IClusters {
168168
return Clusters;
169169
}
170170

171+
const TVector<ui64>& GetClusterSizes() const override {
172+
return ClusterSizes;
173+
}
174+
171175
void Clear() override {
172176
K = InitK;
173177
Clusters.clear();
@@ -177,20 +181,16 @@ class TClusters: public IClusters {
177181
}
178182

179183
bool SetClusters(TVector<TString> && newClusters) override {
180-
Clusters = newClusters;
181-
if (Clusters.size() == 0) {
184+
if (newClusters.size() == 0) {
182185
return false;
183186
}
184-
if (Clusters.size() < K) {
185-
// if this datashard have less than K valid embeddings for this parent
186-
// lets make single centroid for it
187-
K = 1;
188-
Clusters.resize(K);
189-
}
190-
if (!K) {
191-
K = InitK = newClusters.size();
187+
for (const auto& cluster: newClusters) {
188+
if (!IsExpectedSize(cluster)) {
189+
return false;
190+
}
192191
}
193-
Y_ENSURE(Clusters.size() == K);
192+
Clusters = newClusters;
193+
K = newClusters.size();
194194
return true;
195195
}
196196

@@ -235,7 +235,10 @@ class TClusters: public IClusters {
235235
++Round;
236236
return false;
237237
}
238+
return true;
239+
}
238240

241+
void RemoveEmptyClusters() override {
239242
size_t w = 0;
240243
for (size_t r = 0; r < ClusterSizes.size(); ++r) {
241244
if (ClusterSizes[r] != 0) {
@@ -246,7 +249,6 @@ class TClusters: public IClusters {
246249
}
247250
ClusterSizes.erase(ClusterSizes.begin() + w, ClusterSizes.end());
248251
Clusters.erase(Clusters.begin() + w, Clusters.end());
249-
return true;
250252
}
251253

252254
std::optional<ui32> FindCluster(TArrayRef<const TCell> row, ui32 embeddingPos) override {

ydb/core/base/kmeans_clusters.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class IClusters {
1818

1919
virtual const TVector<TString>& GetClusters() const = 0;
2020

21+
virtual const TVector<ui64>& GetClusterSizes() const = 0;
22+
2123
virtual void Clear() = 0;
2224

2325
virtual bool SetClusters(TVector<TString> && newClusters) = 0;
@@ -26,6 +28,8 @@ class IClusters {
2628

2729
virtual bool RecomputeClusters() = 0;
2830

31+
virtual void RemoveEmptyClusters() = 0;
32+
2933
virtual std::optional<ui32> FindCluster(TArrayRef<const TCell> row, ui32 embeddingPos) = 0;
3034

3135
virtual void AggregateToCluster(ui32 pos, const char* embedding) = 0;

ydb/core/base/table_index.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,27 @@ TString ToShortDebugString(const NKikimrTxDataShard::TEvReshuffleKMeansRequest&
212212
return result;
213213
}
214214

215+
TString ToShortDebugString(const NKikimrTxDataShard::TEvRecomputeKMeansRequest& record) {
216+
auto copy = record;
217+
TStringBuilder result;
218+
// clusters are not human readable and can be large like 100Kb+
219+
copy.ClearClusters();
220+
result << copy.ShortDebugString();
221+
result << " Clusters: " << record.ClustersSize();
222+
return result;
223+
}
224+
225+
TString ToShortDebugString(const NKikimrTxDataShard::TEvRecomputeKMeansResponse& record) {
226+
auto copy = record;
227+
TStringBuilder result;
228+
// clusters are not human readable and can be large like 100Kb+
229+
copy.ClearClusters();
230+
copy.ClearClusterSizes();
231+
result << copy.ShortDebugString();
232+
result << " Clusters: " << record.ClustersSize();
233+
return result;
234+
}
235+
215236
TString ToShortDebugString(const NKikimrTxDataShard::TEvSampleKResponse& record) {
216237
auto copy = record;
217238
TStringBuilder result;

ydb/core/base/table_index.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
namespace NKikimrTxDataShard {
1616
class TEvReshuffleKMeansRequest;
17+
class TEvRecomputeKMeansRequest;
18+
class TEvRecomputeKMeansResponse;
1719
class TEvSampleKResponse;
1820
}
1921

@@ -51,6 +53,8 @@ void EnsureNoPostingParentFlag(TClusterId parent);
5153
TClusterId SetPostingParentFlag(TClusterId parent);
5254

5355
TString ToShortDebugString(const NKikimrTxDataShard::TEvReshuffleKMeansRequest& record);
56+
TString ToShortDebugString(const NKikimrTxDataShard::TEvRecomputeKMeansRequest& record);
57+
TString ToShortDebugString(const NKikimrTxDataShard::TEvRecomputeKMeansResponse& record);
5458
TString ToShortDebugString(const NKikimrTxDataShard::TEvSampleKResponse& record);
5559

5660
}

ydb/core/kqp/ut/common/kqp_ut_common.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -273,28 +273,6 @@ inline NYdb::NTable::EIndexType IndexTypeSqlToIndexType(EIndexTypeSql type) {
273273
}
274274
}
275275

276-
inline constexpr TStringBuf IndexSubtypeSqlString(EIndexTypeSql type) {
277-
switch (type) {
278-
case EIndexTypeSql::Global:
279-
case EIndexTypeSql::GlobalSync:
280-
case EIndexTypeSql::GlobalAsync:
281-
return "";
282-
case NKqp::EIndexTypeSql::GlobalVectorKMeansTree:
283-
return "USING vector_kmeans_tree";
284-
}
285-
}
286-
287-
inline constexpr TStringBuf IndexWithSqlString(EIndexTypeSql type) {
288-
switch (type) {
289-
case EIndexTypeSql::Global:
290-
case EIndexTypeSql::GlobalSync:
291-
case EIndexTypeSql::GlobalAsync:
292-
return "";
293-
case NKqp::EIndexTypeSql::GlobalVectorKMeansTree:
294-
return "WITH (similarity=inner_product, vector_type=float, vector_dimension=1024)";
295-
}
296-
}
297-
298276
TString ReformatYson(const TString& yson);
299277
void CompareYson(const TString& expected, const TString& actual, const TString& message = {});
300278
void CompareYson(const TString& expected, const NKikimrMiniKQL::TResult& actual, const TString& message = {});

ydb/core/kqp/ut/scheme/kqp_scheme_ut.cpp

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ Y_UNIT_TEST_SUITE(KqpScheme) {
255255
}, 0, Inflight + 1, NPar::TLocalExecutor::WAIT_COMPLETE | NPar::TLocalExecutor::MED_PRIORITY);
256256
}
257257

258-
void SchemaVersionMissmatchWithTest(bool write) {
258+
void SchemaVersionMismatchWithTest(bool write) {
259259
TKikimrRunner kikimr;
260260

261261
auto db = kikimr.GetTableClient();
@@ -317,7 +317,7 @@ Y_UNIT_TEST_SUITE(KqpScheme) {
317317
}
318318
}
319319

320-
void SchemaVersionMissmatchWithIndexTest(bool write) {
320+
void SchemaVersionMismatchWithIndexTest(bool write) {
321321
//KIKIMR-14282
322322
//YDBREQUESTS-1324
323323
//some cases fail
@@ -406,20 +406,20 @@ Y_UNIT_TEST_SUITE(KqpScheme) {
406406
}
407407
}
408408

409-
Y_UNIT_TEST(SchemaVersionMissmatchWithRead) {
410-
SchemaVersionMissmatchWithTest(false);
409+
Y_UNIT_TEST(SchemaVersionMismatchWithRead) {
410+
SchemaVersionMismatchWithTest(false);
411411
}
412412

413-
Y_UNIT_TEST(SchemaVersionMissmatchWithWrite) {
414-
SchemaVersionMissmatchWithTest(true);
413+
Y_UNIT_TEST(SchemaVersionMismatchWithWrite) {
414+
SchemaVersionMismatchWithTest(true);
415415
}
416416

417-
Y_UNIT_TEST(SchemaVersionMissmatchWithIndexRead) {
418-
SchemaVersionMissmatchWithIndexTest(false);
417+
Y_UNIT_TEST(SchemaVersionMismatchWithIndexRead) {
418+
SchemaVersionMismatchWithIndexTest(false);
419419
}
420420

421-
Y_UNIT_TEST(SchemaVersionMissmatchWithIndexWrite) {
422-
SchemaVersionMissmatchWithIndexTest(true);
421+
Y_UNIT_TEST(SchemaVersionMismatchWithIndexWrite) {
422+
SchemaVersionMismatchWithIndexTest(true);
423423
}
424424

425425
void TouchIndexAfterMoveIndex(bool write, bool replace) {
@@ -2582,11 +2582,26 @@ Y_UNIT_TEST_SUITE(KqpScheme) {
25822582

25832583
auto db = kikimr.GetTableClient();
25842584
auto session = db.CreateSession().GetValueSync().GetSession();
2585-
CreateSampleTablesWithIndex(session);
2585+
2586+
if (type == EIndexTypeSql::GlobalVectorKMeansTree) {
2587+
auto result = session.ExecuteDataQuery(R"(
2588+
REPLACE INTO `Test` (Group, Name, Amount, Comment) VALUES
2589+
(1u, "Jack", 100500ul, "Just Jack"),
2590+
(3u, "Harr", 5600ul, "Not Potter"),
2591+
(3u, "Josh", 8202ul, "Very popular name in GB"),
2592+
(3u, "Anna", 887773ul, "Just Anna"),
2593+
(4u, "Hugo", 77, "Boss");
2594+
)", TTxControl::BeginTx().CommitTx()).GetValueSync();
2595+
UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString());
2596+
} else {
2597+
CreateSampleTablesWithIndex(session);
2598+
}
25862599

25872600
const auto typeStr = IndexTypeSqlString(type).data();
2588-
const auto subtypeStr = IndexSubtypeSqlString(type).data();
2589-
const auto withStr = IndexWithSqlString(type).data();
2601+
const auto subtypeStr = (type == EIndexTypeSql::GlobalVectorKMeansTree
2602+
? "USING vector_kmeans_tree" : "");
2603+
const auto withStr = (type == EIndexTypeSql::GlobalVectorKMeansTree
2604+
? "WITH (similarity=inner_product, vector_type=uint8, vector_dimension=3)" : "");
25902605

25912606
// Non-covered index, single column
25922607
{
@@ -2611,8 +2626,8 @@ Y_UNIT_TEST_SUITE(KqpScheme) {
26112626
if (type == EIndexTypeSql::GlobalVectorKMeansTree) {
26122627
const auto& vectorIndexSettings = std::get<TKMeansTreeSettings>(indexDesc.back().GetIndexSettings()).Settings;
26132628
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.Metric, TVectorIndexSettings::EMetric::InnerProduct);
2614-
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorType, TVectorIndexSettings::EVectorType::Float);
2615-
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorDimension, 1024);
2629+
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorType, TVectorIndexSettings::EVectorType::Uint8);
2630+
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorDimension, 3); // test names are all 4 bytes
26162631

26172632
describe = session.DescribeTable(TString{"/Root/Test/NameIndex/"} + NTableIndex::NTableVectorKmeansTreeIndex::LevelTable).GetValueSync();
26182633
UNIT_ASSERT_EQUAL(describe.GetStatus(), EStatus::SUCCESS);
@@ -2660,8 +2675,8 @@ Y_UNIT_TEST_SUITE(KqpScheme) {
26602675
if (type == EIndexTypeSql::GlobalVectorKMeansTree) {
26612676
const auto& vectorIndexSettings = std::get<TKMeansTreeSettings>(indexDesc.back().GetIndexSettings()).Settings;
26622677
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.Metric, TVectorIndexSettings::EMetric::InnerProduct);
2663-
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorType, TVectorIndexSettings::EVectorType::Float);
2664-
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorDimension, 1024);
2678+
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorType, TVectorIndexSettings::EVectorType::Uint8);
2679+
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorDimension, 3); // test names are all 4 bytes
26652680

26662681
describe = session.DescribeTable(TString{"/Root/Test/CommentIndex/"} + NTableIndex::NTableVectorKmeansTreeIndex::LevelTable).GetValueSync();
26672682
UNIT_ASSERT_EQUAL(describe.GetStatus(), EStatus::SUCCESS);
@@ -2709,8 +2724,8 @@ Y_UNIT_TEST_SUITE(KqpScheme) {
27092724
if (type == EIndexTypeSql::GlobalVectorKMeansTree) {
27102725
const auto& vectorIndexSettings = std::get<TKMeansTreeSettings>(indexDesc.back().GetIndexSettings()).Settings;
27112726
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.Metric, TVectorIndexSettings::EMetric::InnerProduct);
2712-
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorType, TVectorIndexSettings::EVectorType::Float);
2713-
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorDimension, 1024);
2727+
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorType, TVectorIndexSettings::EVectorType::Uint8);
2728+
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorDimension, 3); // test names are all 4 bytes
27142729

27152730
describe = session.DescribeTable(TString{"/Root/Test/NameIndex/"} + NTableIndex::NTableVectorKmeansTreeIndex::LevelTable).GetValueSync();
27162731
UNIT_ASSERT_EQUAL(describe.GetStatus(), EStatus::SUCCESS);
@@ -2758,8 +2773,8 @@ Y_UNIT_TEST_SUITE(KqpScheme) {
27582773
if (type == EIndexTypeSql::GlobalVectorKMeansTree) {
27592774
const auto& vectorIndexSettings = std::get<TKMeansTreeSettings>(indexDesc.back().GetIndexSettings()).Settings;
27602775
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.Metric, TVectorIndexSettings::EMetric::InnerProduct);
2761-
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorType, TVectorIndexSettings::EVectorType::Float);
2762-
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorDimension, 1024);
2776+
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorType, TVectorIndexSettings::EVectorType::Uint8);
2777+
UNIT_ASSERT_VALUES_EQUAL(vectorIndexSettings.VectorDimension, 3); // test names are all 4 bytes
27632778

27642779
describe = session.DescribeTable(TString{"/Root/Test/CommentIndex/"} + NTableIndex::NTableVectorKmeansTreeIndex::LevelTable).GetValueSync();
27652780
UNIT_ASSERT_EQUAL(describe.GetStatus(), EStatus::SUCCESS);

ydb/core/protos/tx_datashard.proto

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,6 +1632,48 @@ message TEvReshuffleKMeansResponse {
16321632
optional uint64 ReadBytes = 11;
16331633
}
16341634

1635+
message TEvRecomputeKMeansRequest {
1636+
optional uint64 Id = 1;
1637+
1638+
optional uint64 TabletId = 2;
1639+
optional NKikimrProto.TPathID PathId = 3;
1640+
1641+
optional uint64 SnapshotTxId = 4;
1642+
optional uint64 SnapshotStep = 5;
1643+
1644+
optional uint64 SeqNoGeneration = 6;
1645+
optional uint64 SeqNoRound = 7;
1646+
1647+
optional Ydb.Table.VectorIndexSettings Settings = 8;
1648+
1649+
// id of parent cluster
1650+
optional uint64 Parent = 9;
1651+
// centroids of clusters
1652+
repeated string Clusters = 10;
1653+
1654+
optional string EmbeddingColumn = 11;
1655+
}
1656+
1657+
message TEvRecomputeKMeansResponse {
1658+
optional uint64 Id = 1;
1659+
1660+
optional uint64 TabletId = 2;
1661+
optional NKikimrProto.TPathID PathId = 3;
1662+
1663+
optional uint64 RequestSeqNoGeneration = 4;
1664+
optional uint64 RequestSeqNoRound = 5;
1665+
1666+
optional NKikimrIndexBuilder.EBuildStatus Status = 6;
1667+
repeated Ydb.Issue.IssueMessage Issues = 7;
1668+
1669+
optional uint64 ReadRows = 8;
1670+
optional uint64 ReadBytes = 9;
1671+
1672+
// recomputed clusters and cluster sizes (row counts for every cluster)
1673+
repeated bytes Clusters = 10;
1674+
repeated uint64 ClusterSizes = 11;
1675+
}
1676+
16351677
message TEvPrefixKMeansRequest {
16361678
optional uint64 Id = 1;
16371679

ydb/core/tx/datashard/build_index/local_kmeans.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class TLocalKMeansScan: public TActor<TLocalKMeansScan>, public NTable::IScan {
7676
TBufferData* UploadBuf = nullptr;
7777

7878
const ui32 Dimensions = 0;
79+
const ui32 K = 0;
7980
NTable::TPos EmbeddingPos = 0;
8081
NTable::TPos DataPos = 1;
8182

@@ -123,6 +124,7 @@ class TLocalKMeansScan: public TActor<TLocalKMeansScan>, public NTable::IScan {
123124
, BuildId{request.GetId()}
124125
, Uploader(request.GetScanSettings())
125126
, Dimensions(request.GetSettings().vector_dimension())
127+
, K(request.GetK())
126128
, ScanSettings(request.GetScanSettings())
127129
, ResponseActorId{responseActorId}
128130
, Response{std::move(response)}
@@ -365,17 +367,26 @@ class TLocalKMeansScan: public TActor<TLocalKMeansScan>, public NTable::IScan {
365367
{
366368
if (State == EState::SAMPLE) {
367369
State = EState::KMEANS;
368-
if (!Clusters->SetClusters(Sampler.Finish().second)) {
370+
auto rows = Sampler.Finish().second;
371+
if (rows.size() == 0) {
369372
// We don't need to do anything,
370373
// because this datashard doesn't have valid embeddings for this prefix
371374
return true;
372375
}
376+
if (rows.size() < K) {
377+
// if this datashard have less than K valid embeddings for this parent
378+
// lets make single centroid for it
379+
rows.resize(1);
380+
}
381+
bool ok = Clusters->SetClusters(std::move(rows));
382+
Y_ENSURE(ok);
373383
Clusters->InitAggregatedClusters();
374384
return false; // do KMEANS
375385
}
376386

377387
if (State == EState::KMEANS) {
378388
if (Clusters->RecomputeClusters()) {
389+
Clusters->RemoveEmptyClusters();
379390
FormLevelRows();
380391
State = UploadState;
381392
return false; // do UPLOAD_*

0 commit comments

Comments
 (0)