We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d16424e commit be9b002Copy full SHA for be9b002
keras/src/backend/jax/distribution_lib.py
@@ -78,11 +78,16 @@ def distribute_tensor(tensor, layout):
78
layout, jax.sharding.Sharding
79
) and tensor.sharding.is_equivalent_to(layout, ndim=len(tensor.shape)):
80
return tensor
81
- # JAX explicit layout support.
+ # JAX explicit "layout" support.
82
elif hasattr(layout, "layout"):
83
current_layout = getattr(tensor, "layout", None)
84
if current_layout == layout:
85
86
+ # JAX explicit "format" support.
87
+ elif hasattr(layout, "format"):
88
+ current_layout = getattr(tensor, "format", None)
89
+ if current_layout == layout:
90
+ return tensor
91
92
return jax.device_put(tensor, layout)
93
0 commit comments