Skip to content

Commit 88b3c64

Browse files
Apply suggestion from @JohannesGaessler
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent cb769e7 commit 88b3c64

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

scripts/compare-llama-bench.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,10 +327,12 @@ def __init__(self, tool: str = "llama-bench"):
327327
self.table_name = "test"
328328
db_fields = LLAMA_BENCH_DB_FIELDS
329329
db_types = LLAMA_BENCH_DB_TYPES
330-
else: # test-backend-ops
330+
elif self.tool == "test-backend-ops":
331331
self.table_name = "test_backend_ops"
332332
db_fields = TEST_BACKEND_OPS_DB_FIELDS
333333
db_types = TEST_BACKEND_OPS_DB_TYPES
334+
else:
335+
assert False
334336

335337
self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});")
336338

@@ -356,8 +358,10 @@ def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequ
356358
def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
357359
if self.tool == "llama-bench":
358360
return self._get_rows_llama_bench(properties, hexsha8_baseline, hexsha8_compare)
359-
else: # test-backend-ops
361+
elif self.tool == "test-backend-ops":
360362
return self._get_rows_test_backend_ops(properties, hexsha8_baseline, hexsha8_compare)
363+
else:
364+
assert False
361365

362366
def _get_rows_llama_bench(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
363367
select_string = ", ".join(
@@ -1041,8 +1045,10 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
10411045
# Determine y-axis label based on tool type
10421046
if tool_type == "llama-bench":
10431047
y_label = "Tokens per second (t/s)"
1044-
else: # test-backend-ops
1048+
elif tool_type == "test-backend-ops":
10451049
y_label = metric_name
1050+
else:
1051+
assert False
10461052

10471053
ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
10481054
ax.set_ylabel(y_label, fontsize=12, fontweight='bold')

0 commit comments

Comments
 (0)