Skip to content

Started data analysis #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ test.vig
data/Dummy/dummy-groundtruth*.bin
*.bin
.idea
venv
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ test_filtered_via:
test_stiched_via:
./bin/main --test -index-type 'stiched' -load 'stiched_index.bin' -L 150 -k 100 -gt-file 'data/Dummy/dummy-groundtruth.bin' -query-file 'data/Dummy/dummy-queries.bin' -query 1

test_and_save_stiched_empty_unfiltered_via:
./bin/main --test -index-type 'stiched' -load 'models/stiched/stiched_index_empty.bin' -L 150 -k 100 -gt-file 'data/Dummy/dummy-groundtruth.bin' -query-file 'data/Dummy/dummy-queries.bin' -query -1 -test-on unfiltered -save-recalls results/empty/empty_stiched_index_unfiltered_recalls.txt

test_and_save_stiched_empty_filtered_via:
./bin/main --test -index-type 'stiched' -load 'models/stiched/stiched_index_empty.bin' -L 150 -k 100 -gt-file 'data/Dummy/dummy-groundtruth.bin' -query-file 'data/Dummy/dummy-queries.bin' -query -1 -test-on filtered -save-recalls results/empty/empty_stiched_index_filtered_recalls.txt

test_and_save_stiched_filled_unfiltered_via:
./bin/main --test -index-type 'stiched' -load 'models/stiched/stiched_index_filled.bin' -L 150 -k 100 -gt-file 'data/Dummy/dummy-groundtruth.bin' -query-file 'data/Dummy/dummy-queries.bin' -query -1 -test-on unfiltered -save-recalls results/filled/filled_stiched_index_unfiltered_recalls.txt

test_and_save_stiched_filled_filtered_via:
./bin/main --test -index-type 'stiched' -load 'models/stiched/stiched_index_filled.bin' -L 150 -k 100 -gt-file 'data/Dummy/dummy-groundtruth.bin' -query-file 'data/Dummy/dummy-queries.bin' -query -1 -test-on filtered -save-recalls results/filled/filled_stiched_index_filtered_recalls.txt


run_tests:
./bin/graph_node_test
Expand Down
62 changes: 62 additions & 0 deletions analysis/analize_empty_stiched_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import numpy as np
import matplotlib.pyplot as plt

def read_recalls(file_path):
with open(file_path, "r") as file:
recalls = [float(line.split(":")[1].split("%")[0]) for line in file.readlines()]
return recalls

def plot_histogram(recalls, title, file_path):
plt.figure(figsize=(10, 7))
plt.hist(recalls, edgecolor='black', bins=10)
plt.title(title)
plt.xlabel("Recall")
plt.ylabel("Frequency")
plt.tight_layout()
plt.savefig(file_path)

def plot_pie_chart(recalls, title, file_path):
ranges = ['0 - 10', '10 - 20', '20 - 30', '30 - 40', '40 - 50', '50 - 60', '60 - 70', '70 - 80', '80 - 90', '90 - 100']
range_counts = [sum(1 for recall in recalls if start <= recall < end) for start, end in zip(range(0, 100, 10), range(10, 110, 10))]
filtered_ranges_counts = [(ranges[i], range_counts[i]) for i in range(len(ranges)) if range_counts[i] > 0]
filtered_ranges, filtered_counts = zip(*filtered_ranges_counts)
colors = plt.cm.RdYlGn(np.linspace(0, 1, len(filtered_ranges)))

plt.figure(figsize=(10, 7))
plt.pie(filtered_counts, labels=filtered_ranges, autopct='%1.1f%%', startangle=140, colors=colors)
plt.title(title)
plt.legend(filtered_ranges, title="Recall Ranges", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
plt.tight_layout()
plt.savefig(file_path, bbox_inches='tight')

def main():
unfiltered_recalls = read_recalls('results/empty/empty_stiched_index_unfiltered_recalls.txt')
unfiltered_recalls = [recall for recall in unfiltered_recalls if not np.isnan(recall)]
print("Unfiltered Recall mean: ", np.mean(unfiltered_recalls))
plot_histogram(
unfiltered_recalls,
"Empty Initialized Stiched Index: Unfiltered Recalls",
"analysis/plots/empty/unfiltered/empty_index_unfiltered_recalls_hist.png"
)
plot_pie_chart(
unfiltered_recalls,
"Empty Initialized Stiched Index: Unfiltered Recalls",
"analysis/plots/empty/unfiltered/empty_index_unfiltered_recalls_pie.png"
)

filtered_recalls = read_recalls('results/empty/empty_stiched_index_filtered_recalls.txt')
filtered_recalls = [recall for recall in filtered_recalls if not np.isnan(recall)]
print("Filtered Recall mean: ", np.mean(filtered_recalls))
plot_histogram(
filtered_recalls,
"Empty Initialized Stiched Index: Filtered Recalls",
"analysis/plots/empty/filtered/empty_index_filtered_recalls_hist.png"
)
plot_pie_chart(
filtered_recalls,
"Empty Initialized Stiched Index: Filtered Recalls",
"analysis/plots/empty/filtered/empty_index_filtered_recalls_pie.png"
)

if __name__ == "__main__":
main()
62 changes: 62 additions & 0 deletions analysis/analize_filled_stiched_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import numpy as np
import matplotlib.pyplot as plt

def read_recalls(file_path):
with open(file_path, "r") as file:
recalls = [float(line.split(":")[1].split("%")[0]) for line in file.readlines()]
return recalls

def plot_histogram(recalls, title, file_path):
plt.figure(figsize=(10, 7))
plt.hist(recalls, edgecolor='black', bins=10)
plt.title(title)
plt.xlabel("Recall")
plt.ylabel("Frequency")
plt.tight_layout()
plt.savefig(file_path)

def plot_pie_chart(recalls, title, file_path):
ranges = ['0 - 10', '10 - 20', '20 - 30', '30 - 40', '40 - 50', '50 - 60', '60 - 70', '70 - 80', '80 - 90', '90 - 100']
range_counts = [sum(1 for recall in recalls if start <= recall < end) for start, end in zip(range(0, 100, 10), range(10, 110, 10))]
filtered_ranges_counts = [(ranges[i], range_counts[i]) for i in range(len(ranges)) if range_counts[i] > 0]
filtered_ranges, filtered_counts = zip(*filtered_ranges_counts)
colors = plt.cm.RdYlGn(np.linspace(0, 1, len(filtered_ranges)))

plt.figure(figsize=(10, 7))
plt.pie(filtered_counts, labels=filtered_ranges, autopct='%1.1f%%', startangle=140, colors=colors)
plt.title(title)
plt.legend(filtered_ranges, title="Recall Ranges", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
plt.tight_layout()
plt.savefig(file_path, bbox_inches='tight')

def main():
unfiltered_recalls = read_recalls('results/filled/filled_stiched_index_unfiltered_recalls.txt')
unfiltered_recalls = [recall for recall in unfiltered_recalls if not np.isnan(recall)]
print("Unfiltered Recall mean: ", np.mean(unfiltered_recalls))
plot_histogram(
unfiltered_recalls,
"Filled Initialized Stiched Index: Unfiltered Recalls",
"analysis/plots/filled/unfiltered/filled_index_unfiltered_recalls_hist.png"
)
plot_pie_chart(
unfiltered_recalls,
"Filled Initialized Stiched Index: Unfiltered Recalls",
"analysis/plots/filled/unfiltered/filled_index_unfiltered_recalls_pie.png"
)

filtered_recalls = read_recalls('results/filled/filled_stiched_index_filtered_recalls.txt')
filtered_recalls = [recall for recall in filtered_recalls if not np.isnan(recall)]
print("Filtered Recall mean: ", np.mean(filtered_recalls))
plot_histogram(
filtered_recalls,
"Filled Initialized Stiched Index: Filtered Recalls",
"analysis/plots/filled/filtered/filled_index_filtered_recalls_hist.png"
)
plot_pie_chart(
filtered_recalls,
"Filled Initialized Stiched Index: Filtered Recalls",
"analysis/plots/filled/filtered/filled_index_filtered_recalls_pie.png"
)

if __name__ == "__main__":
main()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
137 changes: 96 additions & 41 deletions app/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ctime>
#include <chrono>
#include <algorithm>
#include <fstream>
#include "../include/DataVector.h"
#include "../include/VamanaIndex.h"
#include "../include/FilteredVamanaIndex.h"
Expand Down Expand Up @@ -112,14 +113,15 @@ void Create(std::unordered_map<std::string, std::string> args) {
using BaseVectorVector = std::vector<BaseDataVector<float>>;
using BaseVectors = std::vector<DataVector<float>>;

std::string indexType, baseFile, L, R, alpha, outputFile;
std::string indexType, baseFile, L, R, alpha, outputFile, connectionMode;
std::string L_small, R_small, R_stiched;
bool save = false;
bool leaveEmpty = false;

std::vector<std::string> validArguments = {"-index-type", "-base-file", "-L", "-L-small", "-R", "-R-small", "-R-stiched", "-alpha", "-save"};
std::vector<std::string> validArguments = {"-index-type", "-base-file", "-L", "-L-small", "-R", "-R-small", "-R-stiched", "-alpha", "-save", "-random-edges", "-connection-mode"};
for (auto arg : args) {
if (std::find(validArguments.begin(), validArguments.end(), arg.first) == validArguments.end()) {
throw std::invalid_argument("Error: Invalid argument: " + arg.first + ". Valid arguments are: -index-type, -base-file, -L, -L-small, -R, -R-small, -R-stiched, -alpha, -save");
throw std::invalid_argument("Error: Invalid argument: " + arg.first + ". Valid arguments are: -index-type, -base-file, -L, -L-small, -R, -R-small, -R-stiched, -alpha, -save, -connection-mode");
}
}

Expand All @@ -141,6 +143,7 @@ void Create(std::unordered_map<std::string, std::string> args) {
} else {
R = args["-R"];
}

} else if (indexType == "stiched") {
if (args.find("-L-small") == args.end()) {
throw std::invalid_argument("Error: Missing required argument: -L-small");
Expand Down Expand Up @@ -183,6 +186,15 @@ void Create(std::unordered_map<std::string, std::string> args) {
save = true;
}

if (args.find("-connection-mode") != args.end()) {
connectionMode = args["-connection-mode"];
if (connectionMode == "empty") {
leaveEmpty = true;
} else if (connectionMode != "filled") {
throw std::invalid_argument("Error: Invalid value for -connection-mode. Valid values are: empty, filled");
}
}

if (indexType == "simple") {
BaseVectors base_vectors = ReadVectorFile(baseFile);
if (base_vectors.empty()) {
Expand Down Expand Up @@ -211,15 +223,15 @@ void Create(std::unordered_map<std::string, std::string> args) {

if (indexType == "filtered") {
FilteredVamanaIndex<BaseDataVector<float>> index(filters);
index.createGraph(base_vectors, std::stoi(alpha), std::stoi(L), std::stoi(R));
index.createGraph(base_vectors, std::stoi(alpha), std::stoi(L), std::stoi(R), true, leaveEmpty);

if (save) {
index.saveGraph(outputFile);
std::cout << std::endl << green << "Vamana Index was saved successfully to " << brightYellow << "`" << outputFile << "`" << reset << std::endl;
}
} else if (indexType == "stiched") {
StichedVamanaIndex<BaseDataVector<float>> index(filters);
index.createGraph(base_vectors, std::stof(alpha), std::stoi(L_small), std::stoi(R_small), std::stoi(R_stiched));
index.createGraph(base_vectors, std::stof(alpha), std::stoi(L_small), std::stoi(R_small), std::stoi(R_stiched), true, leaveEmpty);

if (save) {
index.saveGraph(outputFile);
Expand Down Expand Up @@ -282,67 +294,110 @@ void TestSimple(std::unordered_map<std::string, std::string> args) {
void TestFilteredOrStiched(std::unordered_map<std::string, std::string> args) {
using QueryVectorVector = std::vector<QueryDataVector<float>>;

std::string indexFile, k, L, groundtruthFile, queryFile, queryNumber;
std::string indexFile, k, L, groundtruthFile, queryFile, queryNumber, testOn, saveRecallsFile;

if (!getParameterValue(args, "-load", indexFile)) return;
if (!getParameterValue(args, "-k", k)) return;
if (!getParameterValue(args, "-L", L)) return;
if (!getParameterValue(args, "-gt-file", groundtruthFile)) return;
if (!getParameterValue(args, "-query-file", queryFile)) return;
if (!getParameterValue(args, "-query", queryNumber)) return;

QueryVectorVector query_vectors = ReadFilteredQueryVectorFile(queryFile);
QueryDataVector<float> xq = query_vectors[std::stoi(queryNumber)];
if (xq.getQueryType() > 1) {
return;
if (args.find("-test-on") != args.end()) {
if (queryNumber != "-1") {
std::cerr << "Error: The -test-on argument can only be used when -query is set to -1." << std::endl;
return;
}
testOn = args["-test-on"];
}
if (args.find("-save-recalls") != args.end()) {
if (queryNumber != "-1") {
std::cerr << "Error: The -save-recalls argument can only be used when -query is set to -1." << std::endl;
return;
}
saveRecallsFile = args["-save-recalls"];
}

QueryVectorVector query_vectors = ReadFilteredQueryVectorFile(queryFile);
FilteredVamanaIndex<BaseDataVector<float>> index;
index.loadGraph(indexFile);
std::vector<std::vector<int>> groundtruth = readGroundtruthFromFile(groundtruthFile);

std::map<Filter, GraphNode<BaseDataVector<float>>> medoids = index.findFilteredMedoid(std::stoi(L));
std::vector<GraphNode<BaseDataVector<float>>> start_nodes;
for (auto filter : index.getFilters()) {
start_nodes.push_back(medoids[filter]);
}

std::vector<CategoricalAttributeFilter> Fx;
if (xq.getQueryType() == 1) {
Fx.push_back(CategoricalAttributeFilter(xq.getV()));
std::ofstream recallFile;
if (!saveRecallsFile.empty()) {
recallFile.open(saveRecallsFile);
if (!recallFile.is_open()) {
std::cerr << "Error: Could not open file " << saveRecallsFile << " for writing." << std::endl;
return;
}
}

std::vector<GraphNode<BaseDataVector<float>>> P = index.getNodes();
std::set<BaseDataVector<float>> exactNeighbors;
auto processQuery = [&](int queryIdx) {
QueryDataVector<float> xq = query_vectors[queryIdx];
if (xq.getQueryType() > 1) {
return;
}

for (auto index : groundtruth[std::stoi(queryNumber)]) {
exactNeighbors.insert(P[index].getData());
if ((int)exactNeighbors.size() >= std::stoi(k)) {
break;
std::vector<CategoricalAttributeFilter> Fx;
if (xq.getQueryType() == 1) {
Fx.push_back(CategoricalAttributeFilter(xq.getV()));
}
}

auto start = std::chrono::high_resolution_clock::now();
FilteredGreedyResult greedyResult = FilteredGreedySearch(index, start_nodes, xq, std::stoi(k), std::stoi(L), Fx, TEST);
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start;
std::vector<GraphNode<BaseDataVector<float>>> P = index.getNodes();
std::set<BaseDataVector<float>> exactNeighbors;

std::set<BaseDataVector<float>> approximateNeighbors = greedyResult.first;
double recall = calculateRecallEvaluation(approximateNeighbors, exactNeighbors);
for (auto idx : groundtruth[queryIdx]) {
exactNeighbors.insert(P[idx].getData());
if ((int)exactNeighbors.size() >= std::stoi(k)) {
break;
}
}

std::cout << brightMagenta << std::endl << "Results:" << reset << std::endl;
std::cout << reset << "Current Query: " << brightCyan << queryNumber << reset << " | ";
std::cout << reset << "Query Type: ";
if (xq.getQueryType() == 0) std::cout << brightBlack << "Uniltered" << reset << " | ";
else std::cout << brightWhite << "Filtered" << reset << " | ";
std::cout << reset << "Recall: ";
if (recall < 0.2) std::cout << brightRed;
else if (recall < 0.4) std::cout << brightOrange;
else if (recall < 0.6) std::cout << brightYellow;
else if (recall < 0.8) std::cout << brightCyan;
else std::cout << brightGreen;
std::cout << recall*100 << "%" << reset << " | ";
std::cout << "Time: " << cyan << elapsed.count() << " seconds" << std::endl;
auto start = std::chrono::high_resolution_clock::now();
FilteredGreedyResult greedyResult = FilteredGreedySearch(index, start_nodes, xq, std::stoi(k), std::stoi(L), Fx, TEST);
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start;

std::set<BaseDataVector<float>> approximateNeighbors = greedyResult.first;
double recall = calculateRecallEvaluation(approximateNeighbors, exactNeighbors);

// std::cout << brightMagenta << std::endl << "Results for query " << queryIdx << ":" << reset << std::endl;
std::cout << reset << "Current Query: " << brightCyan << queryIdx << reset << " | ";
std::cout << reset << "Query Type: ";
if (xq.getQueryType() == 0) std::cout << brightBlack << "Unfiltered" << reset << " | ";
else std::cout << brightWhite << "Filtered " << reset << " | ";
std::cout << reset << "Recall: ";
if (recall < 0.2) std::cout << brightRed;
else if (recall < 0.4) std::cout << brightOrange;
else if (recall < 0.6) std::cout << brightYellow;
else if (recall < 0.8) std::cout << brightCyan;
else std::cout << brightGreen;
std::cout << recall*100 << "%" << reset << " | ";
std::cout << "Time: " << cyan << elapsed.count() << " seconds" << std::endl;

if (recallFile.is_open()) {
recallFile << "Query " << queryIdx << ": " << recall * 100 << "%" << std::endl;
}
};

if (queryNumber == "-1") {
for (size_t i = 0; i < query_vectors.size(); ++i) {
if (testOn == "filtered" && query_vectors[i].getQueryType() != 1) continue;
if (testOn == "unfiltered" && query_vectors[i].getQueryType() != 0) continue;
processQuery(i);
}
} else {
processQuery(std::stoi(queryNumber));
}

if (recallFile.is_open()) {
recallFile.close();
std::cout << "Recalls saved to " << saveRecallsFile << std::endl;
}
}

void Test(std::unordered_map<std::string, std::string> args) {
Expand Down
2 changes: 1 addition & 1 deletion include/FilteredVamanaIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ template <typename vamana_t> class FilteredVamanaIndex : public VamanaIndex<vama
* @param L An unsigned int parameter.
* @param R An unsigned int parameter.
*/
void createGraph(const std::vector<vamana_t>& P, const float& alpha, const unsigned int L, const unsigned int R, bool visualized = true);
void createGraph(const std::vector<vamana_t>& P, const float& alpha, const unsigned int L, const unsigned int R, bool visualized = true, bool empty = true);

/**
* @brief Load a graph from a file. Specifically this method is used to receive the contents of a Vamana Index Graph
Expand Down
2 changes: 1 addition & 1 deletion include/StichedVamanaIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ template <typename vamana_t> class StichedVamanaIndex : public FilteredVamanaInd
* @param R An unsigned int parameter.
*/
void createGraph(const std::vector<vamana_t>& P, const float& alpha, const unsigned int L_small,
const unsigned int R_small, const unsigned int R_stiched, bool visualized = true);
const unsigned int R_small, const unsigned int R_stiched, bool visualized = true, bool empty = true);

};

Expand Down
Loading
Loading