Skip to content

Commit b9e76dc

Browse files
authored
Improve Tensor.item() handling (#307)
1 parent 6ec6f1d commit b9e76dc

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

helion/_compiler/type_propagation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,19 @@ def propagate_call(
581581
raise exc.TypeInferenceError(
582582
f"Tensor.{attr}() args must be literals"
583583
) from None
584+
if attr == "item" and not (args or kwargs):
585+
if origin.is_device():
586+
raise exc.NotAllowedOnDevice("Tensor.item()")
587+
if self.tensor.fake_value.numel() != 1:
588+
raise exc.TypeInferenceError("Tensor.item() requires numel() == 1")
589+
dtype = self.tensor.fake_value.dtype
590+
if dtype.is_complex:
591+
raise exc.TypeInferenceError("Complex tensors not supported")
592+
if dtype.is_floating_point:
593+
return SymFloatType.new_unbacked(origin)
594+
if dtype == torch.bool:
595+
return SymBoolType.new_unbacked(origin)
596+
return SymIntType.new_unbacked(origin)
584597

585598
proxy_args = [x.tree_map(_to_proxy) for x in args]
586599
proxy_kwargs = {k: v.tree_map(_to_proxy) for k, v in kwargs.items()}

test/test_misc.expected

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,37 @@ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass
5353
from helion.runtime.precompile_shim import make_precompiler
5454
return make_precompiler(_kernel_kernel)(a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
5555

56+
--- assertExpectedJournal(TestMisc.test_scalar_tensor_item_method)
57+
from __future__ import annotations
58+
59+
import torch
60+
import triton
61+
import triton.language as tl
62+
63+
@triton.jit
64+
def _kernel_with_scalar_item_kernel(x, result, x_size_0, result_stride_0, x_stride_0, scalar_val, _BLOCK_SIZE_0: tl.constexpr):
65+
pid_0 = tl.program_id(0)
66+
offset_0 = pid_0 * _BLOCK_SIZE_0
67+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
68+
mask_0 = indices_0 < x_size_0
69+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
70+
v_0 = load + scalar_val
71+
tl.store(result + indices_0 * result_stride_0, v_0, mask_0)
72+
73+
def kernel_with_scalar_item(x: torch.Tensor, scalar_tensor: torch.Tensor):
74+
result = torch.empty_like(x)
75+
scalar_val = scalar_tensor.item()
76+
_BLOCK_SIZE_0 = 128
77+
_kernel_with_scalar_item_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, result, x.size(0), result.stride(0), x.stride(0), scalar_val, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
78+
return result
79+
80+
def _kernel_with_scalar_item_make_precompiler(x: torch.Tensor, scalar_tensor: torch.Tensor):
81+
result = torch.empty_like(x)
82+
scalar_val = scalar_tensor.item()
83+
_BLOCK_SIZE_0 = 128
84+
from helion.runtime.precompile_shim import make_precompiler
85+
return make_precompiler(_kernel_with_scalar_item_kernel)(x, result, x.size(0), result.stride(0), x.stride(0), scalar_val, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
86+
5687
--- assertExpectedJournal(TestMisc.test_tile_block_size_constexpr_fix)
5788
from __future__ import annotations
5889

test/test_misc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,32 @@ def kernel_no_config(x: torch.Tensor) -> torch.Tensor:
287287
"no config provided and no implicit config available", str(cm.exception)
288288
)
289289

290+
def test_scalar_tensor_item_method(self):
291+
"""Test using scalar_tensor.item() to extract scalar value in kernel"""
292+
293+
@helion.kernel(use_default_config=True)
294+
def kernel_with_scalar_item(
295+
x: torch.Tensor, scalar_tensor: torch.Tensor
296+
) -> torch.Tensor:
297+
result = torch.empty_like(x)
298+
scalar_val = scalar_tensor.item()
299+
for tile in hl.tile(x.shape):
300+
result[tile] = x[tile] + scalar_val
301+
return result
302+
303+
x = torch.randn(100, device=DEVICE)
304+
code, result = code_and_output(
305+
kernel_with_scalar_item, (x, torch.tensor(5.0, device=DEVICE))
306+
)
307+
self.assertExpectedJournal(code)
308+
torch.testing.assert_close(result, x + 5)
309+
310+
code2, result2 = code_and_output(
311+
kernel_with_scalar_item, (x, torch.tensor(10.0, device=DEVICE))
312+
)
313+
self.assertEqual(code, code2)
314+
torch.testing.assert_close(result2, x + 10)
315+
290316

291317
if __name__ == "__main__":
292318
unittest.main()

0 commit comments

Comments
 (0)