Skip to content

Commit e533c3a

Browse files
authored
Refactor examples to use run_example helper (#225)
1 parent 4669fdc commit e533c3a

17 files changed

+235
-230
lines changed

examples/add.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
import helion
6+
from helion._testing import run_example
67
import helion.language as hl
78

89

@@ -23,17 +24,9 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2324

2425

2526
def check(m: int, n: int) -> None:
26-
from triton.testing import do_bench
27-
2827
x = torch.randn([m, n], device="cuda", dtype=torch.float16)
2928
y = torch.randn([m, n], device="cuda", dtype=torch.float16)
30-
result = add(x, y)
31-
torch.testing.assert_close(result, x + y, rtol=1e-2, atol=1e-1)
32-
sec = do_bench(lambda: add(x, y))
33-
baseline_sec = do_bench(lambda: torch.add(x, y))
34-
print(
35-
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
36-
)
29+
run_example(add, torch.add, (x, y))
3730

3831

3932
def main() -> None:

examples/attention.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
from __future__ import annotations
22

33
import math
4+
from typing import Callable
5+
from typing import cast
46

57
import torch
68
from torch.nn.attention.flex_attention import flex_attention
79

810
import helion
11+
from helion._testing import run_example
912
import helion.language as hl
1013

1114

1215
@helion.kernel(
1316
config=helion.Config(
14-
# This config was autotuned on a 3090, it won't be fast for other architectures
15-
block_sizes=[128, 64],
16-
num_warps=4,
17+
# This config was autotuned on a 5090, it won't be fast for other cards
18+
block_sizes=[128, 16],
19+
loop_orders=[[0, 1]],
20+
l2_groupings=[2],
21+
num_warps=2,
1722
num_stages=3,
18-
indexing="block_ptr",
23+
indexing="pointer",
1924
),
2025
# Static shapes provides a speedup for attention
2126
static_shapes=True,
@@ -82,36 +87,24 @@ def test(
8287
for _ in range(3)
8388
]
8489

85-
# reference implementation
86-
p = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
87-
p = torch.softmax(p.float(), dim=-1).to(dtype)
88-
ref_out = torch.matmul(p, v)
90+
def ref_attention(
91+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
92+
) -> torch.Tensor:
93+
"""Reference manual attention implementation"""
94+
p = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
95+
p = torch.softmax(p.float(), dim=-1).to(dtype)
96+
return torch.matmul(p, v)
8997

90-
# flex attention version
91-
# TODO(jansel): turn the above kernel into a flex attention kernel
92-
flex_compiled = torch.compile(flex_attention, fullgraph=True)
93-
flex_out = flex_compiled(q, k, v)
94-
torch.testing.assert_close(flex_out, ref_out, atol=1e-2, rtol=1e-2)
95-
96-
# sdpa version
97-
sdpa_out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
98-
torch.testing.assert_close(sdpa_out, ref_out, atol=1e-2, rtol=1e-2)
99-
100-
# helion version
101-
hl_out = attention(q, k, v)
102-
torch.testing.assert_close(hl_out, ref_out, atol=1e-2, rtol=1e-2)
103-
104-
# benchmark
105-
from triton.testing import do_bench
106-
107-
spda_sec = do_bench(
108-
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v)
109-
)
110-
flex_sec = do_bench(lambda: flex_compiled(q, k, v))
111-
helion_sec = do_bench(lambda: attention(q, k, v))
112-
print(
113-
f"Helion time: {helion_sec:.4f}ms, flex time: {flex_sec:.4f}, torch time: {spda_sec:.4f}"
98+
flex_compiled = cast(
99+
"Callable[..., torch.Tensor]", torch.compile(flex_attention, fullgraph=True)
114100
)
101+
baselines = {
102+
"torch": torch.nn.functional.scaled_dot_product_attention,
103+
"flex": flex_compiled,
104+
"ref": ref_attention,
105+
}
106+
107+
run_example(attention, baselines, (q, k, v))
115108

116109

117110
def main() -> None:

examples/bmm.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
import helion
6+
from helion._testing import run_example
67
import helion.language as hl
78

89

@@ -26,17 +27,9 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
2627

2728

2829
def check(b: int, m: int, k: int, n: int) -> None:
29-
from triton.testing import do_bench
30-
3130
x = torch.randn([b, m, k], device="cuda", dtype=torch.float16)
3231
y = torch.randn([b, k, n], device="cuda", dtype=torch.float16)
33-
result = bmm(x, y)
34-
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1e-1)
35-
sec = do_bench(lambda: bmm(x, y))
36-
baseline_sec = do_bench(lambda: torch.bmm(x, y))
37-
print(
38-
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
39-
)
32+
run_example(bmm, torch.bmm, (x, y))
4033

4134

4235
def main() -> None:

examples/concatenate.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
import helion
6+
from helion._testing import run_example
67
import helion.language as hl
78

89

@@ -31,18 +32,9 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
3132

3233

3334
def main() -> None:
34-
from triton.testing import do_bench
35-
3635
x = torch.randn([1500, 400], device="cuda")
3736
y = torch.randn([1500, 600], device="cuda")
38-
result = concat2d_dim1(x, y)
39-
expected = torch.cat([x, y], dim=1)
40-
torch.testing.assert_close(result, expected)
41-
sec = do_bench(lambda: concat2d_dim1(x, y))
42-
baseline_sec = do_bench(lambda: torch.cat([x, y], dim=1))
43-
print(
44-
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
45-
)
37+
run_example(concat2d_dim1, lambda x, y: torch.cat([x, y], dim=1), (x, y))
4638

4739

4840
if __name__ == "__main__":

examples/embedding.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
import helion
6+
from helion._testing import run_example
67
import helion.language as hl
78

89

@@ -24,17 +25,11 @@ def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
2425

2526

2627
def main() -> None:
27-
from triton.testing import do_bench
28-
2928
num_embeddings, embedding_dim = 16, 64
3029
x = torch.randint(0, num_embeddings, [256, 32], device="cuda", dtype=torch.int32)
3130
weight = torch.randn([num_embeddings, embedding_dim], device="cuda")
32-
result = embedding(x, weight)
33-
torch.testing.assert_close(result, torch.nn.functional.embedding(x, weight))
34-
sec = do_bench(lambda: embedding(x, weight))
35-
baseline_sec = do_bench(lambda: torch.nn.functional.embedding(x, weight))
36-
print(
37-
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
31+
run_example(
32+
embedding, torch.nn.functional.embedding, (x, weight), atol=0.0, rtol=0.0
3833
)
3934

4035

examples/jagged_dense_add.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
import helion
6+
from helion._testing import run_example
67
import helion.language as hl
78

89
"""
@@ -110,11 +111,12 @@ def random_jagged_2d(
110111

111112
def main() -> None:
112113
rows, cols = 256, 5000
113-
x = random_jagged_2d(rows, cols, device="cuda")
114+
x_data, x_offsets = random_jagged_2d(rows, cols, device="cuda")
114115
y = torch.randn([rows, cols], device="cuda")
115-
result = jagged_dense_add_2d(*x, y)
116-
expected = jagged_dense_add_2d_reference(*x, y)
117-
torch.testing.assert_close(result, expected)
116+
117+
run_example(
118+
jagged_dense_add_2d, jagged_dense_add_2d_reference, (x_data, x_offsets, y)
119+
)
118120

119121

120122
if __name__ == "__main__":

examples/long_sum.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
import helion
6+
from helion._testing import run_example
67
import helion.language as hl
78

89

@@ -72,31 +73,16 @@ def longsum_manual(x: torch.Tensor) -> torch.Tensor:
7273

7374

7475
def check(m: int, n: int) -> None:
75-
from triton.testing import do_bench
76-
7776
x = torch.randn([m, n], device="cuda", dtype=torch.float32)
7877

79-
helion_out = longsum(x)
80-
torch.testing.assert_close(helion_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
81-
print("✅ Results Match ✅ naive reduction")
78+
# Test all three kernel variants against the baseline
79+
kernels = {
80+
"helion naive": longsum,
81+
"helion loop": longsum_w_red_loop,
82+
"helion manual": longsum_manual,
83+
}
8284

83-
helion_red_loop_out = longsum_w_red_loop(x)
84-
torch.testing.assert_close(
85-
helion_red_loop_out, baseline_sum(x), rtol=1e-2, atol=1e-1
86-
)
87-
print("✅ Results Match ✅ Reduction Loop")
88-
89-
helion_manual_out = longsum_manual(x)
90-
torch.testing.assert_close(helion_manual_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
91-
print("✅ Results Match ✅ Manual Reduction Loop")
92-
93-
sec = do_bench(lambda: longsum(x))
94-
loop_sec = do_bench(lambda: longsum_w_red_loop(x))
95-
manual_loop_sec = do_bench(lambda: longsum_manual(x))
96-
baseline_sec = do_bench(lambda: baseline_sum(x))
97-
print(
98-
f"Helion Naive time: {sec:.4f}ms, Helion Looped Time: {loop_sec:.4f}, Helion Manual Loop Time: {manual_loop_sec:.4f} torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x {baseline_sec / loop_sec:.2f}x {baseline_sec / manual_loop_sec:.2f}x"
99-
)
85+
run_example(kernels, baseline_sum, (x,))
10086

10187

10288
def main() -> None:

examples/matmul.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
import helion
6+
from helion._testing import run_example
67
import helion.language as hl
78

89

@@ -24,17 +25,9 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2425

2526

2627
def check(m: int, k: int, n: int) -> None:
27-
from triton.testing import do_bench
28-
2928
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
3029
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
31-
result = matmul(x, y)
32-
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1e-1)
33-
sec = do_bench(lambda: matmul(x, y))
34-
baseline_sec = do_bench(lambda: torch.matmul(x, y))
35-
print(
36-
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
37-
)
30+
run_example(matmul, torch.matmul, (x, y))
3831

3932

4033
def main() -> None:

examples/matmul_layernorm.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
import helion
6+
from helion._testing import run_example
67
import helion.language as hl
78

89

@@ -51,20 +52,11 @@ def matmul_layernorm_pytorch(
5152

5253

5354
def check(m: int, k: int, n: int) -> None:
54-
from triton.testing import do_bench
55-
5655
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
5756
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
5857
weight = torch.randn([n], device="cuda", dtype=torch.float16)
5958
bias = torch.randn([n], device="cuda", dtype=torch.float16)
60-
result = matmul_layernorm(x, y, weight, bias)
61-
expected = matmul_layernorm_pytorch(x, y, weight, bias)
62-
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-1)
63-
sec = do_bench(lambda: matmul_layernorm(x, y, weight, bias))
64-
baseline_sec = do_bench(lambda: matmul_layernorm_pytorch(x, y, weight, bias))
65-
print(
66-
f"Helion time: {sec:.4f}s, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
67-
)
59+
run_example(matmul_layernorm, matmul_layernorm_pytorch, (x, y, weight, bias))
6860

6961

7062
def main() -> None:

examples/matmul_split_k.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
import helion
6+
from helion._testing import run_example
67
from helion.autotuner import PowerOfTwoFragment
78
import helion.language as hl
89

@@ -27,17 +28,9 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2728

2829

2930
def check(m: int, k: int, n: int) -> None:
30-
from triton.testing import do_bench
31-
3231
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
3332
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
34-
result = matmul_split_k(x, y)
35-
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1)
36-
sec = do_bench(lambda: matmul_split_k(x, y))
37-
baseline_sec = do_bench(lambda: torch.matmul(x, y))
38-
print(
39-
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
40-
)
33+
run_example(matmul_split_k, torch.matmul, (x, y), atol=1)
4134

4235

4336
def main() -> None:

0 commit comments

Comments
 (0)