@@ -566,6 +566,16 @@ def h(x, y):
566
566
self .assertArraysAllClose (out , np .sin (np .arange (4. )) + np .arange (10. , 14. ),
567
567
rtol = 1E-7 , check_dtypes = False )
568
568
569
+ @jax .jit
570
+ @functools .partial (jax .vmap , in_axes = 1 , out_axes = 1 )
571
+ def h (x , y ):
572
+ out_shape = jax .ShapeDtypeStruct (x .shape , np .result_type (x .dtype , y .dtype ))
573
+ return jax .pure_callback (lambda x , y : np .sin (x ) + y , out_shape , x , y )
574
+ out = h (jnp .arange (4. )[None ], jnp .arange (10. , 14. )[None ])
575
+ self .assertArraysAllClose (out , np .sin (np .arange (4. )) + np .arange (10. ,
576
+ 14. )[None ],
577
+ rtol = 1E-7 , check_dtypes = False )
578
+
569
579
def test_vmap_vectorized_callback (self ):
570
580
571
581
def cb (x ):
@@ -598,6 +608,15 @@ def h(x, y):
598
608
out = h (jnp .arange (4. ), 4. )
599
609
np .testing .assert_allclose (out , np .sin (np .arange (4. )) + 4. )
600
610
611
+ @jax .jit
612
+ @functools .partial (jax .vmap , in_axes = (1 , None ), out_axes = 1 )
613
+ def h (x , y ):
614
+ return jax .pure_callback (lambda x , y : np .sin (x ) + y , x , x , y ,
615
+ vectorized = True )
616
+ out = h (jnp .arange (4. )[None ], 4. )
617
+ np .testing .assert_allclose (out , np .sin (np .arange (4. )[None ]) + 4. )
618
+
619
+
601
620
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape (self ):
602
621
603
622
def cb (x ):
0 commit comments