File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -772,8 +772,7 @@ def append(x, d):
772
772
results .append (Zero (ct .aval ))
773
773
else :
774
774
if (not core .typecompat (a .at_least_vspace (), a_ := core .get_aval (ct ))
775
- # TODO(mattjj): don't skip check with extended dtype tangent types
776
- and not dtypes .issubdtype (a_ .dtype , dtypes .extended )):
775
+ and not _temporary_dtype_exception (a , a_ )):
777
776
msg = ("Custom VJP bwd rule must produce an output with the same "
778
777
"shape/dtypes as the args tuple of the primal function, but at "
779
778
f"output{ keystr (kp )} the bwd rule produced an output of "
@@ -783,6 +782,14 @@ def append(x, d):
783
782
results .append (ct )
784
783
yield results
785
784
785
+ # TODO(mattjj): remove both these exceptions to cotangent compatibility check
786
+ def _temporary_dtype_exception (a , a_ ) -> bool :
787
+ if isinstance (a , core .ShapedArray ) and isinstance (a_ , core .ShapedArray ):
788
+ return (a .shape == a_ .shape and
789
+ (dtypes .issubdtype (a_ .dtype , dtypes .extended ) or
790
+ dtypes .issubdtype (a .dtype , dtypes .np .inexact )))
791
+ return False
792
+
786
793
787
794
class CustomVJPCallPrimitive (core .CallPrimitive ):
788
795
initial_style : core .Primitive
You can’t perform that action at this time.
0 commit comments