Skip to content

Commit cd77a26

Browse files
committed
Add literal index into tuple
1 parent ab85dfe commit cd77a26

File tree

4 files changed

+225
-0
lines changed

4 files changed

+225
-0
lines changed

helion/_compiler/device_ir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,10 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
814814
value = node.value
815815
assert isinstance(value, ExtendedAST)
816816
type_info = value._type_info
817+
if isinstance(type_info, SequenceType):
818+
if isinstance(node.slice, ast.Constant):
819+
return self.visit(value)[self.visit(node.slice)]
820+
raise NotImplementedError("Only literal index into tuple is supported. ")
817821
if type_info is not None and type_info.origin.is_host():
818822
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
819823
return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]

helion/_compiler/type_propagation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,9 @@ def populate_symbol_origins(self, origin: Origin) -> None:
12211221
for i, subtype in enumerate(self.element_types):
12221222
subtype.populate_symbol_origins(GetItemOrigin(origin, i))
12231223

1224+
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1225+
return super().propagate_getitem(key, origin)
1226+
12241227
def merge(self, other: TypeInfo) -> TypeInfo:
12251228
if isinstance(other, SequenceType):
12261229
self_elements = self.element_types

test/test_tuple.expected

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
This file is automatically generated by assertExpectedJournal calls in test_tuple.py.
2+
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
3+
4+
--- assertExpectedJournal(TestMisc.test_tuple_literal_subscript)
5+
from __future__ import annotations
6+
7+
import torch
8+
import triton
9+
import triton.language as tl
10+
from helion.runtime import default_launcher as _default_launcher
11+
12+
@triton.jit
13+
def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
14+
num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
15+
pid_0 = tl.program_id(0) % num_blocks_0
16+
pid_1 = tl.program_id(0) // num_blocks_0
17+
offset_0 = pid_0 * _BLOCK_SIZE_0
18+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
19+
mask_0 = indices_0 < out_size_0
20+
offset_1 = pid_1 * _BLOCK_SIZE_1
21+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
22+
mask_1 = indices_1 < out_size_1
23+
load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
24+
load_1 = tl.load(inp_tuple_item_1 + (indices_0[:, None] * inp_tuple_item_1_stride_0 + indices_1[None, :] * inp_tuple_item_1_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
25+
v_0 = load_1.to(tl.float32)
26+
v_1 = load + v_0
27+
v_2 = inp_tuple_item_2.to(tl.float32)
28+
v_3 = v_1 * v_2
29+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
30+
31+
def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
32+
out = torch.empty_like(inp_tuple[0])
33+
_BLOCK_SIZE_0 = 8
34+
_BLOCK_SIZE_1 = 8
35+
_launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
36+
return outfrom __future__ import annotations
37+
38+
import torch
39+
import triton
40+
import triton.language as tl
41+
from helion.runtime import default_launcher as _default_launcher
42+
43+
@triton.jit
44+
def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_0_size_0, inp_tuple_item_0_size_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
45+
num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
46+
pid_0 = tl.program_id(0) % num_blocks_0
47+
pid_1 = tl.program_id(0) // num_blocks_0
48+
offset_0 = pid_0 * _BLOCK_SIZE_0
49+
offset_1 = pid_1 * _BLOCK_SIZE_1
50+
load = tl.load(tl.make_block_ptr(inp_tuple_item_0, [inp_tuple_item_0_size_0, inp_tuple_item_0_size_1], [inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
51+
load_1 = tl.load(tl.make_block_ptr(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
52+
v_0 = load_1.to(tl.float32)
53+
v_1 = load + v_0
54+
v_2 = inp_tuple_item_2.to(tl.float32)
55+
v_3 = v_1 * v_2
56+
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1])
57+
58+
def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
59+
out = torch.empty_like(inp_tuple[0])
60+
_BLOCK_SIZE_0 = 8
61+
_BLOCK_SIZE_1 = 8
62+
_launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[0].size(0), inp_tuple[0].size(1), inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
63+
return out
64+
65+
--- assertExpectedJournal(TestMisc.test_tuple_literal_subscript_w_descriptor)
66+
from __future__ import annotations
67+
68+
import torch
69+
import helion
70+
import triton
71+
import triton.language as tl
72+
from helion.runtime import default_launcher as _default_launcher
73+
74+
helion.runtime.set_triton_allocator()
75+
76+
@triton.jit
77+
def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
78+
inp_tuple_item_1_desc = tl.make_tensor_descriptor(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1])
79+
num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
80+
pid_0 = tl.program_id(0) % num_blocks_0
81+
pid_1 = tl.program_id(0) // num_blocks_0
82+
offset_0 = pid_0 * _BLOCK_SIZE_0
83+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
84+
mask_0 = indices_0 < out_size_0
85+
offset_1 = pid_1 * _BLOCK_SIZE_1
86+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
87+
mask_1 = indices_1 < out_size_1
88+
load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
89+
load_1 = inp_tuple_item_1_desc.load([offset_0, offset_1])
90+
v_0 = load_1.to(tl.float32)
91+
v_1 = load + v_0
92+
v_2 = inp_tuple_item_2.to(tl.float32)
93+
v_3 = v_1 * v_2
94+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
95+
96+
def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
97+
out = torch.empty_like(inp_tuple[0])
98+
_BLOCK_SIZE_0 = 8
99+
_BLOCK_SIZE_1 = 8
100+
_launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
101+
return out
102+
103+
--- assertExpectedJournal(TestMisc.test_tuple_unpack)
104+
from __future__ import annotations
105+
106+
import torch
107+
import triton
108+
import triton.language as tl
109+
from helion.runtime import default_launcher as _default_launcher
110+
111+
@triton.jit
112+
def _tuple_unpack_kernel_kernel(a, b, out, a_size_0, a_stride_0, b_stride_0, out_stride_0, x, _BLOCK_SIZE_0: tl.constexpr):
113+
pid_0 = tl.program_id(0)
114+
offset_0 = pid_0 * _BLOCK_SIZE_0
115+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
116+
mask_0 = indices_0 < a_size_0
117+
load = tl.load(a + indices_0 * a_stride_0, mask_0, other=0)
118+
load_1 = tl.load(b + indices_0 * b_stride_0, mask_0, other=0)
119+
v_0 = load_1.to(tl.float32)
120+
v_1 = load + v_0
121+
v_2 = x.to(tl.float32)
122+
v_3 = v_1 + v_2
123+
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
124+
125+
def tuple_unpack_kernel(inp_tuple, *, _launcher=_default_launcher):
126+
a, b, x = inp_tuple
127+
out = torch.empty_like(a)
128+
_BLOCK_SIZE_0 = 4
129+
_launcher(_tuple_unpack_kernel_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out, a.size(0), a.stride(0), b.stride(0), out.stride(0), x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
130+
return out

test/test_tuple.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
import torch
6+
7+
import helion
8+
from helion._compat import supports_tensor_descriptor
9+
from helion._testing import DEVICE
10+
from helion._testing import TestCase
11+
from helion._testing import code_and_output
12+
import helion.language as hl
13+
14+
15+
@helion.kernel
16+
def tuple_literal_index_kernel(inp_tuple) -> torch.Tensor:
17+
out = torch.empty_like(inp_tuple[0])
18+
for tile in hl.tile(out.size()):
19+
out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2]
20+
return out
21+
22+
23+
class TestMisc(TestCase):
24+
def test_tuple_literal_subscript(self):
25+
inp_tuple = (
26+
torch.randn(8, 30, device=DEVICE, dtype=torch.float32),
27+
torch.randn(8, 32, device=DEVICE, dtype=torch.bfloat16),
28+
3,
29+
)
30+
code_pointer, result = code_and_output(
31+
tuple_literal_index_kernel,
32+
(inp_tuple,),
33+
block_size=[8, 8],
34+
indexing="pointer",
35+
)
36+
torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3)
37+
38+
code_block, result = code_and_output(
39+
tuple_literal_index_kernel,
40+
(inp_tuple,),
41+
block_size=[8, 8],
42+
indexing="block_ptr",
43+
)
44+
torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3)
45+
46+
self.assertNotEqual(code_pointer, code_block)
47+
self.assertExpectedJournal(code_pointer + code_block)
48+
49+
@unittest.skipUnless(
50+
supports_tensor_descriptor(), "Tensor descriptor support is required"
51+
)
52+
def test_tuple_literal_subscript_w_descriptor(self):
53+
inp_tuple = (
54+
torch.randn(8, 30, device=DEVICE, dtype=torch.float32),
55+
torch.randn(8, 32, device=DEVICE, dtype=torch.bfloat16),
56+
3,
57+
)
58+
code, result = code_and_output(
59+
tuple_literal_index_kernel,
60+
(inp_tuple,),
61+
block_size=[8, 8],
62+
indexing="tensor_descriptor",
63+
)
64+
torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3)
65+
self.assertExpectedJournal(code)
66+
67+
def test_tuple_unpack(self):
68+
@helion.kernel
69+
def tuple_unpack_kernel(inp_tuple) -> torch.Tensor:
70+
a, b, x = inp_tuple
71+
out = torch.empty_like(a)
72+
for tile in hl.tile(out.size(0)):
73+
out[tile] = a[tile] + b[tile] + x
74+
return out
75+
76+
inp_tuple = (
77+
torch.randn(16, device=DEVICE, dtype=torch.float32),
78+
torch.randn(16, device=DEVICE, dtype=torch.bfloat16),
79+
5,
80+
)
81+
code, result = code_and_output(tuple_unpack_kernel, (inp_tuple,), block_size=4)
82+
torch.testing.assert_close(result, inp_tuple[0] + inp_tuple[1] + 5)
83+
84+
self.assertExpectedJournal(code)
85+
86+
87+
if __name__ == "__main__":
88+
unittest.main()

0 commit comments

Comments
 (0)