Skip to content

Add literal index into tuple #327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,10 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
value = node.value
assert isinstance(value, ExtendedAST)
type_info = value._type_info
if isinstance(type_info, SequenceType):
if isinstance(node.slice, ast.Constant):
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
raise exc.InvalidSequenceSubscription(node.slice)
if type_info is not None and type_info.origin.is_host():
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
Expand Down
3 changes: 3 additions & 0 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,9 @@ def populate_symbol_origins(self, origin: Origin) -> None:
for i, subtype in enumerate(self.element_types):
subtype.populate_symbol_origins(GetItemOrigin(origin, i))

def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
return super().propagate_getitem(key, origin)

def merge(self, other: TypeInfo) -> TypeInfo:
if isinstance(other, SequenceType):
self_elements = self.element_types
Expand Down
4 changes: 4 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,7 @@ class CannotReadDeviceVariableOnHost(BaseError):

class DeviceTensorSubscriptAssignmentNotAllowed(BaseError):
message = "Cannot assign to subscript of device tensor '{0}'."


class InvalidSequenceSubscription(BaseError):
message = "Cannot subscript a sequence with non constant indices. Got '{0!s}'. "
128 changes: 128 additions & 0 deletions test/test_misc.expected
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,131 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
_BLOCK_SIZE_0 = 64
_launcher(_fn_kernel, (triton.cdiv(m, _BLOCK_SIZE_1),), x, out, out.stride(0), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestMisc.test_tuple_literal_subscript)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < out_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < out_size_1
load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load_1 = tl.load(inp_tuple_item_1 + (indices_0[:, None] * inp_tuple_item_1_stride_0 + indices_1[None, :] * inp_tuple_item_1_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = load_1.to(tl.float32)
v_1 = load + v_0
v_2 = inp_tuple_item_2.to(tl.float32)
v_3 = v_1 * v_2
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])

def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
out = torch.empty_like(inp_tuple[0])
_BLOCK_SIZE_0 = 8
_BLOCK_SIZE_1 = 8
_launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return outfrom __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_0_size_0, inp_tuple_item_0_size_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
offset_1 = pid_1 * _BLOCK_SIZE_1
load = tl.load(tl.make_block_ptr(inp_tuple_item_0, [inp_tuple_item_0_size_0, inp_tuple_item_0_size_1], [inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
load_1 = tl.load(tl.make_block_ptr(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
v_0 = load_1.to(tl.float32)
v_1 = load + v_0
v_2 = inp_tuple_item_2.to(tl.float32)
v_3 = v_1 * v_2
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1])

def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
out = torch.empty_like(inp_tuple[0])
_BLOCK_SIZE_0 = 8
_BLOCK_SIZE_1 = 8
_launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[0].size(0), inp_tuple[0].size(1), inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestMisc.test_tuple_literal_subscript_w_descriptor)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

helion.runtime.set_triton_allocator()

@triton.jit
def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
inp_tuple_item_1_desc = tl.make_tensor_descriptor(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1])
num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < out_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < out_size_1
load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load_1 = inp_tuple_item_1_desc.load([offset_0, offset_1])
v_0 = load_1.to(tl.float32)
v_1 = load + v_0
v_2 = inp_tuple_item_2.to(tl.float32)
v_3 = v_1 * v_2
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])

def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
out = torch.empty_like(inp_tuple[0])
_BLOCK_SIZE_0 = 8
_BLOCK_SIZE_1 = 8
_launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestMisc.test_tuple_unpack)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _tuple_unpack_kernel_kernel(a, b, out, a_size_0, a_stride_0, b_stride_0, out_stride_0, x, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < a_size_0
load = tl.load(a + indices_0 * a_stride_0, mask_0, other=0)
load_1 = tl.load(b + indices_0 * b_stride_0, mask_0, other=0)
v_0 = load_1.to(tl.float32)
v_1 = load + v_0
v_2 = x.to(tl.float32)
v_3 = v_1 + v_2
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)

def tuple_unpack_kernel(inp_tuple, *, _launcher=_default_launcher):
a, b, x = inp_tuple
out = torch.empty_like(a)
_BLOCK_SIZE_0 = 4
_launcher(_tuple_unpack_kernel_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out, a.size(0), a.stride(0), b.stride(0), out.stride(0), x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out
77 changes: 77 additions & 0 deletions test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

import helion
from helion._compat import supports_tensor_descriptor
from helion._testing import DEVICE
from helion._testing import TestCase
from helion._testing import code_and_output
Expand Down Expand Up @@ -313,6 +314,82 @@ def kernel_with_scalar_item(
self.assertEqual(code, code2)
torch.testing.assert_close(result2, x + 10)

def test_tuple_literal_subscript(self):
@helion.kernel
def tuple_literal_index_kernel(inp_tuple) -> torch.Tensor:
out = torch.empty_like(inp_tuple[0])
for tile in hl.tile(out.size()):
out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2]
return out

inp_tuple = (
torch.randn(8, 30, device=DEVICE, dtype=torch.float32),
torch.randn(8, 32, device=DEVICE, dtype=torch.bfloat16),
3,
)
code_pointer, result = code_and_output(
tuple_literal_index_kernel,
(inp_tuple,),
block_size=[8, 8],
indexing="pointer",
)
torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3)

code_block, result = code_and_output(
tuple_literal_index_kernel,
(inp_tuple,),
block_size=[8, 8],
indexing="block_ptr",
)
torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3)

self.assertNotEqual(code_pointer, code_block)
self.assertExpectedJournal(code_pointer + code_block)

@unittest.skipUnless(
supports_tensor_descriptor(), "Tensor descriptor support is required"
)
def test_tuple_literal_subscript_w_descriptor(self):
@helion.kernel
def tuple_literal_index_kernel(inp_tuple) -> torch.Tensor:
out = torch.empty_like(inp_tuple[0])
for tile in hl.tile(out.size()):
out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2]
return out

inp_tuple = (
torch.randn(8, 30, device=DEVICE, dtype=torch.float32),
torch.randn(8, 32, device=DEVICE, dtype=torch.bfloat16),
3,
)
code, result = code_and_output(
tuple_literal_index_kernel,
(inp_tuple,),
block_size=[8, 8],
indexing="tensor_descriptor",
)
torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3)
self.assertExpectedJournal(code)

def test_tuple_unpack(self):
@helion.kernel
def tuple_unpack_kernel(inp_tuple) -> torch.Tensor:
a, b, x = inp_tuple
out = torch.empty_like(a)
for tile in hl.tile(out.size(0)):
out[tile] = a[tile] + b[tile] + x
return out

inp_tuple = (
torch.randn(16, device=DEVICE, dtype=torch.float32),
torch.randn(16, device=DEVICE, dtype=torch.bfloat16),
5,
)
code, result = code_and_output(tuple_unpack_kernel, (inp_tuple,), block_size=4)
torch.testing.assert_close(result, inp_tuple[0] + inp_tuple[1] + 5)

self.assertExpectedJournal(code)


if __name__ == "__main__":
unittest.main()
Loading