1
1
from __future__ import annotations
2
2
3
+ from collections import namedtuple
4
+ from dataclasses import dataclass
3
5
import unittest
4
6
5
7
from expecttest import TestCase
@@ -58,20 +60,33 @@ def add3(x, y):
58
60
59
61
def test_inputs (self ):
60
62
@helion .kernel
61
- def kernel (a_list , b_dict , b_tuple ):
63
+ def kernel (a_list , b_dict , b_tuple , c_named_tuple , d_dataclass ):
62
64
a0 , a1 = a_list
63
65
b0 = b_dict ["b0" ]
64
66
(b1 ,) = b_tuple
65
- c0 , c1 = torch .empty_like (a0 ), torch .empty_like (a1 )
67
+ c0 , c1 = c_named_tuple .x , c_named_tuple .y
68
+ d0 , d1 = d_dataclass .x , d_dataclass .y
69
+
70
+ o0 , o1 = torch .empty_like (a0 ), torch .empty_like (a1 )
66
71
for tile in hl .tile (a0 .size ()):
67
- c0 [tile ] = a0 [tile ] + b0 [tile ]
68
- c1 [tile ] = a1 [tile ] + b1 [tile ]
69
- return [c0 , c1 ]
72
+ o0 [tile ] = a0 [tile ] + b0 [ tile ] + c0 [ tile ] + d0 [tile ]
73
+ o1 [tile ] = a1 [tile ] + b1 [ tile ] + c1 [ tile ] + d1 [tile ]
74
+ return [o0 , o1 ]
70
75
71
- x = torch .randn (4 , device = DEVICE )
72
- code , result = code_and_output (kernel , ([x , x ], {"b0" : x }, (x ,)))
73
- torch .testing .assert_close (result [0 ], 2 * x )
74
- torch .testing .assert_close (result [1 ], 2 * x )
76
+ x = torch .ones (4 , device = DEVICE )
77
+ Point = namedtuple ("Point" , ["x" , "y" ]) # noqa: PYI024
78
+ p = Point (x , x )
79
+
80
+ @dataclass (frozen = True )
81
+ class Point2 :
82
+ x : torch .Tensor
83
+ y : torch .Tensor
84
+
85
+ p2 = Point2 (x , x )
86
+
87
+ code , result = code_and_output (kernel , ([x , x ], {"b0" : x }, (x ,), p , p2 ))
88
+ torch .testing .assert_close (result [0 ], 4 * x )
89
+ torch .testing .assert_close (result [1 ], 4 * x )
75
90
self .assertExpectedInline (
76
91
code ,
77
92
"""\
@@ -82,37 +97,49 @@ def kernel(a_list, b_dict, b_tuple):
82
97
import triton.language as tl
83
98
84
99
@triton.jit
85
- def _kernel_kernel(a0, c0, c1 , a0_size_0, a0_stride_0, c0_stride_0, c1_stride_0 , _BLOCK_SIZE_0: tl.constexpr):
100
+ def _kernel_kernel(a0, o0, o1 , a0_size_0, a0_stride_0, o0_stride_0, o1_stride_0 , _BLOCK_SIZE_0: tl.constexpr):
86
101
pid_0 = tl.program_id(0)
87
102
offset_0 = pid_0 * _BLOCK_SIZE_0
88
103
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
89
104
mask_0 = indices_0 < a0_size_0
90
105
load = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
91
106
load_1 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
92
107
v_0 = load + load_1
93
- tl.store(c0 + indices_0 * c0_stride_0, v_0, mask_0)
94
108
load_2 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
109
+ v_1 = v_0 + load_2
95
110
load_3 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
96
- v_1 = load_2 + load_3
97
- tl.store(c1 + indices_0 * c1_stride_0, v_1, mask_0)
98
-
99
- def kernel(a_list, b_dict, b_tuple):
111
+ v_2 = v_1 + load_3
112
+ tl.store(o0 + indices_0 * o0_stride_0, v_2, mask_0)
113
+ load_4 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
114
+ load_5 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
115
+ v_3 = load_4 + load_5
116
+ load_6 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
117
+ v_4 = v_3 + load_6
118
+ load_7 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0)
119
+ v_5 = v_4 + load_7
120
+ tl.store(o1 + indices_0 * o1_stride_0, v_5, mask_0)
121
+
122
+ def kernel(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass):
100
123
a0, a1 = a_list
101
124
b0 = b_dict['b0']
102
125
b1, = b_tuple
103
- c0, c1 = (torch.empty_like(a0), torch.empty_like(a1))
126
+ c0, c1 = (c_named_tuple.x, c_named_tuple.y)
127
+ d0, d1 = (d_dataclass.x, d_dataclass.y)
128
+ o0, o1 = (torch.empty_like(a0), torch.empty_like(a1))
104
129
_BLOCK_SIZE_0 = 4
105
- _kernel_kernel[triton.cdiv(a0.size(0), _BLOCK_SIZE_0),](a0, c0, c1 , a0.size(0), a0.stride(0), c0 .stride(0), c1 .stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
106
- return [c0, c1 ]
130
+ _kernel_kernel[triton.cdiv(a0.size(0), _BLOCK_SIZE_0),](a0, o0, o1 , a0.size(0), a0.stride(0), o0 .stride(0), o1 .stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
131
+ return [o0, o1 ]
107
132
108
- def _kernel_make_precompiler(a_list, b_dict, b_tuple):
133
+ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass ):
109
134
a0, a1 = a_list
110
135
b0 = b_dict['b0']
111
136
b1, = b_tuple
112
- c0, c1 = (torch.empty_like(a0), torch.empty_like(a1))
137
+ c0, c1 = (c_named_tuple.x, c_named_tuple.y)
138
+ d0, d1 = (d_dataclass.x, d_dataclass.y)
139
+ o0, o1 = (torch.empty_like(a0), torch.empty_like(a1))
113
140
_BLOCK_SIZE_0 = 4
114
141
from helion.runtime.precompile_shim import make_precompiler
115
- return make_precompiler(_kernel_kernel)(a0, c0, c1 , a0.size(0), a0.stride(0), c0 .stride(0), c1 .stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""" ,
142
+ return make_precompiler(_kernel_kernel)(a0, o0, o1 , a0.size(0), a0.stride(0), o0 .stride(0), o1 .stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""" ,
116
143
)
117
144
118
145
0 commit comments