Skip to content

Commit c945766

Browse files
superbobryjax authors
authored andcommitted
Skip PallasOpsTest.test_elementwise_exp2 on older jaxlib
PiperOrigin-RevId: 617873650
1 parent 0a0d65a commit c945766

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tests/pallas/pallas_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from jax._src import test_util as jtu
3232
from jax._src import state
3333
from jax._src.lax.control_flow.for_loop import for_loop
34+
from jax._src.lib import version as jaxlib_version
3435
from jax._src.pallas.pallas_call import _trace_to_jaxpr
3536
from jax.interpreters import partial_eval as pe
3637
import jax.numpy as jnp
@@ -1550,6 +1551,9 @@ class PallasOpsTest(PallasTest):
15501551
for fn, dtype in itertools.product(*args)
15511552
)
15521553
def test_elementwise(self, fn, dtype):
1554+
if fn is lax.exp2 and jaxlib_version < (0, 4, 26):
1555+
self.skipTest("Requires jaxlib 0.4.26 or later")
1556+
15531557
@functools.partial(
15541558
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), dtype), grid=1
15551559
)

0 commit comments

Comments
 (0)