Skip to content

Commit d28e8e3

Browse files
authored
[be] Reorganize logging dir (#287)
1 parent 91cc742 commit d28e8e3

File tree

2 files changed

+43
-28
lines changed

2 files changed

+43
-28
lines changed

tritonbench/utils/parser.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,15 @@ def get_parser(args=None):
223223
type=str,
224224
help="Load input file from Tritonbench data JSON.",
225225
)
226+
parser.add_argument(
227+
"--logging-group",
228+
type=str,
229+
default=None,
230+
help="Name of group for benchmarking.",
231+
)
226232

227233
if is_fbcode():
228234
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
229-
parser.add_argument(
230-
"--logging-group",
231-
type=str,
232-
default=None,
233-
help="Override default name for logging in scuba.",
234-
)
235235
parser.add_argument(
236236
"--production-shapes",
237237
action="store_true",

tritonbench/utils/triton_op.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -616,18 +616,6 @@ def _inner(self, *args, **kwargs):
616616
return decorator
617617

618618

619-
def register_benchmark_manually(
620-
operator_name: str,
621-
func_name: str,
622-
baseline: bool = False,
623-
enabled: bool = True,
624-
label: Optional[str] = None,
625-
):
626-
return register_benchmark(
627-
operator_name, func_name, baseline, enabled, fwd_only=False, label=label
628-
)
629-
630-
631619
def register_metric(
632620
# Metrics that only apply to non-baseline impls
633621
# E.g., accuracy, speedup
@@ -1097,8 +1085,35 @@ def get_example_inputs(self):
10971085
except StopIteration:
10981086
return None
10991087

1100-
def get_temp_path(self, path: Union[str, Path]) -> Path:
1101-
return Path(tempfile.gettempdir()) / "tritonbench" / self.name / Path(path)
1088+
def get_temp_path(
1089+
self,
1090+
fn_name: Optional[str] = None,
1091+
) -> Path:
1092+
unix_user: Optional[str] = os.environ.get("USER", None)
1093+
logging_group: Optional[str] = self.logging_group
1094+
parts = [x for x in ["tritonbench", unix_user, logging_group] if x]
1095+
tritonbench_dir_name = "_".join(parts)
1096+
benchmark_name = self.benchmark_name
1097+
fn_part = f"{fn_name}_{self._input_id}" if fn_name else ""
1098+
out_part = Path(tempfile.gettempdir()) / tritonbench_dir_name / benchmark_name
1099+
return out_part / fn_part if fn_part else out_part
1100+
1101+
@property
1102+
def precision(self) -> str:
1103+
if self.tb_args.precision == "bypass" or self.tb_args.precision == "fp8":
1104+
return ""
1105+
return self.tb_args.precision
1106+
1107+
@property
1108+
def benchmark_name(self, default: bool = False) -> str:
1109+
if not default and self.tb_args.benchmark_name:
1110+
return self.tb_args.benchmark_name
1111+
parts = [x for x in [self.precision, self.name, self.mode.value] if x]
1112+
return "_".join(parts)
1113+
1114+
@property
1115+
def logging_group(self) -> Optional[str]:
1116+
return self.tb_args.logging_group
11021117

11031118
def accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
11041119
output = fn()
@@ -1331,7 +1346,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
13311346
do_compile_kineto_trace_in_task,
13321347
)
13331348

1334-
kineto_trace_output_dir = self.get_temp_path("kineto_trace")
1349+
kineto_trace_output_dir = self.get_temp_path(fn_name)
13351350
kineto_trace_output_dir.mkdir(parents=True, exist_ok=True)
13361351
metrics.extra_metrics["_compile_time_kineto_trace_in_task"] = (
13371352
do_compile_kineto_trace_in_task(
@@ -1562,10 +1577,10 @@ def _get_op_task_args(
15621577

15631578
def nsys_rep(self, input_id: int, fn_name: str) -> str:
15641579
op_task_args = self._get_op_task_args(input_id, fn_name, "_nsys_rep_in_task")
1565-
nsys_output_dir = self.get_temp_path(f"nsys_traces/{fn_name}_{input_id}")
1580+
nsys_output_dir = self.get_temp_path(fn_name)
15661581
nsys_output_dir.mkdir(parents=True, exist_ok=True)
15671582
ext = ".nsys-rep"
1568-
nsys_output_file = nsys_output_dir.joinpath(f"nsys_output{ext}").resolve()
1583+
nsys_output_file = nsys_output_dir.joinpath(f"nsys_rep{ext}").resolve()
15691584
nsys_trace_cmd = [
15701585
"nsys",
15711586
"profile",
@@ -1639,11 +1654,11 @@ def service_exists(service_name):
16391654
logger.warn(
16401655
"DCGM may not have been successfully disabled. Proceeding to collect NCU trace anyway..."
16411656
)
1642-
ncu_output_dir = self.get_temp_path(f"ncu_traces/{fn_name}_{input_id}")
1657+
ncu_output_dir = self.get_temp_path(fn_name)
16431658
ncu_output_dir.mkdir(parents=True, exist_ok=True)
16441659
ext = ".csv" if not replay else ".ncu-rep"
16451660
ncu_output_file = ncu_output_dir.joinpath(
1646-
f"ncu_output{'_ir' if profile_ir else ''}{ext}"
1661+
f"ncu_rep{'_ir' if profile_ir else ''}{ext}"
16471662
).resolve()
16481663
ncu_args = [
16491664
"ncu",
@@ -1686,14 +1701,14 @@ def service_exists(service_name):
16861701

16871702
def att_trace(self, input_id: int, fn_name: str) -> str:
16881703
op_task_args = self._get_op_task_args(input_id, fn_name, "_ncu_trace_in_task")
1689-
att_output_dir = self.get_temp_path(f"att_traces/{fn_name}_{input_id}")
1704+
att_output_dir = self.get_temp_path(fn_name)
16901705
att_trace_dir = launch_att(att_output_dir, op_task_args)
16911706
return att_trace_dir
16921707

16931708
def kineto_trace(self, input_id: int, fn: Callable) -> str:
16941709
from tritonbench.components.kineto import do_bench_kineto
16951710

1696-
kineto_output_dir = self.get_temp_path(f"kineto_traces/{fn._name}_{input_id}")
1711+
kineto_output_dir = self.get_temp_path(fn._name)
16971712
kineto_output_dir.mkdir(parents=True, exist_ok=True)
16981713
return do_bench_kineto(
16991714
fn=fn,
@@ -1834,7 +1849,7 @@ def run_and_capture(self, *args, **kwargs):
18341849
fn()
18351850

18361851
if len(compiled_kernels) > 0:
1837-
ir_dir = self.get_temp_path("ir")
1852+
ir_dir = self.get_temp_path(fn._name)
18381853
ir_dir.mkdir(parents=True, exist_ok=True)
18391854
logger.info(
18401855
"Writing %s Triton IRs to %s",

0 commit comments

Comments
 (0)