@@ -358,7 +358,6 @@ def test_bmm(self):
358
358
args ,
359
359
torch .bmm (args [0 ], args [1 ]),
360
360
block_sizes = [16 , 16 , 16 , 16 ],
361
- l2_grouping = 4 ,
362
361
),
363
362
"""\
364
363
from __future__ import annotations
@@ -375,18 +374,19 @@ def _bmm_kernel(A, B, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.conste
375
374
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
376
375
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
377
376
offset_0 = pid_0 * _BLOCK_SIZE_0
378
- indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
377
+ indices_0 = ( offset_0 + tl.arange(0, _BLOCK_SIZE_0) ).to(tl.int32)
379
378
offset_1 = pid_1 * _BLOCK_SIZE_1
380
- indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
379
+ indices_1 = ( offset_1 + tl.arange(0, _BLOCK_SIZE_1) ).to(tl.int32)
381
380
offset_2 = pid_2 * _BLOCK_SIZE_2
382
- indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
381
+ indices_2 = ( offset_2 + tl.arange(0, _BLOCK_SIZE_2) ).to(tl.int32)
383
382
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
384
383
for offset_3 in range(0, 768, _BLOCK_SIZE_3):
385
384
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
386
385
acc_copy = acc
386
+ acc_copy_0 = acc_copy
387
387
load = tl.load(A + (indices_0[:, None, None] * 393216 + indices_1[None, :, None] * 768 + indices_3[None, None, :] * 1), None)
388
388
load_1 = tl.load(B + (indices_0[:, None, None] * 786432 + indices_3[None, :, None] * 1024 + indices_2[None, None, :] * 1), None)
389
- acc = tl.dot(load, load_1, acc=acc_copy , input_precision='tf32')
389
+ acc = tl.dot(load, load_1, acc=acc_copy_0 , input_precision='tf32')
390
390
v_0 = acc.to(tl.float16)
391
391
tl.store(out + (indices_0[:, None, None] * 524288 + indices_1[None, :, None] * 1024 + indices_2[None, None, :] * 1), v_0, None)
392
392
0 commit comments