Skip to content

Commit 2d65571

Browse files
superbobryjax authors
authored andcommitted
Really skip exp2 in Pallas GPU tests with older jaxlib
PiperOrigin-RevId: 618149873
1 parent cd79e71 commit 2d65571

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/pallas/pallas_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,7 @@ class PallasOpsTest(PallasTest):
15511551
for fn, dtype in itertools.product(*args)
15521552
)
15531553
def test_elementwise(self, fn, dtype):
1554-
if fn is lax.exp2 and jaxlib_version < (0, 4, 26):
1554+
if fn is jnp.exp2 and jaxlib_version < (0, 4, 26):
15551555
self.skipTest("Requires jaxlib 0.4.26 or later")
15561556

15571557
@functools.partial(

0 commit comments

Comments
 (0)