Skip to content

Commit c9b538a

Browse files
authored
Add --list-metrics command to display available metrics (#297)
1 parent 608f961 commit c9b538a

File tree

4 files changed

+150
-0
lines changed

4 files changed

+150
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ build/
2020
/*.csv
2121
*.hatchet
2222
autotuner.log
23+
.venv/

run.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from tritonbench.operators_collection import list_operators_by_collection
1717
from tritonbench.utils.env_utils import is_fbcode
1818
from tritonbench.utils.gpu_utils import gpu_lockdown
19+
from tritonbench.utils.list_metrics import list_metrics
1920
from tritonbench.utils.parser import get_parser
2021
from tritonbench.utils.run_utils import run_config, run_in_task
2122

@@ -106,12 +107,18 @@ def run(args: List[str] = []):
106107
usage_report_logger(benchmark_name="tritonbench")
107108
parser = get_parser()
108109
args, extra_args = parser.parse_known_args(args)
110+
109111
tritonparse_init(args.tritonparse)
110112
if args.op:
111113
ops = args.op.split(",")
112114
else:
113115
ops = list_operators_by_collection(args.op_collection)
114116

117+
# Handle --list-metrics after determining operators list
118+
if args.list_metrics:
119+
print(list_metrics(operators=ops if ops else []))
120+
return
121+
115122
# Force isolation in subprocess if testing more than one op.
116123
if len(ops) >= 2:
117124
args.isolate = True

tritonbench/utils/list_metrics.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
Utilities for listing available metrics in tritonbench.
3+
"""
4+
5+
import sys
6+
from dataclasses import fields
7+
from typing import Dict, List, Set
8+
9+
from tritonbench.operators import load_opbench_by_name
10+
from tritonbench.operators_collection import list_operators_by_collection
11+
12+
from tritonbench.utils.triton_op import (
13+
BenchmarkOperatorMetrics,
14+
OVERRIDDEN_METRICS,
15+
REGISTERED_METRICS,
16+
)
17+
18+
19+
def get_builtin_metrics() -> List[str]:
20+
"""Get all built-in metrics from BenchmarkOperatorMetrics dataclass."""
21+
return [
22+
field.name
23+
for field in fields(BenchmarkOperatorMetrics)
24+
if field.name != "extra_metrics"
25+
]
26+
27+
28+
def load_operators_to_register_metrics(operators: List[str]) -> None:
29+
"""Load operators to trigger metrics registration."""
30+
for op_name in operators:
31+
try:
32+
load_opbench_by_name(op_name)
33+
except Exception as e:
34+
print(f"Warning: Failed to load operator '{op_name}': {e}", file=sys.stderr)
35+
36+
37+
def get_custom_metrics_for_operators(operators: List[str]) -> Dict[str, List[str]]:
38+
"""Get custom metrics for specific operators."""
39+
# Load operators to ensure their metrics are registered
40+
load_operators_to_register_metrics(operators)
41+
42+
result = {}
43+
for op_name in operators:
44+
result[op_name] = REGISTERED_METRICS.get(op_name, [])
45+
return result
46+
47+
48+
def get_overridden_metrics_for_operators(operators: List[str]) -> Dict[str, List[str]]:
49+
"""Get overridden metrics for specific operators."""
50+
# Load operators to ensure their metrics are registered
51+
load_operators_to_register_metrics(operators)
52+
53+
result = {}
54+
for op_name in operators:
55+
result[op_name] = OVERRIDDEN_METRICS.get(op_name, [])
56+
return result
57+
58+
59+
def get_all_metrics_for_collection(
60+
collection_name: str,
61+
) -> Dict[str, Dict[str, List[str]]]:
62+
"""Get all metrics for operators in a collection."""
63+
operators = list_operators_by_collection(collection_name)
64+
load_operators_to_register_metrics(operators)
65+
66+
result = {}
67+
for op_name in operators:
68+
result[op_name] = {
69+
"custom": REGISTERED_METRICS.get(op_name, []),
70+
"overridden": OVERRIDDEN_METRICS.get(op_name, []),
71+
}
72+
return result
73+
74+
75+
def format_operator_specific_metrics(
76+
operators: List[str],
77+
builtin_metrics: List[str],
78+
custom_metrics: Dict[str, List[str]],
79+
overridden_metrics: Dict[str, List[str]],
80+
) -> str:
81+
"""Format metrics output for specific operators."""
82+
output = []
83+
84+
# Show built-in metrics (common to all operators)
85+
output.append("Built-in metrics (available for all operators):")
86+
for metric in sorted(builtin_metrics):
87+
output.append(f" {metric}")
88+
89+
# Show metrics for each operator
90+
for op_name in sorted(operators):
91+
custom = custom_metrics.get(op_name, [])
92+
overridden = overridden_metrics.get(op_name, [])
93+
94+
if not custom and not overridden:
95+
continue
96+
97+
output.append(f"\nOperator: {op_name}")
98+
99+
if custom:
100+
output.append(" Custom metrics:")
101+
for metric in sorted(custom):
102+
output.append(f" {metric}")
103+
104+
if overridden:
105+
output.append(" Overridden metrics:")
106+
for metric in sorted(overridden):
107+
output.append(f" {metric}")
108+
109+
return "\n".join(output)
110+
111+
112+
def list_metrics(operators: List[str] = None) -> str:
113+
"""
114+
List available metrics based on the provided operators.
115+
116+
Args:
117+
operators: List of specific operators to show metrics for
118+
119+
Returns:
120+
Formatted string with metrics information
121+
"""
122+
builtin_metrics = get_builtin_metrics()
123+
124+
if operators:
125+
# Specific operators case
126+
custom_metrics = get_custom_metrics_for_operators(operators)
127+
overridden_metrics = get_overridden_metrics_for_operators(operators)
128+
return format_operator_specific_metrics(
129+
operators, builtin_metrics, custom_metrics, overridden_metrics
130+
)
131+
else:
132+
# Global case - show built-in metrics only
133+
output = []
134+
output.append("Built-in metrics (available for all operators):")
135+
for metric in sorted(builtin_metrics):
136+
output.append(f" {metric}")
137+
return "\n".join(output)

tritonbench/utils/parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ def get_parser(args=None):
9696
default=None,
9797
help="Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.",
9898
)
99+
parser.add_argument(
100+
"--list-metrics",
101+
action="store_true",
102+
help="List all available metrics. Can be used with --op or --op-collection to show operator-specific metrics.",
103+
)
99104
parser.add_argument(
100105
"--metrics-gpu-backend",
101106
choices=["torch", "nvml"],

0 commit comments

Comments
 (0)