Skip to content

Commit a0f65f0

Browse files
committed
allow one benchmark maps to multiple kernels; add gemm kernel
1 parent 311fca2 commit a0f65f0

File tree

1 file changed

+51
-8
lines changed

1 file changed

+51
-8
lines changed

benchmarks/run.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from typing import Callable
2727

2828
# Maps tritonbench op names to Helion kernel examples
29-
KERNEL_MAPPINGS: dict[str, tuple[str, str, str]] = {
29+
# Can map to a single kernel or a list of kernels
30+
KERNEL_MAPPINGS: dict[str, tuple[str, str, str] | list[tuple[str, str, str]]] = {
3031
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
3132
# "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
3233
# "embedding": (
@@ -75,6 +76,19 @@
7576
"examples.fp8_gemm",
7677
"fp8_gemm_tritonbench",
7778
),
79+
"gemm": [
80+
# List of gemm variants
81+
(
82+
"tritonbench.operators.gemm.operator",
83+
"examples.matmul",
84+
"matmul",
85+
),
86+
(
87+
"tritonbench.operators.gemm.operator",
88+
"examples.matmul_split_k",
89+
"matmul_split_k",
90+
),
91+
],
7892
}
7993

8094

@@ -196,16 +210,36 @@ def check_and_setup_tritonbench() -> None:
196210

197211

198212
def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
199-
"""Run a single kernel benchmark."""
213+
"""Run a kernel benchmark, handling both single and multiple variants."""
200214
# Check if kernel is in the mapping table
201215
if kernel_name not in KERNEL_MAPPINGS:
202216
print(f"Error: Unknown kernel '{kernel_name}'", file=sys.stderr)
203217
print(
204218
f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}", file=sys.stderr
205219
)
206220
sys.exit(1)
221+
222+
mapping = KERNEL_MAPPINGS[kernel_name]
223+
224+
# Check if it's a list of variants or a single kernel
225+
if isinstance(mapping, list):
226+
# Run each variant
227+
for i, (tritonbench_module, module_path, func_name) in enumerate(mapping):
228+
# Extract variant name from func_name for display
229+
variant_name = func_name
230+
if i > 0:
231+
print(f"\n{'=' * 60}", file=sys.stderr)
232+
print(f"Kernel: {kernel_name} (variant: {variant_name})", file=sys.stderr)
233+
print(f"{'=' * 60}\n", file=sys.stderr)
234+
run_single_kernel_variant(kernel_name, tritonbench_module, module_path, func_name, tritonbench_args.copy(), variant_name)
235+
else:
236+
# Single kernel
237+
tritonbench_module, module_path, func_name = mapping
238+
run_single_kernel_variant(kernel_name, tritonbench_module, module_path, func_name, tritonbench_args)
239+
207240

208-
tritonbench_module, module_path, func_name = KERNEL_MAPPINGS[kernel_name]
241+
def run_single_kernel_variant(kernel_name: str, tritonbench_module: str, module_path: str, func_name: str, tritonbench_args: list[str], variant_name: str | None = None) -> None:
242+
"""Run a single kernel variant."""
209243

210244
# Import from the mapped module
211245
try:
@@ -305,7 +339,10 @@ def _inner() -> Callable[..., Any] | object:
305339
return _inner
306340

307341
# Method name for the benchmark
308-
helion_method_name = f"helion_{kernel_name}"
342+
if variant_name:
343+
helion_method_name = f"helion_{kernel_name}_{variant_name}"
344+
else:
345+
helion_method_name = f"helion_{kernel_name}"
309346

310347
# Import register_benchmark API
311348
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
@@ -325,10 +362,16 @@ def _inner() -> Callable[..., Any] | object:
325362
# Set the decorated method on the Operator class
326363
setattr(Operator, helion_method_name, decorated_method)
327364

328-
print(
329-
f"Running {operator_name} benchmark with Helion implementation...\n",
330-
file=sys.stderr,
331-
)
365+
if variant_name:
366+
print(
367+
f"Running {operator_name} benchmark with Helion implementation (variant: {variant_name})...\n",
368+
file=sys.stderr,
369+
)
370+
else:
371+
print(
372+
f"Running {operator_name} benchmark with Helion implementation...\n",
373+
file=sys.stderr,
374+
)
332375

333376
# Create and run the operator with unknown args
334377
op = Operator(tb_args=tb_args, extra_args=unknown_args)

0 commit comments

Comments
 (0)