@@ -143,3 +143,131 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
143
143
_BLOCK_SIZE_0 = 64
144
144
_launcher(_fn_kernel, (triton.cdiv(m, _BLOCK_SIZE_1),), x, out, out.stride(0), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
145
145
return out
146
+
147
+ --- assertExpectedJournal(TestMisc.test_tuple_literal_subscript)
148
+ from __future__ import annotations
149
+
150
+ import torch
151
+ import triton
152
+ import triton.language as tl
153
+ from helion.runtime import default_launcher as _default_launcher
154
+
155
+ @triton.jit
156
+ def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
157
+ num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
158
+ pid_0 = tl.program_id(0) % num_blocks_0
159
+ pid_1 = tl.program_id(0) // num_blocks_0
160
+ offset_0 = pid_0 * _BLOCK_SIZE_0
161
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
162
+ mask_0 = indices_0 < out_size_0
163
+ offset_1 = pid_1 * _BLOCK_SIZE_1
164
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
165
+ mask_1 = indices_1 < out_size_1
166
+ load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
167
+ load_1 = tl.load(inp_tuple_item_1 + (indices_0[:, None] * inp_tuple_item_1_stride_0 + indices_1[None, :] * inp_tuple_item_1_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
168
+ v_0 = load_1.to(tl.float32)
169
+ v_1 = load + v_0
170
+ v_2 = inp_tuple_item_2.to(tl.float32)
171
+ v_3 = v_1 * v_2
172
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
173
+
174
+ def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
175
+ out = torch.empty_like(inp_tuple[0])
176
+ _BLOCK_SIZE_0 = 8
177
+ _BLOCK_SIZE_1 = 8
178
+ _launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
179
+ return outfrom __future__ import annotations
180
+
181
+ import torch
182
+ import triton
183
+ import triton.language as tl
184
+ from helion.runtime import default_launcher as _default_launcher
185
+
186
+ @triton.jit
187
+ def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_0_size_0, inp_tuple_item_0_size_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
188
+ num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
189
+ pid_0 = tl.program_id(0) % num_blocks_0
190
+ pid_1 = tl.program_id(0) // num_blocks_0
191
+ offset_0 = pid_0 * _BLOCK_SIZE_0
192
+ offset_1 = pid_1 * _BLOCK_SIZE_1
193
+ load = tl.load(tl.make_block_ptr(inp_tuple_item_0, [inp_tuple_item_0_size_0, inp_tuple_item_0_size_1], [inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
194
+ load_1 = tl.load(tl.make_block_ptr(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
195
+ v_0 = load_1.to(tl.float32)
196
+ v_1 = load + v_0
197
+ v_2 = inp_tuple_item_2.to(tl.float32)
198
+ v_3 = v_1 * v_2
199
+ tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1])
200
+
201
+ def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
202
+ out = torch.empty_like(inp_tuple[0])
203
+ _BLOCK_SIZE_0 = 8
204
+ _BLOCK_SIZE_1 = 8
205
+ _launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[0].size(0), inp_tuple[0].size(1), inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
206
+ return out
207
+
208
+ --- assertExpectedJournal(TestMisc.test_tuple_literal_subscript_w_descriptor)
209
+ from __future__ import annotations
210
+
211
+ import torch
212
+ import helion
213
+ import triton
214
+ import triton.language as tl
215
+ from helion.runtime import default_launcher as _default_launcher
216
+
217
+ helion.runtime.set_triton_allocator()
218
+
219
+ @triton.jit
220
+ def _tuple_literal_index_kernel_kernel(out, inp_tuple_item_0, inp_tuple_item_1, inp_tuple_item_1_size_0, inp_tuple_item_1_size_1, out_size_0, out_size_1, inp_tuple_item_0_stride_0, inp_tuple_item_0_stride_1, inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1, out_stride_0, out_stride_1, inp_tuple_item_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
221
+ inp_tuple_item_1_desc = tl.make_tensor_descriptor(inp_tuple_item_1, [inp_tuple_item_1_size_0, inp_tuple_item_1_size_1], [inp_tuple_item_1_stride_0, inp_tuple_item_1_stride_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1])
222
+ num_blocks_0 = tl.cdiv(out_size_0, _BLOCK_SIZE_0)
223
+ pid_0 = tl.program_id(0) % num_blocks_0
224
+ pid_1 = tl.program_id(0) // num_blocks_0
225
+ offset_0 = pid_0 * _BLOCK_SIZE_0
226
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
227
+ mask_0 = indices_0 < out_size_0
228
+ offset_1 = pid_1 * _BLOCK_SIZE_1
229
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
230
+ mask_1 = indices_1 < out_size_1
231
+ load = tl.load(inp_tuple_item_0 + (indices_0[:, None] * inp_tuple_item_0_stride_0 + indices_1[None, :] * inp_tuple_item_0_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
232
+ load_1 = inp_tuple_item_1_desc.load([offset_0, offset_1])
233
+ v_0 = load_1.to(tl.float32)
234
+ v_1 = load + v_0
235
+ v_2 = inp_tuple_item_2.to(tl.float32)
236
+ v_3 = v_1 * v_2
237
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
238
+
239
+ def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher):
240
+ out = torch.empty_like(inp_tuple[0])
241
+ _BLOCK_SIZE_0 = 8
242
+ _BLOCK_SIZE_1 = 8
243
+ _launcher(_tuple_literal_index_kernel_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), out, inp_tuple[0], inp_tuple[1], inp_tuple[1].size(0), inp_tuple[1].size(1), out.size(0), out.size(1), inp_tuple[0].stride(0), inp_tuple[0].stride(1), inp_tuple[1].stride(0), inp_tuple[1].stride(1), out.stride(0), out.stride(1), inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
244
+ return out
245
+
246
+ --- assertExpectedJournal(TestMisc.test_tuple_unpack)
247
+ from __future__ import annotations
248
+
249
+ import torch
250
+ import triton
251
+ import triton.language as tl
252
+ from helion.runtime import default_launcher as _default_launcher
253
+
254
+ @triton.jit
255
+ def _tuple_unpack_kernel_kernel(a, b, out, a_size_0, a_stride_0, b_stride_0, out_stride_0, x, _BLOCK_SIZE_0: tl.constexpr):
256
+ pid_0 = tl.program_id(0)
257
+ offset_0 = pid_0 * _BLOCK_SIZE_0
258
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
259
+ mask_0 = indices_0 < a_size_0
260
+ load = tl.load(a + indices_0 * a_stride_0, mask_0, other=0)
261
+ load_1 = tl.load(b + indices_0 * b_stride_0, mask_0, other=0)
262
+ v_0 = load_1.to(tl.float32)
263
+ v_1 = load + v_0
264
+ v_2 = x.to(tl.float32)
265
+ v_3 = v_1 + v_2
266
+ tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
267
+
268
+ def tuple_unpack_kernel(inp_tuple, *, _launcher=_default_launcher):
269
+ a, b, x = inp_tuple
270
+ out = torch.empty_like(a)
271
+ _BLOCK_SIZE_0 = 4
272
+ _launcher(_tuple_unpack_kernel_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out, a.size(0), a.stride(0), b.stride(0), out.stride(0), x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
273
+ return out
0 commit comments