Skip to content

Commit 3ff927d

Browse files
authored
[Benchmark] Add vector_exp benchmark (#249)
1 parent 8567309 commit 3ff927d

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

benchmark/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>)
2525
"vector_add": ("examples.add", "add"),
2626
"embedding": ("examples.embedding", "embedding_tritonbench"),
27+
"vector_exp": ("examples.exp", "exp_tritonbench"),
2728
}
2829

2930

examples/exp.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import run_example
7+
import helion.language as hl
8+
9+
10+
@helion.kernel()
11+
def exp(x: torch.Tensor) -> torch.Tensor:
12+
out = torch.empty_like(x)
13+
for tile in hl.tile(x.size()):
14+
out[tile] = torch.exp(x[tile])
15+
return out
16+
17+
18+
def exp_tritonbench(x: torch.Tensor) -> dict[str, torch.Tensor]:
19+
"""Wrapper for tritonbench that returns output in expected format."""
20+
return {"output": exp(x)}
21+
22+
23+
def check(n: int) -> None:
24+
x = torch.randn(n, device="cuda", dtype=torch.float32)
25+
run_example(exp, torch.exp, (x,))
26+
27+
28+
def main() -> None:
29+
check(1024 * 1024)
30+
31+
32+
if __name__ == "__main__":
33+
main()

0 commit comments

Comments
 (0)