@@ -228,8 +228,6 @@ class BenchmarkOperatorMetrics:
228
228
compile_trace : Optional [str ] = None
229
229
# att trace directory
230
230
att_trace : Optional [str ] = None
231
- # ncu trace file
232
- ncu_trace : Optional [str ] = None
233
231
# ncu replay file
234
232
ncu_rep : Optional [str ] = None
235
233
# ncu replay file with TTGIR line numbers
@@ -1227,53 +1225,41 @@ def _init_extra_metrics() -> Dict[str, Any]:
1227
1225
metrics .compile_trace = self .compile_time (
1228
1226
input_id , fn_name , metrics , kineto_trace = True
1229
1227
)
1230
- # Collect NCU metrics if any required metrics match the ncu analyzer
1231
- # metrics. Only profile with the necessary metrics to avoid excessive
1232
- # overhead.
1233
1228
if not is_hip ():
1234
- if "ncu_trace" in self .required_metrics :
1235
- metrics .ncu_trace = self .ncu_trace (input_id , fn_name )
1236
- ncu_metrics = []
1237
- for (
1238
- bench_metric ,
1239
- short_ncu_metrics ,
1240
- ) in ncu_analyzer .bench_metric_to_short_ncu_metric .items ():
1241
- # Only process metrics that are required
1242
- if bench_metric in self .required_metrics :
1243
- # For each short metric name in the list of metrics for this benchmark metric
1244
- for short_ncu_metric in short_ncu_metrics :
1245
- # Get the full NCU metric name and add it to our list
1246
- full_metric_name = ncu_analyzer .short_ncu_metric_name [
1247
- short_ncu_metric
1248
- ]
1249
- ncu_metrics .append (full_metric_name )
1250
- extend_ncu_args = (
1251
- ["--metrics" , "," .join (ncu_metrics )] if ncu_metrics else None
1229
+ # ncu metrics (ncu_rep, ncu_rep_ir, or ncu_analyzer metrics)
1230
+ ncu_metrics : List [str ] = ncu_analyzer .get_ncu_metrics (
1231
+ self .required_metrics
1252
1232
)
1253
- if ncu_metrics or "ncu_rep" in self .required_metrics :
1254
- metrics .ncu_rep = self .ncu_trace (
1255
- input_id , fn_name , replay = True , extend_ncu_args = extend_ncu_args
1233
+ if (
1234
+ ncu_metrics
1235
+ or "ncu_rep" in self .required_metrics
1236
+ or "ncu_rep_ir" in self .required_metrics
1237
+ ):
1238
+ profile_ir = "ncu_rep_ir" in self .required_metrics
1239
+ out = self .ncu_trace (
1240
+ input_id ,
1241
+ fn_name ,
1242
+ replay = True ,
1243
+ extend_ncu_args = ncu_metrics ,
1244
+ profile_ir = profile_ir ,
1256
1245
)
1257
1246
# Read and update NCU metrics if any required metrics match the NCU metrics
1258
1247
if ncu_metrics :
1259
1248
ncu_analyzer_results = ncu_analyzer .read_ncu_report (
1260
- metrics . ncu_rep , self .required_metrics
1249
+ out , self .required_metrics
1261
1250
)
1262
1251
for metric_name , metric_value in ncu_analyzer_results .items ():
1263
1252
metrics .extra_metrics [metric_name ] = metric_value
1264
1253
if "arithmetic_intensity" in self .required_metrics :
1265
1254
logger .warning (
1266
1255
"Arithmetic intensity only supports FP32 and FP64 for now."
1267
1256
)
1257
+ if "ncu_rep" in self .required_metrics :
1258
+ metrics .ncu_rep = out
1268
1259
if "ncu_rep_ir" in self .required_metrics :
1269
- metrics .ncu_rep_ir = self .ncu_trace (
1270
- input_id , fn_name , replay = True , profile_ir = True
1271
- )
1272
- nsys_metrics = []
1273
- for metric_name in nsys_analyzer .nsys_metrics_to_reports .keys ():
1274
- if metric_name in self .required_metrics :
1275
- nsys_metrics .append (metric_name )
1276
-
1260
+ metrics .ncu_rep_ir = out
1261
+ # nsys metrics
1262
+ nsys_metrics = nsys_analyzer .get_nsys_metrics (self .required_metrics )
1277
1263
if "nsys_rep" in self .required_metrics or nsys_metrics :
1278
1264
nsys_rep_path = self .nsys_rep (input_id , fn_name )
1279
1265
metrics .nsys_rep = nsys_rep_path
@@ -1602,10 +1588,14 @@ def ncu_trace(
1602
1588
profile_ir = False ,
1603
1589
extend_ncu_args : List [str ] = None ,
1604
1590
) -> str :
1605
- extend_ncu_args = extend_ncu_args or [
1606
- "--set" ,
1607
- "full" ,
1608
- ]
1591
+ extend_ncu_args = (
1592
+ ["--metrics" , "," .join (extend_ncu_args )]
1593
+ if extend_ncu_args
1594
+ else [
1595
+ "--set" ,
1596
+ "full" ,
1597
+ ]
1598
+ )
1609
1599
op_task_args = self ._get_op_task_args (input_id , fn_name , "_ncu_trace_in_task" )
1610
1600
# Disable DCGM
1611
1601
disable_dyno_dcgm = [
0 commit comments