@@ -2789,7 +2789,7 @@ def build_unsafe_call(self):
2789
2789
2790
2790
def load (self ) -> MeshExecutable :
2791
2791
return MeshExecutable (self .xla_executable , self .build_unsafe_call ,
2792
- self .input_avals ,
2792
+ self .input_avals , self . output_avals ,
2793
2793
self .input_shardings , self .output_shardings ,
2794
2794
self .auto_spmd_lowering , self .kept_var_idx ,
2795
2795
self .in_layouts , self .out_layouts ,
@@ -2942,12 +2942,13 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat):
2942
2942
class MeshExecutable (stages .XlaExecutable ):
2943
2943
__slots__ = [
2944
2944
"xla_executable" , "_unsafe_call" , "build_unsafe_call" , "in_avals" ,
2945
- "_in_shardings" , "_out_shardings" , "_auto_spmd_lowering" , "_kept_var_idx" ,
2946
- "_in_layouts" , "_out_layouts" , "_all_args_info" , "_unloaded_executable" ,
2945
+ "out_avals" , "_in_shardings" , "_out_shardings" , "_auto_spmd_lowering" ,
2946
+ "_kept_var_idx" , "_in_layouts" , "_out_layouts" , "_all_args_info" ,
2947
+ "_unloaded_executable" ,
2947
2948
]
2948
2949
2949
- def __init__ (self , xla_executable , build_unsafe_call , in_avals , in_shardings ,
2950
- out_shardings , auto_spmd_lowering , kept_var_idx ,
2950
+ def __init__ (self , xla_executable , build_unsafe_call , in_avals , out_avals ,
2951
+ in_shardings , out_shardings , auto_spmd_lowering , kept_var_idx ,
2951
2952
in_layouts , out_layouts ,
2952
2953
all_args_info : AllArgsInfo | None = None ,
2953
2954
unloaded_executable = None ):
@@ -2956,6 +2957,7 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
2956
2957
# in_avals is a list of global and local avals. Aval is global if input
2957
2958
# is a GDA or jax.Array else local.
2958
2959
self .in_avals = in_avals
2960
+ self .out_avals = out_avals
2959
2961
self ._unsafe_call = None
2960
2962
self ._in_shardings = in_shardings
2961
2963
self ._out_shardings = out_shardings
@@ -3118,8 +3120,9 @@ def _compile_replicated_mesh_executable_from_hlo(
3118
3120
committed = committed , pmap_nreps = pmap_nreps )
3119
3121
xla_executable = None
3120
3122
return MeshExecutable (xla_executable , lambda : unsafe_call , global_in_avals ,
3121
- in_shardings , out_shardings , auto_spmd_lowering ,
3122
- kept_var_idx , (None ,) * len (global_in_avals ),
3123
+ global_out_avals , in_shardings , out_shardings ,
3124
+ auto_spmd_lowering , kept_var_idx ,
3125
+ (None ,) * len (global_in_avals ),
3123
3126
(None ,) * len (global_out_avals ))
3124
3127
3125
3128
0 commit comments