Skip to content

Commit 85ae546

Browse files
oulgenpytorchmergebot
authored andcommitted
Support namedtuple and dataclass (#41)
Pull Request resolved: #41 Approved by: https://github.com/yf225, https://github.com/jansel
1 parent 75fe166 commit 85ae546

File tree

4 files changed

+123
-29
lines changed

4 files changed

+123
-29
lines changed

helion/_compiler/compile_environment.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,26 @@ def to_fake(self, obj: object, origin: Origin) -> object:
164164
return obj.value
165165
if isinstance(obj, list):
166166
return [self.to_fake(e, origin) for e in obj]
167+
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
168+
return type(obj)(
169+
**{ # pyre-ignore[6]
170+
k: self.to_fake(e, origin)
171+
for k, e in obj._asdict().items() # pyre-ignore[16]
172+
}
173+
)
167174
if isinstance(obj, tuple):
168175
return tuple(self.to_fake(e, origin) for e in obj)
169176
if isinstance(obj, dict):
170177
return {k: self.to_fake(e, origin) for k, e in obj.items()}
171-
# TODO(jansel): support other types of args
178+
if dataclasses.is_dataclass(obj):
179+
return dataclasses.replace(
180+
obj,
181+
**{
182+
k: self.to_fake(getattr(obj, k), origin)
183+
for k in obj.__dataclass_fields__ # pyre-ignore[16]
184+
},
185+
)
186+
172187
raise TypeError(f"unsupported argument type {type(obj)} ({origin})")
173188

174189
def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:

helion/_compiler/type_propagation.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,36 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo:
225225
zip(value.keys(), cls._unpack_example(items, origin), strict=False)
226226
),
227227
)
228+
if isinstance(value, tuple) and hasattr(value, "_asdict"):
229+
# namedtuple
230+
return ClassType(
231+
origin,
232+
dict(
233+
zip(
234+
value._fields, # pyre-ignore[16]
235+
cls._unpack_example(
236+
value._asdict().items(), # pyre-ignore[16]
237+
origin,
238+
),
239+
strict=False,
240+
)
241+
),
242+
)
243+
if dataclasses.is_dataclass(value):
244+
keys = value.__dataclass_fields__.keys() # pyre-ignore[16]
245+
return ClassType(
246+
origin,
247+
dict(
248+
zip(
249+
keys,
250+
cls._unpack_example(
251+
tuple((k, getattr(value, k)) for k in keys),
252+
origin,
253+
),
254+
strict=False,
255+
)
256+
),
257+
)
228258
return UnknownType(
229259
debug_msg=f"{type(value).__name__} is not supported",
230260
origin=origin,
@@ -1122,6 +1152,11 @@ def tree_map(self, fn: Callable[[TypeInfo], object]) -> dict[str | int, object]:
11221152
return {k: v.tree_map(fn) for k, v in self.element_types.items()}
11231153

11241154

1155+
class ClassType(DictType):
1156+
def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
1157+
return self.element_types[attr]
1158+
1159+
11251160
class SliceType(CollectionType):
11261161
element_types: slice
11271162

helion/runtime/kernel.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4+
import dataclasses
45
import functools
56
import inspect
67
import logging
@@ -140,9 +141,15 @@ def _specialization_key(self, obj: object) -> Hashable:
140141
try:
141142
extractor = _specialization_extractors[type(obj)]
142143
except KeyError:
143-
raise TypeError(
144-
f"unsupported argument type: {type(obj).__name__}"
145-
) from None
144+
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
145+
# this is a namedtuple
146+
extractor = _specialization_extractors["namedtuple"]
147+
elif dataclasses.is_dataclass(obj):
148+
extractor = _specialization_extractors["dataclass"]
149+
else:
150+
raise TypeError(
151+
f"unsupported argument type: {type(obj).__name__}"
152+
) from None
146153
return extractor(self, obj)
147154

148155
def normalize_args(self, *args: object, **kwargs: object) -> tuple[object, ...]:
@@ -462,6 +469,14 @@ def _sequence_key(fn: Kernel, obj: Sequence) -> Hashable:
462469
return type(obj), tuple([fn._specialization_key(item) for item in obj])
463470

464471

472+
def _mapping_key(
473+
fn: Kernel, obj: dict[str | int, object], real_type: type[object]
474+
) -> Hashable:
475+
return real_type, tuple(
476+
sorted((k, fn._specialization_key(v)) for k, v in obj.items())
477+
)
478+
479+
465480
def _number_key(fn: Kernel, n: float | bool) -> object:
466481
return type(n)
467482

@@ -475,7 +490,9 @@ def _function_key(fn: Kernel, obj: types.FunctionType) -> object:
475490
return obj.__code__
476491

477492

478-
_specialization_extractors: dict[type[object], Callable[[Kernel, object], Hashable]] = {
493+
_specialization_extractors: dict[
494+
type[object] | str, Callable[[Kernel, object], Hashable]
495+
] = {
479496
torch.Tensor: _tensor_key,
480497
torch.nn.Parameter: _tensor_key,
481498
torch.dtype: lambda fn, x: x,
@@ -486,9 +503,9 @@ def _function_key(fn: Kernel, obj: types.FunctionType) -> object:
486503
str: lambda fn, x: x,
487504
list: _sequence_key,
488505
tuple: _sequence_key,
489-
dict: lambda fn, x: tuple(
490-
sorted((k, fn._specialization_key(v)) for k, v in x.items())
491-
),
506+
dict: lambda fn, x: _mapping_key(fn, x, type(x)),
507+
"namedtuple": lambda fn, x: _mapping_key(fn, x._asdict(), type(x)),
508+
"dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)),
492509
types.FunctionType: _function_key,
493510
types.BuiltinFunctionType: lambda fn, x: x,
494511
ConstExpr: lambda fn, x: x.value,

test/test_misc.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from collections import namedtuple
4+
from dataclasses import dataclass
35
import unittest
46

57
from expecttest import TestCase
@@ -58,20 +60,33 @@ def add3(x, y):
5860

5961
def test_inputs(self):
6062
@helion.kernel
61-
def kernel(a_list, b_dict, b_tuple):
63+
def kernel(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass):
6264
a0, a1 = a_list
6365
b0 = b_dict["b0"]
6466
(b1,) = b_tuple
65-
c0, c1 = torch.empty_like(a0), torch.empty_like(a1)
67+
c0, c1 = c_named_tuple.x, c_named_tuple.y
68+
d0, d1 = d_dataclass.x, d_dataclass.y
69+
70+
o0, o1 = torch.empty_like(a0), torch.empty_like(a1)
6671
for tile in hl.tile(a0.size()):
67-
c0[tile] = a0[tile] + b0[tile]
68-
c1[tile] = a1[tile] + b1[tile]
69-
return [c0, c1]
72+
o0[tile] = a0[tile] + b0[tile] + c0[tile] + d0[tile]
73+
o1[tile] = a1[tile] + b1[tile] + c1[tile] + d1[tile]
74+
return [o0, o1]
7075

71-
x = torch.randn(4, device=DEVICE)
72-
code, result = code_and_output(kernel, ([x, x], {"b0": x}, (x,)))
73-
torch.testing.assert_close(result[0], 2 * x)
74-
torch.testing.assert_close(result[1], 2 * x)
76+
x = torch.ones(4, device=DEVICE)
77+
Point = namedtuple("Point", ["x", "y"]) # noqa: PYI024
78+
p = Point(x, x)
79+
80+
@dataclass(frozen=True)
81+
class Point2:
82+
x: torch.Tensor
83+
y: torch.Tensor
84+
85+
p2 = Point2(x, x)
86+
87+
code, result = code_and_output(kernel, ([x, x], {"b0": x}, (x,), p, p2))
88+
torch.testing.assert_close(result[0], 4 * x)
89+
torch.testing.assert_close(result[1], 4 * x)
7590
self.assertExpectedInline(
7691
code,
7792
"""\
@@ -82,37 +97,49 @@ def kernel(a_list, b_dict, b_tuple):
8297
import triton.language as tl
8398
8499
@triton.jit
85-
def _kernel_kernel(a0, c0, c1, a0_size_0, a0_stride_0, c0_stride_0, c1_stride_0, _BLOCK_SIZE_0: tl.constexpr):
100+
def _kernel_kernel(a0, o0, o1, a0_size_0, a0_stride_0, o0_stride_0, o1_stride_0, _BLOCK_SIZE_0: tl.constexpr):
86101
pid_0 = tl.program_id(0)
87102
offset_0 = pid_0 * _BLOCK_SIZE_0
88103
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
89104
mask_0 = indices_0 < a0_size_0
90105
load = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
91106
load_1 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
92107
v_0 = load + load_1
93-
tl.store(c0 + indices_0 * c0_stride_0, v_0, mask_0)
94108
load_2 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
109+
v_1 = v_0 + load_2
95110
load_3 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
96-
v_1 = load_2 + load_3
97-
tl.store(c1 + indices_0 * c1_stride_0, v_1, mask_0)
98-
99-
def kernel(a_list, b_dict, b_tuple):
111+
v_2 = v_1 + load_3
112+
tl.store(o0 + indices_0 * o0_stride_0, v_2, mask_0)
113+
load_4 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
114+
load_5 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
115+
v_3 = load_4 + load_5
116+
load_6 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
117+
v_4 = v_3 + load_6
118+
load_7 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
119+
v_5 = v_4 + load_7
120+
tl.store(o1 + indices_0 * o1_stride_0, v_5, mask_0)
121+
122+
def kernel(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass):
100123
a0, a1 = a_list
101124
b0 = b_dict['b0']
102125
b1, = b_tuple
103-
c0, c1 = (torch.empty_like(a0), torch.empty_like(a1))
126+
c0, c1 = (c_named_tuple.x, c_named_tuple.y)
127+
d0, d1 = (d_dataclass.x, d_dataclass.y)
128+
o0, o1 = (torch.empty_like(a0), torch.empty_like(a1))
104129
_BLOCK_SIZE_0 = 4
105-
_kernel_kernel[triton.cdiv(a0.size(0), _BLOCK_SIZE_0),](a0, c0, c1, a0.size(0), a0.stride(0), c0.stride(0), c1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
106-
return [c0, c1]
130+
_kernel_kernel[triton.cdiv(a0.size(0), _BLOCK_SIZE_0),](a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
131+
return [o0, o1]
107132
108-
def _kernel_make_precompiler(a_list, b_dict, b_tuple):
133+
def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass):
109134
a0, a1 = a_list
110135
b0 = b_dict['b0']
111136
b1, = b_tuple
112-
c0, c1 = (torch.empty_like(a0), torch.empty_like(a1))
137+
c0, c1 = (c_named_tuple.x, c_named_tuple.y)
138+
d0, d1 = (d_dataclass.x, d_dataclass.y)
139+
o0, o1 = (torch.empty_like(a0), torch.empty_like(a1))
113140
_BLOCK_SIZE_0 = 4
114141
from helion.runtime.precompile_shim import make_precompiler
115-
return make_precompiler(_kernel_kernel)(a0, c0, c1, a0.size(0), a0.stride(0), c0.stride(0), c1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
142+
return make_precompiler(_kernel_kernel)(a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
116143
)
117144

118145

0 commit comments

Comments
 (0)