Skip to content

[be] Reorganize logging dir #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,15 @@ def get_parser(args=None):
type=str,
help="Load input file from Tritonbench data JSON.",
)
parser.add_argument(
"--logging-group",
type=str,
default=None,
help="Name of group for benchmarking.",
)

if is_fbcode():
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
parser.add_argument(
"--logging-group",
type=str,
default=None,
help="Override default name for logging in scuba.",
)
parser.add_argument(
"--production-shapes",
action="store_true",
Expand Down
59 changes: 37 additions & 22 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,18 +616,6 @@ def _inner(self, *args, **kwargs):
return decorator


def register_benchmark_manually(
operator_name: str,
func_name: str,
baseline: bool = False,
enabled: bool = True,
label: Optional[str] = None,
):
return register_benchmark(
operator_name, func_name, baseline, enabled, fwd_only=False, label=label
)


def register_metric(
# Metrics that only apply to non-baseline impls
# E.g., accuracy, speedup
Expand Down Expand Up @@ -1097,8 +1085,35 @@ def get_example_inputs(self):
except StopIteration:
return None

def get_temp_path(self, path: Union[str, Path]) -> Path:
return Path(tempfile.gettempdir()) / "tritonbench" / self.name / Path(path)
def get_temp_path(
self,
fn_name: Optional[str] = None,
) -> Path:
unix_user: Optional[str] = os.environ.get("USER", None)
logging_group: Optional[str] = self.logging_group
parts = [x for x in ["tritonbench", unix_user, logging_group] if x]
tritonbench_dir_name = "_".join(parts)
benchmark_name = self.benchmark_name
fn_part = f"{fn_name}_{self._input_id}" if fn_name else ""
out_part = Path(tempfile.gettempdir()) / tritonbench_dir_name / benchmark_name
return out_part / fn_part if fn_part else out_part

@property
def precision(self) -> str:
if self.tb_args.precision == "bypass" or self.tb_args.precision == "fp8":
return ""
return self.tb_args.precision

@property
def benchmark_name(self, default: bool = False) -> str:
if not default and self.tb_args.benchmark_name:
return self.tb_args.benchmark_name
parts = [x for x in [self.precision, self.name, self.mode.value] if x]
return "_".join(parts)

@property
def logging_group(self) -> Optional[str]:
return self.tb_args.logging_group

def accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
Expand Down Expand Up @@ -1331,7 +1346,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
do_compile_kineto_trace_in_task,
)

kineto_trace_output_dir = self.get_temp_path("kineto_trace")
kineto_trace_output_dir = self.get_temp_path(fn_name)
kineto_trace_output_dir.mkdir(parents=True, exist_ok=True)
metrics.extra_metrics["_compile_time_kineto_trace_in_task"] = (
do_compile_kineto_trace_in_task(
Expand Down Expand Up @@ -1562,10 +1577,10 @@ def _get_op_task_args(

def nsys_rep(self, input_id: int, fn_name: str) -> str:
op_task_args = self._get_op_task_args(input_id, fn_name, "_nsys_rep_in_task")
nsys_output_dir = self.get_temp_path(f"nsys_traces/{fn_name}_{input_id}")
nsys_output_dir = self.get_temp_path(fn_name)
nsys_output_dir.mkdir(parents=True, exist_ok=True)
ext = ".nsys-rep"
nsys_output_file = nsys_output_dir.joinpath(f"nsys_output{ext}").resolve()
nsys_output_file = nsys_output_dir.joinpath(f"nsys_rep{ext}").resolve()
nsys_trace_cmd = [
"nsys",
"profile",
Expand Down Expand Up @@ -1639,11 +1654,11 @@ def service_exists(service_name):
logger.warn(
"DCGM may not have been successfully disabled. Proceeding to collect NCU trace anyway..."
)
ncu_output_dir = self.get_temp_path(f"ncu_traces/{fn_name}_{input_id}")
ncu_output_dir = self.get_temp_path(fn_name)
ncu_output_dir.mkdir(parents=True, exist_ok=True)
ext = ".csv" if not replay else ".ncu-rep"
ncu_output_file = ncu_output_dir.joinpath(
f"ncu_output{'_ir' if profile_ir else ''}{ext}"
f"ncu_rep{'_ir' if profile_ir else ''}{ext}"
).resolve()
ncu_args = [
"ncu",
Expand Down Expand Up @@ -1686,14 +1701,14 @@ def service_exists(service_name):

def att_trace(self, input_id: int, fn_name: str) -> str:
op_task_args = self._get_op_task_args(input_id, fn_name, "_ncu_trace_in_task")
att_output_dir = self.get_temp_path(f"att_traces/{fn_name}_{input_id}")
att_output_dir = self.get_temp_path(fn_name)
att_trace_dir = launch_att(att_output_dir, op_task_args)
return att_trace_dir

def kineto_trace(self, input_id: int, fn: Callable) -> str:
from tritonbench.components.kineto import do_bench_kineto

kineto_output_dir = self.get_temp_path(f"kineto_traces/{fn._name}_{input_id}")
kineto_output_dir = self.get_temp_path(fn._name)
kineto_output_dir.mkdir(parents=True, exist_ok=True)
return do_bench_kineto(
fn=fn,
Expand Down Expand Up @@ -1834,7 +1849,7 @@ def run_and_capture(self, *args, **kwargs):
fn()

if len(compiled_kernels) > 0:
ir_dir = self.get_temp_path("ir")
ir_dir = self.get_temp_path(fn._name)
ir_dir.mkdir(parents=True, exist_ok=True)
logger.info(
"Writing %s Triton IRs to %s",
Expand Down