Skip to content

Distance saving functionality #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ FLAGS += $(DEPFLAGS)
# Execution Rules

create_simple_via:
./bin/main --create -index-type 'simple' -base-file 'data/siftsmall/siftsmall_base.fvecs' -L 120 -R 12 -alpha 1.0 -save 'simple_index.bin'
./bin/main --create -index-type 'simple' -base-file 'data/siftsmall/siftsmall_base.fvecs' -L 120 -R 12 -alpha 1.0 -save 'simple_index.bin' -distance-save 'matrix' -distance-threads 1

create_filtered_via:
./bin/main --create -index-type 'filtered' -base-file 'data/Dummy/dummy-data.bin' -L 120 -R 12 -alpha 1.0 -save 'filtered_index.bin'
./bin/main --create -index-type 'filtered' -base-file 'data/Dummy/dummy-data.bin' -L 120 -R 12 -alpha 1.0 -save 'filtered_index.bin' -distance-save 'matrix' -distance-threads 1

create_stiched_via:
./bin/main --create -index-type 'stiched' -base-file 'data/Dummy/dummy-data.bin' -L-small 150 -R-small 12 -R-stiched 20 -alpha 1.0 -save 'stiched_index.bin'
./bin/main --create -index-type 'stiched' -base-file 'data/Dummy/dummy-data.bin' -L-small 150 -R-small 12 -R-stiched 20 -alpha 1.0 -save 'stiched_index.bin' -distance-threads 1 -computing-threads 1

compute_groundtruth:
./bin/main --compute-gt -base-file 'data/Dummy/dummy-data.bin' -query-file 'data/Dummy/dummy-queries.bin' -gt-file 'data/Dummy/dummy-groundtruth.bin'
Expand Down
53 changes: 45 additions & 8 deletions app/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,21 @@ void Create(std::unordered_map<std::string, std::string> args) {
using BaseVectorVector = std::vector<BaseDataVector<float>>;
using BaseVectors = std::vector<DataVector<float>>;

std::string indexType, baseFile, L, R, alpha, outputFile, connectionMode;
std::string indexType, baseFile, L, R, alpha, outputFile, connectionMode, distanceSaveMethod;
std::string L_small, R_small, R_stiched;
bool save = false;
bool leaveEmpty = false;
int distanceThreads = 1; // Default value
int computingThreads = 1; // Default value

std::vector<std::string> validArguments = {"-index-type", "-base-file", "-L", "-L-small", "-R", "-R-small", "-R-stiched", "-alpha", "-save", "-random-edges", "-connection-mode", "-distance-threads", "-distance-save"};
if (args["-index-type"] == "stiched") {
validArguments.push_back("-computing-threads");
}

std::vector<std::string> validArguments = {"-index-type", "-base-file", "-L", "-L-small", "-R", "-R-small", "-R-stiched", "-alpha", "-save", "-random-edges", "-connection-mode", "-distance-threads"};
for (auto arg : args) {
if (std::find(validArguments.begin(), validArguments.end(), arg.first) == validArguments.end()) {
throw std::invalid_argument("Error: Invalid argument: " + arg.first + ". Valid arguments are: -index-type, -base-file, -L, -L-small, -R, -R-small, -R-stiched, -alpha, -save, -connection-mode, -distance-threads");
throw std::invalid_argument("Error: Invalid argument: " + arg.first + ". Valid arguments are: -index-type, -base-file, -L, -L-small, -R, -R-small, -R-stiched, -alpha, -save, -connection-mode, -distance-threads, -distance-save");
}
}

Expand All @@ -146,6 +151,8 @@ void Create(std::unordered_map<std::string, std::string> args) {
}

} else if (indexType == "stiched") {
validArguments.push_back("-computing-threads");

if (args.find("-L-small") == args.end()) {
throw std::invalid_argument("Error: Missing required argument: -L-small");
} else {
Expand All @@ -163,6 +170,10 @@ void Create(std::unordered_map<std::string, std::string> args) {
} else {
R_stiched = args["-R-stiched"];
}

if (args.find("-computing-threads") != args.end()) {
computingThreads = std::stoi(args["-computing-threads"]);
}
} else {
throw std::invalid_argument("Error: Invalid index type: " + indexType + ". Supported index types are: simple, filtered, stiched");
}
Expand Down Expand Up @@ -196,7 +207,19 @@ void Create(std::unordered_map<std::string, std::string> args) {
}
}

if (args.find("-distance-save") != args.end()) {
distanceSaveMethod = args["-distance-save"];
if (distanceSaveMethod != "none" && distanceSaveMethod != "matrix") {
throw std::invalid_argument("Error: Invalid value for -distance-save. Valid values are: none, matrix");
}
} else {
distanceSaveMethod = "none"; // Default value
}

if (args.find("-distance-threads") != args.end()) {
if (distanceSaveMethod != "matrix") {
throw std::invalid_argument("Error: -distance-threads can only be used if -distance-save is set to 'matrix'");
}
distanceThreads = std::stoi(args["-distance-threads"]);
}

Expand All @@ -206,9 +229,16 @@ void Create(std::unordered_map<std::string, std::string> args) {
std::cerr << "Error reading base file" << std::endl;
return;
}

DISTANCE_SAVE_METHOD distanceSaveMethodEnum = NONE;
if (distanceSaveMethod == "none") {
distanceSaveMethodEnum = NONE;
} else if (distanceSaveMethod == "matrix") {
distanceSaveMethodEnum = MATRIX;
}

VamanaIndex<DataVector<float>> vamanaIndex = VamanaIndex<DataVector<float>>();
vamanaIndex.createGraph(base_vectors, std::stof(alpha), std::stoi(L), std::stoi(R), distanceThreads, true);
vamanaIndex.createGraph(base_vectors, std::stof(alpha), std::stoi(L), std::stoi(R), distanceSaveMethodEnum, distanceThreads, true);

if (save) {
if (!vamanaIndex.saveGraph(outputFile)) {
Expand All @@ -226,17 +256,24 @@ void Create(std::unordered_map<std::string, std::string> args) {
filters.insert(filter);
}

DISTANCE_SAVE_METHOD distanceSaveMethodEnum = NONE;
if (distanceSaveMethod == "none") {
distanceSaveMethodEnum = NONE;
} else if (distanceSaveMethod == "matrix") {
distanceSaveMethodEnum = MATRIX;
}

if (indexType == "filtered") {
FilteredVamanaIndex<BaseDataVector<float>> index(filters);
index.createGraph(base_vectors, std::stoi(alpha), std::stoi(L), std::stoi(R), distanceThreads, true, leaveEmpty);
index.createGraph(base_vectors, std::stoi(alpha), std::stoi(L), std::stoi(R), distanceSaveMethodEnum, distanceThreads, true, leaveEmpty);

if (save) {
index.saveGraph(outputFile);
std::cout << std::endl << green << "Vamana Index was saved successfully to " << brightYellow << "`" << outputFile << "`" << reset << std::endl;
}
} else if (indexType == "stiched") {
StichedVamanaIndex<BaseDataVector<float>> index(filters);
index.createGraph(base_vectors, std::stof(alpha), std::stoi(L_small), std::stoi(R_small), std::stoi(R_stiched), distanceThreads, true, leaveEmpty);
index.createGraph(base_vectors, std::stof(alpha), std::stoi(L_small), std::stoi(R_small), std::stoi(R_stiched), distanceSaveMethodEnum, distanceThreads, computingThreads, true, leaveEmpty);

if (save) {
index.saveGraph(outputFile);
Expand Down Expand Up @@ -277,7 +314,7 @@ void TestSimple(std::unordered_map<std::string, std::string> args) {
GraphNode<DataVector<float>> s = vamanaIndex.findMedoid(vamanaIndex.getGraph(), 1000);

auto start = std::chrono::high_resolution_clock::now();
SimpleGreedyResult greedyResult = GreedySearch(vamanaIndex, s, query_vectors.at(std::stoi(queryNumber)), std::stoi(k), std::stoi(L), TEST);
SimpleGreedyResult greedyResult = GreedySearch(vamanaIndex, s, query_vectors.at(std::stoi(queryNumber)), std::stoi(k), std::stoi(L), NONE);
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start;

Expand Down Expand Up @@ -363,7 +400,7 @@ void TestFilteredOrStiched(std::unordered_map<std::string, std::string> args) {
}

auto start = std::chrono::high_resolution_clock::now();
FilteredGreedyResult greedyResult = FilteredGreedySearch(index, start_nodes, xq, std::stoi(k), std::stoi(L), Fx, TEST);
FilteredGreedyResult greedyResult = FilteredGreedySearch(index, start_nodes, xq, std::stoi(k), std::stoi(L), Fx, NONE);
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start;

Expand Down
11 changes: 10 additions & 1 deletion include/FilteredVamanaIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,16 @@ template <typename vamana_t> class FilteredVamanaIndex : public VamanaIndex<vama
* @param L An unsigned int parameter.
* @param R An unsigned int parameter.
*/
void createGraph(const std::vector<vamana_t>& P, const float& alpha, const unsigned int L, const unsigned int R, unsigned int distance_threads = 1, bool visualized = true, bool empty = true);
void createGraph(
const std::vector<vamana_t>& P,
const float& alpha,
const unsigned int L,
const unsigned int R,
const DISTANCE_SAVE_METHOD distanceSaveMethod = NONE,
unsigned int distance_threads = 1,
bool visualized = true,
bool empty = true
);

/**
* @brief Load a graph from a file. Specifically this method is used to receive the contents of a Vamana Index Graph
Expand Down
5 changes: 3 additions & 2 deletions include/GreedySearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum EXEC_MODE {
TEST = 1
};


template <typename vamana_t> class VamanaIndex;
template <typename vamana_t> class FilteredVamanaIndex;

Expand All @@ -52,7 +53,7 @@ template <typename graph_t, typename query_t> std::pair<std::set<graph_t>, std::
const query_t& xq,
unsigned int k,
unsigned int L,
const EXEC_MODE execMode = CREATE
const DISTANCE_SAVE_METHOD distanceSaveMethod = NONE
);

/**
Expand Down Expand Up @@ -82,7 +83,7 @@ template <typename graph_t, typename query_t> std::pair<std::set<graph_t>, std::
const unsigned int k,
const unsigned int L,
const std::vector<CategoricalAttributeFilter>& queryFilters,
const EXEC_MODE execMode = CREATE
const DISTANCE_SAVE_METHOD distanceSaveMethod = NONE
);


Expand Down
20 changes: 17 additions & 3 deletions include/RobustPrune.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "DataVector.h"
#include "BQDataVectors.h"
#include "VamanaIndex.h"
#include "distance.h"

template <typename graph_t> class VamanaIndex;
template <typename graph_t> class FilteredVamanaIndex;
Expand All @@ -27,8 +28,14 @@ template <typename graph_t> class FilteredVamanaIndex;
* 4. Removes nodes from `V` that do not satisfy the distance threshold defined by `alpha`.
* 5. Stops when the number of neighbors of `p_node` reaches `R` or `V` is empty.
*/
template <typename graph_t>
void RobustPrune(VamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::set<graph_t>& V, float alpha, int R);
template <typename graph_t> void RobustPrune(
VamanaIndex<graph_t>& index,
GraphNode<graph_t>& p_node,
std::set<graph_t>& V,
float alpha,
int R,
const DISTANCE_SAVE_METHOD distanceSaveMethod
);

/**
* @brief Prunes the neighbors of a given node in a graph based on a robust pruning algorithm with filtering.
Expand All @@ -45,4 +52,11 @@ void RobustPrune(VamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node, std::s
* @param R An integer specifying the maximum number of neighbors to retain.
*/
template <typename graph_t>
void FilteredRobustPrune(FilteredVamanaIndex<graph_t>& index, GraphNode<graph_t>& p_node,std::set<graph_t>& V, float alpha,int R);
void FilteredRobustPrune(
FilteredVamanaIndex<graph_t>& index,
GraphNode<graph_t>& p_node,
std::set<graph_t>& V,
float alpha,
int R,
const DISTANCE_SAVE_METHOD distanceSaveMethod
);
20 changes: 16 additions & 4 deletions include/StichedVamanaIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ template <typename vamana_t> class StichedVamanaIndex : public FilteredVamanaInd
/**
* @brief Default constructor for the StichedVamanaIndex class.
*/
StichedVamanaIndex() : FilteredVamanaIndex<vamana_t>() {}
StichedVamanaIndex()
: FilteredVamanaIndex<vamana_t>() {}

/**
* @brief Constructor for the StichedVamanaIndex class with filters.
*
* @param filters A set of CategoricalAttributeFilter to initialize the index with.
*/
StichedVamanaIndex(std::set<CategoricalAttributeFilter> filters) : FilteredVamanaIndex<vamana_t>(filters) {}
StichedVamanaIndex(std::set<CategoricalAttributeFilter> filters)
: FilteredVamanaIndex<vamana_t>(filters) {}

/**
* @brief Create the graph with the given parameters.
Expand All @@ -30,8 +32,18 @@ template <typename vamana_t> class StichedVamanaIndex : public FilteredVamanaInd
* @param L An unsigned int parameter.
* @param R An unsigned int parameter.
*/
void createGraph(const std::vector<vamana_t>& P, const float& alpha, const unsigned int L_small,
const unsigned int R_small, const unsigned int R_stiched, unsigned int distance_threads, bool visualized = true, bool empty = true);
void createGraph(
const std::vector<vamana_t>& P,
const float& alpha,
const unsigned int L_small,
const unsigned int R_small,
const unsigned int R_stiched,
const DISTANCE_SAVE_METHOD distanceSaveMethod,
unsigned int distance_threads,
unsigned int compute_threads = 500,
bool visualized = true,
bool empty = true
);

};

Expand Down
11 changes: 10 additions & 1 deletion include/VamanaIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,16 @@ template <typename vamana_t> class VamanaIndex {
* @param R the parameter R
*
*/
void createGraph(const std::vector<vamana_t>& P, const float& alpha, const unsigned int L, const unsigned int& R, unsigned int distance_threads = 1, bool visualize = true, double** distanceMatrix = nullptr);
void createGraph(
const std::vector<vamana_t>& P,
const float& alpha,
const unsigned int L,
const unsigned int& R,
const DISTANCE_SAVE_METHOD distanceSaveMethod = NONE,
unsigned int distance_threads = 1,
bool visualize = true,
double** distanceMatrix = nullptr
);

/**
* @brief Saves a specific graph into a file. Specifically this method is used to save the contents of a Vamana
Expand Down
5 changes: 5 additions & 0 deletions include/distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
#include <vector>
#include "DataVector.h"

enum DISTANCE_SAVE_METHOD {
NONE = 0,
MATRIX = 1,
};


/**
* @brief Comparator structure for ordering elements by Euclidean distance.
Expand Down
26 changes: 20 additions & 6 deletions src/Graphics/ProgressBar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ bool isUtf8Supported() {
}

/**
* @brief Function to display a progress bar with a percentage.
* @brief Function to display a progress bar with a percentage and a loading animation.
*
* Loading Symbol: =>
*
Expand All @@ -30,19 +30,22 @@ void displayProgressBar(
const int current, const int total, const std::string& message, const std::chrono::steady_clock::time_point& startTime, const unsigned int barWidth) {

static bool utf8Supported = isUtf8Supported();
static const char loadingSymbols[] = {'-', '\\', '|', '/'};
static int loadingIndex = 0;
static int callCounter = 0; // Counter to slow down the animation

const std::string horizontalLineSymbol = "\u2500";
const std::string verticalLineSymbol = "\u2502";
const std::string crossSymbol = "\u253C";

if (firstTime) {
std::cout << brightMagenta << "Action" << std::setw(22) << reset << " " << verticalLineSymbol << " ";
std::cout << brightMagenta << "Action" << std::setw(24) << reset << " " << verticalLineSymbol << " ";
std::cout << brightMagenta << "Progress" << std::setw(36) << reset << " " << verticalLineSymbol << " ";
std::cout << brightMagenta << "Time Remaining" << reset << " | ";
std::cout << brightMagenta << "Time Elapsed" << reset << std::endl;


for (unsigned int i = 0; i < 25; i++) { std::cout << horizontalLineSymbol; } std::cout << crossSymbol;
for (unsigned int i = 0; i < 27; i++) { std::cout << horizontalLineSymbol; } std::cout << crossSymbol;
for (unsigned int i = 0; i < 42; i++) { std::cout << horizontalLineSymbol; } std::cout << crossSymbol;
for (unsigned int i = 0; i < 16; i++) { std::cout << horizontalLineSymbol; } std::cout << crossSymbol;
for (unsigned int i = 0; i < 15; i++) { std::cout << horizontalLineSymbol; }
Expand All @@ -67,8 +70,18 @@ void displayProgressBar(
int minutes = remainingSeconds / 60;
int seconds = remainingSeconds % 60;

// Display action message
std::cout << brightYellow << std::setw(24) << std::setfill(' ') << std::left << message << reset;
// Display action message with loading animation
if (current > 0 && current < total) {
std::cout << brightYellow << std::setw(24) << std::setfill(' ') << std::left << message;
if (callCounter % 8 == 0) { // Update loading symbol every 8 calls
loadingIndex++;
}
std::cout << " " << loadingSymbols[loadingIndex % 4] << reset;
}
else if (current == total) {
std::cout << brightGreen << std::setw(24) << std::setfill(' ') << std::left << message;
std::cout << " " << tickSymbol << reset;
}

// Display progress bar
std::cout << " " << verticalLineSymbol << " ";
Expand Down Expand Up @@ -104,7 +117,7 @@ void displayProgressBar(
}
else if (current == total) {
std::cout << " " << verticalLineSymbol << " " << yellow;
std::cout << brightGreen << "Done " << tickSymbol << std::setw(12) << std::setfill(' ') << reset;
std::cout << brightGreen << "Done" << std::setw(14) << std::setfill(' ') << reset;
}

// Display elapsed time
Expand All @@ -119,6 +132,7 @@ void displayProgressBar(
std::cout << "\r"; // Return the cursor to the start of the line
std::cout.flush();

callCounter++;
}

/**
Expand Down
Loading
Loading