Skip to content

Commit 9a9f3e7

Browse files
authored
More fair comparison (#146)
1 parent eb7be14 commit 9a9f3e7

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def test(
8989

9090
# flex attention version
9191
# TODO(jansel): turn the above kernel into a flex attention kernel
92-
flex_out = flex_attention(q, k, v)
92+
flex_compiled = torch.compile(flex_attention, fullgraph=True)
93+
flex_out = flex_compiled(q, k, v)
9394
torch.testing.assert_close(flex_out, ref_out, atol=1e-2, rtol=1e-2)
9495

9596
# sdpa version
@@ -106,7 +107,7 @@ def test(
106107
spda_sec = do_bench(
107108
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v)
108109
)
109-
flex_sec = do_bench(lambda: flex_attention(q, k, v))
110+
flex_sec = do_bench(lambda: flex_compiled(q, k, v))
110111
helion_sec = do_bench(lambda: attention(q, k, v))
111112
print(
112113
f"Helion time: {helion_sec:.4f}ms, flex time: {flex_sec:.4f}, torch time: {spda_sec:.4f}"

0 commit comments

Comments
 (0)