Skip to content

Commit 9f7158a

Browse files
authored
[Benchmark] Add sum to TritonBench integration (#257)
- Add sum_tritonbench wrapper function that handles 1D input - Add sum to KERNEL_MAPPINGS in benchmark/run.py - Include kernel reset logic to ensure clean state before benchmarking
1 parent 573fc23 commit 9f7158a

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

benchmark/run.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"vector_exp": ("examples.exp", "exp_tritonbench"),
2828
# TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg.
2929
"rms_norm": ("examples.rms_norm", "rms_norm_tritonbench", {"num_inputs": 3}),
30+
"sum": ("examples.sum", "sum_tritonbench"),
3031
}
3132

3233

@@ -236,6 +237,15 @@ def helion_method( # pyre-ignore[3]
236237
) -> Callable[..., Any]:
237238
"""Helion implementation."""
238239

240+
# Reset all Helion kernels before creating the benchmark function
241+
# so that each input size can go through its own autotuning.
242+
from helion.runtime.kernel import Kernel
243+
244+
for attr_name in dir(module):
245+
attr = getattr(module, attr_name)
246+
if isinstance(attr, Kernel):
247+
attr.reset()
248+
239249
def _inner() -> Callable[..., Any]: # pyre-ignore[3]
240250
return kernel_func(*args)
241251

examples/sum.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ def sum_kernel(x: torch.Tensor) -> torch.Tensor:
1919
return out
2020

2121

22+
def sum_tritonbench(x: torch.Tensor) -> torch.Tensor:
23+
"""Wrapper for tritonbench that handles 1D input."""
24+
if x.ndim == 1:
25+
# For 1D tensors, reshape to 2D for sum_kernel
26+
x_2d = x.unsqueeze(0)
27+
result = sum_kernel(x_2d)
28+
return result.squeeze()
29+
return sum_kernel(x)
30+
31+
2232
def check(m: int, n: int) -> None:
2333
x = torch.randn([m, n], device="cuda", dtype=torch.float32)
2434
kernels = {"helion": sum_kernel}

0 commit comments

Comments
 (0)