diff --git a/tritonbench/utils/parser.py b/tritonbench/utils/parser.py index 6efefb33..9e0666cb 100644 --- a/tritonbench/utils/parser.py +++ b/tritonbench/utils/parser.py @@ -223,15 +223,15 @@ def get_parser(args=None): type=str, help="Load input file from Tritonbench data JSON.", ) + parser.add_argument( + "--logging-group", + type=str, + default=None, + help="Name of group for benchmarking.", + ) if is_fbcode(): parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.") - parser.add_argument( - "--logging-group", - type=str, - default=None, - help="Override default name for logging in scuba.", - ) parser.add_argument( "--production-shapes", action="store_true", diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index d60348e3..5ae2fe48 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -616,18 +616,6 @@ def _inner(self, *args, **kwargs): return decorator -def register_benchmark_manually( - operator_name: str, - func_name: str, - baseline: bool = False, - enabled: bool = True, - label: Optional[str] = None, -): - return register_benchmark( - operator_name, func_name, baseline, enabled, fwd_only=False, label=label - ) - - def register_metric( # Metrics that only apply to non-baseline impls # E.g., accuracy, speedup @@ -1097,8 +1085,35 @@ def get_example_inputs(self): except StopIteration: return None - def get_temp_path(self, path: Union[str, Path]) -> Path: - return Path(tempfile.gettempdir()) / "tritonbench" / self.name / Path(path) + def get_temp_path( + self, + fn_name: Optional[str] = None, + ) -> Path: + unix_user: Optional[str] = os.environ.get("USER", None) + logging_group: Optional[str] = self.logging_group + parts = [x for x in ["tritonbench", unix_user, logging_group] if x] + tritonbench_dir_name = "_".join(parts) + benchmark_name = self.benchmark_name + fn_part = f"{fn_name}_{self._input_id}" if fn_name else "" + out_part = Path(tempfile.gettempdir()) / tritonbench_dir_name / benchmark_name + return out_part / fn_part if fn_part else out_part + + @property + def precision(self) -> str: + if self.tb_args.precision == "bypass" or self.tb_args.precision == "fp8": + return "" + return self.tb_args.precision + + @property + def benchmark_name(self, default: bool = False) -> str: + if not default and self.tb_args.benchmark_name: + return self.tb_args.benchmark_name + parts = [x for x in [self.precision, self.name, self.mode.value] if x] + return "_".join(parts) + + @property + def logging_group(self) -> Optional[str]: + return self.tb_args.logging_group def accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: output = fn() @@ -1331,7 +1346,7 @@ def _init_extra_metrics() -> Dict[str, Any]: do_compile_kineto_trace_in_task, ) - kineto_trace_output_dir = self.get_temp_path("kineto_trace") + kineto_trace_output_dir = self.get_temp_path(fn_name) kineto_trace_output_dir.mkdir(parents=True, exist_ok=True) metrics.extra_metrics["_compile_time_kineto_trace_in_task"] = ( do_compile_kineto_trace_in_task( @@ -1562,10 +1577,10 @@ def _get_op_task_args( def nsys_rep(self, input_id: int, fn_name: str) -> str: op_task_args = self._get_op_task_args(input_id, fn_name, "_nsys_rep_in_task") - nsys_output_dir = self.get_temp_path(f"nsys_traces/{fn_name}_{input_id}") + nsys_output_dir = self.get_temp_path(fn_name) nsys_output_dir.mkdir(parents=True, exist_ok=True) ext = ".nsys-rep" - nsys_output_file = nsys_output_dir.joinpath(f"nsys_output{ext}").resolve() + nsys_output_file = nsys_output_dir.joinpath(f"nsys_rep{ext}").resolve() nsys_trace_cmd = [ "nsys", "profile", @@ -1639,11 +1654,11 @@ def service_exists(service_name): logger.warn( "DCGM may not have been successfully disabled. Proceeding to collect NCU trace anyway..." ) - ncu_output_dir = self.get_temp_path(f"ncu_traces/{fn_name}_{input_id}") + ncu_output_dir = self.get_temp_path(fn_name) ncu_output_dir.mkdir(parents=True, exist_ok=True) ext = ".csv" if not replay else ".ncu-rep" ncu_output_file = ncu_output_dir.joinpath( - f"ncu_output{'_ir' if profile_ir else ''}{ext}" + f"ncu_rep{'_ir' if profile_ir else ''}{ext}" ).resolve() ncu_args = [ "ncu", @@ -1686,14 +1701,14 @@ def service_exists(service_name): def att_trace(self, input_id: int, fn_name: str) -> str: op_task_args = self._get_op_task_args(input_id, fn_name, "_ncu_trace_in_task") - att_output_dir = self.get_temp_path(f"att_traces/{fn_name}_{input_id}") + att_output_dir = self.get_temp_path(fn_name) att_trace_dir = launch_att(att_output_dir, op_task_args) return att_trace_dir def kineto_trace(self, input_id: int, fn: Callable) -> str: from tritonbench.components.kineto import do_bench_kineto - kineto_output_dir = self.get_temp_path(f"kineto_traces/{fn._name}_{input_id}") + kineto_output_dir = self.get_temp_path(fn._name) kineto_output_dir.mkdir(parents=True, exist_ok=True) return do_bench_kineto( fn=fn, @@ -1834,7 +1849,7 @@ def run_and_capture(self, *args, **kwargs): fn() if len(compiled_kernels) > 0: - ir_dir = self.get_temp_path("ir") + ir_dir = self.get_temp_path(fn._name) ir_dir.mkdir(parents=True, exist_ok=True) logger.info( "Writing %s Triton IRs to %s",