diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index cbb56240..53777b1e 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -814,6 +814,10 @@ def visit_Subscript(self, node: ast.Subscript) -> object: value = node.value assert isinstance(value, ExtendedAST) type_info = value._type_info + if isinstance(type_info, SequenceType): + if isinstance(node.slice, ast.Constant): + return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue] + raise exc.InvalidSequenceSubscription(node.slice) if type_info is not None and type_info.origin.is_host(): return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType] return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType] diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 189bb3e8..7111ed10 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -1221,6 +1221,9 @@ def populate_symbol_origins(self, origin: Origin) -> None: for i, subtype in enumerate(self.element_types): subtype.populate_symbol_origins(GetItemOrigin(origin, i)) + def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: + return super().propagate_getitem(key, origin) + def merge(self, other: TypeInfo) -> TypeInfo: if isinstance(other, SequenceType): self_elements = self.element_types diff --git a/helion/exc.py b/helion/exc.py index 08cae769..c1d979ca 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -323,3 +323,7 @@ class CannotReadDeviceVariableOnHost(BaseError): class DeviceTensorSubscriptAssignmentNotAllowed(BaseError): message = "Cannot assign to subscript of device tensor '{0}'." + + +class InvalidSequenceSubscription(BaseError): + message = "Cannot subscript a sequence with non constant indices. Got '{0!s}'. " diff --git a/test/test_misc.expected b/test/test_misc.expected index 131b1172..2be76a39 100644 --- a/test/test_misc.expected +++ b/test/test_misc.expected @@ -143,3 +143,131 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 64 _launcher(_fn_kernel, (triton.cdiv(m, _BLOCK_SIZE_1),), x, out, out.stride(0), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) return out + +--- assertExpectedJournal(TestMisc.test_tuple_literal_subscript) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < out_size_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < out_size_1 + load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + load_1 = tl.load(inp_tuple_item_1 + (indices_0[:, None] * inp_tuple_item_1_stride_0 + indices_1[None, :] * inp_tuple_item_1_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_0 = load_1.to(tl.float32) + v_1 = load + v_0 + v_2 = inp_tuple_item_2.to(tl.float32) + v_3 = v_1 * v_2 + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :]) + +def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher): + out = torch.empty_like(inp_tuple[0]) + _BLOCK_SIZE_0 = 8 + _BLOCK_SIZE_1 = 8 + _launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return outfrom __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_0_size_0, inp_tuple_item_0_size_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + load = tl.load(tl.make_block_ptr(inp_tuple_item_0, [inp_tuple_item_0_size_0, inp_tuple_item_0_size_1], [inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') + load_1 = tl.load(tl.make_block_ptr(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') + v_0 = load_1.to(tl.float32) + v_1 = load + v_0 + v_2 = inp_tuple_item_2.to(tl.float32) + v_3 = v_1 * v_2 + tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1]) + +def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher): + out = torch.empty_like(inp_tuple[0]) + _BLOCK_SIZE_0 = 8 + _BLOCK_SIZE_1 = 8 + _launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[0].size(0), inp_tuple[0].size(1), inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestMisc.test_tuple_literal_subscript_w_descriptor) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +helion.runtime.set_triton_allocator() + +@triton.jit +def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + inp_tuple_item_1_desc = tl.make_tensor_descriptor(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) + num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < out_size_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < out_size_1 + load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + load_1 = inp_tuple_item_1_desc.load([offset_0, offset_1]) + v_0 = load_1.to(tl.float32) + v_1 = load + v_0 + v_2 = inp_tuple_item_2.to(tl.float32) + v_3 = v_1 * v_2 + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :]) + +def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher): + out = torch.empty_like(inp_tuple[0]) + _BLOCK_SIZE_0 = 8 + _BLOCK_SIZE_1 = 8 + _launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestMisc.test_tuple_unpack) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _tuple_unpack_kernel_kernel(a, b, out, a_size_0, a_stride_0, b_stride_0, out_stride_0, x, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < a_size_0 + load = tl.load(a + indices_0 * a_stride_0, mask_0, other=0) + load_1 = tl.load(b + indices_0 * b_stride_0, mask_0, other=0) + v_0 = load_1.to(tl.float32) + v_1 = load + v_0 + v_2 = x.to(tl.float32) + v_3 = v_1 + v_2 + tl.store(out + indices_0 * out_stride_0, v_3, mask_0) + +def tuple_unpack_kernel(inp_tuple, *, _launcher=_default_launcher): + a, b, x = inp_tuple + out = torch.empty_like(a) + _BLOCK_SIZE_0 = 4 + _launcher(_tuple_unpack_kernel_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out, a.size(0), a.stride(0), b.stride(0), out.stride(0), x, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out diff --git a/test/test_misc.py b/test/test_misc.py index 66fd6755..c8b92336 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -9,6 +9,7 @@ import torch import helion +from helion._compat import supports_tensor_descriptor from helion._testing import DEVICE from helion._testing import TestCase from helion._testing import code_and_output @@ -313,6 +314,82 @@ def kernel_with_scalar_item( self.assertEqual(code, code2) torch.testing.assert_close(result2, x + 10) + def test_tuple_literal_subscript(self): + @helion.kernel + def tuple_literal_index_kernel(inp_tuple) -> torch.Tensor: + out = torch.empty_like(inp_tuple[0]) + for tile in hl.tile(out.size()): + out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2] + return out + + inp_tuple = ( + torch.randn(8, 30, device=DEVICE, dtype=torch.float32), + torch.randn(8, 32, device=DEVICE, dtype=torch.bfloat16), + 3, + ) + code_pointer, result = code_and_output( + tuple_literal_index_kernel, + (inp_tuple,), + block_size=[8, 8], + indexing="pointer", + ) + torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3) + + code_block, result = code_and_output( + tuple_literal_index_kernel, + (inp_tuple,), + block_size=[8, 8], + indexing="block_ptr", + ) + torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3) + + self.assertNotEqual(code_pointer, code_block) + self.assertExpectedJournal(code_pointer + code_block) + + @unittest.skipUnless( + supports_tensor_descriptor(), "Tensor descriptor support is required" + ) + def test_tuple_literal_subscript_w_descriptor(self): + @helion.kernel + def tuple_literal_index_kernel(inp_tuple) -> torch.Tensor: + out = torch.empty_like(inp_tuple[0]) + for tile in hl.tile(out.size()): + out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2] + return out + + inp_tuple = ( + torch.randn(8, 30, device=DEVICE, dtype=torch.float32), + torch.randn(8, 32, device=DEVICE, dtype=torch.bfloat16), + 3, + ) + code, result = code_and_output( + tuple_literal_index_kernel, + (inp_tuple,), + block_size=[8, 8], + indexing="tensor_descriptor", + ) + torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3) + self.assertExpectedJournal(code) + + def test_tuple_unpack(self): + @helion.kernel + def tuple_unpack_kernel(inp_tuple) -> torch.Tensor: + a, b, x = inp_tuple + out = torch.empty_like(a) + for tile in hl.tile(out.size(0)): + out[tile] = a[tile] + b[tile] + x + return out + + inp_tuple = ( + torch.randn(16, device=DEVICE, dtype=torch.float32), + torch.randn(16, device=DEVICE, dtype=torch.bfloat16), + 5, + ) + code, result = code_and_output(tuple_unpack_kernel, (inp_tuple,), block_size=4) + torch.testing.assert_close(result, inp_tuple[0] + inp_tuple[1] + 5) + + self.assertExpectedJournal(code) + if __name__ == "__main__": unittest.main()