Skip to content

Commit 00489be

Browse files
author
jax authors
committed
Fix a bug where exceptions were thrown in debug message formatting, when sharding was set to None on arrays.
PiperOrigin-RevId: 621193460
1 parent 2ee4c0f commit 00489be

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jax/_src/pjit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
112112
if arg_names is None:
113113
arg_names = [''] * len(args_flat)
114114
for a, n in zip(args_flat, arg_names):
115-
da = a.sharding._device_assignment if hasattr(a, 'sharding') else None
115+
da = (a.sharding._device_assignment
116+
if getattr(a, 'sharding', None) is not None else None)
116117
arg_list.append((n, da, shaped_abstractify(a)))
117118

118119
mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name)

0 commit comments

Comments
 (0)