@@ -141,6 +141,72 @@ def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
141
141
return make_precompiler(_matmul_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""" ,
142
142
)
143
143
144
+ def test_bmm (self ):
145
+ args = (
146
+ torch .randn ([16 , 512 , 768 ], device = DEVICE , dtype = torch .float16 ),
147
+ torch .randn ([16 , 768 , 1024 ], device = DEVICE , dtype = torch .float16 ),
148
+ )
149
+ self .assertExpectedInline (
150
+ run_example (
151
+ "bmm" ,
152
+ args ,
153
+ torch .bmm (args [0 ], args [1 ]),
154
+ block_sizes = [[16 , 16 , 16 ], 16 ],
155
+ l2_grouping = 4 ,
156
+ ),
157
+ """\
158
+ from __future__ import annotations
159
+
160
+ import torch
161
+ import triton
162
+ import triton.language as tl
163
+
164
+ @triton.jit
165
+ def _bmm_kernel(A, B, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
166
+ num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
167
+ num_blocks_1 = tl.cdiv(512, _BLOCK_SIZE_1)
168
+ pid_0 = tl.program_id(0) % num_blocks_0
169
+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
170
+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
171
+ offset_0 = pid_0 * _BLOCK_SIZE_0
172
+ indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
173
+ offset_1 = pid_1 * _BLOCK_SIZE_1
174
+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
175
+ offset_2 = pid_2 * _BLOCK_SIZE_2
176
+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
177
+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
178
+ for offset_3 in range(0, 768, _BLOCK_SIZE_3):
179
+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
180
+ acc_copy = acc
181
+ load = tl.load(A + (indices_0[:, None, None] * 393216 + indices_1[None, :, None] * 768 + indices_3[None, None, :] * 1), None)
182
+ load_1 = tl.load(B + (indices_0[:, None, None] * 786432 + indices_3[None, :, None] * 1024 + indices_2[None, None, :] * 1), None)
183
+ acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
184
+ v_0 = acc.to(tl.float16)
185
+ tl.store(out + (indices_0[:, None, None] * 524288 + indices_1[None, :, None] * 1024 + indices_2[None, None, :] * 1), v_0, None)
186
+
187
+ def bmm(A: torch.Tensor, B: torch.Tensor):
188
+ b, m, k = A.size()
189
+ b, k, n = B.size()
190
+ out = torch.empty([b, m, n], device=A.device, dtype=torch.promote_types(A.dtype, B.dtype))
191
+ _BLOCK_SIZE_0 = 16
192
+ _BLOCK_SIZE_1 = 16
193
+ _BLOCK_SIZE_2 = 16
194
+ _BLOCK_SIZE_3 = 16
195
+ _bmm_kernel[triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),](A, B, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
196
+ return out
197
+
198
+ def _bmm_make_precompiler(A: torch.Tensor, B: torch.Tensor):
199
+ b, m, k = A.size()
200
+ b, k, n = B.size()
201
+ out = torch.empty([b, m, n], device=A.device, dtype=torch.promote_types(A.dtype, B.dtype))
202
+ _BLOCK_SIZE_0 = 16
203
+ _BLOCK_SIZE_1 = 16
204
+ _BLOCK_SIZE_2 = 16
205
+ _BLOCK_SIZE_3 = 16
206
+ from helion.runtime.precompile_shim import make_precompiler
207
+ return make_precompiler(_bmm_kernel)(A, B, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)""" ,
208
+ )
209
+
144
210
def test_template_via_closure0 (self ):
145
211
bias = torch .randn ([1 , 1024 ], device = DEVICE , dtype = torch .float16 )
146
212
args = (
0 commit comments