Skip to content

Commit 7641e0f

Browse files
committed
Support list, tuple and dict inputs
ghstack-source-id: 2d332b4 Pull Request resolved: #34
1 parent a1e9c99 commit 7641e0f

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

helion/_compiler/compile_environment.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ def to_fake(self, obj: object, origin: Origin) -> object:
162162
return lift_closures(obj, origin)
163163
if isinstance(obj, ConstExpr):
164164
return obj.value
165+
if isinstance(obj, list):
166+
return [self.to_fake(e, origin) for e in obj]
167+
if isinstance(obj, tuple):
168+
return tuple(self.to_fake(e, origin) for e in obj)
169+
if isinstance(obj, dict):
170+
return {k: self.to_fake(e, origin) for k, e in obj.items()}
165171
# TODO(jansel): support other types of args
166172
raise TypeError(f"unsupported argument type {type(obj)} ({origin})")
167173

test/test_misc.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import unittest
4+
35
from expecttest import TestCase
46
import pytest
57
import torch
@@ -53,3 +55,66 @@ def add3(x, y):
5355
code_and_output(add2, (x, x))
5456

5557
code_and_output(add3, (x, x))
58+
59+
def test_inputs(self):
60+
@helion.kernel
61+
def kernel(a_list, b_dict, b_tuple):
62+
a0, a1 = a_list
63+
b0 = b_dict["b0"]
64+
(b1,) = b_tuple
65+
c0, c1 = torch.empty_like(a0), torch.empty_like(a1)
66+
for tile in hl.tile(a0.size()):
67+
c0[tile] = a0[tile] + b0[tile]
68+
c1[tile] = a1[tile] + b1[tile]
69+
return [c0, c1]
70+
71+
x = torch.randn(4, device=DEVICE)
72+
code, result = code_and_output(kernel, ([x, x], {"b0": x}, (x,)))
73+
torch.testing.assert_close(result[0], 2 * x)
74+
torch.testing.assert_close(result[1], 2 * x)
75+
self.assertExpectedInline(
76+
code,
77+
"""\
78+
from __future__ import annotations
79+
80+
import torch
81+
import triton
82+
import triton.language as tl
83+
84+
@triton.jit
85+
def _kernel_kernel(a0, c0, c1, a0_size_0, a0_stride_0, c0_stride_0, c1_stride_0, _BLOCK_SIZE_0: tl.constexpr):
86+
pid_0 = tl.program_id(0)
87+
offset_0 = pid_0 * _BLOCK_SIZE_0
88+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
89+
mask_0 = indices_0 < a0_size_0
90+
load = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
91+
load_1 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
92+
v_0 = load + load_1
93+
tl.store(c0 + indices_0 * c0_stride_0, v_0, mask_0)
94+
load_2 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
95+
load_3 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
96+
v_1 = load_2 + load_3
97+
tl.store(c1 + indices_0 * c1_stride_0, v_1, mask_0)
98+
99+
def kernel(a_list, b_dict, b_tuple):
100+
a0, a1 = a_list
101+
b0 = b_dict['b0']
102+
b1, = b_tuple
103+
c0, c1 = (torch.empty_like(a0), torch.empty_like(a1))
104+
_BLOCK_SIZE_0 = 4
105+
_kernel_kernel[triton.cdiv(a0.size(0), _BLOCK_SIZE_0),](a0, c0, c1, a0.size(0), a0.stride(0), c0.stride(0), c1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
106+
return [c0, c1]
107+
108+
def _kernel_make_precompiler(a_list, b_dict, b_tuple):
109+
a0, a1 = a_list
110+
b0 = b_dict['b0']
111+
b1, = b_tuple
112+
c0, c1 = (torch.empty_like(a0), torch.empty_like(a1))
113+
_BLOCK_SIZE_0 = 4
114+
from helion.runtime.precompile_shim import make_precompiler
115+
return make_precompiler(_kernel_kernel)(a0, c0, c1, a0.size(0), a0.stride(0), c0.stride(0), c1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
116+
)
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main()

0 commit comments

Comments
 (0)