Skip to content

Commit 6aa44c2

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Clean up ncu_trace and ncu metrics
Summary: Remove `ncu_trace` metric as we have decided to integrate NCU metric analysis within the framework. Now we always save the replay file, analyze the replay file, and aggregate the metrics. Reviewed By: FindHao Differential Revision: D78164136 fbshipit-source-id: 399106cb92fabd2e708e94168e1fde85ef17c627
1 parent 7740c6d commit 6aa44c2

File tree

3 files changed

+61
-41
lines changed

3 files changed

+61
-41
lines changed

tritonbench/components/ncu/ncu_analyzer.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,29 @@
5959
}
6060

6161

62-
def import_ncu_python_path():
62+
def get_ncu_metrics(metrics: List[str]) -> List[str]:
63+
"""
64+
This function returns a list of all the NCU metrics used in the benchmark.
65+
66+
Returns:
67+
list: A list of all the NCU metrics used in the benchmark.
68+
"""
69+
ncu_metrics = []
70+
for (
71+
bench_metric,
72+
short_ncu_metrics,
73+
) in bench_metric_to_short_ncu_metric.items():
74+
# Only process metrics that are required
75+
if bench_metric in metrics:
76+
# For each short metric name in the list of metrics for this benchmark metric
77+
for short_ncu_metric in short_ncu_metrics:
78+
# Get the full NCU metric name and add it to our list
79+
full_metric_name = short_ncu_metric_name[short_ncu_metric]
80+
ncu_metrics.append(full_metric_name)
81+
return ncu_metrics
82+
83+
84+
def _import_ncu_python_path():
6385
"""
6486
This function modifies the Python path to include the NVIDIA Nsight Compute (NCU) Python modules.
6587
It searches for the 'ncu' command in the system PATH, determines its location, and appends the
@@ -153,7 +175,7 @@ def read_ncu_report(report_path: str, required_metrics: List[str]):
153175
assert os.path.exists(
154176
report_path
155177
), f"The NCU report at {report_path} does not exist."
156-
import_ncu_python_path()
178+
_import_ncu_python_path()
157179
import ncu_report
158180

159181
# save all kernels' metrics. {metric_name: [kernel1_metric_value, kernel2_metric_value, ...]}

tritonbench/components/ncu/nsys_analyzer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
}
2121

2222

23+
def get_nsys_metrics(metrics: List[str]) -> List[str]:
24+
nsys_metrics = []
25+
for metric_name in nsys_metrics_to_reports.keys():
26+
if metric_name in metrics:
27+
nsys_metrics.append(metric_name)
28+
return nsys_metrics
29+
30+
2331
def read_nsys_report(
2432
report_path: str, required_metrics: List[str]
2533
) -> Dict[str, List[float]]:

tritonbench/utils/triton_op.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,6 @@ class BenchmarkOperatorMetrics:
228228
compile_trace: Optional[str] = None
229229
# att trace directory
230230
att_trace: Optional[str] = None
231-
# ncu trace file
232-
ncu_trace: Optional[str] = None
233231
# ncu replay file
234232
ncu_rep: Optional[str] = None
235233
# ncu replay file with TTGIR line numbers
@@ -1227,53 +1225,41 @@ def _init_extra_metrics() -> Dict[str, Any]:
12271225
metrics.compile_trace = self.compile_time(
12281226
input_id, fn_name, metrics, kineto_trace=True
12291227
)
1230-
# Collect NCU metrics if any required metrics match the ncu analyzer
1231-
# metrics. Only profile with the necessary metrics to avoid excessive
1232-
# overhead.
12331228
if not is_hip():
1234-
if "ncu_trace" in self.required_metrics:
1235-
metrics.ncu_trace = self.ncu_trace(input_id, fn_name)
1236-
ncu_metrics = []
1237-
for (
1238-
bench_metric,
1239-
short_ncu_metrics,
1240-
) in ncu_analyzer.bench_metric_to_short_ncu_metric.items():
1241-
# Only process metrics that are required
1242-
if bench_metric in self.required_metrics:
1243-
# For each short metric name in the list of metrics for this benchmark metric
1244-
for short_ncu_metric in short_ncu_metrics:
1245-
# Get the full NCU metric name and add it to our list
1246-
full_metric_name = ncu_analyzer.short_ncu_metric_name[
1247-
short_ncu_metric
1248-
]
1249-
ncu_metrics.append(full_metric_name)
1250-
extend_ncu_args = (
1251-
["--metrics", ",".join(ncu_metrics)] if ncu_metrics else None
1229+
# ncu metrics (ncu_rep, ncu_rep_ir, or ncu_analyzer metrics)
1230+
ncu_metrics: List[str] = ncu_analyzer.get_ncu_metrics(
1231+
self.required_metrics
12521232
)
1253-
if ncu_metrics or "ncu_rep" in self.required_metrics:
1254-
metrics.ncu_rep = self.ncu_trace(
1255-
input_id, fn_name, replay=True, extend_ncu_args=extend_ncu_args
1233+
if (
1234+
ncu_metrics
1235+
or "ncu_rep" in self.required_metrics
1236+
or "ncu_rep_ir" in self.required_metrics
1237+
):
1238+
profile_ir = "ncu_rep_ir" in self.required_metrics
1239+
out = self.ncu_trace(
1240+
input_id,
1241+
fn_name,
1242+
replay=True,
1243+
extend_ncu_args=ncu_metrics,
1244+
profile_ir=profile_ir,
12561245
)
12571246
# Read and update NCU metrics if any required metrics match the NCU metrics
12581247
if ncu_metrics:
12591248
ncu_analyzer_results = ncu_analyzer.read_ncu_report(
1260-
metrics.ncu_rep, self.required_metrics
1249+
out, self.required_metrics
12611250
)
12621251
for metric_name, metric_value in ncu_analyzer_results.items():
12631252
metrics.extra_metrics[metric_name] = metric_value
12641253
if "arithmetic_intensity" in self.required_metrics:
12651254
logger.warning(
12661255
"Arithmetic intensity only supports FP32 and FP64 for now."
12671256
)
1257+
if "ncu_rep" in self.required_metrics:
1258+
metrics.ncu_rep = out
12681259
if "ncu_rep_ir" in self.required_metrics:
1269-
metrics.ncu_rep_ir = self.ncu_trace(
1270-
input_id, fn_name, replay=True, profile_ir=True
1271-
)
1272-
nsys_metrics = []
1273-
for metric_name in nsys_analyzer.nsys_metrics_to_reports.keys():
1274-
if metric_name in self.required_metrics:
1275-
nsys_metrics.append(metric_name)
1276-
1260+
metrics.ncu_rep_ir = out
1261+
# nsys metrics
1262+
nsys_metrics = nsys_analyzer.get_nsys_metrics(self.required_metrics)
12771263
if "nsys_rep" in self.required_metrics or nsys_metrics:
12781264
nsys_rep_path = self.nsys_rep(input_id, fn_name)
12791265
metrics.nsys_rep = nsys_rep_path
@@ -1602,10 +1588,14 @@ def ncu_trace(
16021588
profile_ir=False,
16031589
extend_ncu_args: List[str] = None,
16041590
) -> str:
1605-
extend_ncu_args = extend_ncu_args or [
1606-
"--set",
1607-
"full",
1608-
]
1591+
extend_ncu_args = (
1592+
["--metrics", ",".join(extend_ncu_args)]
1593+
if extend_ncu_args
1594+
else [
1595+
"--set",
1596+
"full",
1597+
]
1598+
)
16091599
op_task_args = self._get_op_task_args(input_id, fn_name, "_ncu_trace_in_task")
16101600
# Disable DCGM
16111601
disable_dyno_dcgm = [

0 commit comments

Comments
 (0)