Skip to content

Commit fab8f6c

Browse files
author
jax authors
committed
Merge pull request #19986 from sharadmv:pure-callback-bug
PiperOrigin-RevId: 610563393
2 parents 4b74d03 + d7bf956 commit fab8f6c

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

jax/_src/callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def pure_callback_batching_rule(
9797
vectorized: bool,
9898
result_avals: Sequence[core.ShapedArray],
9999
):
100-
axis_size = next(a.shape[0] for a, d in zip(args, dims)
100+
axis_size = next(a.shape[d] for a, d in zip(args, dims)
101101
if d is not batching.not_mapped)
102102
new_args = [arg if dim is batching.not_mapped else
103103
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]

tests/python_callback_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,16 @@ def h(x, y):
566566
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.),
567567
rtol=1E-7, check_dtypes=False)
568568

569+
@jax.jit
570+
@functools.partial(jax.vmap, in_axes=1, out_axes=1)
571+
def h(x, y):
572+
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
573+
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y)
574+
out = h(jnp.arange(4.)[None], jnp.arange(10., 14.)[None])
575+
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10.,
576+
14.)[None],
577+
rtol=1E-7, check_dtypes=False)
578+
569579
def test_vmap_vectorized_callback(self):
570580

571581
def cb(x):
@@ -598,6 +608,15 @@ def h(x, y):
598608
out = h(jnp.arange(4.), 4.)
599609
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
600610

611+
@jax.jit
612+
@functools.partial(jax.vmap, in_axes=(1, None), out_axes=1)
613+
def h(x, y):
614+
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
615+
vectorized=True)
616+
out = h(jnp.arange(4.)[None], 4.)
617+
np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.)
618+
619+
601620
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
602621

603622
def cb(x):

0 commit comments

Comments
 (0)