Skip to content

Commit 750d3b7

Browse files
authored
Vector index workload improvements for large tables (#20734)
1 parent a54252c commit 750d3b7

10 files changed

+331
-281
lines changed
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
#include "vector_workload_params.h"
2-
#include <ydb/library/workload/abstract/workload_factory.h>
3-
4-
namespace NYdbWorkload {
5-
6-
TWorkloadFactory::TRegistrator<TVectorWorkloadParams> VectorRegistrar("vector");
7-
8-
}
1+
#include "vector_workload_params.h"
2+
#include <ydb/library/workload/abstract/workload_factory.h>
3+
4+
namespace NYdbWorkload {
5+
6+
TWorkloadFactory::TRegistrator<TVectorWorkloadParams> VectorRegistrar("vector");
7+
8+
}

ydb/library/workload/vector/vector_recall_evaluator.cpp

Lines changed: 126 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -15,100 +15,144 @@
1515

1616
namespace NYdbWorkload {
1717

18+
using Clock = std::chrono::steady_clock;
19+
1820
// TVectorRecallEvaluator implementation
1921
TVectorRecallEvaluator::TVectorRecallEvaluator(const TVectorWorkloadParams& params)
2022
: Params(params)
2123
{
2224
}
2325

26+
void TVectorRecallEvaluator::SelectReferenceResults(const TVectorSampler& sampler) {
27+
Cout << "Selecting reference results..." << Endl;
28+
const auto startTime = Clock::now();
29+
30+
auto [functionName, isAscending] = GetMetricInfo(Params.Metric);
31+
32+
// We'll select multiple reference results with one full scan using a window function
33+
auto refQueryBuilder = TStringBuilder() << "--!syntax_v1\n"
34+
<< "DECLARE $Samples as List<Struct<id: uint64, embedding: string";
35+
if (Params.PrefixColumn) {
36+
refQueryBuilder << ", prefix: " << Params.PrefixType;
37+
}
38+
refQueryBuilder << ">>;\n"
39+
<< "SELECT * FROM ("
40+
<< " SELECT s.id AS id"
41+
<< ", UNWRAP(CAST(m." << Params.KeyColumn << " AS string)) AS result_id"
42+
<< ", UNWRAP(Knn::" << functionName << "(m." << Params.EmbeddingColumn << ", s.embedding)) AS distance"
43+
<< ", (ROW_NUMBER() OVER w) AS position"
44+
<< " FROM " << Params.TableName << " m"
45+
<< (Params.PrefixColumn ? " INNER JOIN " : " CROSS JOIN ") << "AS_TABLE($Samples) AS s";
46+
if (Params.PrefixColumn) {
47+
refQueryBuilder << " ON s.prefix = m." << *Params.PrefixColumn;
48+
}
49+
refQueryBuilder << " WINDOW w AS (PARTITION BY s.id"
50+
<< " ORDER BY Knn::" << functionName << "(m." << Params.EmbeddingColumn << ", s.embedding)"
51+
<< (isAscending ? " ASC" : " DESC") << ")"
52+
<< ") AS t WHERE position <= " << Params.Limit
53+
<< " ORDER BY id, position";
54+
55+
std::string refQuery = refQueryBuilder;
56+
57+
// Process targets in batches (batch size should be ~10000 / Limit)
58+
const ui64 batchSize = 10000 / Params.Limit;
59+
for (ui64 batchStart = 0; batchStart < sampler.GetTargetCount(); batchStart += batchSize) {
60+
const size_t batchEnd = (batchStart + batchSize < sampler.GetTargetCount() ? batchStart + batchSize : sampler.GetTargetCount());
61+
NYdb::TParamsBuilder paramsBuilder;
62+
63+
auto & builder = paramsBuilder.AddParam("$Samples").BeginList();
64+
for (size_t i = batchStart; i < batchEnd; i++) {
65+
builder.AddListItem()
66+
.BeginStruct()
67+
.AddMember("id").Uint64(i)
68+
.AddMember("embedding").String(sampler.GetTargetEmbedding(i));
69+
if (Params.PrefixColumn) {
70+
builder.AddMember("prefix", sampler.GetPrefixValue(i));
71+
}
72+
builder.EndStruct();
73+
}
74+
builder.EndList().Build();
75+
76+
std::optional<NYdb::TResultSet> resultSet;
77+
NYdb::NStatusHelpers::ThrowOnError(Params.QueryClient->RetryQuerySync([&](NYdb::NQuery::TSession session) {
78+
auto result = session.ExecuteQuery(refQuery, NYdb::NQuery::TTxControl::NoTx(), paramsBuilder.Build())
79+
.GetValueSync();
80+
Y_ABORT_UNLESS(result.IsSuccess(), "Reference search result query failed: %s", result.GetIssues().ToString().c_str());
81+
resultSet = result.GetResultSet(0);
82+
return result;
83+
}));
84+
85+
ui64 refId = 0;
86+
std::vector<std::string> refList;
87+
88+
NYdb::TResultSetParser parser(*resultSet);
89+
while (parser.TryNextRow()) {
90+
ui64 id = parser.ColumnParser("id").GetUint64();
91+
std::string res = parser.ColumnParser("result_id").GetString();
92+
if (id != refId) {
93+
if (refList.size()) {
94+
References[refId] = refList;
95+
}
96+
refList.clear();
97+
refId = id;
98+
}
99+
refList.push_back(res);
100+
}
101+
if (refList.size()) {
102+
References[refId] = std::move(refList);
103+
}
104+
}
105+
Cout << "Reference results for " << sampler.GetTargetCount()
106+
<< " targets selected in " << (int)((Clock::now() - startTime) / std::chrono::seconds(1)) << " seconds.\n";
107+
}
108+
24109
void TVectorRecallEvaluator::MeasureRecall(const TVectorSampler& sampler) {
110+
SelectReferenceResults(sampler);
111+
25112
Cout << "Recall measurement..." << Endl;
26-
27-
// Prepare the query for scan
28-
std::string queryScan = MakeSelect(Params.TableName, {}, Params.KeyColumn, Params.EmbeddingColumn, Params.PrefixColumn, 0, Params.Metric);
29-
113+
30114
// Create the query for index search
31-
std::string queryIndex = MakeSelect(Params.TableName, Params.IndexName, Params.KeyColumn, Params.EmbeddingColumn, Params.PrefixColumn, Params.KmeansTreeSearchClusters, Params.Metric);
115+
std::string queryIndex = MakeSelect(Params, Params.IndexName);
32116

33117
// Process targets in batches
118+
const auto startTime = Clock::now();
34119
for (size_t batchStart = 0; batchStart < sampler.GetTargetCount(); batchStart += Params.RecallThreads) {
35120
size_t batchEnd = std::min(batchStart + Params.RecallThreads, sampler.GetTargetCount());
36-
37-
// Start async queries for this batch - both scan and index queries
38-
std::vector<std::pair<size_t, NYdb::NQuery::TAsyncExecuteQueryResult>> asyncScanQueries;
121+
122+
// Start async queries for this batch
39123
std::vector<std::pair<size_t, NYdb::NQuery::TAsyncExecuteQueryResult>> asyncIndexQueries;
40-
asyncScanQueries.reserve(batchEnd - batchStart);
41124
asyncIndexQueries.reserve(batchEnd - batchStart);
42-
125+
43126
for (size_t i = batchStart; i < batchEnd; i++) {
44127
const auto& targetEmbedding = sampler.GetTargetEmbedding(i);
45-
std::optional<i64> prefixValue;
128+
std::optional<NYdb::TValue> prefixValue;
46129
if (Params.PrefixColumn) {
47130
prefixValue = sampler.GetPrefixValue(i);
48131
}
49132

50133
NYdb::TParams params = MakeSelectParams(targetEmbedding, prefixValue, Params.Limit);
51-
52-
// Execute scan query for ground truth
53-
auto asyncScanResult = Params.QueryClient->RetryQuery([queryScan, params](NYdb::NQuery::TSession session) {
54-
return session.ExecuteQuery(
55-
queryScan,
56-
NYdb::NQuery::TTxControl::NoTx(),
57-
params);
58-
});
59-
134+
60135
// Execute index query for recall measurement
61136
auto asyncIndexResult = Params.QueryClient->RetryQuery([queryIndex, params](NYdb::NQuery::TSession session) {
62137
return session.ExecuteQuery(
63138
queryIndex,
64139
NYdb::NQuery::TTxControl::NoTx(),
65140
params);
66141
});
67-
68-
asyncScanQueries.emplace_back(i, std::move(asyncScanResult));
142+
69143
asyncIndexQueries.emplace_back(i, std::move(asyncIndexResult));
70144
}
71-
72-
// Wait for all scan queries in this batch to complete and build ground truth
73-
std::unordered_map<size_t, std::vector<ui64>> batchEtalons;
74-
for (auto& [targetIndex, asyncResult] : asyncScanQueries) {
75-
auto result = asyncResult.GetValueSync();
76-
Y_ABORT_UNLESS(result.IsSuccess(), "Scan query failed for target %zu: %s",
77-
targetIndex, result.GetIssues().ToString().c_str());
78-
79-
auto resultSet = result.GetResultSet(0);
80-
NYdb::TResultSetParser parser(resultSet);
81-
82-
// Build etalons for this target locally
83-
std::vector<ui64> etalons;
84-
etalons.reserve(Params.Limit);
85-
86-
// Extract all IDs from the result set
87-
while (parser.TryNextRow()) {
88-
ui64 id = parser.ColumnParser(Params.KeyColumn).GetUint64();
89-
etalons.push_back(id);
90-
}
91-
if (etalons.empty()) {
92-
Cerr << "Warning: target " << targetIndex << " have empty etalon sets" << Endl;
93-
}
94-
95-
batchEtalons[targetIndex] = std::move(etalons);
96-
}
97-
145+
98146
// Wait for all index queries in this batch to complete and measure recall
99147
for (auto& [targetIndex, asyncResult] : asyncIndexQueries) {
100148
auto result = asyncResult.GetValueSync();
101-
// Process the index query result and calculate recall using etalon nearest neighbours
102-
ProcessIndexQueryResult(result, targetIndex, batchEtalons[targetIndex], false);
103-
}
104-
105-
// Log progress for large datasets
106-
if (sampler.GetTargetCount() > 100 && (batchEnd % 100 == 0 || batchEnd == sampler.GetTargetCount())) {
107-
Cout << "Processed " << batchEnd << " of " << sampler.GetTargetCount() << " targets..." << Endl;
149+
// Process the index query result and calculate recall using reference nearest neighbours
150+
ProcessIndexQueryResult(result, targetIndex, References[targetIndex], false);
108151
}
109152
}
110-
111-
Cout << "Recall measurement completed for " << sampler.GetTargetCount() << " targets."
153+
154+
Cout << "Recall measurement completed for " << sampler.GetTargetCount()
155+
<< " targets in " << (int)((Clock::now() - startTime) / std::chrono::seconds(1)) << " seconds."
112156
<< "\nAverage recall: " << GetAverageRecall() << Endl;
113157
}
114158

@@ -130,55 +174,52 @@ size_t TVectorRecallEvaluator::GetProcessedTargets() const {
130174
}
131175

132176
// Process index query results
133-
void TVectorRecallEvaluator::ProcessIndexQueryResult(const NYdb::NQuery::TExecuteQueryResult& queryResult, size_t targetIndex, const std::vector<ui64>& etalons, bool verbose) {
134-
if (!queryResult.IsSuccess()) {
135-
// Ignore the error. It's printed in the verbose mode
136-
return;
137-
}
138-
177+
void TVectorRecallEvaluator::ProcessIndexQueryResult(const NYdb::NQuery::TExecuteQueryResult& queryResult,
178+
size_t targetIndex, const std::vector<std::string>& references, bool verbose) {
179+
Y_ABORT_UNLESS(queryResult.IsSuccess(), "Query failed: %s", queryResult.GetIssues().ToString().c_str());
180+
139181
// Get the result set
140182
auto resultSet = queryResult.GetResultSet(0);
141183
NYdb::TResultSetParser parser(resultSet);
142-
184+
143185
// Extract IDs from index search results
144-
std::vector<ui64> indexResults;
186+
std::vector<std::string> indexResults;
145187
while (parser.TryNextRow()) {
146-
ui64 id = parser.ColumnParser(Params.KeyColumn).GetUint64();
147-
indexResults.push_back(id);
188+
indexResults.push_back(parser.ColumnParser("id").GetString());
148189
}
149-
150-
// Create etalon set for efficient lookup
151-
std::unordered_set<ui64> etalonSet(etalons.begin(), etalons.end());
152-
153-
// Check if target ID is first in results
154-
if (!indexResults.empty() && !etalons.empty()) {
155-
ui64 targetId = etalons[0]; // First etalon is the target ID itself
156-
190+
191+
// Create reference set for efficient lookup
192+
std::unordered_set<std::string> referenceSet(references.begin(), references.end());
193+
194+
if (references.empty()) {
195+
Cerr << "Warning: Empty references for target " << targetIndex << "\n";
196+
} else if (indexResults.empty()) {
197+
Cerr << "Warning: Empty index results for target " << targetIndex << "\n";
198+
} else {
199+
// Check if the first reference is the target ID itself
200+
const std::string & targetId = references[0];
201+
157202
if (verbose && indexResults[0] != targetId) {
158-
Cerr << "Warning: Target ID " << targetId << " is not the first result for target "
203+
Cerr << "Warning: Target ID " << targetId << " is not the first result for target "
159204
<< targetIndex << ". Found " << indexResults[0] << " instead." << Endl;
160205
}
161-
206+
162207
// Calculate recall
163208
size_t relevantRetrieved = 0;
164209
for (const auto& id : indexResults) {
165-
if (etalonSet.count(id)) {
210+
if (referenceSet.count(id)) {
166211
relevantRetrieved++;
167212
}
168213
}
169-
214+
170215
// Calculate recall for this target
171-
double recall = etalons.empty() ? 0.0 : static_cast<double>(relevantRetrieved) / etalons.size();
216+
double recall = references.empty() ? 0.0 : static_cast<double>(relevantRetrieved) / references.size();
172217
AddRecall(recall);
173218

174219
// Add warning when zero relevant results found
175220
if (verbose && relevantRetrieved == 0 && !indexResults.empty()) {
176221
Cerr << "Warning: Zero relevant results for target " << targetIndex << Endl;
177222
}
178-
} else {
179-
// Handle empty results or empty etalons
180-
if (verbose)
181-
Cerr << "Warning: Empty results or etalons for target " << targetIndex << Endl;
182223
}
183224
}
184225

ydb/library/workload/vector/vector_recall_evaluator.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,28 @@ class TVectorSampler;
1414
class TVectorRecallEvaluator {
1515
public:
1616
TVectorRecallEvaluator(const TVectorWorkloadParams& params);
17-
17+
1818
// Core functionality for recall measurement using sampled vectors
1919
void MeasureRecall(const TVectorSampler& sampler);
20-
20+
2121
// Recall metrics methods
2222
void AddRecall(double recall);
2323
double GetAverageRecall() const;
2424
double GetTotalRecall() const;
2525
size_t GetProcessedTargets() const;
26-
26+
2727
private:
28+
void SelectReferenceResults(const TVectorSampler& sampler);
29+
2830
// Process index query results (internal method)
29-
void ProcessIndexQueryResult(const NYdb::NQuery::TExecuteQueryResult& result, size_t targetIndex, const std::vector<ui64>& etalons, bool verbose);
30-
31+
void ProcessIndexQueryResult(const NYdb::NQuery::TExecuteQueryResult& result, size_t targetIndex, const std::vector<std::string>& references, bool verbose);
32+
3133
const TVectorWorkloadParams& Params;
3234

3335
double TotalRecall = 0.0;
3436
size_t ProcessedTargets = 0;
37+
38+
std::unordered_map<ui64, std::vector<std::string>> References;
3539
};
3640

3741
} // namespace NYdbWorkload

0 commit comments

Comments
 (0)