Skip to content

Commit d57bb8c

Browse files
yashk2810jax authors
authored andcommitted
Raise a better error message when an invalid input is passed to jit call.
Before: ``` TypeError: Argument 'ShapeDtypeStruct(shape=(4, 2), dtype=int32)' of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type. ``` After: ``` TypeError: Argument 'x['b']['c']' of shape int32[4,2] of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type. ``` The error is raised deep down the stack during `shard_arg`, so we raise an `InvalidInputException` and catch it in `_python_pjit_helper` where we have the `arg_names` information. PiperOrigin-RevId: 618014044
1 parent 7f7e0c0 commit d57bb8c

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

jax/_src/interpreters/xla.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ def tuple_sharding_proto(elems):
110110
return proto
111111

112112

113-
114-
115113
### handlers
116114

117115
# JAX abstract values -> XLA shapes
@@ -132,6 +130,10 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
132130

133131
# IR constants
134132

133+
class InvalidInputException(Exception):
134+
pass
135+
136+
135137
# TODO(mattjj): try to remove this canonicalize_dtype stuff
136138
def canonicalize_dtype(x):
137139
typ = type(x)
@@ -142,8 +144,8 @@ def canonicalize_dtype(x):
142144
if handler: return handler(x)
143145
if hasattr(x, '__jax_array__'):
144146
return canonicalize_dtype(x.__jax_array__())
145-
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid "
146-
"JAX type.")
147+
raise InvalidInputException(
148+
f"Argument '{x}' of type {type(x)} is not a valid JAX type.")
147149

148150
def _canonicalize_masked_array_dtype(x):
149151
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "

jax/_src/pjit.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,11 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
167167
_infer_params(jit_info, args, kwargs)
168168
for arg in args_flat:
169169
dispatch.check_arg(arg)
170+
170171
if attrs_tracked:
171172
init_states = _get_states(attrs_tracked)
172173
args_flat = [*init_states, *args_flat]
174+
173175
try:
174176
out_flat = pjit_p.bind(*args_flat, **params)
175177
except pxla.DeviceAssignmentMismatchError as e:
@@ -180,12 +182,29 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
180182
msg = _device_assignment_mismatch_error(
181183
fun_name, fails, args_flat, api_name, arg_names)
182184
raise ValueError(msg) from None
185+
except xla.InvalidInputException as e:
186+
arg_names = [''] * len(args_flat) if arg_names is None else arg_names
187+
# Run canonicalization again to figure out which arg failed.
188+
if params['jaxpr'].consts:
189+
raise TypeError(e.args[0]) from e
190+
else:
191+
for arg, name, aval in zip(args_flat, arg_names, params['jaxpr'].in_avals):
192+
try:
193+
xla.canonicalize_dtype(arg)
194+
except xla.InvalidInputException as _:
195+
# Reraise as TypeError with the new message.
196+
raise TypeError(
197+
f"Argument '{name}' of shape {aval.str_short()} of type"
198+
f' {type(arg)} is not a valid JAX type.') from e
199+
raise AssertionError("Unreachable") from e
200+
183201
if attrs_tracked:
184202
final_states, out_flat = split_list(out_flat, [len(attrs_tracked)])
185203
_set_states(attrs_tracked, final_states)
186204
outs = tree_unflatten(out_tree, out_flat)
187205
return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked
188206

207+
189208
def _set_states(attrs_tracked, vals):
190209
from jax.experimental.attrs import jax_setattr # type: ignore
191210
for ((obj, attr), val) in zip(attrs_tracked, vals):

tests/pjit_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2927,6 +2927,16 @@ def test_pjit_device_backend_axis_resources_error(self):
29272927
'out_shardings should not be specified.'):
29282928
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
29292929

2930+
def test_check_arg_error(self):
2931+
sds = jax.ShapeDtypeStruct((4, 2), np.int32)
2932+
inp = np.arange(8).reshape(4, 2)
2933+
2934+
with self.assertRaisesRegex(
2935+
TypeError,
2936+
r"Argument 'x\['b'\]\['c'\]' of shape int32\[4,2\] of "
2937+
"type.*ShapeDtypeStruct.*is not a valid JAX type."):
2938+
jax.jit(lambda x: x)({'a': inp, 'b': {'c': sds}})
2939+
29302940
def test_pjit_device_backend_both_error(self):
29312941
with self.assertRaisesRegex(
29322942
ValueError, "can't specify both a device and a backend for jit"):

0 commit comments

Comments
 (0)