@@ -616,18 +616,6 @@ def _inner(self, *args, **kwargs):
616
616
return decorator
617
617
618
618
619
- def register_benchmark_manually (
620
- operator_name : str ,
621
- func_name : str ,
622
- baseline : bool = False ,
623
- enabled : bool = True ,
624
- label : Optional [str ] = None ,
625
- ):
626
- return register_benchmark (
627
- operator_name , func_name , baseline , enabled , fwd_only = False , label = label
628
- )
629
-
630
-
631
619
def register_metric (
632
620
# Metrics that only apply to non-baseline impls
633
621
# E.g., accuracy, speedup
@@ -1097,8 +1085,35 @@ def get_example_inputs(self):
1097
1085
except StopIteration :
1098
1086
return None
1099
1087
1100
- def get_temp_path (self , path : Union [str , Path ]) -> Path :
1101
- return Path (tempfile .gettempdir ()) / "tritonbench" / self .name / Path (path )
1088
+ def get_temp_path (
1089
+ self ,
1090
+ fn_name : Optional [str ] = None ,
1091
+ ) -> Path :
1092
+ unix_user : Optional [str ] = os .environ .get ("USER" , None )
1093
+ logging_group : Optional [str ] = self .logging_group
1094
+ parts = [x for x in ["tritonbench" , unix_user , logging_group ] if x ]
1095
+ tritonbench_dir_name = "_" .join (parts )
1096
+ benchmark_name = self .benchmark_name
1097
+ fn_part = f"{ fn_name } _{ self ._input_id } " if fn_name else ""
1098
+ out_part = Path (tempfile .gettempdir ()) / tritonbench_dir_name / benchmark_name
1099
+ return out_part / fn_part if fn_part else out_part
1100
+
1101
+ @property
1102
+ def precision (self ) -> str :
1103
+ if self .tb_args .precision == "bypass" or self .tb_args .precision == "fp8" :
1104
+ return ""
1105
+ return self .tb_args .precision
1106
+
1107
+ @property
1108
+ def benchmark_name (self , default : bool = False ) -> str :
1109
+ if not default and self .tb_args .benchmark_name :
1110
+ return self .tb_args .benchmark_name
1111
+ parts = [x for x in [self .precision , self .name , self .mode .value ] if x ]
1112
+ return "_" .join (parts )
1113
+
1114
+ @property
1115
+ def logging_group (self ) -> Optional [str ]:
1116
+ return self .tb_args .logging_group
1102
1117
1103
1118
def accuracy (self , fn : Callable , baseline_fn : Callable ) -> bool :
1104
1119
output = fn ()
@@ -1331,7 +1346,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
1331
1346
do_compile_kineto_trace_in_task ,
1332
1347
)
1333
1348
1334
- kineto_trace_output_dir = self .get_temp_path ("kineto_trace" )
1349
+ kineto_trace_output_dir = self .get_temp_path (fn_name )
1335
1350
kineto_trace_output_dir .mkdir (parents = True , exist_ok = True )
1336
1351
metrics .extra_metrics ["_compile_time_kineto_trace_in_task" ] = (
1337
1352
do_compile_kineto_trace_in_task (
@@ -1562,10 +1577,10 @@ def _get_op_task_args(
1562
1577
1563
1578
def nsys_rep (self , input_id : int , fn_name : str ) -> str :
1564
1579
op_task_args = self ._get_op_task_args (input_id , fn_name , "_nsys_rep_in_task" )
1565
- nsys_output_dir = self .get_temp_path (f"nsys_traces/ { fn_name } _ { input_id } " )
1580
+ nsys_output_dir = self .get_temp_path (fn_name )
1566
1581
nsys_output_dir .mkdir (parents = True , exist_ok = True )
1567
1582
ext = ".nsys-rep"
1568
- nsys_output_file = nsys_output_dir .joinpath (f"nsys_output { ext } " ).resolve ()
1583
+ nsys_output_file = nsys_output_dir .joinpath (f"nsys_rep { ext } " ).resolve ()
1569
1584
nsys_trace_cmd = [
1570
1585
"nsys" ,
1571
1586
"profile" ,
@@ -1639,11 +1654,11 @@ def service_exists(service_name):
1639
1654
logger .warn (
1640
1655
"DCGM may not have been successfully disabled. Proceeding to collect NCU trace anyway..."
1641
1656
)
1642
- ncu_output_dir = self .get_temp_path (f"ncu_traces/ { fn_name } _ { input_id } " )
1657
+ ncu_output_dir = self .get_temp_path (fn_name )
1643
1658
ncu_output_dir .mkdir (parents = True , exist_ok = True )
1644
1659
ext = ".csv" if not replay else ".ncu-rep"
1645
1660
ncu_output_file = ncu_output_dir .joinpath (
1646
- f"ncu_output { '_ir' if profile_ir else '' } { ext } "
1661
+ f"ncu_rep { '_ir' if profile_ir else '' } { ext } "
1647
1662
).resolve ()
1648
1663
ncu_args = [
1649
1664
"ncu" ,
@@ -1686,14 +1701,14 @@ def service_exists(service_name):
1686
1701
1687
1702
def att_trace (self , input_id : int , fn_name : str ) -> str :
1688
1703
op_task_args = self ._get_op_task_args (input_id , fn_name , "_ncu_trace_in_task" )
1689
- att_output_dir = self .get_temp_path (f"att_traces/ { fn_name } _ { input_id } " )
1704
+ att_output_dir = self .get_temp_path (fn_name )
1690
1705
att_trace_dir = launch_att (att_output_dir , op_task_args )
1691
1706
return att_trace_dir
1692
1707
1693
1708
def kineto_trace (self , input_id : int , fn : Callable ) -> str :
1694
1709
from tritonbench .components .kineto import do_bench_kineto
1695
1710
1696
- kineto_output_dir = self .get_temp_path (f"kineto_traces/ { fn ._name } _ { input_id } " )
1711
+ kineto_output_dir = self .get_temp_path (fn ._name )
1697
1712
kineto_output_dir .mkdir (parents = True , exist_ok = True )
1698
1713
return do_bench_kineto (
1699
1714
fn = fn ,
@@ -1834,7 +1849,7 @@ def run_and_capture(self, *args, **kwargs):
1834
1849
fn ()
1835
1850
1836
1851
if len (compiled_kernels ) > 0 :
1837
- ir_dir = self .get_temp_path ("ir" )
1852
+ ir_dir = self .get_temp_path (fn . _name )
1838
1853
ir_dir .mkdir (parents = True , exist_ok = True )
1839
1854
logger .info (
1840
1855
"Writing %s Triton IRs to %s" ,
0 commit comments