|
20 | 20 | from typing import Callable
|
21 | 21 |
|
22 | 22 | # Maps tritonbench op names to Helion kernel examples
|
23 |
| -KERNEL_MAPPINGS: dict[str, tuple[str, str]] = { |
24 |
| - # <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>) |
| 23 | +KERNEL_MAPPINGS: dict[str, tuple[str, str] | tuple[str, str, dict[str, Any]]] = { |
| 24 | + # <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>, <optional_extra_args>) |
25 | 25 | "vector_add": ("examples.add", "add"),
|
26 | 26 | "embedding": ("examples.embedding", "embedding_tritonbench"),
|
27 | 27 | "vector_exp": ("examples.exp", "exp_tritonbench"),
|
| 28 | + # TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg. |
| 29 | + "rms_norm": ("examples.rms_norm", "rms_norm_tritonbench", {"num_inputs": 3}), |
28 | 30 | }
|
29 | 31 |
|
30 | 32 |
|
@@ -165,7 +167,14 @@ def main() -> None:
|
165 | 167 |
|
166 | 168 | # Check if kernel is in the mapping table
|
167 | 169 | assert kernel_name in KERNEL_MAPPINGS
|
168 |
| - module_path, func_name = KERNEL_MAPPINGS[kernel_name] |
| 170 | + mapping = KERNEL_MAPPINGS[kernel_name] |
| 171 | + |
| 172 | + # Parse mapping - can be (module, func) or (module, func, extra_args) |
| 173 | + if len(mapping) == 2: |
| 174 | + module_path, func_name = mapping |
| 175 | + kernel_extra_args = {} |
| 176 | + else: |
| 177 | + module_path, func_name, kernel_extra_args = mapping |
169 | 178 | # Import from the mapped module
|
170 | 179 | try:
|
171 | 180 | module = importlib.import_module(module_path)
|
@@ -203,6 +212,13 @@ def main() -> None:
|
203 | 212 | assert "--op" not in tritonbench_args
|
204 | 213 | tritonbench_args = ["--op", operator_name, *tritonbench_args]
|
205 | 214 |
|
| 215 | + # Apply kernel-specific default arguments if not already specified by user |
| 216 | + for arg_name, arg_value in kernel_extra_args.items(): |
| 217 | + # Convert underscore to hyphen for CLI args (e.g., num_inputs -> --num-inputs) |
| 218 | + cli_arg = f"--{arg_name.replace('_', '-')}" |
| 219 | + if cli_arg not in tritonbench_args: |
| 220 | + tritonbench_args.extend([cli_arg, str(arg_value)]) |
| 221 | + |
206 | 222 | tb_args = tb_parser.parse_args(tritonbench_args)
|
207 | 223 |
|
208 | 224 | # Register the Helion kernel with tritonbench BEFORE importing the operator
|
|
0 commit comments