@@ -593,10 +593,9 @@ def _shaped_abstractify_slow(x):
593
593
594
594
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
595
595
def shaped_abstractify (x ):
596
- try :
597
- return _shaped_abstractify_handlers [type (x )](x )
598
- except KeyError :
599
- return _shaped_abstractify_slow (x )
596
+ handler = _shaped_abstractify_handlers .get (type (x ), None )
597
+ return handler (x ) if handler is not None else _shaped_abstractify_slow (x )
598
+
600
599
_shaped_abstractify_handlers : dict [Any , Callable [[Any ], core .ShapedArray ]] = {}
601
600
602
601
@@ -619,6 +618,13 @@ def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
619
618
_shaped_abstractify_handlers .update ((t , _np_scalar_abstractify )
620
619
for t in numpy_scalar_types )
621
620
621
+ def _python_scalar_abstractify (x : int | float | complex | bool ) -> ShapedArray :
622
+ typ = type (x )
623
+ dtype = dtypes ._scalar_type_to_dtype (typ , x )
624
+ return ShapedArray ((), dtype , weak_type = typ in dtypes ._weak_types )
625
+ _shaped_abstractify_handlers .update ((t , _python_scalar_abstractify )
626
+ for t in dtypes .python_scalar_dtypes )
627
+
622
628
# This decorator exists to make it easier to monkey-patch APIs in JAX.
623
629
# By default it does nothing, but it can be monkey-patched to do other things.
624
630
def api_hook (fun , tag : str ):
0 commit comments