Skip to content

Commit a2eed91

Browse files
authored
Merge pull request #93 from AntonisZks/Antonis_workstation
Distance saving functionality
2 parents b951777 + eeff6e8 commit a2eed91

14 files changed

+347
-132
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ FLAGS += $(DEPFLAGS)
6767
# Execution Rules
6868

6969
create_simple_via:
70-
./bin/main --create -index-type 'simple' -base-file 'data/siftsmall/siftsmall_base.fvecs' -L 120 -R 12 -alpha 1.0 -save 'simple_index.bin'
70+
./bin/main --create -index-type 'simple' -base-file 'data/siftsmall/siftsmall_base.fvecs' -L 120 -R 12 -alpha 1.0 -save 'simple_index.bin' -distance-save 'matrix' -distance-threads 1
7171

7272
create_filtered_via:
73-
./bin/main --create -index-type 'filtered' -base-file 'data/Dummy/dummy-data.bin' -L 120 -R 12 -alpha 1.0 -save 'filtered_index.bin'
73+
./bin/main --create -index-type 'filtered' -base-file 'data/Dummy/dummy-data.bin' -L 120 -R 12 -alpha 1.0 -save 'filtered_index.bin' -distance-save 'matrix' -distance-threads 1
7474

7575
create_stiched_via:
76-
./bin/main --create -index-type 'stiched' -base-file 'data/Dummy/dummy-data.bin' -L-small 150 -R-small 12 -R-stiched 20 -alpha 1.0 -save 'stiched_index.bin'
76+
./bin/main --create -index-type 'stiched' -base-file 'data/Dummy/dummy-data.bin' -L-small 150 -R-small 12 -R-stiched 20 -alpha 1.0 -save 'stiched_index.bin' -distance-threads 1 -computing-threads 1
7777

7878
compute_groundtruth:
7979
./bin/main --compute-gt -base-file 'data/Dummy/dummy-data.bin' -query-file 'data/Dummy/dummy-queries.bin' -gt-file 'data/Dummy/dummy-groundtruth.bin'

app/main.cpp

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,21 @@ void Create(std::unordered_map<std::string, std::string> args) {
113113
using BaseVectorVector = std::vector<BaseDataVector<float>>;
114114
using BaseVectors = std::vector<DataVector<float>>;
115115

116-
std::string indexType, baseFile, L, R, alpha, outputFile, connectionMode;
116+
std::string indexType, baseFile, L, R, alpha, outputFile, connectionMode, distanceSaveMethod;
117117
std::string L_small, R_small, R_stiched;
118118
bool save = false;
119119
bool leaveEmpty = false;
120120
int distanceThreads = 1; // Default value
121+
int computingThreads = 1; // Default value
122+
123+
std::vector<std::string> validArguments = {"-index-type", "-base-file", "-L", "-L-small", "-R", "-R-small", "-R-stiched", "-alpha", "-save", "-random-edges", "-connection-mode", "-distance-threads", "-distance-save"};
124+
if (args["-index-type"] == "stiched") {
125+
validArguments.push_back("-computing-threads");
126+
}
121127

122-
std::vector<std::string> validArguments = {"-index-type", "-base-file", "-L", "-L-small", "-R", "-R-small", "-R-stiched", "-alpha", "-save", "-random-edges", "-connection-mode", "-distance-threads"};
123128
for (auto arg : args) {
124129
if (std::find(validArguments.begin(), validArguments.end(), arg.first) == validArguments.end()) {
125-
throw std::invalid_argument("Error: Invalid argument: " + arg.first + ". Valid arguments are: -index-type, -base-file, -L, -L-small, -R, -R-small, -R-stiched, -alpha, -save, -connection-mode, -distance-threads");
130+
throw std::invalid_argument("Error: Invalid argument: " + arg.first + ". Valid arguments are: -index-type, -base-file, -L, -L-small, -R, -R-small, -R-stiched, -alpha, -save, -connection-mode, -distance-threads, -distance-save");
126131
}
127132
}
128133

@@ -146,6 +151,8 @@ void Create(std::unordered_map<std::string, std::string> args) {
146151
}
147152

148153
} else if (indexType == "stiched") {
154+
validArguments.push_back("-computing-threads");
155+
149156
if (args.find("-L-small") == args.end()) {
150157
throw std::invalid_argument("Error: Missing required argument: -L-small");
151158
} else {
@@ -163,6 +170,10 @@ void Create(std::unordered_map<std::string, std::string> args) {
163170
} else {
164171
R_stiched = args["-R-stiched"];
165172
}
173+
174+
if (args.find("-computing-threads") != args.end()) {
175+
computingThreads = std::stoi(args["-computing-threads"]);
176+
}
166177
} else {
167178
throw std::invalid_argument("Error: Invalid index type: " + indexType + ". Supported index types are: simple, filtered, stiched");
168179
}
@@ -196,7 +207,19 @@ void Create(std::unordered_map<std::string, std::string> args) {
196207
}
197208
}
198209

210+
if (args.find("-distance-save") != args.end()) {
211+
distanceSaveMethod = args["-distance-save"];
212+
if (distanceSaveMethod != "none" && distanceSaveMethod != "matrix") {
213+
throw std::invalid_argument("Error: Invalid value for -distance-save. Valid values are: none, matrix");
214+
}
215+
} else {
216+
distanceSaveMethod = "none"; // Default value
217+
}
218+
199219
if (args.find("-distance-threads") != args.end()) {
220+
if (distanceSaveMethod != "matrix") {
221+
throw std::invalid_argument("Error: -distance-threads can only be used if -distance-save is set to 'matrix'");
222+
}
200223
distanceThreads = std::stoi(args["-distance-threads"]);
201224
}
202225

@@ -206,9 +229,16 @@ void Create(std::unordered_map<std::string, std::string> args) {
206229
std::cerr << "Error reading base file" << std::endl;
207230
return;
208231
}
232+
233+
DISTANCE_SAVE_METHOD distanceSaveMethodEnum = NONE;
234+
if (distanceSaveMethod == "none") {
235+
distanceSaveMethodEnum = NONE;
236+
} else if (distanceSaveMethod == "matrix") {
237+
distanceSaveMethodEnum = MATRIX;
238+
}
209239

210240
VamanaIndex<DataVector<float>> vamanaIndex = VamanaIndex<DataVector<float>>();
211-
vamanaIndex.createGraph(base_vectors, std::stof(alpha), std::stoi(L), std::stoi(R), distanceThreads, true);
241+
vamanaIndex.createGraph(base_vectors, std::stof(alpha), std::stoi(L), std::stoi(R), distanceSaveMethodEnum, distanceThreads, true);
212242

213243
if (save) {
214244
if (!vamanaIndex.saveGraph(outputFile)) {
@@ -226,17 +256,24 @@ void Create(std::unordered_map<std::string, std::string> args) {
226256
filters.insert(filter);
227257
}
228258

259+
DISTANCE_SAVE_METHOD distanceSaveMethodEnum = NONE;
260+
if (distanceSaveMethod == "none") {
261+
distanceSaveMethodEnum = NONE;
262+
} else if (distanceSaveMethod == "matrix") {
263+
distanceSaveMethodEnum = MATRIX;
264+
}
265+
229266
if (indexType == "filtered") {
230267
FilteredVamanaIndex<BaseDataVector<float>> index(filters);
231-
index.createGraph(base_vectors, std::stoi(alpha), std::stoi(L), std::stoi(R), distanceThreads, true, leaveEmpty);
268+
index.createGraph(base_vectors, std::stoi(alpha), std::stoi(L), std::stoi(R), distanceSaveMethodEnum, distanceThreads, true, leaveEmpty);
232269

233270
if (save) {
234271
index.saveGraph(outputFile);
235272
std::cout << std::endl << green << "Vamana Index was saved successfully to " << brightYellow << "`" << outputFile << "`" << reset << std::endl;
236273
}
237274
} else if (indexType == "stiched") {
238275
StichedVamanaIndex<BaseDataVector<float>> index(filters);
239-
index.createGraph(base_vectors, std::stof(alpha), std::stoi(L_small), std::stoi(R_small), std::stoi(R_stiched), distanceThreads, true, leaveEmpty);
276+
index.createGraph(base_vectors, std::stof(alpha), std::stoi(L_small), std::stoi(R_small), std::stoi(R_stiched), distanceSaveMethodEnum, distanceThreads, computingThreads, true, leaveEmpty);
240277

241278
if (save) {
242279
index.saveGraph(outputFile);
@@ -277,7 +314,7 @@ void TestSimple(std::unordered_map<std::string, std::string> args) {
277314
GraphNode<DataVector<float>> s = vamanaIndex.findMedoid(vamanaIndex.getGraph(), 1000);
278315

279316
auto start = std::chrono::high_resolution_clock::now();
280-
SimpleGreedyResult greedyResult = GreedySearch(vamanaIndex, s, query_vectors.at(std::stoi(queryNumber)), std::stoi(k), std::stoi(L), TEST);
317+
SimpleGreedyResult greedyResult = GreedySearch(vamanaIndex, s, query_vectors.at(std::stoi(queryNumber)), std::stoi(k), std::stoi(L), NONE);
281318
auto end = std::chrono::high_resolution_clock::now();
282319
std::chrono::duration<double> elapsed = end - start;
283320

@@ -363,7 +400,7 @@ void TestFilteredOrStiched(std::unordered_map<std::string, std::string> args) {
363400
}
364401

365402
auto start = std::chrono::high_resolution_clock::now();
366-
FilteredGreedyResult greedyResult = FilteredGreedySearch(index, start_nodes, xq, std::stoi(k), std::stoi(L), Fx, TEST);
403+
FilteredGreedyResult greedyResult = FilteredGreedySearch(index, start_nodes, xq, std::stoi(k), std::stoi(L), Fx, NONE);
367404
auto end = std::chrono::high_resolution_clock::now();
368405
std::chrono::duration<double> elapsed = end - start;
369406

include/FilteredVamanaIndex.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,16 @@ template <typename vamana_t> class FilteredVamanaIndex : public VamanaIndex<vama
6060
* @param L An unsigned int parameter.
6161
* @param R An unsigned int parameter.
6262
*/
63-
void createGraph(const std::vector<vamana_t>& P, const float& alpha, const unsigned int L, const unsigned int R, unsigned int distance_threads = 1, bool visualized = true, bool empty = true);
63+
void createGraph(
64+
const std::vector<vamana_t>& P,
65+
const float& alpha,
66+
const unsigned int L,
67+
const unsigned int R,
68+
const DISTANCE_SAVE_METHOD distanceSaveMethod = NONE,
69+
unsigned int distance_threads = 1,
70+
bool visualized = true,
71+
bool empty = true
72+
);
6473

6574
/**
6675
* @brief Load a graph from a file. Specifically this method is used to receive the contents of a Vamana Index Graph

include/GreedySearch.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ enum EXEC_MODE {
2828
TEST = 1
2929
};
3030

31+
3132
template <typename vamana_t> class VamanaIndex;
3233
template <typename vamana_t> class FilteredVamanaIndex;
3334

@@ -52,7 +53,7 @@ template <typename graph_t, typename query_t> std::pair<std::set<graph_t>, std::
5253
const query_t& xq,
5354
unsigned int k,
5455
unsigned int L,
55-
const EXEC_MODE execMode = CREATE
56+
const DISTANCE_SAVE_METHOD distanceSaveMethod = NONE
5657
);
5758

5859
/**
@@ -82,7 +83,7 @@ template <typename graph_t, typename query_t> std::pair<std::set<graph_t>, std::
8283
const unsigned int k,
8384
const unsigned int L,
8485
const std::vector<CategoricalAttributeFilter>& queryFilters,
85-
const EXEC_MODE execMode = CREATE
86+
const DISTANCE_SAVE_METHOD distanceSaveMethod = NONE
8687
);
8788

8889

include/RobustPrune.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "DataVector.h"
44
#include "BQDataVectors.h"
55
#include "VamanaIndex.h"
6+
#include "distance.h"
67

78
template <typename graph_t> class VamanaIndex;
89
template <typename graph_t> class FilteredVamanaIndex;
@@ -27,8 +28,14 @@ template <typename graph_t> class FilteredVamanaIndex;
2728
* 4. Removes nodes from `V` that do not satisfy the distance threshold defined by `alpha`.
2829
* 5. Stops when the number of neighbors of `p_node` reaches `R` or `V` is empty.
2930
*/
30-
template <typename graph_t>
31-
void RobustPrune(VamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::set<graph_t>& V, float alpha, int R);
31+
template <typename graph_t> void RobustPrune(
32+
VamanaIndex<graph_t>& index,
33+
GraphNode<graph_t>& p_node,
34+
std::set<graph_t>& V,
35+
float alpha,
36+
int R,
37+
const DISTANCE_SAVE_METHOD distanceSaveMethod
38+
);
3239

3340
/**
3441
* @brief Prunes the neighbors of a given node in a graph based on a robust pruning algorithm with filtering.
@@ -45,4 +52,11 @@ void RobustPrune(VamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::s
4552
* @param R An integer specifying the maximum number of neighbors to retain.
4653
*/
4754
template <typename graph_t>
48-
void FilteredRobustPrune(FilteredVamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node,std::set<graph_t>& V, float alpha,int R);
55+
void FilteredRobustPrune(
56+
FilteredVamanaIndex<graph_t>& index,
57+
GraphNode<graph_t>& p_node,
58+
std::set<graph_t>& V,
59+
float alpha,
60+
int R,
61+
const DISTANCE_SAVE_METHOD distanceSaveMethod
62+
);

include/StichedVamanaIndex.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ template <typename vamana_t> class StichedVamanaIndex : public FilteredVamanaInd
1313
/**
1414
* @brief Default constructor for the StichedVamanaIndex class.
1515
*/
16-
StichedVamanaIndex() : FilteredVamanaIndex<vamana_t>() {}
16+
StichedVamanaIndex()
17+
: FilteredVamanaIndex<vamana_t>() {}
1718

1819
/**
1920
* @brief Constructor for the StichedVamanaIndex class with filters.
2021
*
2122
* @param filters A set of CategoricalAttributeFilter to initialize the index with.
2223
*/
23-
StichedVamanaIndex(std::set<CategoricalAttributeFilter> filters) : FilteredVamanaIndex<vamana_t>(filters) {}
24+
StichedVamanaIndex(std::set<CategoricalAttributeFilter> filters)
25+
: FilteredVamanaIndex<vamana_t>(filters) {}
2426

2527
/**
2628
* @brief Create the graph with the given parameters.
@@ -30,8 +32,18 @@ template <typename vamana_t> class StichedVamanaIndex : public FilteredVamanaInd
3032
* @param L An unsigned int parameter.
3133
* @param R An unsigned int parameter.
3234
*/
33-
void createGraph(const std::vector<vamana_t>& P, const float& alpha, const unsigned int L_small,
34-
const unsigned int R_small, const unsigned int R_stiched, unsigned int distance_threads, bool visualized = true, bool empty = true);
35+
void createGraph(
36+
const std::vector<vamana_t>& P,
37+
const float& alpha,
38+
const unsigned int L_small,
39+
const unsigned int R_small,
40+
const unsigned int R_stiched,
41+
const DISTANCE_SAVE_METHOD distanceSaveMethod,
42+
unsigned int distance_threads,
43+
unsigned int compute_threads = 500,
44+
bool visualized = true,
45+
bool empty = true
46+
);
3547

3648
};
3749

include/VamanaIndex.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,16 @@ template <typename vamana_t> class VamanaIndex {
103103
* @param R the parameter R
104104
*
105105
*/
106-
void createGraph(const std::vector<vamana_t>& P, const float& alpha, const unsigned int L, const unsigned int& R, unsigned int distance_threads = 1, bool visualize = true, double** distanceMatrix = nullptr);
106+
void createGraph(
107+
const std::vector<vamana_t>& P,
108+
const float& alpha,
109+
const unsigned int L,
110+
const unsigned int& R,
111+
const DISTANCE_SAVE_METHOD distanceSaveMethod = NONE,
112+
unsigned int distance_threads = 1,
113+
bool visualize = true,
114+
double** distanceMatrix = nullptr
115+
);
107116

108117
/**
109118
* @brief Saves a specific graph into a file. Specifically this method is used to save the contents of a Vamana

include/distance.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
#include <vector>
88
#include "DataVector.h"
99

10+
enum DISTANCE_SAVE_METHOD {
11+
NONE = 0,
12+
MATRIX = 1,
13+
};
14+
1015

1116
/**
1217
* @brief Comparator structure for ordering elements by Euclidean distance.

src/Graphics/ProgressBar.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ bool isUtf8Supported() {
1515
}
1616

1717
/**
18-
* @brief Function to display a progress bar with a percentage.
18+
* @brief Function to display a progress bar with a percentage and a loading animation.
1919
*
2020
* Loading Symbol: =>
2121
*
@@ -30,19 +30,22 @@ void displayProgressBar(
3030
const int current, const int total, const std::string& message, const std::chrono::steady_clock::time_point& startTime, const unsigned int barWidth) {
3131

3232
static bool utf8Supported = isUtf8Supported();
33+
static const char loadingSymbols[] = {'-', '\\', '|', '/'};
34+
static int loadingIndex = 0;
35+
static int callCounter = 0; // Counter to slow down the animation
3336

3437
const std::string horizontalLineSymbol = "\u2500";
3538
const std::string verticalLineSymbol = "\u2502";
3639
const std::string crossSymbol = "\u253C";
3740

3841
if (firstTime) {
39-
std::cout << brightMagenta << "Action" << std::setw(22) << reset << " " << verticalLineSymbol << " ";
42+
std::cout << brightMagenta << "Action" << std::setw(24) << reset << " " << verticalLineSymbol << " ";
4043
std::cout << brightMagenta << "Progress" << std::setw(36) << reset << " " << verticalLineSymbol << " ";
4144
std::cout << brightMagenta << "Time Remaining" << reset << " | ";
4245
std::cout << brightMagenta << "Time Elapsed" << reset << std::endl;
4346

4447

45-
for (unsigned int i = 0; i < 25; i++) { std::cout << horizontalLineSymbol; } std::cout << crossSymbol;
48+
for (unsigned int i = 0; i < 27; i++) { std::cout << horizontalLineSymbol; } std::cout << crossSymbol;
4649
for (unsigned int i = 0; i < 42; i++) { std::cout << horizontalLineSymbol; } std::cout << crossSymbol;
4750
for (unsigned int i = 0; i < 16; i++) { std::cout << horizontalLineSymbol; } std::cout << crossSymbol;
4851
for (unsigned int i = 0; i < 15; i++) { std::cout << horizontalLineSymbol; }
@@ -67,8 +70,18 @@ void displayProgressBar(
6770
int minutes = remainingSeconds / 60;
6871
int seconds = remainingSeconds % 60;
6972

70-
// Display action message
71-
std::cout << brightYellow << std::setw(24) << std::setfill(' ') << std::left << message << reset;
73+
// Display action message with loading animation
74+
if (current > 0 && current < total) {
75+
std::cout << brightYellow << std::setw(24) << std::setfill(' ') << std::left << message;
76+
if (callCounter % 8 == 0) { // Update loading symbol every 8 calls
77+
loadingIndex++;
78+
}
79+
std::cout << " " << loadingSymbols[loadingIndex % 4] << reset;
80+
}
81+
else if (current == total) {
82+
std::cout << brightGreen << std::setw(24) << std::setfill(' ') << std::left << message;
83+
std::cout << " " << tickSymbol << reset;
84+
}
7285

7386
// Display progress bar
7487
std::cout << " " << verticalLineSymbol << " ";
@@ -104,7 +117,7 @@ void displayProgressBar(
104117
}
105118
else if (current == total) {
106119
std::cout << " " << verticalLineSymbol << " " << yellow;
107-
std::cout << brightGreen << "Done " << tickSymbol << std::setw(12) << std::setfill(' ') << reset;
120+
std::cout << brightGreen << "Done" << std::setw(14) << std::setfill(' ') << reset;
108121
}
109122

110123
// Display elapsed time
@@ -119,6 +132,7 @@ void displayProgressBar(
119132
std::cout << "\r"; // Return the cursor to the start of the line
120133
std::cout.flush();
121134

135+
callCounter++;
122136
}
123137

124138
/**

0 commit comments

Comments
 (0)