Skip to content

Commit 6f38f27

Browse files
mattjjjax authors
authored andcommitted
temporarily relax the cotangent dtype check introduced in #19009
PiperOrigin-RevId: 615883208
1 parent 993abb1 commit 6f38f27

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

jax/_src/custom_derivatives.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -772,8 +772,7 @@ def append(x, d):
772772
results.append(Zero(ct.aval))
773773
else:
774774
if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct))
775-
# TODO(mattjj): don't skip check with extended dtype tangent types
776-
and not dtypes.issubdtype(a_.dtype, dtypes.extended)):
775+
and not _temporary_dtype_exception(a, a_)):
777776
msg = ("Custom VJP bwd rule must produce an output with the same "
778777
"shape/dtypes as the args tuple of the primal function, but at "
779778
f"output{keystr(kp)} the bwd rule produced an output of "
@@ -783,6 +782,14 @@ def append(x, d):
783782
results.append(ct)
784783
yield results
785784

785+
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
786+
def _temporary_dtype_exception(a, a_) -> bool:
787+
if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray):
788+
return (a.shape == a_.shape and
789+
(dtypes.issubdtype(a_.dtype, dtypes.extended) or
790+
dtypes.issubdtype(a.dtype, dtypes.np.inexact)))
791+
return False
792+
786793

787794
class CustomVJPCallPrimitive(core.CallPrimitive):
788795
initial_style: core.Primitive

0 commit comments

Comments
 (0)