@@ -327,10 +327,12 @@ def __init__(self, tool: str = "llama-bench"):
327
327
self .table_name = "test"
328
328
db_fields = LLAMA_BENCH_DB_FIELDS
329
329
db_types = LLAMA_BENCH_DB_TYPES
330
- else : # test-backend-ops
330
+ elif self . tool == " test-backend-ops" :
331
331
self .table_name = "test_backend_ops"
332
332
db_fields = TEST_BACKEND_OPS_DB_FIELDS
333
333
db_types = TEST_BACKEND_OPS_DB_TYPES
334
+ else :
335
+ assert False
334
336
335
337
self .cursor .execute (f"CREATE TABLE { self .table_name } ({ ', ' .join (' ' .join (x ) for x in zip (db_fields , db_types ))} );" )
336
338
@@ -356,8 +358,10 @@ def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequ
356
358
def get_rows (self , properties : list [str ], hexsha8_baseline : str , hexsha8_compare : str ) -> Sequence [tuple ]:
357
359
if self .tool == "llama-bench" :
358
360
return self ._get_rows_llama_bench (properties , hexsha8_baseline , hexsha8_compare )
359
- else : # test-backend-ops
361
+ elif self . tool == " test-backend-ops" :
360
362
return self ._get_rows_test_backend_ops (properties , hexsha8_baseline , hexsha8_compare )
363
+ else :
364
+ assert False
361
365
362
366
def _get_rows_llama_bench (self , properties : list [str ], hexsha8_baseline : str , hexsha8_compare : str ) -> Sequence [tuple ]:
363
367
select_string = ", " .join (
@@ -1041,8 +1045,10 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
1041
1045
# Determine y-axis label based on tool type
1042
1046
if tool_type == "llama-bench" :
1043
1047
y_label = "Tokens per second (t/s)"
1044
- else : # test-backend-ops
1048
+ elif tool_type == " test-backend-ops" :
1045
1049
y_label = metric_name
1050
+ else :
1051
+ assert False
1046
1052
1047
1053
ax .set_xlabel (plot_x_label , fontsize = 12 , fontweight = 'bold' )
1048
1054
ax .set_ylabel (y_label , fontsize = 12 , fontweight = 'bold' )
0 commit comments