Skip to content

Commit 98e4b9e

Browse files
superbobryjax authors
authored andcommitted
Fixed a typo in the lowering rule for lax.dot_general_p
See #19990 for a reproducer. PiperOrigin-RevId: 610743180
1 parent 4513a51 commit 98e4b9e

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,9 +1837,9 @@ def _dot_general_lowering(
18371837
assert batch_dims == ((), ())
18381838

18391839
if a_contract_dim == 0:
1840-
a = tt_dialect.permute(a, (1, 0))
1840+
a = tt_dialect.trans(a, (1, 0))
18411841
if b_contract_dim == 1:
1842-
b = tt_dialect.permute(b, (1, 0))
1842+
b = tt_dialect.trans(b, (1, 0))
18431843

18441844
if precision is None:
18451845
allow_tf32 = True

tests/pallas/pallas_test.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -388,15 +388,20 @@ def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk):
388388
interpret=self.INTERPRET), jnp.matmul(x, y)
389389
np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05)
390390

391-
@parameterized.named_parameters(*(
392-
dict(testcase_name=f"{size}_{dtype}", size=size, dtype=dtype)
393-
for size in [16, 32, 64]
394-
for dtype in ["float32", "float16"]
395-
))
396-
def test_dot(self, size, dtype):
391+
@parameterized.product(
392+
size=[16, 32, 64],
393+
dtype=["float32", "float16"],
394+
trans_a=[False, True],
395+
trans_b=[False, True],
396+
)
397+
def test_dot(self, size, dtype, trans_a, trans_b):
397398
if not self.check_gpu_capability_at_least(70):
398399
raise unittest.SkipTest(
399400
"Matmul only works on GPUs with capability >= sm70")
401+
if trans_a or trans_b:
402+
# TODO(slebedev): Remove this once the problematic Triton pass is fixed.
403+
raise unittest.SkipTest(
404+
"Triton crashes if any of the operands are transposed")
400405

401406
@functools.partial(
402407
self.pallas_call,
@@ -405,7 +410,7 @@ def test_dot(self, size, dtype):
405410
def dot(x_ref, y_ref, o_ref):
406411
x = x_ref[:, :]
407412
y = y_ref[:, :]
408-
o_ref[:, :] = pl.dot(x, y).astype(o_ref.dtype)
413+
o_ref[:, :] = pl.dot(x, y, trans_a, trans_b).astype(o_ref.dtype)
409414

410415
k1, k2 = random.split(random.key(0))
411416
x = random.normal(k1, (size, size), dtype=dtype)

0 commit comments

Comments
 (0)