|
30 | 30 | from jax._src import test_util as jtu
|
31 | 31 | from jax._src import util
|
32 | 32 | from jax._src.lib import xla_client
|
| 33 | +from jax._src.lib import xla_extension_version |
33 | 34 | from jax.experimental import io_callback
|
34 | 35 | from jax.experimental import pjit
|
35 | 36 | from jax.experimental.maps import xmap
|
@@ -639,7 +640,6 @@ def h(x, y):
|
639 | 640 | out = h(jnp.arange(4.)[None], 4.)
|
640 | 641 | np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.)
|
641 | 642 |
|
642 |
| - |
643 | 643 | def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
|
644 | 644 |
|
645 | 645 | def cb(x):
|
@@ -716,6 +716,18 @@ def f(x):
|
716 | 716 | ValueError, "Pure callbacks do not support JVP."):
|
717 | 717 | f(2.)
|
718 | 718 |
|
| 719 | + @unittest.skipIf(xla_extension_version < 245, "jaxlib version too old") |
| 720 | + def test_error_propagation(self): |
| 721 | + def throws_error_fn(x): |
| 722 | + raise RuntimeError("Errors should propagate.") |
| 723 | + |
| 724 | + @jax.jit |
| 725 | + def f(x): |
| 726 | + return jax.pure_callback(throws_error_fn, x, x) |
| 727 | + |
| 728 | + with self.assertRaisesRegex(Exception, "Errors should propagate."): |
| 729 | + print(np.array(f(2.0)), flush=True) |
| 730 | + |
719 | 731 | def test_can_take_grad_of_pure_callback_with_custom_jvp(self):
|
720 | 732 |
|
721 | 733 | @jax.custom_jvp
|
@@ -833,7 +845,6 @@ def f(self, ys):
|
833 | 845 | # callback alive.
|
834 | 846 | np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))
|
835 | 847 |
|
836 |
| - |
837 | 848 | def test_callback_inside_xmap(self):
|
838 | 849 |
|
839 | 850 | def _callback(x):
|
|
0 commit comments