Skip to content

Commit 0b70244

Browse files
yashk2810jax authors
authored andcommitted
Thread out_avals to MeshExecutable
PiperOrigin-RevId: 612037684
1 parent 8569b89 commit 0b70244

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2789,7 +2789,7 @@ def build_unsafe_call(self):
27892789

27902790
def load(self) -> MeshExecutable:
27912791
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
2792-
self.input_avals,
2792+
self.input_avals, self.output_avals,
27932793
self.input_shardings, self.output_shardings,
27942794
self.auto_spmd_lowering, self.kept_var_idx,
27952795
self.in_layouts, self.out_layouts,
@@ -2942,12 +2942,13 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat):
29422942
class MeshExecutable(stages.XlaExecutable):
29432943
__slots__ = [
29442944
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
2945-
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
2946-
"_in_layouts", "_out_layouts", "_all_args_info", "_unloaded_executable",
2945+
"out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering",
2946+
"_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info",
2947+
"_unloaded_executable",
29472948
]
29482949

2949-
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
2950-
out_shardings, auto_spmd_lowering, kept_var_idx,
2950+
def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals,
2951+
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
29512952
in_layouts, out_layouts,
29522953
all_args_info: AllArgsInfo | None = None,
29532954
unloaded_executable=None):
@@ -2956,6 +2957,7 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
29562957
# in_avals is a list of global and local avals. Aval is global if input
29572958
# is a GDA or jax.Array else local.
29582959
self.in_avals = in_avals
2960+
self.out_avals = out_avals
29592961
self._unsafe_call = None
29602962
self._in_shardings = in_shardings
29612963
self._out_shardings = out_shardings
@@ -3118,8 +3120,9 @@ def _compile_replicated_mesh_executable_from_hlo(
31183120
committed=committed, pmap_nreps=pmap_nreps)
31193121
xla_executable = None
31203122
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
3121-
in_shardings, out_shardings, auto_spmd_lowering,
3122-
kept_var_idx, (None,) * len(global_in_avals),
3123+
global_out_avals, in_shardings, out_shardings,
3124+
auto_spmd_lowering, kept_var_idx,
3125+
(None,) * len(global_in_avals),
31233126
(None,) * len(global_out_avals))
31243127

31253128

0 commit comments

Comments
 (0)