Skip to content

Commit 9a00721

Browse files
pschuhjax authors
authored andcommitted
Propagate effects errors to the results (only if effects are enabled).
This will now happen when results of effectful computations are converted to numpy arrays. PiperOrigin-RevId: 615883363
1 parent 6f38f27 commit 9a00721

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

tests/python_callback_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from jax._src import test_util as jtu
3131
from jax._src import util
3232
from jax._src.lib import xla_client
33+
from jax._src.lib import xla_extension_version
3334
from jax.experimental import io_callback
3435
from jax.experimental import pjit
3536
from jax.experimental.maps import xmap
@@ -639,7 +640,6 @@ def h(x, y):
639640
out = h(jnp.arange(4.)[None], 4.)
640641
np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.)
641642

642-
643643
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
644644

645645
def cb(x):
@@ -716,6 +716,18 @@ def f(x):
716716
ValueError, "Pure callbacks do not support JVP."):
717717
f(2.)
718718

719+
@unittest.skipIf(xla_extension_version < 245, "jaxlib version too old")
720+
def test_error_propagation(self):
721+
def throws_error_fn(x):
722+
raise RuntimeError("Errors should propagate.")
723+
724+
@jax.jit
725+
def f(x):
726+
return jax.pure_callback(throws_error_fn, x, x)
727+
728+
with self.assertRaisesRegex(Exception, "Errors should propagate."):
729+
print(np.array(f(2.0)), flush=True)
730+
719731
def test_can_take_grad_of_pure_callback_with_custom_jvp(self):
720732

721733
@jax.custom_jvp
@@ -833,7 +845,6 @@ def f(self, ys):
833845
# callback alive.
834846
np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))
835847

836-
837848
def test_callback_inside_xmap(self):
838849

839850
def _callback(x):

0 commit comments

Comments
 (0)