12
12
#include < fstream>
13
13
#include < iostream>
14
14
15
+ // Mutex for synchronizing distance calculations
15
16
std::mutex distanceMutex;
16
17
17
18
/* *
@@ -24,13 +25,36 @@ std::mutex distanceMutex;
24
25
* @return A vector containing a shuffled sequence of integers from `start` to `end`
25
26
*/
26
27
static std::vector<int > generateRandomPermutation (const unsigned int start, const unsigned int end) {
27
-
28
+ // Create a vector containing all integers from start to end
28
29
std::vector<int > permutation (end - start + 1 );
30
+
31
+ // Fill the vector with sequential values starting from 'start'
29
32
std::iota (permutation.begin (), permutation.end (), start);
33
+
34
+ // Shuffle the vector randomly using a random number generator
30
35
std::shuffle (permutation.begin (), permutation.end (), std::mt19937{std::random_device{}()});
31
-
36
+
37
+ // Return the shuffled vector
32
38
return permutation;
39
+ }
33
40
41
+ /* *
42
+ * @brief Generates a random index within a specified range.
43
+ *
44
+ * @param start The starting integer of the range (inclusive)
45
+ * @param end The ending integer of the range (inclusive)
46
+ *
47
+ * @return A random integer within the specified range
48
+ */
49
+ static int generateRandomIndex (const unsigned int start, const unsigned int end) {
50
+ // Create a random number generator
51
+ std::mt19937 generator (std::random_device{}());
52
+
53
+ // Create a uniform distribution within the specified range
54
+ std::uniform_int_distribution<unsigned int > distribution (start, end);
55
+
56
+ // Generate and return a random integer within the range
57
+ return distribution (generator);
34
58
}
35
59
36
60
/* *
@@ -43,9 +67,13 @@ static std::vector<int> generateRandomPermutation(const unsigned int start, cons
43
67
* @return A set of unique random indices of the specified length, excluding index i
44
68
*/
45
69
static std::set<int > generateRandomIndices (const unsigned int max, const unsigned int i, unsigned int length) {
46
-
70
+ // Create a set to store unique random indices
47
71
std::set<int > indices;
72
+
73
+ // Create a random number generator
48
74
std::mt19937 generator (std::random_device{}());
75
+
76
+ // Create a uniform distribution within the specified range
49
77
std::uniform_int_distribution<unsigned int > distribution (0 , max - 1 );
50
78
51
79
// Generate random indices until the set reaches the desired length
@@ -56,8 +84,8 @@ static std::set<int> generateRandomIndices(const unsigned int max, const unsigne
56
84
}
57
85
}
58
86
87
+ // Return the set of unique random indices
59
88
return indices;
60
-
61
89
}
62
90
63
91
/* *
@@ -168,16 +196,12 @@ void VamanaIndex<vamana_t>::createGraph(
168
196
using GreedyResult = std::pair<std::set<vamana_t >, std::set<vamana_t >>;
169
197
GreedyResult greedyResult;
170
198
171
- // Check if the dataset is empty or it has only one point
172
199
if (P.size () <= 1 ) return ;
173
200
174
- // Initialize graph memory
175
201
unsigned int n = P.size ();
176
202
this ->P = P;
177
203
178
204
if (distanceSaveMethod == MATRIX) {
179
-
180
- // If the distance matrix is provided, use it, otherwise compute the distances
181
205
if (distanceMatrix != nullptr ) {
182
206
this ->distanceMatrix = distanceMatrix;
183
207
} else {
@@ -187,21 +211,17 @@ void VamanaIndex<vamana_t>::createGraph(
187
211
}
188
212
this ->computeDistances (visualize, distance_threads);
189
213
}
190
-
191
214
}
192
215
193
- this ->G .setNodesCount (n);
194
-
195
- // Set the number of nodes in the graph, fill the nodes with the dataset points, and create random edges for the nodes
196
216
this ->G .setNodesCount (n);
197
217
this ->fillGraphNodes ();
198
218
this ->createRandomEdges (R);
199
219
200
- // Find the medoid node in the graph, and generate a random permutation of node indices
201
- GraphNode<vamana_t > s = findMedoid (this ->G , visualize, 1000 );
220
+ // Replace the call to findMedoid with the selection of a random point as the medoid
221
+ GraphNode<vamana_t > s = *(this ->G .getNode (generateRandomIndex (0 , n-1 )));
222
+
202
223
std::vector<int > sigma = generateRandomPermutation (0 , n-1 );
203
224
204
- // Define a lambda function to process each node in the sigma permutation
205
225
auto processNode = [&](int i) {
206
226
GraphNode<vamana_t >* sigma_i_node = this ->G .getNode (sigma.at (i));
207
227
vamana_t sigma_i = sigma_i_node->getData ();
@@ -227,7 +247,6 @@ void VamanaIndex<vamana_t>::createGraph(
227
247
}
228
248
};
229
249
230
- // Run the lambda process function if visualization is enabled, otherwise run it without progress visualization
231
250
if (visualize) {
232
251
withProgress (0 , n, " Creating Vamana" , processNode);
233
252
} else {
@@ -236,14 +255,12 @@ void VamanaIndex<vamana_t>::createGraph(
236
255
}
237
256
}
238
257
239
- // Free up the memory allocated for the distance matrix, if it was computed
240
258
if (distanceSaveMethod == MATRIX && distanceMatrix == nullptr ) {
241
259
for (unsigned int i = 0 ; i < n; i++) {
242
260
delete[] this ->distanceMatrix [i];
243
261
}
244
262
delete[] this ->distanceMatrix ;
245
263
}
246
-
247
264
}
248
265
249
266
/* *
@@ -374,22 +391,13 @@ template <typename vamana_t> GraphNode<vamana_t> VamanaIndex<vamana_t>::findMedo
374
391
}
375
392
}
376
393
377
- // Find the medoid node among the sampled nodes by calculating the average distance for each one
378
- float min_average_distance = std::numeric_limits<float >::max ();
379
- GraphNode<vamana_t >* medoid_node = nullptr ;
380
-
381
- for (int i = 0 ; i < sample_size; ++i) {
382
- float total_distance = std::accumulate (distance_matrix[i].begin (), distance_matrix[i].end (), 0 .0f );
383
- float average_distance = total_distance / (sample_size - 1 );
384
- if (average_distance < min_average_distance) {
385
- min_average_distance = average_distance;
386
- medoid_node = graph.getNode (sampled_indices[i]);
387
- }
388
- }
394
+ // Randomly select a point as the medoid
395
+ GraphNode<vamana_t >* medoid_node = graph.getNode (generateRandomIndex (0 , graph.getNodesCount () - 1 ));
389
396
390
397
return *medoid_node;
391
398
392
399
}
393
400
401
+ // Explicit template instantiation for specific types
394
402
template class VamanaIndex <DataVector<float >>;
395
403
template class VamanaIndex <BaseDataVector<float >>;
0 commit comments