Skip to content

Commit 7413894

Browse files
author
jax authors
committed
Merge pull request #20599 from mattjj:temp-config-to-disable-custom-vjp-shape-check
PiperOrigin-RevId: 622224003
2 parents eff8a47 + 3d4687f commit 7413894

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

jax/_src/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,13 @@ def _update_disable_jit_thread_local(val):
13901390
upgrade=True,
13911391
help='Enable eager-mode pmap when jax_disable_jit is activated.')
13921392

1393+
# TODO(mattjj): remove once we land mutable array plumbing, or face great shame
1394+
custom_vjp_disable_shape_check = define_bool_state(
1395+
name='jax_custom_vjp_disable_shape_check',
1396+
default=False,
1397+
upgrade=True,
1398+
help='Disable the check from #19009 to enable some custom_vjp hacks.')
1399+
13931400
xla_runtime_errors = define_bool_state(
13941401
name='jax_experimental_unsafe_xla_runtime_errors',
13951402
default=False,

jax/_src/custom_derivatives.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,8 @@ def append(x, d):
772772
results.append(Zero(ct.aval))
773773
else:
774774
if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct))
775-
and not _temporary_dtype_exception(a, a_)):
775+
and not (_temporary_dtype_exception(a, a_) or
776+
_temporary_shape_exception(a, a_))):
776777
msg = ("Custom VJP bwd rule must produce an output with the same "
777778
"shape/dtypes as the args tuple of the primal function, but at "
778779
f"output{keystr(kp)} the bwd rule produced an output of "
@@ -790,6 +791,9 @@ def _temporary_dtype_exception(a, a_) -> bool:
790791
dtypes.issubdtype(a.dtype, dtypes.np.inexact)))
791792
return False
792793

794+
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
795+
def _temporary_shape_exception(a, a_) -> bool:
796+
return config.custom_vjp_disable_shape_check.value
793797

794798
class CustomVJPCallPrimitive(core.CallPrimitive):
795799
initial_style: core.Primitive

tests/api_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9372,6 +9372,27 @@ def foo_bwd(_, g):
93729372
r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'):
93739373
jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4))
93749374

9375+
def test_bwd_rule_shape_mismatch_disable(self):
9376+
# TODO(mattjj): remove this test when the config option is removed
9377+
@jax.custom_vjp
9378+
def foo(x, y):
9379+
return x
9380+
9381+
def foo_fwd(x, y):
9382+
return x, None
9383+
9384+
def foo_bwd(_, g):
9385+
return jnp.zeros(3), jnp.zeros(3)
9386+
9387+
foo.defvjp(foo_fwd, foo_bwd)
9388+
9389+
try:
9390+
jax.config.update('jax_custom_vjp_disable_shape_check', True)
9391+
jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4))
9392+
finally:
9393+
jax.config.update('jax_custom_vjp_disable_shape_check', False)
9394+
9395+
93759396
def transpose_unary(f, x_example):
93769397
def transposed(y):
93779398
x, = api.linear_transpose(f, x_example)(y)

0 commit comments

Comments
 (0)