@@ -60,22 +60,20 @@ void RobustPrune(VamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::s
60
60
for (auto neighbor : *neighbors) {
61
61
V.insert (neighbor);
62
62
}
63
- // Remove p_node itself from V
63
+
64
+ // Remove p_node itself from V, and clear the neighbors of p_node
64
65
V.erase (p);
65
- // Clear the neighbors of p_node
66
66
p_node.clearNeighbors ();
67
67
68
68
// Continue pruning until V is empty or the desired number of neighbors is reached
69
69
while (!V.empty ()) {
70
70
71
- // Find the closest neighbor to p_node in V
71
+ // Find the closest neighbor to p_node in V, and initialize the distance to p_star
72
72
graph_t p_star = getSetItemAtIndex (0 , V);
73
- // float p_star_distance = euclideanDistance(p, p_star);
74
73
float p_star_distance = index.getDistanceMatrix ()[p.getIndex ()][p_star.getIndex ()];
75
74
76
75
// Update p_star if a closer neighbor is found
77
76
for (auto p_tone : V) {
78
- // float currentDistance = euclideanDistance(p, p_tone);
79
77
float currentDistance = index.getDistanceMatrix ()[p.getIndex ()][p_tone.getIndex ()];
80
78
81
79
if (currentDistance < p_star_distance) {
@@ -95,17 +93,34 @@ void RobustPrune(VamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::s
95
93
// Create a copy of V to avoid modifying the original set during iteration
96
94
std::set<graph_t > V_copy = V;
97
95
for (auto p_tone : V_copy) {
96
+
98
97
// Remove neighbors that are too far from p_star based on alpha and euclideanDistance
99
- // if ((alpha * euclideanDistance(p_star, p_tone)) <= euclideanDistance(p, p_tone)) {
100
- // V.erase( p_tone) ;
101
- // }
102
- if ((alpha * index. getDistanceMatrix ()[p_star. getIndex ()][p_tone. getIndex ()]) <= index. getDistanceMatrix ()[p. getIndex ()][p_tone. getIndex ()] ) {
98
+ double distance1 = index. getDistanceMatrix ()[p_star. getIndex ()][ p_tone. getIndex ()];
99
+ double distance2 = index. getDistanceMatrix ()[p. getIndex ()][ p_tone. getIndex ()] ;
100
+
101
+ if ((alpha * distance1) <= distance2 ) {
103
102
V.erase (p_tone);
104
103
}
104
+
105
105
}
106
106
}
107
+
107
108
}
108
109
110
+ /* *
111
+ * @brief Prunes the neighbors of a given node in a graph based on a robust pruning algorithm with filtering.
112
+ *
113
+ * This function modifies the neighbors of the given node `p_node` in the graph `G` by selecting
114
+ * a subset of neighbors that are within a certain distance threshold defined by `alpha` and `R`,
115
+ * while also applying additional filtering criteria.
116
+ *
117
+ * @tparam graph_t The type of the graph nodes.
118
+ * @param G The graph containing the node to be pruned.
119
+ * @param p_node The node whose neighbors are to be pruned.
120
+ * @param V A set of graph nodes to be considered for pruning.
121
+ * @param alpha A float value used as a multiplier for the distance threshold.
122
+ * @param R An integer specifying the maximum number of neighbors to retain.
123
+ */
109
124
template <typename graph_t >
110
125
void FilteredRobustPrune (FilteredVamanaIndex<graph_t >& index, GraphNode<graph_t >& p_node, std::set<graph_t >& V, float alpha, int R) {
111
126
@@ -118,23 +133,19 @@ void FilteredRobustPrune(FilteredVamanaIndex<graph_t>& index, GraphNode<graph_t>
118
133
V.insert (neighbor);
119
134
}
120
135
121
- // Remove p_node itself from V
136
+ // Remove p_node itself from V, and clear the neighbors of p_node
122
137
V.erase (p);
123
-
124
- // Clear the neighbors of p_node
125
138
p_node.clearNeighbors ();
126
139
127
140
// Continue pruning until V is empty or the desired number of neighbors is reached
128
141
while (!V.empty ()) {
129
142
130
143
// Find the closest neighbor to p_node in V
131
144
graph_t p_star = getSetItemAtIndex (0 , V);
132
- // float p_star_distance = euclideanDistance(p, p_star);
133
145
float p_star_distance = index.getDistanceMatrix ()[p.getIndex ()][p_star.getIndex ()];
134
146
135
147
// Update p_star if a closer neighbor is found
136
148
for (auto p_tone : V) {
137
- // float currentDistance = euclideanDistance(p, p_tone);
138
149
float currentDistance = index.getDistanceMatrix ()[p.getIndex ()][p_tone.getIndex ()];
139
150
if (currentDistance < p_star_distance) {
140
151
p_star_distance = currentDistance;
@@ -166,18 +177,24 @@ void FilteredRobustPrune(FilteredVamanaIndex<graph_t>& index, GraphNode<graph_t>
166
177
}
167
178
}
168
179
169
- // Remove nodes that do NOT satisfy the filtering condition
170
- // if ((alpha * euclideanDistance(p_star, p_tone)) <= euclideanDistance(p, p_tone)) {
171
- // V.erase( p_tone) ;
172
- // }
173
- if ((alpha * index. getDistanceMatrix ()[p_star. getIndex ()][p_tone. getIndex ()]) <= index. getDistanceMatrix ()[p. getIndex ()][p_tone. getIndex ()] ) {
180
+ // Remove neighbors that are too far from p_star based on alpha and euclideanDistance
181
+ double distance1 = index. getDistanceMatrix ()[p_star. getIndex ()][ p_tone. getIndex ()];
182
+ double distance2 = index. getDistanceMatrix ()[p. getIndex ()][ p_tone. getIndex ()] ;
183
+
184
+ if ((alpha * distance1) <= distance2 ) {
174
185
V.erase (p_tone);
175
186
}
176
187
177
188
}
178
189
}
190
+
179
191
}
180
192
193
+
194
+ // / Explicit instantiations for RobustPrune and FilteredRobustPrune
195
+
196
+
197
+ // Explicit instantiation for RobustPrune with float data type and DataVector query type
181
198
template void RobustPrune<DataVector<float >>(
182
199
VamanaIndex<DataVector<float >>& index,
183
200
GraphNode<DataVector<float >>& p_node,
@@ -186,6 +203,7 @@ template void RobustPrune<DataVector<float>>(
186
203
int R
187
204
);
188
205
206
+ // Explicit instantiation for FilteredRobustPrune with float data type and DataVector query type
189
207
template void RobustPrune<BaseDataVector<float >>(
190
208
VamanaIndex<BaseDataVector<float >>& index,
191
209
GraphNode<BaseDataVector<float >>& p_node,
@@ -194,6 +212,7 @@ template void RobustPrune<BaseDataVector<float>>(
194
212
int R
195
213
);
196
214
215
+ // Explicit instantiation for FilteredRobustPrune with float data type and DataVector query type
197
216
template void FilteredRobustPrune<BaseDataVector<float >>(
198
217
FilteredVamanaIndex<BaseDataVector<float >>& index,
199
218
GraphNode<BaseDataVector<float >>& p_node,
0 commit comments