Skip to content

Commit dd9132f

Browse files
authored
Fix dtype detection for JAX types. (#21184)
The jax types like `jax.float32` have a string representation of ``` <class 'jax.numpy.float32'> ``` so with the previous code, would be "standardized" as `float32'>` (trailing quote and angle bracket), which is an invalid type. But, the JAX dtypes _do_ have a `__name__` property, so should be properly detected if we switch the order around. Kept the old `jax.numpy` string version in place in case that worked with older versions of JAX.
1 parent 50dae30 commit dd9132f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

keras/src/backend/common/variables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,12 +564,12 @@ def standardize_dtype(dtype):
564564
dtype = dtypes.PYTHON_DTYPES_MAP.get(dtype, dtype)
565565
if hasattr(dtype, "name"):
566566
dtype = dtype.name
567+
elif hasattr(dtype, "__name__"):
568+
dtype = dtype.__name__
567569
elif hasattr(dtype, "__str__") and (
568570
"torch" in str(dtype) or "jax.numpy" in str(dtype)
569571
):
570572
dtype = str(dtype).split(".")[-1]
571-
elif hasattr(dtype, "__name__"):
572-
dtype = dtype.__name__
573573

574574
if dtype not in dtypes.ALLOWED_DTYPES:
575575
raise ValueError(f"Invalid dtype: {dtype}")

0 commit comments

Comments
 (0)