@@ -113,16 +113,21 @@ void Create(std::unordered_map<std::string, std::string> args) {
113
113
using BaseVectorVector = std::vector<BaseDataVector<float >>;
114
114
using BaseVectors = std::vector<DataVector<float >>;
115
115
116
- std::string indexType, baseFile, L, R, alpha, outputFile, connectionMode;
116
+ std::string indexType, baseFile, L, R, alpha, outputFile, connectionMode, distanceSaveMethod ;
117
117
std::string L_small, R_small, R_stiched;
118
118
bool save = false ;
119
119
bool leaveEmpty = false ;
120
120
int distanceThreads = 1 ; // Default value
121
+ int computingThreads = 1 ; // Default value
122
+
123
+ std::vector<std::string> validArguments = {" -index-type" , " -base-file" , " -L" , " -L-small" , " -R" , " -R-small" , " -R-stiched" , " -alpha" , " -save" , " -random-edges" , " -connection-mode" , " -distance-threads" , " -distance-save" };
124
+ if (args[" -index-type" ] == " stiched" ) {
125
+ validArguments.push_back (" -computing-threads" );
126
+ }
121
127
122
- std::vector<std::string> validArguments = {" -index-type" , " -base-file" , " -L" , " -L-small" , " -R" , " -R-small" , " -R-stiched" , " -alpha" , " -save" , " -random-edges" , " -connection-mode" , " -distance-threads" };
123
128
for (auto arg : args) {
124
129
if (std::find (validArguments.begin (), validArguments.end (), arg.first ) == validArguments.end ()) {
125
- throw std::invalid_argument (" Error: Invalid argument: " + arg.first + " . Valid arguments are: -index-type, -base-file, -L, -L-small, -R, -R-small, -R-stiched, -alpha, -save, -connection-mode, -distance-threads" );
130
+ throw std::invalid_argument (" Error: Invalid argument: " + arg.first + " . Valid arguments are: -index-type, -base-file, -L, -L-small, -R, -R-small, -R-stiched, -alpha, -save, -connection-mode, -distance-threads, -distance-save " );
126
131
}
127
132
}
128
133
@@ -146,6 +151,8 @@ void Create(std::unordered_map<std::string, std::string> args) {
146
151
}
147
152
148
153
} else if (indexType == " stiched" ) {
154
+ validArguments.push_back (" -computing-threads" );
155
+
149
156
if (args.find (" -L-small" ) == args.end ()) {
150
157
throw std::invalid_argument (" Error: Missing required argument: -L-small" );
151
158
} else {
@@ -163,6 +170,10 @@ void Create(std::unordered_map<std::string, std::string> args) {
163
170
} else {
164
171
R_stiched = args[" -R-stiched" ];
165
172
}
173
+
174
+ if (args.find (" -computing-threads" ) != args.end ()) {
175
+ computingThreads = std::stoi (args[" -computing-threads" ]);
176
+ }
166
177
} else {
167
178
throw std::invalid_argument (" Error: Invalid index type: " + indexType + " . Supported index types are: simple, filtered, stiched" );
168
179
}
@@ -196,7 +207,19 @@ void Create(std::unordered_map<std::string, std::string> args) {
196
207
}
197
208
}
198
209
210
+ if (args.find (" -distance-save" ) != args.end ()) {
211
+ distanceSaveMethod = args[" -distance-save" ];
212
+ if (distanceSaveMethod != " none" && distanceSaveMethod != " matrix" ) {
213
+ throw std::invalid_argument (" Error: Invalid value for -distance-save. Valid values are: none, matrix" );
214
+ }
215
+ } else {
216
+ distanceSaveMethod = " none" ; // Default value
217
+ }
218
+
199
219
if (args.find (" -distance-threads" ) != args.end ()) {
220
+ if (distanceSaveMethod != " matrix" ) {
221
+ throw std::invalid_argument (" Error: -distance-threads can only be used if -distance-save is set to 'matrix'" );
222
+ }
200
223
distanceThreads = std::stoi (args[" -distance-threads" ]);
201
224
}
202
225
@@ -206,9 +229,16 @@ void Create(std::unordered_map<std::string, std::string> args) {
206
229
std::cerr << " Error reading base file" << std::endl;
207
230
return ;
208
231
}
232
+
233
+ DISTANCE_SAVE_METHOD distanceSaveMethodEnum = NONE;
234
+ if (distanceSaveMethod == " none" ) {
235
+ distanceSaveMethodEnum = NONE;
236
+ } else if (distanceSaveMethod == " matrix" ) {
237
+ distanceSaveMethodEnum = MATRIX;
238
+ }
209
239
210
240
VamanaIndex<DataVector<float >> vamanaIndex = VamanaIndex<DataVector<float >>();
211
- vamanaIndex.createGraph (base_vectors, std::stof (alpha), std::stoi (L), std::stoi (R), distanceThreads, true );
241
+ vamanaIndex.createGraph (base_vectors, std::stof (alpha), std::stoi (L), std::stoi (R), distanceSaveMethodEnum, distanceThreads, true );
212
242
213
243
if (save) {
214
244
if (!vamanaIndex.saveGraph (outputFile)) {
@@ -226,17 +256,24 @@ void Create(std::unordered_map<std::string, std::string> args) {
226
256
filters.insert (filter);
227
257
}
228
258
259
+ DISTANCE_SAVE_METHOD distanceSaveMethodEnum = NONE;
260
+ if (distanceSaveMethod == " none" ) {
261
+ distanceSaveMethodEnum = NONE;
262
+ } else if (distanceSaveMethod == " matrix" ) {
263
+ distanceSaveMethodEnum = MATRIX;
264
+ }
265
+
229
266
if (indexType == " filtered" ) {
230
267
FilteredVamanaIndex<BaseDataVector<float >> index (filters);
231
- index.createGraph (base_vectors, std::stoi (alpha), std::stoi (L), std::stoi (R), distanceThreads, true , leaveEmpty);
268
+ index.createGraph (base_vectors, std::stoi (alpha), std::stoi (L), std::stoi (R), distanceSaveMethodEnum, distanceThreads, true , leaveEmpty);
232
269
233
270
if (save) {
234
271
index.saveGraph (outputFile);
235
272
std::cout << std::endl << green << " Vamana Index was saved successfully to " << brightYellow << " `" << outputFile << " `" << reset << std::endl;
236
273
}
237
274
} else if (indexType == " stiched" ) {
238
275
StichedVamanaIndex<BaseDataVector<float >> index (filters);
239
- index.createGraph (base_vectors, std::stof (alpha), std::stoi (L_small), std::stoi (R_small), std::stoi (R_stiched), distanceThreads, true , leaveEmpty);
276
+ index.createGraph (base_vectors, std::stof (alpha), std::stoi (L_small), std::stoi (R_small), std::stoi (R_stiched), distanceSaveMethodEnum, distanceThreads, computingThreads , true , leaveEmpty);
240
277
241
278
if (save) {
242
279
index.saveGraph (outputFile);
@@ -277,7 +314,7 @@ void TestSimple(std::unordered_map<std::string, std::string> args) {
277
314
GraphNode<DataVector<float >> s = vamanaIndex.findMedoid (vamanaIndex.getGraph (), 1000 );
278
315
279
316
auto start = std::chrono::high_resolution_clock::now ();
280
- SimpleGreedyResult greedyResult = GreedySearch (vamanaIndex, s, query_vectors.at (std::stoi (queryNumber)), std::stoi (k), std::stoi (L), TEST );
317
+ SimpleGreedyResult greedyResult = GreedySearch (vamanaIndex, s, query_vectors.at (std::stoi (queryNumber)), std::stoi (k), std::stoi (L), NONE );
281
318
auto end = std::chrono::high_resolution_clock::now ();
282
319
std::chrono::duration<double > elapsed = end - start;
283
320
@@ -363,7 +400,7 @@ void TestFilteredOrStiched(std::unordered_map<std::string, std::string> args) {
363
400
}
364
401
365
402
auto start = std::chrono::high_resolution_clock::now ();
366
- FilteredGreedyResult greedyResult = FilteredGreedySearch (index, start_nodes, xq, std::stoi (k), std::stoi (L), Fx, TEST );
403
+ FilteredGreedyResult greedyResult = FilteredGreedySearch (index, start_nodes, xq, std::stoi (k), std::stoi (L), Fx, NONE );
367
404
auto end = std::chrono::high_resolution_clock::now ();
368
405
std::chrono::duration<double > elapsed = end - start;
369
406
0 commit comments