Skip to content

Commit 9076129

Browse files
authored
Fix issue with BLOCK_SIZE0.to(torch.int32) (#254)
See #237 (comment)
1 parent 6df7516 commit 9076129

File tree

3 files changed

+84
-1
lines changed

3 files changed

+84
-1
lines changed

helion/_compiler/inductor_lowering.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ..exc import InductorLoweringError
4343
from ..language._decorators import APIFunc
4444
from ..language._decorators import is_api_func
45+
from .ast_extension import ExtendedAST
4546
from .ast_extension import create
4647
from .ast_extension import expr_from_string
4748
from .ast_extension import statement_from_string
@@ -350,8 +351,22 @@ def visit(n: torch.fx.Node) -> None:
350351
ast_val = expr_from_string(
351352
"tensor[" + ", ".join(expand) + "]", tensor=ast_val
352353
)
353-
input_asts.append(ast_val)
354+
if (
355+
isinstance(ast_val, ast.Name)
356+
and ast_val.id in device_function._constexpr_args
357+
):
358+
# introduce a copy so triton doesn't complain about `id.to(...)` calls
359+
assert isinstance(ast_val, ExtendedAST)
360+
with ast_val:
361+
copy_var = device_function.new_var(f"{ast_val.id}_", dce=True)
362+
ctx.cg.add_statement(
363+
statement_from_string(f"{copy_var} = {ast_val.id}")
364+
)
365+
input_asts.append(expr_from_string(f"{copy_var}"))
366+
else:
367+
input_asts.append(ast_val)
354368

369+
device_function: DeviceFunction = ctx.cg.device_function
355370
ndim: int = max([x.ndim for x in self.input_fake_tensors(node)] or (0,))
356371
input_asts: list[ast.AST] = []
357372
map_arg((node.args, node.kwargs), visit)

test/test_misc.expected

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,54 @@ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass
5353
from helion.runtime.precompile_shim import make_precompiler
5454
return make_precompiler(_kernel_kernel)(a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
5555

56+
--- assertExpectedJournal(TestMisc.test_tile_block_size_constexpr_fix)
57+
from __future__ import annotations
58+
59+
import torch
60+
import triton
61+
import triton.language as tl
62+
from torch._inductor.runtime.triton_compat import libdevice
63+
64+
@triton.jit
65+
def _test_tile_block_size_usage_kernel(out, x_size_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
66+
pid_0 = tl.program_id(0)
67+
offset_0 = pid_0 * _BLOCK_SIZE_0
68+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
69+
mask_0 = indices_0 < x_size_0
70+
_BLOCK_SIZE_0_ = _BLOCK_SIZE_0
71+
v_0 = _BLOCK_SIZE_0_.to(tl.int32)
72+
v_1 = indices_0 % v_0
73+
v_2 = tl.full([], 0, tl.int32)
74+
v_3 = v_1 != v_2
75+
v_4 = libdevice.signbit(v_1) != 0 if v_1.dtype is tl.float32 else v_1 < 0
76+
v_5 = libdevice.signbit(v_0) != 0 if v_0.dtype is tl.float32 else v_0 < 0
77+
v_6 = v_4 != v_5
78+
v_7 = v_3 & v_6
79+
v_8 = v_1 + v_0
80+
v_9 = tl.where(v_7, v_8, v_1)
81+
sub = -1 + _BLOCK_SIZE_0
82+
v_10 = sub.to(tl.int32)
83+
v_11 = v_9 == v_10
84+
v_12 = tl.full([], 0, tl.int64)
85+
v_13 = tl.full([], 1, tl.int64)
86+
v_14 = v_13[None]
87+
v_15 = v_12[None]
88+
v_16 = tl.where(v_11, v_14, v_15)
89+
v_17 = v_16.to(tl.int32)
90+
tl.store(out + indices_0 * out_stride_0, v_17, mask_0)
91+
92+
def test_tile_block_size_usage(x: torch.Tensor):
93+
out = torch.zeros_like(x, dtype=torch.int32)
94+
_BLOCK_SIZE_0 = 32
95+
_test_tile_block_size_usage_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](out, x.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
96+
return out
97+
98+
def _test_tile_block_size_usage_make_precompiler(x: torch.Tensor):
99+
out = torch.zeros_like(x, dtype=torch.int32)
100+
_BLOCK_SIZE_0 = 32
101+
from helion.runtime.precompile_shim import make_precompiler
102+
return make_precompiler(_test_tile_block_size_usage_kernel)(out, x.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
103+
56104
--- assertExpectedJournal(TestMisc.test_torch_alloc)
57105
from __future__ import annotations
58106

test/test_misc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,26 @@ def test_tile_id(x: torch.Tensor) -> torch.Tensor:
216216
result = test_tile_id.bind((x,)).compile_config(config)(x)
217217
self.assertEqual(result.sum().item(), 16)
218218

219+
def test_tile_block_size_constexpr_fix(self):
220+
"""Test that tile.block_size can be used in expressions without compilation errors."""
221+
222+
@helion.kernel(use_default_config=True)
223+
def test_tile_block_size_usage(x: torch.Tensor) -> torch.Tensor:
224+
out = torch.zeros_like(x, dtype=torch.int32)
225+
for tile in hl.tile(x.shape[0]):
226+
# This should not cause a compilation error when tile.block_size is used
227+
# in expressions that generate .to() calls
228+
block_size_temp = tile.block_size
229+
mask = tile.index % block_size_temp == block_size_temp - 1
230+
out[tile] = torch.where(mask, 1, 0)
231+
return out
232+
233+
x = torch.randn(32, device=DEVICE)
234+
code, result = code_and_output(test_tile_block_size_usage, (x,))
235+
self.assertExpectedJournal(code)
236+
# The result should have 1s at positions that are last in their tile
237+
self.assertTrue(result.sum().item() > 0)
238+
219239

220240
if __name__ == "__main__":
221241
unittest.main()

0 commit comments

Comments
 (0)