Skip to content

Commit 37e8af3

Browse files
authored
Fix unstable CI (#299)
1 parent 55759ff commit 37e8af3

File tree

4 files changed

+7
-9
lines changed

4 files changed

+7
-9
lines changed

test/test_autotuner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_config_fragment0(self):
4545
self.assertExpectedJournal("\n".join(map(repr, configs)))
4646

4747
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
48+
@patch.object(loops, "_supports_warp_specialize", lambda: True)
4849
def test_config_fragment1(self):
4950
args = (
5051
torch.randn([8, 512, 512], device=DEVICE),

test/test_register_tunable.expected

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ This file is automatically generated by assertExpectedJournal calls in test_regi
22
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
33

44
--- assertExpectedJournal(TestRegisterTunable.test_integer_fragment)
5-
helion.Config(block_sizes=[128], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat', multiplier=3)
5+
helion.Config(block_sizes=[128], range_unroll_factors=[0], range_warp_specializes=[], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat', multiplier=3)
66

77
--- assertExpectedJournal(TestRegisterTunable.test_integer_fragment)
88
from __future__ import annotations

test/test_register_tunable.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from __future__ import annotations
22

33
import unittest
4+
from unittest.mock import patch
45

56
import torch
67

78
import helion
9+
from helion import _compat
810
from helion._testing import DEVICE
911
from helion._testing import TestCase
1012
from helion._testing import code_and_output
1113
from helion.autotuner import EnumFragment
1214
from helion.autotuner import IntegerFragment
1315
from helion.autotuner import PowerOfTwoFragment
1416
import helion.language as hl
17+
from helion.language import loops
1518

1619

1720
class TestRegisterTunable(TestCase):
@@ -41,6 +44,8 @@ def kernel_with_tunable(x: torch.Tensor) -> torch.Tensor:
4144
)
4245
self.assertExpectedJournal(code)
4346

47+
@patch.object(_compat, "_supports_tensor_descriptor", lambda: False)
48+
@patch.object(loops, "_supports_warp_specialize", lambda: False)
4449
def test_integer_fragment(self):
4550
@helion.kernel()
4651
def kernel_with_int_param(x: torch.Tensor) -> torch.Tensor:

test/test_tensor_descriptor.expected

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
8080
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
8181
out = torch.empty_like(q_view)
8282
sm_scale = 1.0 / math.sqrt(head_dim)
83-
qk_scale = sm_scale * 1.44269504
8483
_BLOCK_SIZE_1 = 16
85-
_RDIM_SIZE_2 = 64
8684
_BLOCK_SIZE_3 = 16
8785
_attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
8886
return out.view(q_in.size())
@@ -98,9 +96,7 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
9896
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
9997
out = torch.empty_like(q_view)
10098
sm_scale = 1.0 / math.sqrt(head_dim)
101-
qk_scale = sm_scale * 1.44269504
10299
_BLOCK_SIZE_1 = 16
103-
_RDIM_SIZE_2 = 64
104100
_BLOCK_SIZE_3 = 16
105101
from helion.runtime.precompile_shim import make_precompiler
106102
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
@@ -182,9 +178,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
182178
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
183179
out = torch.empty_like(q_view)
184180
sm_scale = 1.0 / math.sqrt(head_dim)
185-
qk_scale = sm_scale * 1.44269504
186181
_BLOCK_SIZE_1 = 128
187-
_RDIM_SIZE_2 = 64
188182
_BLOCK_SIZE_3 = 64
189183
_attention_kernel[64 * triton.cdiv(1024, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
190184
return out.view(q_in.size())
@@ -200,9 +194,7 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
200194
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
201195
out = torch.empty_like(q_view)
202196
sm_scale = 1.0 / math.sqrt(head_dim)
203-
qk_scale = sm_scale * 1.44269504
204197
_BLOCK_SIZE_1 = 128
205-
_RDIM_SIZE_2 = 64
206198
_BLOCK_SIZE_3 = 64
207199
from helion.runtime.precompile_shim import make_precompiler
208200
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)

0 commit comments

Comments
 (0)