Skip to content

Commit a46adce

Browse files
committed
Added documentation and commented the Filtered Robust on Stiched Vamana
1 parent 3cca80a commit a46adce

File tree

8 files changed

+256
-230
lines changed

8 files changed

+256
-230
lines changed

include/GreedySearch.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ template <typename graph_t, typename query_t> std::pair<std::set<graph_t>, std::
5555
const EXEC_MODE execMode = CREATE
5656
);
5757

58+
/**
59+
* @brief Greedy search algorithm for finding the k nearest nodes in a graph relative to a query vector.
60+
*
61+
* This function implements a greedy search that iteratively explores the closest nodes to the query vector.
62+
* It maintains a candidate set of nodes to visit and a visited set of nodes already processed. This version
63+
* of the function is used with a FilteredVamanaIndex, which applies additional filtering criteria to the search.
64+
*
65+
* @param graph_t Type of data stored in the graph nodes
66+
* @param query_t Type of the query vector
67+
* @param index The FilteredVamanaIndex to search
68+
* @param S Starting nodes for the search
69+
* @param xq Query vector for distance computation
70+
* @param k Number of nearest nodes to return
71+
* @param L Maximum number of nodes in the candidate set
72+
* @param queryFilters A vector of CategoricalAttributeFilter objects to apply to the search
73+
* @param mode Execution mode for the algorithm
74+
*
75+
* @return Pair of sets: the first set contains the k nearest nodes, and the second set contains all visited nodes
76+
*
77+
*/
5878
template <typename graph_t, typename query_t> std::pair<std::set<graph_t>, std::set<graph_t>> FilteredGreedySearch(
5979
const FilteredVamanaIndex<graph_t>& index,
6080
const std::vector<GraphNode<graph_t>>& S,

include/RobustPrune.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,19 @@ template <typename graph_t> class FilteredVamanaIndex;
3030
template <typename graph_t>
3131
void RobustPrune(VamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::set<graph_t>& V, float alpha, int R);
3232

33+
/**
34+
* @brief Prunes the neighbors of a given node in a graph based on a robust pruning algorithm with filtering.
35+
*
36+
* This function modifies the neighbors of the given node `p_node` in the graph `G` by selecting
37+
* a subset of neighbors that are within a certain distance threshold defined by `alpha` and `R`,
38+
* while also applying additional filtering criteria.
39+
*
40+
* @tparam graph_t The type of the graph nodes.
41+
* @param G The graph containing the node to be pruned.
42+
* @param p_node The node whose neighbors are to be pruned.
43+
* @param V A set of graph nodes to be considered for pruning.
44+
* @param alpha A float value used as a multiplier for the distance threshold.
45+
* @param R An integer specifying the maximum number of neighbors to retain.
46+
*/
3347
template <typename graph_t>
3448
void FilteredRobustPrune(FilteredVamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node,std::set<graph_t>& V, float alpha,int R);

src/VIA/Algorithms/GreedySearch.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ GreedySearch(const VamanaIndex<graph_t>& index, const GraphNode<graph_t>& s, con
118118

119119
// Limit the size of candidates to L by keeping the closest L elements to the query
120120
if (candidates.size() > static_cast<size_t>(L)) {
121-
std::set<graph_t, EuclideanDistanceOrder<graph_t, query_t>> newCandidates{EuclideanDistanceOrder<graph_t, query_t>(xq, index.getDistanceMatrix(), execMode==CREATE)};
121+
std::set<graph_t, EuclideanDistanceOrder<graph_t, query_t>> newCandidates{
122+
EuclideanDistanceOrder<graph_t, query_t>(xq, index.getDistanceMatrix(), execMode==CREATE)
123+
};
124+
122125
for (auto candidate : candidates) {
123126
newCandidates.insert(candidate);
124127
}
@@ -137,7 +140,10 @@ GreedySearch(const VamanaIndex<graph_t>& index, const GraphNode<graph_t>& s, con
137140
}
138141

139142
// Final selection of k closest candidates after main loop
140-
std::set<graph_t, EuclideanDistanceOrder<graph_t, query_t>> newCandidates{EuclideanDistanceOrder<graph_t, query_t>(xq, index.getDistanceMatrix(), execMode==CREATE)};
143+
std::set<graph_t, EuclideanDistanceOrder<graph_t, query_t>> newCandidates{
144+
EuclideanDistanceOrder<graph_t, query_t>(xq, index.getDistanceMatrix(), execMode==CREATE)
145+
};
146+
141147
for (auto candidate : candidates) {
142148
newCandidates.insert(candidate);
143149
}
@@ -153,6 +159,26 @@ GreedySearch(const VamanaIndex<graph_t>& index, const GraphNode<graph_t>& s, con
153159

154160
}
155161

162+
/**
163+
* @brief Greedy search algorithm for finding the k nearest nodes in a graph relative to a query vector.
164+
*
165+
* This function implements a greedy search that iteratively explores the closest nodes to the query vector.
166+
* It maintains a candidate set of nodes to visit and a visited set of nodes already processed. This version
167+
* of the function is used with a FilteredVamanaIndex, which applies additional filtering criteria to the search.
168+
*
169+
* @param graph_t Type of data stored in the graph nodes
170+
* @param query_t Type of the query vector
171+
* @param index The FilteredVamanaIndex to search
172+
* @param S Starting nodes for the search
173+
* @param xq Query vector for distance computation
174+
* @param k Number of nearest nodes to return
175+
* @param L Maximum number of nodes in the candidate set
176+
* @param queryFilters A vector of CategoricalAttributeFilter objects to apply to the search
177+
* @param mode Execution mode for the algorithm
178+
*
179+
* @return Pair of sets: the first set contains the k nearest nodes, and the second set contains all visited nodes
180+
*
181+
*/
156182
template <typename graph_t, typename query_t>
157183
std::pair<std::set<graph_t>, std::set<graph_t>> FilteredGreedySearch(
158184
const FilteredVamanaIndex<graph_t>& index, const std::vector<GraphNode<graph_t>>& S, const query_t& xq,
@@ -241,7 +267,10 @@ std::pair<std::set<graph_t>, std::set<graph_t>> FilteredGreedySearch(
241267
// Limit the size of candidates to L by keeping the closest L elements to the query
242268
if (candidates.size() > static_cast<size_t>(L)) {
243269

244-
std::set<graph_t, EuclideanDistanceOrder<graph_t, query_t>> newCandidates{EuclideanDistanceOrder<graph_t, query_t>(xq, index.getDistanceMatrix(), mode==CREATE)};
270+
std::set<graph_t, EuclideanDistanceOrder<graph_t, query_t>> newCandidates{
271+
EuclideanDistanceOrder<graph_t, query_t>(xq, index.getDistanceMatrix(), mode==CREATE)
272+
};
273+
245274
for (auto candidate : candidates) {
246275
newCandidates.insert(candidate);
247276
}
@@ -261,7 +290,10 @@ std::pair<std::set<graph_t>, std::set<graph_t>> FilteredGreedySearch(
261290
}
262291

263292
// Final selection of k closest candidates after main loop
264-
std::set<graph_t, EuclideanDistanceOrder<graph_t, query_t>> newCandidates{EuclideanDistanceOrder<graph_t, query_t>(xq, index.getDistanceMatrix(), mode==CREATE)};
293+
std::set<graph_t, EuclideanDistanceOrder<graph_t, query_t>> newCandidates{
294+
EuclideanDistanceOrder<graph_t, query_t>(xq, index.getDistanceMatrix(), mode==CREATE)
295+
};
296+
265297
for (auto candidate : candidates) {
266298
newCandidates.insert(candidate);
267299
}

src/VIA/Algorithms/RobustPrune.cpp

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,20 @@ void RobustPrune(VamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::s
6060
for (auto neighbor : *neighbors) {
6161
V.insert(neighbor);
6262
}
63-
// Remove p_node itself from V
63+
64+
// Remove p_node itself from V, and clear the neighbors of p_node
6465
V.erase(p);
65-
// Clear the neighbors of p_node
6666
p_node.clearNeighbors();
6767

6868
// Continue pruning until V is empty or the desired number of neighbors is reached
6969
while (!V.empty()) {
7070

71-
// Find the closest neighbor to p_node in V
71+
// Find the closest neighbor to p_node in V, and initialize the distance to p_star
7272
graph_t p_star = getSetItemAtIndex(0, V);
73-
// float p_star_distance = euclideanDistance(p, p_star);
7473
float p_star_distance = index.getDistanceMatrix()[p.getIndex()][p_star.getIndex()];
7574

7675
// Update p_star if a closer neighbor is found
7776
for (auto p_tone : V) {
78-
// float currentDistance = euclideanDistance(p, p_tone);
7977
float currentDistance = index.getDistanceMatrix()[p.getIndex()][p_tone.getIndex()];
8078

8179
if (currentDistance < p_star_distance) {
@@ -95,17 +93,34 @@ void RobustPrune(VamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::s
9593
// Create a copy of V to avoid modifying the original set during iteration
9694
std::set<graph_t> V_copy = V;
9795
for (auto p_tone : V_copy) {
96+
9897
// Remove neighbors that are too far from p_star based on alpha and euclideanDistance
99-
// if ((alpha * euclideanDistance(p_star, p_tone)) <= euclideanDistance(p, p_tone)) {
100-
// V.erase(p_tone);
101-
// }
102-
if ((alpha * index.getDistanceMatrix()[p_star.getIndex()][p_tone.getIndex()]) <= index.getDistanceMatrix()[p.getIndex()][p_tone.getIndex()]) {
98+
double distance1 = index.getDistanceMatrix()[p_star.getIndex()][p_tone.getIndex()];
99+
double distance2 = index.getDistanceMatrix()[p.getIndex()][p_tone.getIndex()];
100+
101+
if ((alpha * distance1) <= distance2) {
103102
V.erase(p_tone);
104103
}
104+
105105
}
106106
}
107+
107108
}
108109

110+
/**
111+
* @brief Prunes the neighbors of a given node in a graph based on a robust pruning algorithm with filtering.
112+
*
113+
* This function modifies the neighbors of the given node `p_node` in the graph `G` by selecting
114+
* a subset of neighbors that are within a certain distance threshold defined by `alpha` and `R`,
115+
* while also applying additional filtering criteria.
116+
*
117+
* @tparam graph_t The type of the graph nodes.
118+
* @param G The graph containing the node to be pruned.
119+
* @param p_node The node whose neighbors are to be pruned.
120+
* @param V A set of graph nodes to be considered for pruning.
121+
* @param alpha A float value used as a multiplier for the distance threshold.
122+
* @param R An integer specifying the maximum number of neighbors to retain.
123+
*/
109124
template <typename graph_t>
110125
void FilteredRobustPrune(FilteredVamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::set<graph_t>& V, float alpha, int R) {
111126

@@ -118,23 +133,19 @@ void FilteredRobustPrune(FilteredVamanaIndex<graph_t>& index, GraphNode<graph_t>
118133
V.insert(neighbor);
119134
}
120135

121-
// Remove p_node itself from V
136+
// Remove p_node itself from V, and clear the neighbors of p_node
122137
V.erase(p);
123-
124-
// Clear the neighbors of p_node
125138
p_node.clearNeighbors();
126139

127140
// Continue pruning until V is empty or the desired number of neighbors is reached
128141
while (!V.empty()) {
129142

130143
// Find the closest neighbor to p_node in V
131144
graph_t p_star = getSetItemAtIndex(0, V);
132-
// float p_star_distance = euclideanDistance(p, p_star);
133145
float p_star_distance = index.getDistanceMatrix()[p.getIndex()][p_star.getIndex()];
134146

135147
// Update p_star if a closer neighbor is found
136148
for (auto p_tone : V) {
137-
// float currentDistance = euclideanDistance(p, p_tone);
138149
float currentDistance = index.getDistanceMatrix()[p.getIndex()][p_tone.getIndex()];
139150
if (currentDistance < p_star_distance) {
140151
p_star_distance = currentDistance;
@@ -166,18 +177,24 @@ void FilteredRobustPrune(FilteredVamanaIndex<graph_t>& index, GraphNode<graph_t>
166177
}
167178
}
168179

169-
// Remove nodes that do NOT satisfy the filtering condition
170-
// if ((alpha * euclideanDistance(p_star, p_tone)) <= euclideanDistance(p, p_tone)) {
171-
// V.erase(p_tone);
172-
// }
173-
if ((alpha * index.getDistanceMatrix()[p_star.getIndex()][p_tone.getIndex()]) <= index.getDistanceMatrix()[p.getIndex()][p_tone.getIndex()]) {
180+
// Remove neighbors that are too far from p_star based on alpha and euclideanDistance
181+
double distance1 = index.getDistanceMatrix()[p_star.getIndex()][p_tone.getIndex()];
182+
double distance2 = index.getDistanceMatrix()[p.getIndex()][p_tone.getIndex()];
183+
184+
if ((alpha * distance1) <= distance2) {
174185
V.erase(p_tone);
175186
}
176187

177188
}
178189
}
190+
179191
}
180192

193+
194+
/// Explicit instantiations for RobustPrune and FilteredRobustPrune
195+
196+
197+
// Explicit instantiation for RobustPrune with float data type and DataVector query type
181198
template void RobustPrune<DataVector<float>>(
182199
VamanaIndex<DataVector<float>>& index,
183200
GraphNode<DataVector<float>>& p_node,
@@ -186,6 +203,7 @@ template void RobustPrune<DataVector<float>>(
186203
int R
187204
);
188205

206+
// Explicit instantiation for FilteredRobustPrune with float data type and DataVector query type
189207
template void RobustPrune<BaseDataVector<float>>(
190208
VamanaIndex<BaseDataVector<float>>& index,
191209
GraphNode<BaseDataVector<float>>& p_node,
@@ -194,6 +212,7 @@ template void RobustPrune<BaseDataVector<float>>(
194212
int R
195213
);
196214

215+
// Explicit instantiation for FilteredRobustPrune with float data type and DataVector query type
197216
template void FilteredRobustPrune<BaseDataVector<float>>(
198217
FilteredVamanaIndex<BaseDataVector<float>>& index,
199218
GraphNode<BaseDataVector<float>>& p_node,

src/VIA/Algorithms/StichedVamanaIndex.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,21 @@ void StichedVamanaIndex<vamana_t>::createGraph(
8787

8888
});
8989

90-
// Filtered Robust Prune for every v in V (the nodes of the graph)
91-
for (unsigned int i = 0; i < this->G.getNodesCount(); i++) {
9290

93-
GraphNode<vamana_t>* currentNode = this->G.getNode(i);
94-
std::set<vamana_t> neighbors = currentNode->getNeighborsSet();
9591

96-
// Run Filtered Robust Prune for the current node and its neighbors
97-
FilteredRobustPrune(*this, *currentNode, neighbors, alpha, R_stiched);
92+
// NOTE: Without the Filtered Robust we get better results...
9893

99-
}
94+
95+
// // Filtered Robust Prune for every v in V (the nodes of the graph)
96+
// for (unsigned int i = 0; i < this->G.getNodesCount(); i++) {
97+
98+
// GraphNode<vamana_t>* currentNode = this->G.getNode(i);
99+
// std::set<vamana_t> neighbors = currentNode->getNeighborsSet();
100+
101+
// // Run Filtered Robust Prune for the current node and its neighbors
102+
// FilteredRobustPrune(*this, *currentNode, neighbors, alpha, R_stiched);
103+
104+
// }
100105

101106
// Free up the memory allocated for the distance matrix
102107
for (unsigned int i = 0; i < n; i++) {

0 commit comments

Comments
 (0)