Skip to content

Commit 02ce860

Browse files
authored
Add k_factor to local-benchmarks (#517)
1 parent 701b9cc commit 02ce860

File tree

1 file changed

+41
-32
lines changed

1 file changed

+41
-32
lines changed

apis/python/test/local-benchmarks.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class RemoteURIType(Enum):
8686
)
8787

8888

89+
def sift_string():
90+
return "(SIFT 10K)" if USE_SIFT_SMALL else "(SIFT 1M)"
91+
92+
8993
class TimerMode(Enum):
9094
INGESTION = "ingestion"
9195
QUERY = "query"
@@ -202,7 +206,9 @@ def save_charts(self):
202206
plt.figure(figsize=(20, 12))
203207
plt.xlabel("Average Query Accuracy")
204208
plt.ylabel("Time (seconds)")
205-
plt.title(f"{self.name}: Ingestion Time vs Average Query Accuracy")
209+
plt.title(
210+
f"{self.name}: Ingestion Time vs Average Query Accuracy {sift_string()}"
211+
)
206212
self.add_data_to_ingestion_time_vs_average_query_accuracy()
207213
plt.legend()
208214
plt.savefig(
@@ -214,7 +220,7 @@ def save_charts(self):
214220
plt.figure(figsize=(20, 12))
215221
plt.xlabel("Accuracy")
216222
plt.ylabel("Time (seconds)")
217-
plt.title(f"{self.name}: Query Time vs Accuracy")
223+
plt.title(f"{self.name}: Query Time vs Accuracy {sift_string()}")
218224
self.add_data_to_query_time_vs_accuracy()
219225
plt.legend()
220226
plt.savefig(
@@ -245,7 +251,7 @@ def save_charts(self):
245251
plt.figure(figsize=(20, 12))
246252
plt.xlabel("Average Query Accuracy")
247253
plt.ylabel("Time (seconds)")
248-
plt.title("Ingestion Time vs Average Query Accuracy")
254+
plt.title(f"Ingestion Time vs Average Query Accuracy {sift_string()}")
249255
for idx, timer in self.timers:
250256
timer.add_data_to_ingestion_time_vs_average_query_accuracy(
251257
markers[idx % len(markers)]
@@ -258,7 +264,7 @@ def save_charts(self):
258264
plt.figure(figsize=(20, 12))
259265
plt.xlabel("Accuracy")
260266
plt.ylabel("Time (seconds)")
261-
plt.title("Query Time vs Accuracy")
267+
plt.title(f"Query Time vs Accuracy {sift_string()}")
262268
for idx, timer in self.timers:
263269
timer.add_data_to_query_time_vs_accuracy(markers[idx % len(markers)])
264270
plt.legend()
@@ -414,35 +420,38 @@ def benchmark_ivf_pq():
414420
dimensions = queries.shape[1]
415421
gt_i, gt_d = get_groundtruth_ivec(SIFT_GROUNDTRUTH_PATH, k=k, nqueries=len(queries))
416422

417-
for partitions in [50]:
418-
for num_subspaces in [dimensions / 2, dimensions / 4, dimensions / 8]:
419-
tag = f"{index_type}_partitions={partitions}_num_subspaces={num_subspaces}"
420-
logger.info(f"Running {tag}")
421-
422-
index_uri = get_uri(tag)
423-
424-
timer.start(tag, TimerMode.INGESTION)
425-
index = ingest(
426-
index_type=index_type,
427-
index_uri=index_uri,
428-
source_uri=SIFT_BASE_PATH,
429-
config=config,
430-
partitions=partitions,
431-
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
432-
num_subspaces=num_subspaces,
433-
)
434-
ingest_time = timer.stop(tag, TimerMode.INGESTION)
435-
436-
for nprobe in [5, 10, 20, 40, 60]:
437-
timer.start(tag, TimerMode.QUERY)
438-
_, result = index.query(queries, k=k, nprobe=nprobe)
439-
query_time = timer.stop(tag, TimerMode.QUERY)
440-
acc = timer.accuracy(tag, accuracy(result, gt_i))
441-
logger.info(
442-
f"Finished {tag} with nprobe={nprobe}. Ingestion: {ingest_time:.4f}s. Query: {query_time:.4f}s. Accuracy: {acc:.4f}."
423+
for partitions in [200]:
424+
for num_subspaces in [dimensions / 4]:
425+
for k_factor in [1, 1.5, 2, 4, 8, 16]:
426+
tag = f"{index_type}_partitions={partitions}_num_subspaces={num_subspaces}_k_factor={k_factor}"
427+
logger.info(f"Running {tag}")
428+
429+
index_uri = get_uri(tag)
430+
431+
timer.start(tag, TimerMode.INGESTION)
432+
index = ingest(
433+
index_type=index_type,
434+
index_uri=index_uri,
435+
source_uri=SIFT_BASE_PATH,
436+
config=config,
437+
partitions=partitions,
438+
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
439+
num_subspaces=num_subspaces,
443440
)
444-
445-
cleanup_uri(index_uri)
441+
ingest_time = timer.stop(tag, TimerMode.INGESTION)
442+
443+
for nprobe in [5, 10, 20, 40, 60]:
444+
timer.start(tag, TimerMode.QUERY)
445+
_, result = index.query(
446+
queries, k=k, nprobe=nprobe, k_factor=k_factor
447+
)
448+
query_time = timer.stop(tag, TimerMode.QUERY)
449+
acc = timer.accuracy(tag, accuracy(result, gt_i))
450+
logger.info(
451+
f"Finished {tag} with nprobe={nprobe}. Ingestion: {ingest_time:.4f}s. Query: {query_time:.4f}s. Accuracy: {acc:.4f}."
452+
)
453+
454+
cleanup_uri(index_uri)
446455

447456
timer.save_and_print_results()
448457

0 commit comments

Comments
 (0)