15
15
16
16
namespace NYdbWorkload {
17
17
18
+ using Clock = std::chrono::steady_clock;
19
+
18
20
// TVectorRecallEvaluator implementation
19
21
TVectorRecallEvaluator::TVectorRecallEvaluator (const TVectorWorkloadParams& params)
20
22
: Params(params)
21
23
{
22
24
}
23
25
26
+ void TVectorRecallEvaluator::SelectReferenceResults (const TVectorSampler& sampler) {
27
+ Cout << " Selecting reference results..." << Endl;
28
+ const auto startTime = Clock::now ();
29
+
30
+ auto [functionName, isAscending] = GetMetricInfo (Params.Metric );
31
+
32
+ // We'll select multiple reference results with one full scan using a window function
33
+ auto refQueryBuilder = TStringBuilder () << " --!syntax_v1\n "
34
+ << " DECLARE $Samples as List<Struct<id: uint64, embedding: string" ;
35
+ if (Params.PrefixColumn ) {
36
+ refQueryBuilder << " , prefix: " << Params.PrefixType ;
37
+ }
38
+ refQueryBuilder << " >>;\n "
39
+ << " SELECT * FROM ("
40
+ << " SELECT s.id AS id"
41
+ << " , UNWRAP(CAST(m." << Params.KeyColumn << " AS string)) AS result_id"
42
+ << " , UNWRAP(Knn::" << functionName << " (m." << Params.EmbeddingColumn << " , s.embedding)) AS distance"
43
+ << " , (ROW_NUMBER() OVER w) AS position"
44
+ << " FROM " << Params.TableName << " m"
45
+ << (Params.PrefixColumn ? " INNER JOIN " : " CROSS JOIN " ) << " AS_TABLE($Samples) AS s" ;
46
+ if (Params.PrefixColumn ) {
47
+ refQueryBuilder << " ON s.prefix = m." << *Params.PrefixColumn ;
48
+ }
49
+ refQueryBuilder << " WINDOW w AS (PARTITION BY s.id"
50
+ << " ORDER BY Knn::" << functionName << " (m." << Params.EmbeddingColumn << " , s.embedding)"
51
+ << (isAscending ? " ASC" : " DESC" ) << " )"
52
+ << " ) AS t WHERE position <= " << Params.Limit
53
+ << " ORDER BY id, position" ;
54
+
55
+ std::string refQuery = refQueryBuilder;
56
+
57
+ // Process targets in batches (batch size should be ~10000 / Limit)
58
+ const ui64 batchSize = 10000 / Params.Limit ;
59
+ for (ui64 batchStart = 0 ; batchStart < sampler.GetTargetCount (); batchStart += batchSize) {
60
+ const size_t batchEnd = (batchStart + batchSize < sampler.GetTargetCount () ? batchStart + batchSize : sampler.GetTargetCount ());
61
+ NYdb::TParamsBuilder paramsBuilder;
62
+
63
+ auto & builder = paramsBuilder.AddParam (" $Samples" ).BeginList ();
64
+ for (size_t i = batchStart; i < batchEnd; i++) {
65
+ builder.AddListItem ()
66
+ .BeginStruct ()
67
+ .AddMember (" id" ).Uint64 (i)
68
+ .AddMember (" embedding" ).String (sampler.GetTargetEmbedding (i));
69
+ if (Params.PrefixColumn ) {
70
+ builder.AddMember (" prefix" , sampler.GetPrefixValue (i));
71
+ }
72
+ builder.EndStruct ();
73
+ }
74
+ builder.EndList ().Build ();
75
+
76
+ std::optional<NYdb::TResultSet> resultSet;
77
+ NYdb::NStatusHelpers::ThrowOnError (Params.QueryClient->RetryQuerySync ([&](NYdb::NQuery::TSession session) {
78
+ auto result = session.ExecuteQuery (refQuery, NYdb::NQuery::TTxControl::NoTx (), paramsBuilder.Build ())
79
+ .GetValueSync ();
80
+ Y_ABORT_UNLESS (result.IsSuccess (), " Reference search result query failed: %s" , result.GetIssues ().ToString ().c_str ());
81
+ resultSet = result.GetResultSet (0 );
82
+ return result;
83
+ }));
84
+
85
+ ui64 refId = 0 ;
86
+ std::vector<std::string> refList;
87
+
88
+ NYdb::TResultSetParser parser (*resultSet);
89
+ while (parser.TryNextRow ()) {
90
+ ui64 id = parser.ColumnParser (" id" ).GetUint64 ();
91
+ std::string res = parser.ColumnParser (" result_id" ).GetString ();
92
+ if (id != refId) {
93
+ if (refList.size ()) {
94
+ References[refId] = refList;
95
+ }
96
+ refList.clear ();
97
+ refId = id;
98
+ }
99
+ refList.push_back (res);
100
+ }
101
+ if (refList.size ()) {
102
+ References[refId] = std::move (refList);
103
+ }
104
+ }
105
+ Cout << " Reference results for " << sampler.GetTargetCount ()
106
+ << " targets selected in " << (int )((Clock::now () - startTime) / std::chrono::seconds (1 )) << " seconds.\n " ;
107
+ }
108
+
24
109
void TVectorRecallEvaluator::MeasureRecall (const TVectorSampler& sampler) {
110
+ SelectReferenceResults (sampler);
111
+
25
112
Cout << " Recall measurement..." << Endl;
26
-
27
- // Prepare the query for scan
28
- std::string queryScan = MakeSelect (Params.TableName , {}, Params.KeyColumn , Params.EmbeddingColumn , Params.PrefixColumn , 0 , Params.Metric );
29
-
113
+
30
114
// Create the query for index search
31
- std::string queryIndex = MakeSelect (Params. TableName , Params.IndexName , Params. KeyColumn , Params. EmbeddingColumn , Params. PrefixColumn , Params. KmeansTreeSearchClusters , Params. Metric );
115
+ std::string queryIndex = MakeSelect (Params, Params.IndexName );
32
116
33
117
// Process targets in batches
118
+ const auto startTime = Clock::now ();
34
119
for (size_t batchStart = 0 ; batchStart < sampler.GetTargetCount (); batchStart += Params.RecallThreads ) {
35
120
size_t batchEnd = std::min (batchStart + Params.RecallThreads , sampler.GetTargetCount ());
36
-
37
- // Start async queries for this batch - both scan and index queries
38
- std::vector<std::pair<size_t , NYdb::NQuery::TAsyncExecuteQueryResult>> asyncScanQueries;
121
+
122
+ // Start async queries for this batch
39
123
std::vector<std::pair<size_t , NYdb::NQuery::TAsyncExecuteQueryResult>> asyncIndexQueries;
40
- asyncScanQueries.reserve (batchEnd - batchStart);
41
124
asyncIndexQueries.reserve (batchEnd - batchStart);
42
-
125
+
43
126
for (size_t i = batchStart; i < batchEnd; i++) {
44
127
const auto & targetEmbedding = sampler.GetTargetEmbedding (i);
45
- std::optional<i64 > prefixValue;
128
+ std::optional<NYdb::TValue > prefixValue;
46
129
if (Params.PrefixColumn ) {
47
130
prefixValue = sampler.GetPrefixValue (i);
48
131
}
49
132
50
133
NYdb::TParams params = MakeSelectParams (targetEmbedding, prefixValue, Params.Limit );
51
-
52
- // Execute scan query for ground truth
53
- auto asyncScanResult = Params.QueryClient ->RetryQuery ([queryScan, params](NYdb::NQuery::TSession session) {
54
- return session.ExecuteQuery (
55
- queryScan,
56
- NYdb::NQuery::TTxControl::NoTx (),
57
- params);
58
- });
59
-
134
+
60
135
// Execute index query for recall measurement
61
136
auto asyncIndexResult = Params.QueryClient ->RetryQuery ([queryIndex, params](NYdb::NQuery::TSession session) {
62
137
return session.ExecuteQuery (
63
138
queryIndex,
64
139
NYdb::NQuery::TTxControl::NoTx (),
65
140
params);
66
141
});
67
-
68
- asyncScanQueries.emplace_back (i, std::move (asyncScanResult));
142
+
69
143
asyncIndexQueries.emplace_back (i, std::move (asyncIndexResult));
70
144
}
71
-
72
- // Wait for all scan queries in this batch to complete and build ground truth
73
- std::unordered_map<size_t , std::vector<ui64>> batchEtalons;
74
- for (auto & [targetIndex, asyncResult] : asyncScanQueries) {
75
- auto result = asyncResult.GetValueSync ();
76
- Y_ABORT_UNLESS (result.IsSuccess (), " Scan query failed for target %zu: %s" ,
77
- targetIndex, result.GetIssues ().ToString ().c_str ());
78
-
79
- auto resultSet = result.GetResultSet (0 );
80
- NYdb::TResultSetParser parser (resultSet);
81
-
82
- // Build etalons for this target locally
83
- std::vector<ui64> etalons;
84
- etalons.reserve (Params.Limit );
85
-
86
- // Extract all IDs from the result set
87
- while (parser.TryNextRow ()) {
88
- ui64 id = parser.ColumnParser (Params.KeyColumn ).GetUint64 ();
89
- etalons.push_back (id);
90
- }
91
- if (etalons.empty ()) {
92
- Cerr << " Warning: target " << targetIndex << " have empty etalon sets" << Endl;
93
- }
94
-
95
- batchEtalons[targetIndex] = std::move (etalons);
96
- }
97
-
145
+
98
146
// Wait for all index queries in this batch to complete and measure recall
99
147
for (auto & [targetIndex, asyncResult] : asyncIndexQueries) {
100
148
auto result = asyncResult.GetValueSync ();
101
- // Process the index query result and calculate recall using etalon nearest neighbours
102
- ProcessIndexQueryResult (result, targetIndex, batchEtalons[targetIndex], false );
103
- }
104
-
105
- // Log progress for large datasets
106
- if (sampler.GetTargetCount () > 100 && (batchEnd % 100 == 0 || batchEnd == sampler.GetTargetCount ())) {
107
- Cout << " Processed " << batchEnd << " of " << sampler.GetTargetCount () << " targets..." << Endl;
149
+ // Process the index query result and calculate recall using reference nearest neighbours
150
+ ProcessIndexQueryResult (result, targetIndex, References[targetIndex], false );
108
151
}
109
152
}
110
-
111
- Cout << " Recall measurement completed for " << sampler.GetTargetCount () << " targets."
153
+
154
+ Cout << " Recall measurement completed for " << sampler.GetTargetCount ()
155
+ << " targets in " << (int )((Clock::now () - startTime) / std::chrono::seconds (1 )) << " seconds."
112
156
<< " \n Average recall: " << GetAverageRecall () << Endl;
113
157
}
114
158
@@ -130,55 +174,52 @@ size_t TVectorRecallEvaluator::GetProcessedTargets() const {
130
174
}
131
175
132
176
// Process index query results
133
- void TVectorRecallEvaluator::ProcessIndexQueryResult (const NYdb::NQuery::TExecuteQueryResult& queryResult, size_t targetIndex, const std::vector<ui64>& etalons, bool verbose) {
134
- if (!queryResult.IsSuccess ()) {
135
- // Ignore the error. It's printed in the verbose mode
136
- return ;
137
- }
138
-
177
+ void TVectorRecallEvaluator::ProcessIndexQueryResult (const NYdb::NQuery::TExecuteQueryResult& queryResult,
178
+ size_t targetIndex, const std::vector<std::string>& references, bool verbose) {
179
+ Y_ABORT_UNLESS (queryResult.IsSuccess (), " Query failed: %s" , queryResult.GetIssues ().ToString ().c_str ());
180
+
139
181
// Get the result set
140
182
auto resultSet = queryResult.GetResultSet (0 );
141
183
NYdb::TResultSetParser parser (resultSet);
142
-
184
+
143
185
// Extract IDs from index search results
144
- std::vector<ui64 > indexResults;
186
+ std::vector<std::string > indexResults;
145
187
while (parser.TryNextRow ()) {
146
- ui64 id = parser.ColumnParser (Params.KeyColumn ).GetUint64 ();
147
- indexResults.push_back (id);
188
+ indexResults.push_back (parser.ColumnParser (" id" ).GetString ());
148
189
}
149
-
150
- // Create etalon set for efficient lookup
151
- std::unordered_set<ui64> etalonSet (etalons.begin (), etalons.end ());
152
-
153
- // Check if target ID is first in results
154
- if (!indexResults.empty () && !etalons.empty ()) {
155
- ui64 targetId = etalons[0 ]; // First etalon is the target ID itself
156
-
190
+
191
+ // Create reference set for efficient lookup
192
+ std::unordered_set<std::string> referenceSet (references.begin (), references.end ());
193
+
194
+ if (references.empty ()) {
195
+ Cerr << " Warning: Empty references for target " << targetIndex << " \n " ;
196
+ } else if (indexResults.empty ()) {
197
+ Cerr << " Warning: Empty index results for target " << targetIndex << " \n " ;
198
+ } else {
199
+ // Check if the first reference is the target ID itself
200
+ const std::string & targetId = references[0 ];
201
+
157
202
if (verbose && indexResults[0 ] != targetId) {
158
- Cerr << " Warning: Target ID " << targetId << " is not the first result for target "
203
+ Cerr << " Warning: Target ID " << targetId << " is not the first result for target "
159
204
<< targetIndex << " . Found " << indexResults[0 ] << " instead." << Endl;
160
205
}
161
-
206
+
162
207
// Calculate recall
163
208
size_t relevantRetrieved = 0 ;
164
209
for (const auto & id : indexResults) {
165
- if (etalonSet .count (id)) {
210
+ if (referenceSet .count (id)) {
166
211
relevantRetrieved++;
167
212
}
168
213
}
169
-
214
+
170
215
// Calculate recall for this target
171
- double recall = etalons .empty () ? 0.0 : static_cast <double >(relevantRetrieved) / etalons .size ();
216
+ double recall = references .empty () ? 0.0 : static_cast <double >(relevantRetrieved) / references .size ();
172
217
AddRecall (recall);
173
218
174
219
// Add warning when zero relevant results found
175
220
if (verbose && relevantRetrieved == 0 && !indexResults.empty ()) {
176
221
Cerr << " Warning: Zero relevant results for target " << targetIndex << Endl;
177
222
}
178
- } else {
179
- // Handle empty results or empty etalons
180
- if (verbose)
181
- Cerr << " Warning: Empty results or etalons for target " << targetIndex << Endl;
182
223
}
183
224
}
184
225
0 commit comments