Skip to content

Commit be9b002

Browse files
authored
Add "format" support for latest JAX. (#21406)
The `jax.experimental.layout.Format` configuration has a `"format"` attribute rather than a `"layout"` attribute.
1 parent d16424e commit be9b002

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,16 @@ def distribute_tensor(tensor, layout):
7878
layout, jax.sharding.Sharding
7979
) and tensor.sharding.is_equivalent_to(layout, ndim=len(tensor.shape)):
8080
return tensor
81-
# JAX explicit layout support.
81+
# JAX explicit "layout" support.
8282
elif hasattr(layout, "layout"):
8383
current_layout = getattr(tensor, "layout", None)
8484
if current_layout == layout:
8585
return tensor
86+
# JAX explicit "format" support.
87+
elif hasattr(layout, "format"):
88+
current_layout = getattr(tensor, "format", None)
89+
if current_layout == layout:
90+
return tensor
8691

8792
return jax.device_put(tensor, layout)
8893

0 commit comments

Comments
 (0)