Skip to content

Commit 9523547

Browse files
hawkinspjax authors
authored andcommitted
Add a fast path for Python scalars to shaped_abstractify.
PiperOrigin-RevId: 618015741
1 parent 07e45c3 commit 9523547

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

jax/_src/api_util.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -593,10 +593,9 @@ def _shaped_abstractify_slow(x):
593593

594594
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
595595
def shaped_abstractify(x):
596-
try:
597-
return _shaped_abstractify_handlers[type(x)](x)
598-
except KeyError:
599-
return _shaped_abstractify_slow(x)
596+
handler = _shaped_abstractify_handlers.get(type(x), None)
597+
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
598+
600599
_shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {}
601600

602601

@@ -619,6 +618,13 @@ def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
619618
_shaped_abstractify_handlers.update((t, _np_scalar_abstractify)
620619
for t in numpy_scalar_types)
621620

621+
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
622+
typ = type(x)
623+
dtype = dtypes._scalar_type_to_dtype(typ, x)
624+
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
625+
_shaped_abstractify_handlers.update((t, _python_scalar_abstractify)
626+
for t in dtypes.python_scalar_dtypes)
627+
622628
# This decorator exists to make it easier to monkey-patch APIs in JAX.
623629
# By default it does nothing, but it can be monkey-patched to do other things.
624630
def api_hook(fun, tag: str):

jax/_src/dtypes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,12 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
623623
return np.dtype(_least_upper_bound(config.numpy_dtype_promotion.value, a_tp, b_tp))
624624

625625
def is_weakly_typed(x: Any) -> bool:
626+
if type(x) in _weak_types:
627+
return True
626628
try:
627629
return x.aval.weak_type
628630
except AttributeError:
629-
return type(x) in _weak_types
631+
return False
630632

631633
def is_python_scalar(x: Any) -> bool:
632634
try:

0 commit comments

Comments
 (0)