Skip to content

Commit a2159f6

Browse files
committed
Add literal index into tuple
stack-info: PR: #327, branch: joydddd/stack/16
1 parent ab85dfe commit a2159f6

File tree

5 files changed

+216
-0
lines changed

5 files changed

+216
-0
lines changed

helion/_compiler/device_ir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,10 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
814814
value = node.value
815815
assert isinstance(value, ExtendedAST)
816816
type_info = value._type_info
817+
if isinstance(type_info, SequenceType):
818+
if isinstance(node.slice, ast.Constant):
819+
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
820+
raise exc.InvalidSequenceSubscription(node.slice)
817821
if type_info is not None and type_info.origin.is_host():
818822
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
819823
return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]

helion/_compiler/type_propagation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,9 @@ def populate_symbol_origins(self, origin: Origin) -> None:
12211221
for i, subtype in enumerate(self.element_types):
12221222
subtype.populate_symbol_origins(GetItemOrigin(origin, i))
12231223

1224+
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1225+
return super().propagate_getitem(key, origin)
1226+
12241227
def merge(self, other: TypeInfo) -> TypeInfo:
12251228
if isinstance(other, SequenceType):
12261229
self_elements = self.element_types

helion/exc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,7 @@ class CannotReadDeviceVariableOnHost(BaseError):
323323

324324
class DeviceTensorSubscriptAssignmentNotAllowed(BaseError):
325325
message = "Cannot assign to subscript of device tensor '{0}'."
326+
327+
328+
class InvalidSequenceSubscription(BaseError):
329+
message = "Cannot subscript a sequence with non constant indices. Got '{0!s}'. "

test/test_misc.expected

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,131 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
143143
_BLOCK_SIZE_0 = 64
144144
_launcher(_fn_kernel, (triton.cdiv(m, _BLOCK_SIZE_1),), x, out, out.stride(0), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
145145
return out
146+
147+
--- assertExpectedJournal(TestMisc.test_tuple_literal_subscript)
148+
from __future__ import annotations
149+
150+
import torch
151+
import triton
152+
import triton.language as tl
153+
from helion.runtime import default_launcher as _default_launcher
154+
155+
@triton.jit
156+
def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
157+
num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
158+
pid_0 = tl.program_id(0) % num_blocks_0
159+
pid_1 = tl.program_id(0) // num_blocks_0
160+
offset_0 = pid_0 * _BLOCK_SIZE_0
161+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
162+
mask_0 = indices_0 < out_size_0
163+
offset_1 = pid_1 * _BLOCK_SIZE_1
164+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
165+
mask_1 = indices_1 < out_size_1
166+
load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
167+
load_1 = tl.load(inp_tuple_item_1 + (indices_0[:, None] * inp_tuple_item_1_stride_0 + indices_1[None, :] * inp_tuple_item_1_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
168+
v_0 = load_1.to(tl.float32)
169+
v_1 = load + v_0
170+
v_2 = inp_tuple_item_2.to(tl.float32)
171+
v_3 = v_1 * v_2
172+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
173+
174+
def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
175+
out = torch.empty_like(inp_tuple[0])
176+
_BLOCK_SIZE_0 = 8
177+
_BLOCK_SIZE_1 = 8
178+
_launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
179+
return outfrom __future__ import annotations
180+
181+
import torch
182+
import triton
183+
import triton.language as tl
184+
from helion.runtime import default_launcher as _default_launcher
185+
186+
@triton.jit
187+
def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_0_size_0, inp_tuple_item_0_size_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
188+
num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
189+
pid_0 = tl.program_id(0) % num_blocks_0
190+
pid_1 = tl.program_id(0) // num_blocks_0
191+
offset_0 = pid_0 * _BLOCK_SIZE_0
192+
offset_1 = pid_1 * _BLOCK_SIZE_1
193+
load = tl.load(tl.make_block_ptr(inp_tuple_item_0, [inp_tuple_item_0_size_0, inp_tuple_item_0_size_1], [inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
194+
load_1 = tl.load(tl.make_block_ptr(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
195+
v_0 = load_1.to(tl.float32)
196+
v_1 = load + v_0
197+
v_2 = inp_tuple_item_2.to(tl.float32)
198+
v_3 = v_1 * v_2
199+
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1])
200+
201+
def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
202+
out = torch.empty_like(inp_tuple[0])
203+
_BLOCK_SIZE_0 = 8
204+
_BLOCK_SIZE_1 = 8
205+
_launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[0].size(0), inp_tuple[0].size(1), inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
206+
return out
207+
208+
--- assertExpectedJournal(TestMisc.test_tuple_literal_subscript_w_descriptor)
209+
from __future__ import annotations
210+
211+
import torch
212+
import helion
213+
import triton
214+
import triton.language as tl
215+
from helion.runtime import default_launcher as _default_launcher
216+
217+
helion.runtime.set_triton_allocator()
218+
219+
@triton.jit
220+
def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
221+
inp_tuple_item_1_desc = tl.make_tensor_descriptor(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1])
222+
num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
223+
pid_0 = tl.program_id(0) % num_blocks_0
224+
pid_1 = tl.program_id(0) // num_blocks_0
225+
offset_0 = pid_0 * _BLOCK_SIZE_0
226+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
227+
mask_0 = indices_0 < out_size_0
228+
offset_1 = pid_1 * _BLOCK_SIZE_1
229+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
230+
mask_1 = indices_1 < out_size_1
231+
load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
232+
load_1 = inp_tuple_item_1_desc.load([offset_0, offset_1])
233+
v_0 = load_1.to(tl.float32)
234+
v_1 = load + v_0
235+
v_2 = inp_tuple_item_2.to(tl.float32)
236+
v_3 = v_1 * v_2
237+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
238+
239+
def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
240+
out = torch.empty_like(inp_tuple[0])
241+
_BLOCK_SIZE_0 = 8
242+
_BLOCK_SIZE_1 = 8
243+
_launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
244+
return out
245+
246+
--- assertExpectedJournal(TestMisc.test_tuple_unpack)
247+
from __future__ import annotations
248+
249+
import torch
250+
import triton
251+
import triton.language as tl
252+
from helion.runtime import default_launcher as _default_launcher
253+
254+
@triton.jit
255+
def _tuple_unpack_kernel_kernel(a, b, out, a_size_0, a_stride_0, b_stride_0, out_stride_0, x, _BLOCK_SIZE_0: tl.constexpr):
256+
pid_0 = tl.program_id(0)
257+
offset_0 = pid_0 * _BLOCK_SIZE_0
258+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
259+
mask_0 = indices_0 < a_size_0
260+
load = tl.load(a + indices_0 * a_stride_0, mask_0, other=0)
261+
load_1 = tl.load(b + indices_0 * b_stride_0, mask_0, other=0)
262+
v_0 = load_1.to(tl.float32)
263+
v_1 = load + v_0
264+
v_2 = x.to(tl.float32)
265+
v_3 = v_1 + v_2
266+
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
267+
268+
def tuple_unpack_kernel(inp_tuple, *, _launcher=_default_launcher):
269+
a, b, x = inp_tuple
270+
out = torch.empty_like(a)
271+
_BLOCK_SIZE_0 = 4
272+
_launcher(_tuple_unpack_kernel_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out, a.size(0), a.stride(0), b.stride(0), out.stride(0), x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
273+
return out

test/test_misc.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
import helion
12+
from helion._compat import supports_tensor_descriptor
1213
from helion._testing import DEVICE
1314
from helion._testing import TestCase
1415
from helion._testing import code_and_output
@@ -313,6 +314,82 @@ def kernel_with_scalar_item(
313314
self.assertEqual(code, code2)
314315
torch.testing.assert_close(result2, x + 10)
315316

317+
def test_tuple_literal_subscript(self):
318+
@helion.kernel
319+
def tuple_literal_index_kernel(inp_tuple) -> torch.Tensor:
320+
out = torch.empty_like(inp_tuple[0])
321+
for tile in hl.tile(out.size()):
322+
out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2]
323+
return out
324+
325+
inp_tuple = (
326+
torch.randn(8, 30, device=DEVICE, dtype=torch.float32),
327+
torch.randn(8, 32, device=DEVICE, dtype=torch.bfloat16),
328+
3,
329+
)
330+
code_pointer, result = code_and_output(
331+
tuple_literal_index_kernel,
332+
(inp_tuple,),
333+
block_size=[8, 8],
334+
indexing="pointer",
335+
)
336+
torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3)
337+
338+
code_block, result = code_and_output(
339+
tuple_literal_index_kernel,
340+
(inp_tuple,),
341+
block_size=[8, 8],
342+
indexing="block_ptr",
343+
)
344+
torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3)
345+
346+
self.assertNotEqual(code_pointer, code_block)
347+
self.assertExpectedJournal(code_pointer + code_block)
348+
349+
@unittest.skipUnless(
350+
supports_tensor_descriptor(), "Tensor descriptor support is required"
351+
)
352+
def test_tuple_literal_subscript_w_descriptor(self):
353+
@helion.kernel
354+
def tuple_literal_index_kernel(inp_tuple) -> torch.Tensor:
355+
out = torch.empty_like(inp_tuple[0])
356+
for tile in hl.tile(out.size()):
357+
out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2]
358+
return out
359+
360+
inp_tuple = (
361+
torch.randn(8, 30, device=DEVICE, dtype=torch.float32),
362+
torch.randn(8, 32, device=DEVICE, dtype=torch.bfloat16),
363+
3,
364+
)
365+
code, result = code_and_output(
366+
tuple_literal_index_kernel,
367+
(inp_tuple,),
368+
block_size=[8, 8],
369+
indexing="tensor_descriptor",
370+
)
371+
torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3)
372+
self.assertExpectedJournal(code)
373+
374+
def test_tuple_unpack(self):
375+
@helion.kernel
376+
def tuple_unpack_kernel(inp_tuple) -> torch.Tensor:
377+
a, b, x = inp_tuple
378+
out = torch.empty_like(a)
379+
for tile in hl.tile(out.size(0)):
380+
out[tile] = a[tile] + b[tile] + x
381+
return out
382+
383+
inp_tuple = (
384+
torch.randn(16, device=DEVICE, dtype=torch.float32),
385+
torch.randn(16, device=DEVICE, dtype=torch.bfloat16),
386+
5,
387+
)
388+
code, result = code_and_output(tuple_unpack_kernel, (inp_tuple,), block_size=4)
389+
torch.testing.assert_close(result, inp_tuple[0] + inp_tuple[1] + 5)
390+
391+
self.assertExpectedJournal(code)
392+
316393

317394
if __name__ == "__main__":
318395
unittest.main()

0 commit comments

Comments
 (0)