@@ -80,9 +80,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
80
80
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
81
81
out = torch.empty_like(q_view)
82
82
sm_scale = 1.0 / math.sqrt(head_dim)
83
- qk_scale = sm_scale * 1.44269504
84
83
_BLOCK_SIZE_1 = 16
85
- _RDIM_SIZE_2 = 64
86
84
_BLOCK_SIZE_3 = 16
87
85
_attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
88
86
return out.view(q_in.size())
@@ -98,9 +96,7 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
98
96
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
99
97
out = torch.empty_like(q_view)
100
98
sm_scale = 1.0 / math.sqrt(head_dim)
101
- qk_scale = sm_scale * 1.44269504
102
99
_BLOCK_SIZE_1 = 16
103
- _RDIM_SIZE_2 = 64
104
100
_BLOCK_SIZE_3 = 16
105
101
from helion.runtime.precompile_shim import make_precompiler
106
102
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
@@ -182,9 +178,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
182
178
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
183
179
out = torch.empty_like(q_view)
184
180
sm_scale = 1.0 / math.sqrt(head_dim)
185
- qk_scale = sm_scale * 1.44269504
186
181
_BLOCK_SIZE_1 = 128
187
- _RDIM_SIZE_2 = 64
188
182
_BLOCK_SIZE_3 = 64
189
183
_attention_kernel[64 * triton.cdiv(1024, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
190
184
return out.view(q_in.size())
@@ -200,9 +194,7 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
200
194
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
201
195
out = torch.empty_like(q_view)
202
196
sm_scale = 1.0 / math.sqrt(head_dim)
203
- qk_scale = sm_scale * 1.44269504
204
197
_BLOCK_SIZE_1 = 128
205
- _RDIM_SIZE_2 = 64
206
198
_BLOCK_SIZE_3 = 64
207
199
from helion.runtime.precompile_shim import make_precompiler
208
200
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
0 commit comments