Skip to content

Commit 69bf3b8

Browse files
yashk2810jax authors
authored andcommitted
Don't do layout checks during compiled safe call on DCE'd args.
PiperOrigin-RevId: 623347380
1 parent c09a45a commit 69bf3b8

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3200,6 +3200,7 @@ def check_array_xla_sharding_layout_match(
32003200

32013201
if (xla_extension_version >= 249 and not db_xs and arg._committed and
32023202
arg.layout.device_local_layout is not None and xl is not None and
3203+
not isinstance(xl, AutoLayout) and
32033204
arg.layout.device_local_layout != xl):
32043205
errors.append(
32053206
("Got input layout(s) that compiled object was called with: "

tests/layout_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,25 @@ def f(x, y, z, a, b, c):
233233
self.assertArraysEqual(out1, out5)
234234
self.assertArraysEqual(out2, out6)
235235

236+
def test_no_error_dced_args(self):
237+
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
238+
shape = (8, 2)
239+
s = NamedSharding(mesh, P('x', 'y'))
240+
np_inp = np.arange(math.prod(shape)).reshape(shape)
241+
arr1 = jax.device_put(np_inp, s)
242+
arr2 = jax.device_put(np_inp, s)
243+
arrs = [arr1, arr2]
244+
245+
def f(x, y):
246+
return x * 2
247+
248+
jf = jax.jit(f, in_shardings=Layout(DLL.AUTO, s),
249+
out_shardings=Layout(DLL.AUTO, s))
250+
compiled = jf.lower(np_inp, np_inp).compile()
251+
arg_layouts, _ = compiled.input_layouts()
252+
arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_layouts)]
253+
compiled(*arrs)
254+
236255
def test_aot_layout_mismatch(self):
237256
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
238257
shape = (256, 4, 2)

0 commit comments

Comments
 (0)