Skip to content

Commit 00dbcee

Browse files
yashk2810froystig
authored andcommitted
Expose out_info on Compiled just as it exists on Lowered and Traced.
Co-authored-by: Roy Frostig <frostig@google.com> PiperOrigin-RevId: 781233563
1 parent ee9e148 commit 00dbcee

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

jax/_src/stages.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,13 @@ def memory_analysis(self) -> Any | None:
427427
except NotImplementedError:
428428
return None
429429

430+
@property
431+
def out_info(self): # PyTree of OutInfo
432+
out_avals = self._executable.out_avals
433+
out_shardings = self._executable._out_shardings
434+
return self.out_tree.unflatten(
435+
[OutInfo(o.shape, o.dtype, s) for o, s in zip(out_avals, out_shardings)])
436+
430437
def runtime_executable(self) -> Any | None:
431438
"""An arbitrary object representation of this executable.
432439

tests/pjit_test.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4347,7 +4347,7 @@ def f(x, y):
43474347
input_shardings, _ = f.lower(inp, inp).compile().input_shardings
43484348
self.assertLen(input_shardings, 2)
43494349

4350-
def test_aot_out_info(self):
4350+
def test_lowered_out_info(self):
43514351
inp = np.arange(8, dtype=np.int32)
43524352
out_info = jax.jit(lambda x: x).lower((inp, inp)).out_info
43534353
self.assertEqual(out_info[0].shape, (8,))
@@ -4357,6 +4357,26 @@ def test_aot_out_info(self):
43574357
self.assertEqual(out_info[0].sharding, None)
43584358
self.assertEqual(out_info[1].sharding, None)
43594359

4360+
def test_lowered_out_info_mesh(self):
4361+
mesh = jtu.create_mesh((2,), 'x')
4362+
arr = jax.device_put(np.arange(8, dtype=np.int32),
4363+
NamedSharding(mesh, P('x')))
4364+
lowered = jax.jit(lambda x: x * 2).lower(arr)
4365+
out_info = lowered.out_info
4366+
self.assertEqual(out_info.shape, (8,))
4367+
self.assertEqual(out_info.dtype, np.int32)
4368+
self.assertEqual(out_info.sharding, None)
4369+
4370+
def test_compiled_out_info(self):
4371+
mesh = jtu.create_mesh((2,), 'x')
4372+
arr = jax.device_put(np.arange(8, dtype=np.int32),
4373+
NamedSharding(mesh, P('x')))
4374+
compiled = jax.jit(lambda x: x * 2).lower(arr).compile()
4375+
out_info = compiled.out_info
4376+
self.assertEqual(out_info.shape, (8,))
4377+
self.assertEqual(out_info.dtype, np.int32)
4378+
self.assertEqual(out_info.sharding, NamedSharding(mesh, P('x')))
4379+
43604380
def test_jit_trace(self):
43614381
def f(x):
43624382
return x * 2

0 commit comments

Comments
 (0)