Skip to content

Commit 55233a0

Browse files
yashk2810jax authors
authored andcommitted
device_local_layout can be None on a jax.Array for backends that don't implement certain required methods for a jax.Array to populate the device_local_layout.
Skip the error checks when arr.layout.device_local_layout is None. PiperOrigin-RevId: 622007598
1 parent b322d39 commit 55233a0

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

jax/_src/pjit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,11 +1279,14 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
12791279
else:
12801280
resolved_in_layouts.append(None)
12811281
else:
1282-
if committed and arg_layout != jit_in_l:
1282+
# arg_layout can be None because some backends don't implement the
1283+
# required layout methods. Hence `arr.layout` can return
1284+
# `Layout(None, sharding)`
1285+
if committed and arg_layout is not None and arg_layout != jit_in_l:
12831286
raise ValueError('Layout passed to jit does not match the layout '
12841287
'on the respective arg. '
12851288
f'Got pjit layout: {jit_in_l},\n'
1286-
f'arg sharding: {arg_layout} for '
1289+
f'arg layout: {arg_layout} for '
12871290
f'arg shape: {shaped_abstractify(arg).str_short()}')
12881291
resolved_in_layouts.append(jit_in_l)
12891292
return tuple(resolved_in_layouts)

0 commit comments

Comments
 (0)