Skip to content

Commit 7d6e252

Browse files
vitalifkunga
authored andcommitted
Fix prefixed vector index with PK columns (#18196) (#18889)
1 parent 0b5ca19 commit 7d6e252

File tree

9 files changed

+141
-74
lines changed

9 files changed

+141
-74
lines changed

ydb/core/kqp/ut/indexes/kqp_indexes_prefixed_vector_ut.cpp

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) {
173173
DoPositiveQueriesPrefixedVectorIndexOrderBy(session, "CosineSimilarity", "DESC", covered);
174174
}
175175

176-
TSession DoCreateTableForPrefixedVectorIndex(TTableClient& db, bool nullable) {
176+
TSession DoCreateTableForPrefixedVectorIndex(TTableClient& db, bool nullable, bool suffixPk = false) {
177177
auto session = db.CreateSession().GetValueSync().GetSession();
178178

179179
{
@@ -191,14 +191,25 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) {
191191
.AddNonNullableColumn("emb", EPrimitiveType::String)
192192
.AddNonNullableColumn("data", EPrimitiveType::String);
193193
}
194-
tableBuilder.SetPrimaryKeyColumns({"pk"});
194+
if (suffixPk) {
195+
tableBuilder.SetPrimaryKeyColumns({"pk", "user"});
196+
} else {
197+
tableBuilder.SetPrimaryKeyColumns({"pk"});
198+
}
195199
tableBuilder.BeginPartitioningSettings()
196200
.SetMinPartitionsCount(3)
197-
.EndPartitioningSettings();
198-
auto partitions = TExplicitPartitions{}
199-
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(40).EndTuple().Build())
200-
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(60).EndTuple().Build());
201-
tableBuilder.SetPartitionAtKeys(partitions);
201+
.EndPartitioningSettings();
202+
if (suffixPk) {
203+
auto partitions = TExplicitPartitions{}
204+
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(40).AddElement().OptionalString("").EndTuple().Build())
205+
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(60).AddElement().OptionalString("").EndTuple().Build());
206+
tableBuilder.SetPartitionAtKeys(partitions);
207+
} else {
208+
auto partitions = TExplicitPartitions{}
209+
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(40).EndTuple().Build())
210+
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(60).EndTuple().Build());
211+
tableBuilder.SetPartitionAtKeys(partitions);
212+
}
202213
auto result = session.CreateTable("/Root/TestTable", tableBuilder.Build()).ExtractValueSync();
203214
UNIT_ASSERT_VALUES_EQUAL(result.IsTransportError(), false);
204215
UNIT_ASSERT_VALUES_EQUAL_C(result.GetStatus(), EStatus::SUCCESS, result.GetIssues().ToString());
@@ -488,6 +499,58 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) {
488499
DoPositiveQueriesPrefixedVectorIndexOrderByCosine(session, true /*covered*/);
489500
}
490501

502+
Y_UNIT_TEST_QUAD(CosineDistanceWithPkPrefix, Nullable, Covered) {
503+
NKikimrConfig::TFeatureFlags featureFlags;
504+
featureFlags.SetEnableVectorIndex(true);
505+
auto setting = NKikimrKqp::TKqpSetting();
506+
auto serverSettings = TKikimrSettings()
507+
.SetFeatureFlags(featureFlags)
508+
.SetKqpSettings({setting});
509+
510+
TKikimrRunner kikimr(serverSettings);
511+
kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::BUILD_INDEX, NActors::NLog::PRI_TRACE);
512+
kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::FLAT_TX_SCHEMESHARD, NActors::NLog::PRI_TRACE);
513+
514+
auto db = kikimr.GetTableClient();
515+
516+
auto session = DoCreateTableForPrefixedVectorIndex(db, Nullable, true);
517+
{
518+
const TString createIndex(Q_(Sprintf(R"(
519+
ALTER TABLE `/Root/TestTable`
520+
ADD INDEX index
521+
GLOBAL USING vector_kmeans_tree
522+
ON (user, emb) %s
523+
WITH (distance=cosine, vector_type="uint8", vector_dimension=2, levels=2, clusters=2);
524+
)", (Covered ? "COVER (emb, data)" : ""))));
525+
526+
auto result = session.ExecuteSchemeQuery(createIndex)
527+
.ExtractValueSync();
528+
529+
UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString());
530+
}
531+
{
532+
auto result = session.DescribeTable("/Root/TestTable").ExtractValueSync();
533+
UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), NYdb::EStatus::SUCCESS);
534+
const auto& indexes = result.GetTableDescription().GetIndexDescriptions();
535+
UNIT_ASSERT_EQUAL(indexes.size(), 1);
536+
UNIT_ASSERT_EQUAL(indexes[0].GetIndexName(), "index");
537+
std::vector<std::string> indexKeyColumns{"user", "emb"};
538+
UNIT_ASSERT_EQUAL(indexes[0].GetIndexColumns(), indexKeyColumns);
539+
std::vector<std::string> indexDataColumns;
540+
if (Covered) {
541+
indexDataColumns = {"emb", "data"};
542+
}
543+
UNIT_ASSERT_EQUAL(indexes[0].GetDataColumns(), indexDataColumns);
544+
const auto& settings = std::get<TKMeansTreeSettings>(indexes[0].GetIndexSettings());
545+
UNIT_ASSERT_EQUAL(settings.Settings.Metric, NYdb::NTable::TVectorIndexSettings::EMetric::CosineDistance);
546+
UNIT_ASSERT_EQUAL(settings.Settings.VectorType, NYdb::NTable::TVectorIndexSettings::EVectorType::Uint8);
547+
UNIT_ASSERT_EQUAL(settings.Settings.VectorDimension, 2);
548+
UNIT_ASSERT_EQUAL(settings.Levels, 2);
549+
UNIT_ASSERT_EQUAL(settings.Clusters, 2);
550+
}
551+
DoPositiveQueriesPrefixedVectorIndexOrderByCosine(session, Covered);
552+
}
553+
491554
}
492555

493556
}

ydb/core/protos/tx_datashard.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,6 +1662,8 @@ message TEvPrefixKMeansRequest {
16621662
optional uint32 PrefixColumns = 17;
16631663

16641664
optional NKikimrIndexBuilder.TIndexBuildScanSettings ScanSettings = 18;
1665+
1666+
repeated string SourcePrimaryKeyColumns = 19;
16651667
}
16661668

16671669
message TEvPrefixKMeansResponse {

ydb/core/tx/datashard/build_index/kmeans_helper.cpp

Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -45,51 +45,22 @@ void AddRowToLevel(TBufferData& buffer, TClusterId parent, TClusterId child, con
4545
buffer.AddRow(TSerializedCellVec{pk}, TSerializedCellVec::Serialize(data));
4646
}
4747

48-
void AddRowMainToBuild(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row) {
49-
EnsureNoPostingParentFlag(parent);
50-
51-
std::array<TCell, 1> cells;
52-
cells[0] = TCell::Make(parent);
53-
auto pk = TSerializedCellVec::Serialize(cells);
54-
TSerializedCellVec::UnsafeAppendCells(key, pk);
55-
buffer.AddRow(TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(row),
56-
TSerializedCellVec{key});
57-
}
58-
59-
void AddRowMainToPosting(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 dataPos)
60-
{
61-
parent = SetPostingParentFlag(parent);
62-
63-
std::array<TCell, 1> cells;
64-
cells[0] = TCell::Make(parent);
65-
auto pk = TSerializedCellVec::Serialize(cells);
66-
TSerializedCellVec::UnsafeAppendCells(key, pk);
67-
buffer.AddRow(TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(row.Slice(dataPos)),
68-
TSerializedCellVec{key});
69-
}
70-
71-
void AddRowBuildToBuild(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 prefixColumns)
72-
{
73-
EnsureNoPostingParentFlag(parent);
48+
void AddRowToData(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> sourcePk,
49+
TArrayRef<const TCell> dataColumns, TArrayRef<const TCell> origKey, bool isPostingLevel) {
50+
if (isPostingLevel) {
51+
parent = SetPostingParentFlag(parent);
52+
} else {
53+
EnsureNoPostingParentFlag(parent);
54+
}
7455

7556
std::array<TCell, 1> cells;
7657
cells[0] = TCell::Make(parent);
7758
auto pk = TSerializedCellVec::Serialize(cells);
78-
TSerializedCellVec::UnsafeAppendCells(key.Slice(prefixColumns), pk);
79-
buffer.AddRow(TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(row),
80-
TSerializedCellVec{key});
81-
}
82-
83-
void AddRowBuildToPosting(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 dataPos, ui32 prefixColumns)
84-
{
85-
parent = SetPostingParentFlag(parent);
59+
TSerializedCellVec::UnsafeAppendCells(sourcePk, pk);
8660

87-
std::array<TCell, 1> cells;
88-
cells[0] = TCell::Make(parent);
89-
auto pk = TSerializedCellVec::Serialize(cells);
90-
TSerializedCellVec::UnsafeAppendCells(key.Slice(prefixColumns), pk);
91-
buffer.AddRow(TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(row.Slice(dataPos)),
92-
TSerializedCellVec{key});
61+
buffer.AddRow(TSerializedCellVec{std::move(pk)},
62+
TSerializedCellVec::Serialize(dataColumns),
63+
TSerializedCellVec{origKey});
9364
}
9465

9566
TTags MakeScanTags(const TUserTable& table, const TProtoStringType& embedding,
@@ -114,12 +85,11 @@ TTags MakeScanTags(const TUserTable& table, const TProtoStringType& embedding,
11485

11586
std::shared_ptr<NTxProxy::TUploadTypes> MakeOutputTypes(const TUserTable& table, NKikimrTxDataShard::EKMeansState uploadState,
11687
const TProtoStringType& embedding, const google::protobuf::RepeatedPtrField<TProtoStringType>& data,
117-
ui32 prefixColumns)
88+
const google::protobuf::RepeatedPtrField<TProtoStringType>& pkColumns)
11889
{
11990
auto types = GetAllTypes(table);
12091

12192
auto result = std::make_shared<NTxProxy::TUploadTypes>();
122-
result->reserve(1 + 1 + std::min((table.KeyColumnTypes.size() - prefixColumns) + data.size(), types.size()));
12393

12494
Ydb::Type type;
12595
type.set_type_id(NTableIndex::ClusterIdType);
@@ -133,8 +103,14 @@ std::shared_ptr<NTxProxy::TUploadTypes> MakeOutputTypes(const TUserTable& table,
133103
types.erase(it);
134104
}
135105
};
136-
for (const auto& column : table.KeyColumnIds | std::views::drop(prefixColumns)) {
137-
addType(table.Columns.at(column).Name);
106+
if (pkColumns.size()) {
107+
for (const auto& column : pkColumns) {
108+
addType(column);
109+
}
110+
} else {
111+
for (const auto& column : table.KeyColumnIds) {
112+
addType(table.Columns.at(column).Name);
113+
}
138114
}
139115
switch (uploadState) {
140116
case NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_BUILD:

ydb/core/tx/datashard/build_index/kmeans_helper.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,21 +140,16 @@ struct TMaxInnerProductSimilarity : TMetric<TCoord> {
140140

141141
void AddRowToLevel(TBufferData& buffer, TClusterId parent, TClusterId child, const TString& embedding, bool isPostingLevel);
142142

143-
void AddRowMainToBuild(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row);
143+
void AddRowToData(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> sourcePk,
144+
TArrayRef<const TCell> dataColumns, TArrayRef<const TCell> origKey, bool isPostingLevel);
144145

145-
void AddRowMainToPosting(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 dataPos);
146-
147-
void AddRowBuildToBuild(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 prefixColumns = 1);
148-
149-
void AddRowBuildToPosting(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 dataPos, ui32 prefixColumns = 1);
150-
151-
TTags MakeScanTags(const TUserTable& table, const TProtoStringType& embedding,
146+
TTags MakeScanTags(const TUserTable& table, const TProtoStringType& embedding,
152147
const google::protobuf::RepeatedPtrField<TProtoStringType>& data, ui32& embeddingPos,
153148
ui32& dataPos, NTable::TTag& embeddingTag);
154149

155150
std::shared_ptr<NTxProxy::TUploadTypes> MakeOutputTypes(const TUserTable& table, NKikimrTxDataShard::EKMeansState uploadState,
156151
const TProtoStringType& embedding, const google::protobuf::RepeatedPtrField<TProtoStringType>& data,
157-
ui32 prefixColumns = 0);
152+
const google::protobuf::RepeatedPtrField<TProtoStringType>& pkColumns = {});
158153

159154
void MakeScan(auto& record, const auto& createScan, const auto& badRequest)
160155
{

ydb/core/tx/datashard/build_index/local_kmeans.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,28 +447,28 @@ class TLocalKMeansScan final : public TLocalKMeansScanBase {
447447
void FeedMainToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row) noexcept
448448
{
449449
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
450-
AddRowMainToBuild(*OutputBuf, Child + *pos, key, row);
450+
AddRowToData(*OutputBuf, Child + *pos, key, row, key, false);
451451
}
452452
}
453453

454454
void FeedMainToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row) noexcept
455455
{
456456
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
457-
AddRowMainToPosting(*OutputBuf, Child + *pos, key, row, DataPos);
457+
AddRowToData(*OutputBuf, Child + *pos, key, row.Slice(DataPos), key, true);
458458
}
459459
}
460460

461461
void FeedBuildToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row) noexcept
462462
{
463463
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
464-
AddRowBuildToBuild(*OutputBuf, Child + *pos, key, row);
464+
AddRowToData(*OutputBuf, Child + *pos, key.Slice(1), row, key, false);
465465
}
466466
}
467467

468468
void FeedBuildToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row) noexcept
469469
{
470470
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
471-
AddRowBuildToPosting(*OutputBuf, Child + *pos, key, row, DataPos);
471+
AddRowToData(*OutputBuf, Child + *pos, key.Slice(1), row.Slice(DataPos), key, true);
472472
}
473473
}
474474

ydb/core/tx/datashard/build_index/prefix_kmeans.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ using namespace NKMeans;
2929
*
3030
* Request:
3131
* - The client sends TEvPrefixKMeansRequest with:
32-
* - Child: base ID from which new cluster IDs are assigned within this request.
32+
* - Child: base ID from which new cluster IDs are assigned within this request.
3333
* - Each prefix group processed will be assigned cluster IDs starting at Child + 1.
3434
* - For a request with K clusters per prefix, the IDs used for the first prefix group are
3535
* (Child + 1) to (Child + K), and the parent ID for these is Child.
@@ -93,6 +93,9 @@ class TPrefixKMeansScanBase: public TActor<TPrefixKMeansScanBase>, public NTable
9393

9494
// FIXME: save PrefixRows as std::vector<std::pair<TSerializedCellVec, TSerializedCellVec>> to avoid parsing
9595
const ui32 PrefixColumns;
96+
// for PrefixKMeans, original table's primary key columns are passed separately,
97+
// because the prefix table contains them in a different order if they are both in PK and in the prefix
98+
const ui32 DataColumnCount;
9699
TSerializedCellVec Prefix;
97100
TBufferData PrefixRows;
98101
bool IsFirstPrefixFeed = true;
@@ -126,10 +129,14 @@ class TPrefixKMeansScanBase: public TActor<TPrefixKMeansScanBase>, public NTable
126129
, ResponseActorId{responseActorId}
127130
, Response{std::move(response)}
128131
, PrefixColumns{request.GetPrefixColumns()}
132+
, DataColumnCount{(ui32)request.GetDataColumns().size()}
129133
{
130134
const auto& embedding = request.GetEmbeddingColumn();
131-
const auto& data = request.GetDataColumns();
132-
ScanTags = MakeScanTags(table, embedding, data, EmbeddingPos, DataPos, EmbeddingTag);
135+
TVector<TString> data{request.GetDataColumns().begin(), request.GetDataColumns().end()};
136+
for (auto & col: request.GetSourcePrimaryKeyColumns()) {
137+
data.push_back(col);
138+
}
139+
ScanTags = MakeScanTags(table, embedding, {data.begin(), data.end()}, EmbeddingPos, DataPos, EmbeddingTag);
133140
Lead.To(ScanTags, {}, NTable::ESeek::Lower);
134141
{
135142
Ydb::Type type;
@@ -141,7 +148,11 @@ class TPrefixKMeansScanBase: public TActor<TPrefixKMeansScanBase>, public NTable
141148
(*levelTypes)[2] = {NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn, type};
142149
LevelBuf = Uploader.AddDestination(request.GetLevelName(), std::move(levelTypes));
143150
}
144-
OutputBuf = Uploader.AddDestination(request.GetOutputName(), MakeOutputTypes(table, UploadState, embedding, data, PrefixColumns));
151+
{
152+
auto outputTypes = MakeOutputTypes(table, UploadState, embedding,
153+
{data.begin(), data.begin()+request.GetDataColumns().size()}, request.GetSourcePrimaryKeyColumns());
154+
OutputBuf = Uploader.AddDestination(request.GetOutputName(), outputTypes);
155+
}
145156
{
146157
auto types = GetAllTypes(table);
147158

@@ -465,14 +476,14 @@ class TPrefixKMeansScan final : public TPrefixKMeansScanBase {
465476
void FeedBuildToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
466477
{
467478
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
468-
AddRowBuildToBuild(*OutputBuf, Child + *pos, key, row, PrefixColumns);
479+
AddRowToData(*OutputBuf, Child + *pos, row.Slice(DataPos+DataColumnCount), row.Slice(0, DataPos+DataColumnCount), key, false);
469480
}
470481
}
471482

472483
void FeedBuildToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
473484
{
474485
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
475-
AddRowBuildToPosting(*OutputBuf, Child + *pos, key, row, DataPos, PrefixColumns);
486+
AddRowToData(*OutputBuf, Child + *pos, row.Slice(DataPos+DataColumnCount), row.Slice(DataPos, DataColumnCount), key, true);
476487
}
477488
}
478489

@@ -609,6 +620,14 @@ void TDataShard::HandleSafe(TEvDataShard::TEvPrefixKMeansRequest::TPtr& ev, cons
609620
if (request.GetPrefixColumns() > userTable.KeyColumnIds.size()) {
610621
badRequest(TStringBuilder() << "Should not be requested on more than " << userTable.KeyColumnIds.size() << " prefix columns");
611622
}
623+
if (request.GetSourcePrimaryKeyColumns().size() == 0) {
624+
badRequest("Request should include source primary key columns");
625+
}
626+
for (auto pkColumn : request.GetSourcePrimaryKeyColumns()) {
627+
if (!tags.contains(pkColumn)) {
628+
badRequest(TStringBuilder() << "Unknown source primary key column: " << pkColumn);
629+
}
630+
}
612631

613632
if (trySendBadRequest()) {
614633
return;

ydb/core/tx/datashard/build_index/reshuffle_kmeans.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,28 +293,28 @@ class TReshuffleKMeansScan final : public TReshuffleKMeansScanBase {
293293
void FeedMainToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row) noexcept
294294
{
295295
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
296-
AddRowMainToBuild(*OutputBuf, Child + *pos, key, row);
296+
AddRowToData(*OutputBuf, Child + *pos, key, row, key, false);
297297
}
298298
}
299299

300300
void FeedMainToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row) noexcept
301301
{
302302
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
303-
AddRowMainToPosting(*OutputBuf, Child + *pos, key, row, DataPos);
303+
AddRowToData(*OutputBuf, Child + *pos, key, row.Slice(DataPos), key, true);
304304
}
305305
}
306306

307307
void FeedBuildToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row) noexcept
308308
{
309309
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
310-
AddRowBuildToBuild(*OutputBuf, Child + *pos, key, row);
310+
AddRowToData(*OutputBuf, Child + *pos, key.Slice(1), row, key, false);
311311
}
312312
}
313313

314314
void FeedBuildToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row) noexcept
315315
{
316316
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
317-
AddRowBuildToPosting(*OutputBuf, Child + *pos, key, row, DataPos);
317+
AddRowToData(*OutputBuf, Child + *pos, key.Slice(1), row.Slice(DataPos), key, true);
318318
}
319319
}
320320
};

0 commit comments

Comments
 (0)