File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -89,7 +89,8 @@ def test(
89
89
90
90
# flex attention version
91
91
# TODO(jansel): turn the above kernel into a flex attention kernel
92
- flex_out = flex_attention (q , k , v )
92
+ flex_compiled = torch .compile (flex_attention , fullgraph = True )
93
+ flex_out = flex_compiled (q , k , v )
93
94
torch .testing .assert_close (flex_out , ref_out , atol = 1e-2 , rtol = 1e-2 )
94
95
95
96
# sdpa version
@@ -106,7 +107,7 @@ def test(
106
107
spda_sec = do_bench (
107
108
lambda : torch .nn .functional .scaled_dot_product_attention (q , k , v )
108
109
)
109
- flex_sec = do_bench (lambda : flex_attention (q , k , v ))
110
+ flex_sec = do_bench (lambda : flex_compiled (q , k , v ))
110
111
helion_sec = do_bench (lambda : attention (q , k , v ))
111
112
print (
112
113
f"Helion time: { helion_sec :.4f} ms, flex time: { flex_sec :.4f} , torch time: { spda_sec :.4f} "
You can’t perform that action at this time.
0 commit comments