Skip to content

Commit 81ea1de

Browse files
committed
Initialization of the medoid in Vamana with random points
1 parent 485e91f commit 81ea1de

File tree

1 file changed

+38
-30
lines changed

1 file changed

+38
-30
lines changed

src/VIA/Algorithms/VamanaIndex.cpp

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <fstream>
1313
#include <iostream>
1414

15+
// Mutex for synchronizing distance calculations
1516
std::mutex distanceMutex;
1617

1718
/**
@@ -24,13 +25,36 @@ std::mutex distanceMutex;
2425
* @return A vector containing a shuffled sequence of integers from `start` to `end`
2526
*/
2627
static std::vector<int> generateRandomPermutation(const unsigned int start, const unsigned int end) {
27-
28+
// Create a vector containing all integers from start to end
2829
std::vector<int> permutation(end - start + 1);
30+
31+
// Fill the vector with sequential values starting from 'start'
2932
std::iota(permutation.begin(), permutation.end(), start);
33+
34+
// Shuffle the vector randomly using a random number generator
3035
std::shuffle(permutation.begin(), permutation.end(), std::mt19937{std::random_device{}()});
31-
36+
37+
// Return the shuffled vector
3238
return permutation;
39+
}
3340

41+
/**
42+
* @brief Generates a random index within a specified range.
43+
*
44+
* @param start The starting integer of the range (inclusive)
45+
* @param end The ending integer of the range (inclusive)
46+
*
47+
* @return A random integer within the specified range
48+
*/
49+
static int generateRandomIndex(const unsigned int start, const unsigned int end) {
50+
// Create a random number generator
51+
std::mt19937 generator(std::random_device{}());
52+
53+
// Create a uniform distribution within the specified range
54+
std::uniform_int_distribution<unsigned int> distribution(start, end);
55+
56+
// Generate and return a random integer within the range
57+
return distribution(generator);
3458
}
3559

3660
/**
@@ -43,9 +67,13 @@ static std::vector<int> generateRandomPermutation(const unsigned int start, cons
4367
* @return A set of unique random indices of the specified length, excluding index i
4468
*/
4569
static std::set<int> generateRandomIndices(const unsigned int max, const unsigned int i, unsigned int length) {
46-
70+
// Create a set to store unique random indices
4771
std::set<int> indices;
72+
73+
// Create a random number generator
4874
std::mt19937 generator(std::random_device{}());
75+
76+
// Create a uniform distribution within the specified range
4977
std::uniform_int_distribution<unsigned int> distribution(0, max - 1);
5078

5179
// Generate random indices until the set reaches the desired length
@@ -56,8 +84,8 @@ static std::set<int> generateRandomIndices(const unsigned int max, const unsigne
5684
}
5785
}
5886

87+
// Return the set of unique random indices
5988
return indices;
60-
6189
}
6290

6391
/**
@@ -168,16 +196,12 @@ void VamanaIndex<vamana_t>::createGraph(
168196
using GreedyResult = std::pair<std::set<vamana_t>, std::set<vamana_t>>;
169197
GreedyResult greedyResult;
170198

171-
// Check if the dataset is empty or it has only one point
172199
if (P.size() <= 1) return;
173200

174-
// Initialize graph memory
175201
unsigned int n = P.size();
176202
this->P = P;
177203

178204
if (distanceSaveMethod == MATRIX) {
179-
180-
// If the distance matrix is provided, use it, otherwise compute the distances
181205
if (distanceMatrix != nullptr) {
182206
this->distanceMatrix = distanceMatrix;
183207
} else {
@@ -187,21 +211,17 @@ void VamanaIndex<vamana_t>::createGraph(
187211
}
188212
this->computeDistances(visualize, distance_threads);
189213
}
190-
191214
}
192215

193-
this->G.setNodesCount(n);
194-
195-
// Set the number of nodes in the graph, fill the nodes with the dataset points, and create random edges for the nodes
196216
this->G.setNodesCount(n);
197217
this->fillGraphNodes();
198218
this->createRandomEdges(R);
199219

200-
// Find the medoid node in the graph, and generate a random permutation of node indices
201-
GraphNode<vamana_t> s = findMedoid(this->G, visualize, 1000);
220+
// Replace the call to findMedoid with the selection of a random point as the medoid
221+
GraphNode<vamana_t> s = *(this->G.getNode(generateRandomIndex(0, n-1)));
222+
202223
std::vector<int> sigma = generateRandomPermutation(0, n-1);
203224

204-
// Define a lambda function to process each node in the sigma permutation
205225
auto processNode = [&](int i) {
206226
GraphNode<vamana_t>* sigma_i_node = this->G.getNode(sigma.at(i));
207227
vamana_t sigma_i = sigma_i_node->getData();
@@ -227,7 +247,6 @@ void VamanaIndex<vamana_t>::createGraph(
227247
}
228248
};
229249

230-
// Run the lambda process function if visualization is enabled, otherwise run it without progress visualization
231250
if (visualize) {
232251
withProgress(0, n, "Creating Vamana", processNode);
233252
} else {
@@ -236,14 +255,12 @@ void VamanaIndex<vamana_t>::createGraph(
236255
}
237256
}
238257

239-
// Free up the memory allocated for the distance matrix, if it was computed
240258
if (distanceSaveMethod == MATRIX && distanceMatrix == nullptr) {
241259
for (unsigned int i = 0; i < n; i++) {
242260
delete[] this->distanceMatrix[i];
243261
}
244262
delete[] this->distanceMatrix;
245263
}
246-
247264
}
248265

249266
/**
@@ -374,22 +391,13 @@ template <typename vamana_t> GraphNode<vamana_t> VamanaIndex<vamana_t>::findMedo
374391
}
375392
}
376393

377-
// Find the medoid node among the sampled nodes by calculating the average distance for each one
378-
float min_average_distance = std::numeric_limits<float>::max();
379-
GraphNode<vamana_t>* medoid_node = nullptr;
380-
381-
for (int i = 0; i < sample_size; ++i) {
382-
float total_distance = std::accumulate(distance_matrix[i].begin(), distance_matrix[i].end(), 0.0f);
383-
float average_distance = total_distance / (sample_size - 1);
384-
if (average_distance < min_average_distance) {
385-
min_average_distance = average_distance;
386-
medoid_node = graph.getNode(sampled_indices[i]);
387-
}
388-
}
394+
// Randomly select a point as the medoid
395+
GraphNode<vamana_t>* medoid_node = graph.getNode(generateRandomIndex(0, graph.getNodesCount() - 1));
389396

390397
return *medoid_node;
391398

392399
}
393400

401+
// Explicit template instantiation for specific types
394402
template class VamanaIndex<DataVector<float>>;
395403
template class VamanaIndex<BaseDataVector<float>>;

0 commit comments

Comments
 (0)