@@ -891,7 +891,7 @@ def build_execute_fun(self):
891
891
self .unordered_effects ,
892
892
self .ordered_effects , self .keepalive ,
893
893
bool (self .host_callbacks ),
894
- set (range (len (input_indices ))), [] )
894
+ set (range (len (input_indices ))), None )
895
895
return execute_fun
896
896
897
897
def load (self ) -> PmapExecutable :
@@ -1155,7 +1155,7 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
1155
1155
unordered_effects : list [core .Effect ],
1156
1156
ordered_effects : list [core .Effect ], keepalive : Any ,
1157
1157
has_host_callbacks : bool , kept_var_idx : set [int ],
1158
- out_mut : Sequence [int | None ]):
1158
+ out_mut : Sequence [int | None ] | None ):
1159
1159
self .xla_executable = xla_executable
1160
1160
self .name = name
1161
1161
self .backend = backend
@@ -1210,7 +1210,7 @@ def __call__(self, *args):
1210
1210
out = self .out_handler (out_arrays )
1211
1211
else :
1212
1212
out = results .consume_with_handlers (self .out_handler .handlers )
1213
- if not self .out_mut :
1213
+ if self .out_mut is None :
1214
1214
return out
1215
1215
else :
1216
1216
out_ = []
@@ -2282,7 +2282,6 @@ def lower_mesh_computation(
2282
2282
host_callbacks = lowering_result .host_callbacks ,
2283
2283
keepalive = lowering_result .keepalive ,
2284
2284
kept_var_idx = set (range (len (global_in_avals ))),
2285
- out_mut = None ,
2286
2285
backend = backend ,
2287
2286
device_assignment = _create_da_object (tuple (mesh .devices .flat )),
2288
2287
committed = True ,
@@ -2297,7 +2296,6 @@ class MeshComputation(stages.XlaLowering):
2297
2296
2298
2297
def __init__ (self , name : str , hlo : ir .Module | None ,
2299
2298
donated_invars : Sequence [bool ], ** compile_args ):
2300
- compile_args .setdefault ('out_mut' , None ) # TODO(mattjj): remove default
2301
2299
self ._name = name
2302
2300
self ._hlo = hlo
2303
2301
self ._donated_invars = donated_invars
@@ -2763,7 +2761,7 @@ class UnloadedMeshExecutable:
2763
2761
keepalive : Sequence [Any ]
2764
2762
host_callbacks : Sequence [Any ]
2765
2763
kept_var_idx : set [int ]
2766
- out_mut : Sequence [None | int ]
2764
+ out_mut : Sequence [None | int ] | None
2767
2765
auto_spmd_lowering : bool
2768
2766
in_layouts : Sequence [SpecifiedLayout | None ]
2769
2767
out_layouts : Sequence [SpecifiedLayout | None ]
@@ -2802,7 +2800,7 @@ def from_hlo(name: str,
2802
2800
global_out_avals : Sequence [ShapedArray ],
2803
2801
in_shardings : Sequence [sharding_impls .XLACompatibleSharding | AUTO ],
2804
2802
out_shardings : Sequence [(sharding_impls .XLACompatibleSharding | AUTO |
2805
- UnspecifiedValue )],
2803
+ UnspecifiedValue )],
2806
2804
spmd_lowering : bool ,
2807
2805
tuple_args : bool ,
2808
2806
auto_spmd_lowering : bool ,
@@ -2811,13 +2809,13 @@ def from_hlo(name: str,
2811
2809
host_callbacks : list [Any ],
2812
2810
keepalive : Any ,
2813
2811
kept_var_idx : set [int ],
2814
- out_mut : Sequence [None | int ],
2815
2812
backend : xb .XlaBackend ,
2816
2813
device_assignment : xc .DeviceList | Sequence [xc .Device ], # type: ignore
2817
2814
committed : bool ,
2818
2815
in_layouts : MaybeLayout ,
2819
2816
out_layouts : MaybeLayout ,
2820
2817
pmap_nreps : int = 1 ,
2818
+ out_mut : Sequence [None | int ] | None = None ,
2821
2819
shape_poly_state : mlir .ShapePolyLoweringState | None = None ,
2822
2820
all_default_mem_kind : bool = True ,
2823
2821
all_args_info : AllArgsInfo | None = None ,
0 commit comments