Skip to content

Commit 7c64ae1

Browse files
authored
[Benchmark] Add rms_norm benchmark (#253)
- Add rms_norm to KERNEL_MAPPINGS with custom num_inputs=3 to avoid reduction dim size 8192 error - Update KERNEL_MAPPINGS to support optional extra arguments for kernels - Parse and apply kernel-specific default arguments to tritonbench
1 parent da9cf12 commit 7c64ae1

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

benchmark/run.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from typing import Callable
2121

2222
# Maps tritonbench op names to Helion kernel examples
23-
KERNEL_MAPPINGS: dict[str, tuple[str, str]] = {
24-
# <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>)
23+
KERNEL_MAPPINGS: dict[str, tuple[str, str] | tuple[str, str, dict[str, Any]]] = {
24+
# <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>, <optional_extra_args>)
2525
"vector_add": ("examples.add", "add"),
2626
"embedding": ("examples.embedding", "embedding_tritonbench"),
2727
"vector_exp": ("examples.exp", "exp_tritonbench"),
28+
# TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg.
29+
"rms_norm": ("examples.rms_norm", "rms_norm_tritonbench", {"num_inputs": 3}),
2830
}
2931

3032

@@ -165,7 +167,14 @@ def main() -> None:
165167

166168
# Check if kernel is in the mapping table
167169
assert kernel_name in KERNEL_MAPPINGS
168-
module_path, func_name = KERNEL_MAPPINGS[kernel_name]
170+
mapping = KERNEL_MAPPINGS[kernel_name]
171+
172+
# Parse mapping - can be (module, func) or (module, func, extra_args)
173+
if len(mapping) == 2:
174+
module_path, func_name = mapping
175+
kernel_extra_args = {}
176+
else:
177+
module_path, func_name, kernel_extra_args = mapping
169178
# Import from the mapped module
170179
try:
171180
module = importlib.import_module(module_path)
@@ -203,6 +212,13 @@ def main() -> None:
203212
assert "--op" not in tritonbench_args
204213
tritonbench_args = ["--op", operator_name, *tritonbench_args]
205214

215+
# Apply kernel-specific default arguments if not already specified by user
216+
for arg_name, arg_value in kernel_extra_args.items():
217+
# Convert underscore to hyphen for CLI args (e.g., num_inputs -> --num-inputs)
218+
cli_arg = f"--{arg_name.replace('_', '-')}"
219+
if cli_arg not in tritonbench_args:
220+
tritonbench_args.extend([cli_arg, str(arg_value)])
221+
206222
tb_args = tb_parser.parse_args(tritonbench_args)
207223

208224
# Register the Helion kernel with tritonbench BEFORE importing the operator

0 commit comments

Comments
 (0)