Skip to content

Commit 2761f26

Browse files
yashk2810jax authors
authored andcommitted
Set out_mut to None as default on from_hlo instead of in __init__ of MeshComputation and correct the types too.
PiperOrigin-RevId: 611814102
1 parent cfeb113 commit 2761f26

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def build_execute_fun(self):
891891
self.unordered_effects,
892892
self.ordered_effects, self.keepalive,
893893
bool(self.host_callbacks),
894-
set(range(len(input_indices))), [])
894+
set(range(len(input_indices))), None)
895895
return execute_fun
896896

897897
def load(self) -> PmapExecutable:
@@ -1155,7 +1155,7 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
11551155
unordered_effects: list[core.Effect],
11561156
ordered_effects: list[core.Effect], keepalive: Any,
11571157
has_host_callbacks: bool, kept_var_idx: set[int],
1158-
out_mut: Sequence[int | None]):
1158+
out_mut: Sequence[int | None] | None):
11591159
self.xla_executable = xla_executable
11601160
self.name = name
11611161
self.backend = backend
@@ -1210,7 +1210,7 @@ def __call__(self, *args):
12101210
out = self.out_handler(out_arrays)
12111211
else:
12121212
out = results.consume_with_handlers(self.out_handler.handlers)
1213-
if not self.out_mut:
1213+
if self.out_mut is None:
12141214
return out
12151215
else:
12161216
out_ = []
@@ -2282,7 +2282,6 @@ def lower_mesh_computation(
22822282
host_callbacks=lowering_result.host_callbacks,
22832283
keepalive=lowering_result.keepalive,
22842284
kept_var_idx=set(range(len(global_in_avals))),
2285-
out_mut=None,
22862285
backend=backend,
22872286
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
22882287
committed=True,
@@ -2297,7 +2296,6 @@ class MeshComputation(stages.XlaLowering):
22972296

22982297
def __init__(self, name: str, hlo: ir.Module | None,
22992298
donated_invars: Sequence[bool], **compile_args):
2300-
compile_args.setdefault('out_mut', None) # TODO(mattjj): remove default
23012299
self._name = name
23022300
self._hlo = hlo
23032301
self._donated_invars = donated_invars
@@ -2763,7 +2761,7 @@ class UnloadedMeshExecutable:
27632761
keepalive: Sequence[Any]
27642762
host_callbacks: Sequence[Any]
27652763
kept_var_idx: set[int]
2766-
out_mut: Sequence[None | int]
2764+
out_mut: Sequence[None | int] | None
27672765
auto_spmd_lowering: bool
27682766
in_layouts: Sequence[SpecifiedLayout | None]
27692767
out_layouts: Sequence[SpecifiedLayout | None]
@@ -2802,7 +2800,7 @@ def from_hlo(name: str,
28022800
global_out_avals: Sequence[ShapedArray],
28032801
in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO],
28042802
out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO |
2805-
UnspecifiedValue)],
2803+
UnspecifiedValue)],
28062804
spmd_lowering: bool,
28072805
tuple_args: bool,
28082806
auto_spmd_lowering: bool,
@@ -2811,13 +2809,13 @@ def from_hlo(name: str,
28112809
host_callbacks: list[Any],
28122810
keepalive: Any,
28132811
kept_var_idx: set[int],
2814-
out_mut: Sequence[None | int],
28152812
backend: xb.XlaBackend,
28162813
device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore
28172814
committed: bool,
28182815
in_layouts: MaybeLayout,
28192816
out_layouts: MaybeLayout,
28202817
pmap_nreps: int = 1,
2818+
out_mut: Sequence[None | int] | None = None,
28212819
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
28222820
all_default_mem_kind: bool = True,
28232821
all_args_info: AllArgsInfo | None = None,

0 commit comments

Comments
 (0)