Skip to content

Commit 4669fdc

Browse files
authored
Fix failing tests on main (#231)
1 parent 0377217 commit 4669fdc

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

.pyre_configuration

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"source": "examples"
1111
}
1212
],
13-
"optional_search_path": ["../pytorch", "../pytorch-hg", "../pytorch-nightly"],
13+
"optional_search_path": ["../pytorch", "../pytorch-hg", "../pytorch-nightly", "../triton/python"],
1414
"python_version": "3.10",
1515
"exclude": [".*third_party.*"],
1616
"strict": true

test/test_examples.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ def test_bmm(self):
358358
args,
359359
torch.bmm(args[0], args[1]),
360360
block_sizes=[16, 16, 16, 16],
361-
l2_grouping=4,
362361
),
363362
"""\
364363
from __future__ import annotations
@@ -375,18 +374,19 @@ def _bmm_kernel(A, B, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.conste
375374
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
376375
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
377376
offset_0 = pid_0 * _BLOCK_SIZE_0
378-
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
377+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
379378
offset_1 = pid_1 * _BLOCK_SIZE_1
380-
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
379+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
381380
offset_2 = pid_2 * _BLOCK_SIZE_2
382-
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
381+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
383382
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
384383
for offset_3 in range(0, 768, _BLOCK_SIZE_3):
385384
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
386385
acc_copy = acc
386+
acc_copy_0 = acc_copy
387387
load = tl.load(A + (indices_0[:, None, None] * 393216 + indices_1[None, :, None] * 768 + indices_3[None, None, :] * 1), None)
388388
load_1 = tl.load(B + (indices_0[:, None, None] * 786432 + indices_3[None, :, None] * 1024 + indices_2[None, None, :] * 1), None)
389-
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
389+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
390390
v_0 = acc.to(tl.float16)
391391
tl.store(out + (indices_0[:, None, None] * 524288 + indices_1[None, :, None] * 1024 + indices_2[None, None, :] * 1), v_0, None)
392392

0 commit comments

Comments
 (0)