Skip to content

[MOD-8206] INT8 flow tests #573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ class PyHNSWLibIndex : public PyVecSimIndex {
} else if (type == VecSimType_FLOAT16) {
auto *hnsw = dynamic_cast<HNSWIndex<float16, float> *>(index.get());
hnsw->saveIndex(location);
} else if (type == VecSimType_INT8) {
auto *hnsw = dynamic_cast<HNSWIndex<int8_t, float> *>(index.get());
hnsw->saveIndex(location);
} else {
throw std::runtime_error("Invalid index data type");
}
Expand Down Expand Up @@ -432,6 +435,10 @@ class PyHNSWLibIndex : public PyVecSimIndex {
return dynamic_cast<HNSWIndex<float16, float> *>(this->index.get())
->checkIntegrity()
.valid_state;
} else if (type == VecSimType_INT8) {
return dynamic_cast<HNSWIndex<int8_t, float> *>(this->index.get())
->checkIntegrity()
.valid_state;
} else {
throw std::runtime_error("Invalid index data type");
}
Expand Down
29 changes: 29 additions & 0 deletions tests/flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def create_hnsw_params(dim, num_elements, metric, data_type, ef_construction=200
hnsw_params.multi = is_multi

return hnsw_params

# Helper function for creating an index,uses the default HNSW parameters if not specified.
def create_hnsw_index(dim, num_elements, metric, data_type, ef_construction=200, m=16, ef_runtime=10, epsilon=0.01,
is_multi=False):
Expand All @@ -40,6 +41,23 @@ def create_hnsw_index(dim, num_elements, metric, data_type, ef_construction=200,

return HNSWIndex(hnsw_params)

# Helper function for creating an index, uses the default flat parameters if not specified.
def create_flat_index(dim, metric, data_type, is_multi=False):
bfparams = BFParams()

bfparams.dim = dim
bfparams.type = data_type
bfparams.metric = metric
bfparams.multi = is_multi

return BFIndex(bfparams)

def create_add_vectors(index, vectors):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it called "create"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It creates also the list of (key, vector) tuples

label_to_vec_list = []
for i, vector in enumerate(vectors):
index.add_vector(vector, i)
label_to_vec_list.append((i, vector))
return label_to_vec_list

# Compute the expected speedup as a function of the expected parallel section rate of the code by Amdahl's law
def expected_speedup(expected_parallel_rate, n_threads):
Expand All @@ -61,9 +79,20 @@ def vec_to_bfloat16(vec):
def vec_to_float16(vec):
return vec.astype(np.float16)

def create_int8_vectors(shape, rng: np.random.Generator = None):
rng = np.random.default_rng(seed=42) if rng is None else rng
return rng.integers(low=-128, high=127, size=shape, dtype=np.int8)

def get_ground_truth_results(dist_func, query, vectors, k):
results = [{"dist": dist_func(query, vec), "label": key} for key, vec in vectors]
results = sorted(results, key=lambda x: x["dist"])
keys = [res["label"] for res in results[:k]]

return results, keys

def fp32_expand_and_calc_cosine_dist(a, b):
# stupid numpy doesn't make any intermediate conversions when handling small types
# so we might get overflow. We need to convert to float32 ourselves.
a_float32 = a.astype(np.float32)
b_float32 = b.astype(np.float32)
return spatial.distance.cosine(a_float32, b_float32)
198 changes: 198 additions & 0 deletions tests/flow/test_bruteforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,201 @@ def test_bf_float16_multivalue():

assert_allclose(bf_labels, [keys], rtol=1e-5, atol=0)
assert_allclose(bf_distances, [dists], rtol=1e-5, atol=0)

'''
A Class to run common tests for BF index

The following tests will *automatically* run if the class is inherited:
* test_serialization - single L2 index
* test_L2 - single L2 index
* test_batch_iterator - single L2 index

The following tests should be *explicitly* called from a method prefixed with test_*
# range_query(dist_func) - single cosine index

@param create_data_func is a function expects num_elements, dim, [and optional np.random.Generator] as input and
returns a (num_elements, dim) numpy array of vectors
uses multi L2 index
# multi_value(create_data_func, num_per_label) -
'''
class GeneralTest():
dim = 128
num_elements = 10_000
num_queries = 1

data_type = None

rng = np.random.default_rng(seed=42)
vectors_data = None
query_data = None

# single FLAT index with L2 metric
cache_flat_index_L2_single = None
cached_label_to_vec_list = None

@classmethod
def create_index(cls, metric = VecSimMetric_L2, is_multi=False):
assert cls.data_type is not None
return create_flat_index(cls.dim, metric, cls.data_type, is_multi=is_multi)

@classmethod
def create_add_vectors(cls, index):
assert cls.vectors_data is not None
return create_add_vectors(index, cls.vectors_data)

@classmethod
def get_cached_single_L2_index(cls):
if cls.cache_flat_index_L2_single is None:
cls.cache_flat_index_L2_single = cls.create_index()
cls.cached_label_to_vec_list = cls.create_add_vectors(cls.cache_flat_index_L2_single)
return cls.cache_flat_index_L2_single, cls.cached_label_to_vec_list

@staticmethod
def compute_correct(res_labels, res_dist, gt_labels, gt_dist_label_list):
correct = 0
for i, label in enumerate(res_labels):
for j, correct_label in enumerate(gt_labels):
if label == correct_label:
correct += 1
assert math.isclose(res_dist[i], gt_dist_label_list[j]["dist"], rel_tol=1e-5)
break

return correct

@classmethod
def knn(cls, index, label_vec_list, dist_func):
k = 10

results, keys = get_ground_truth_results(dist_func, cls.query_data[0], label_vec_list, k)
dists = [res["dist"] for res in results]
bf_labels, bf_distances = index.knn_query(cls.query_data, k=k)
assert_allclose(bf_labels, [keys], rtol=1e-5, atol=0)
assert_allclose(bf_distances, [dists[:k]], rtol=1e-5, atol=0)
print(f"\nsanity test for L2 and {cls.data_type} pass")

def test_L2(self):
index, label_to_vec_list = self.get_cached_single_L2_index()
self.knn(index, label_to_vec_list, spatial.distance.sqeuclidean)

def test_batch_iterator(self):
index, _ = self.get_cached_single_L2_index()
# num_elements = self.num_labels
batch_size = 10


batch_iterator = index.create_batch_iterator(self.query_data)
labels_first_batch, distances_first_batch = batch_iterator.get_next_results(batch_size, BY_ID)
for i, _ in enumerate(labels_first_batch[0][:-1]):
# assert sorting by id
assert(labels_first_batch[0][i] < labels_first_batch[0][i+1])

_, distances_second_batch = batch_iterator.get_next_results(batch_size, BY_SCORE)
for i, dist in enumerate(distances_second_batch[0][:-1]):
# assert sorting by score
assert(distances_second_batch[0][i] < distances_second_batch[0][i+1])
# assert that every distance in the second batch is higher than any distance of the first batch
assert(len(distances_first_batch[0][np.where(distances_first_batch[0] > dist)]) == 0)

# reset
batch_iterator.reset()

# Run again in batches until depleted
batch_size = 1500
returned_results_num = 0
iterations = 0
start = time.time()
while batch_iterator.has_next():
iterations += 1
labels, distances = batch_iterator.get_next_results(batch_size, BY_SCORE)
returned_results_num += len(labels[0])

print(f'Total search time for running batches of size {batch_size} for index with {self.num_elements} of dim={self.dim}: {time.time() - start}')
assert (returned_results_num == self.num_elements)
assert (iterations == np.ceil(self.num_elements/batch_size))

##### Should be explicitly called #####
def range_query(self, dist_func):
bfindex = self.create_index(VecSimMetric_Cosine)
label_to_vec_list = self.create_add_vectors(bfindex)
radius = 0.7

start = time.time()
bf_labels, bf_distances = bfindex.range_query(self.query_data[0], radius=radius)
end = time.time()
res_num = len(bf_labels[0])
print(f'\nlookup time for {self.num_elements} vectors with dim={self.dim} took {end - start} seconds, got {res_num} results')

# Verify that we got exactly all vectors within the range
results, keys = get_ground_truth_results(dist_func, self.query_data[0], label_to_vec_list, res_num)

assert_allclose(max(bf_distances[0]), results[res_num-1]["dist"], rtol=1e-05)
assert np.array_equal(np.array(bf_labels[0]), np.array(keys))
assert max(bf_distances[0]) <= radius
# Verify that the next closest vector that hasn't returned is not within the range
assert results[res_num]["dist"] > radius

# Expect zero results for radius==0
bf_labels, bf_distances = bfindex.range_query(self.query_data[0], radius=0)
assert len(bf_labels[0]) == 0

def multi_value(self, create_data_func, num_per_label = 5):
# num_labels=5_000
# num_per_label=20
# num_elements = num_labels * num_per_label
num_labels = self.num_elements // num_per_label
k = 10

data = create_data_func((num_labels, self.dim), self.rng)

index = self.create_index(is_multi=True)

vectors = []
for i, vector in enumerate(data):
for _ in range(num_per_label):
index.add_vector(vector, i)
vectors.append((i, vector))

dists = {}
for key, vec in vectors:
# Setting or updating the score for each label.
# If it's the first time we calculate a score for a label dists.get(key, dist)
# will return dist so we will choose the actual score the first time.
dist = spatial.distance.sqeuclidean(self.query_data[0], vec)
dists[key] = min(dist, dists.get(key, dist))

dists = list(dists.items())
dists = sorted(dists, key=lambda pair: pair[1])[:k]
keys = [key for key, _ in dists[:k]]
dists = [dist for _, dist in dists[:k]]

start = time.time()
bf_labels, bf_distances = index.knn_query(self.query_data[0], k=10)
end = time.time()

print(f'\nlookup time for {self.num_elements} vectors ({num_labels} labels and {num_per_label} vectors per label) with dim={self.dim} took {end - start} seconds')

assert_allclose(bf_labels, [keys], rtol=1e-5, atol=0)
assert_allclose(bf_distances, [dists], rtol=1e-5, atol=0)

class TestINT8(GeneralTest):

GeneralTest.data_type = VecSimType_INT8

#### Create vectors
GeneralTest.vectors_data = create_int8_vectors((GeneralTest.num_elements, GeneralTest.dim), GeneralTest.rng)

#### Create queries
GeneralTest.query_data = create_int8_vectors((GeneralTest.num_queries, GeneralTest.dim), GeneralTest.rng)

def test_Cosine(self):

index = self.create_index(VecSimMetric_Cosine)
label_to_vec_list = self.create_add_vectors(index)

self.knn(index, label_to_vec_list, fp32_expand_and_calc_cosine_dist)

def test_range_query(self):
self.range_query(fp32_expand_and_calc_cosine_dist)

def test_multi_value(self):
self.multi_value(create_int8_vectors)
Loading
Loading