@@ -1989,10 +1989,8 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
1989
1989
1990
1990
1991
1991
class AllArgsInfo (NamedTuple ):
1992
- """Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
1992
+ """Avals and debug_info for all arguments prior to DCE."""
1993
1993
in_avals : Sequence [core .ShapedArray ]
1994
- in_shardings : Any
1995
- in_layouts : Any
1996
1994
debug_info : core .JaxprDebugInfo | None
1997
1995
1998
1996
@@ -2038,8 +2036,7 @@ def lower_sharding_computation(
2038
2036
auto_spmd_lowering = check_if_any_auto (
2039
2037
it .chain .from_iterable ([in_shardings , out_shardings ])) # type: ignore
2040
2038
2041
- all_args_info = AllArgsInfo (global_in_avals , in_shardings , in_layouts ,
2042
- closed_jaxpr .jaxpr .debug_info )
2039
+ all_args_info = AllArgsInfo (global_in_avals , closed_jaxpr .jaxpr .debug_info )
2043
2040
2044
2041
(closed_jaxpr , global_in_avals , global_out_avals , donated_invars ,
2045
2042
kept_var_idx , name_stack ) = _dce_jaxpr (
@@ -3013,28 +3010,22 @@ def xla_extension_executable(self):
3013
3010
return self .xla_executable
3014
3011
3015
3012
def call (self , * args ):
3013
+ args_after_dce = [a for i , a in enumerate (args ) if i in self ._kept_var_idx ]
3016
3014
if self ._all_args_info is None :
3017
- kept_args = [ a for i , a in enumerate ( args ) if i in self . _kept_var_idx ]
3015
+ kept_args = args_after_dce
3018
3016
ref_avals = self .in_avals
3019
- in_shardings = self ._in_shardings
3020
- in_layouts = self ._in_layouts
3021
3017
debug_info = None
3022
3018
else :
3023
3019
kept_args = args
3024
3020
ref_avals = self ._all_args_info .in_avals
3025
- iter_in_shardings = iter (self ._in_shardings )
3026
- in_shardings = [next (iter_in_shardings ) if i in self ._kept_var_idx else s
3027
- for i , s in enumerate (self ._all_args_info .in_shardings )]
3028
- iter_in_layouts = iter (self ._in_layouts )
3029
- in_layouts = [next (iter_in_layouts ) if i in self ._kept_var_idx else s
3030
- for i , s in enumerate (self ._all_args_info .in_layouts )]
3031
3021
debug_info = self ._all_args_info .debug_info
3032
3022
3033
- arg_avals = map (xla .abstractify , kept_args )
3034
- check_arg_avals_for_call (ref_avals , arg_avals , debug_info )
3023
+ all_arg_avals = map (xla .abstractify , kept_args )
3024
+ check_arg_avals_for_call (ref_avals , all_arg_avals , debug_info )
3035
3025
# Check the GDA sharding and the input sharding.
3036
- check_array_xla_sharding_layout_match (kept_args , in_shardings ,
3037
- in_layouts , debug_info )
3026
+ check_array_xla_sharding_layout_match (
3027
+ args_after_dce , self ._in_shardings , self ._in_layouts , debug_info ,
3028
+ self ._kept_var_idx )
3038
3029
return self .unsafe_call (* args ) # pylint: disable=not-callable
3039
3030
3040
3031
def input_shardings (self ) -> Sequence [sharding_impls .XLACompatibleSharding ]:
@@ -3163,16 +3154,22 @@ def check_device_backend_on_shardings(shardings) -> bool:
3163
3154
3164
3155
3165
3156
def check_array_xla_sharding_layout_match (
3166
- args , in_xla_shardings : Sequence [sharding_impls .XLACompatibleSharding ],
3157
+ args_after_dce ,
3158
+ in_xla_shardings : Sequence [sharding_impls .XLACompatibleSharding ],
3167
3159
in_xla_layouts : Sequence [DeviceLocalLayout ],
3168
- jaxpr_debug_info : core .JaxprDebugInfo | None ) -> None :
3160
+ jaxpr_debug_info : core .JaxprDebugInfo | None ,
3161
+ kept_var_idx : set [int ]) -> None :
3169
3162
from jax ._src .array import ArrayImpl
3170
- arg_names = (['' ] * len (args ) if jaxpr_debug_info is None else
3171
- jaxpr_debug_info .arg_names )
3163
+ # jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
3164
+ arg_names = (
3165
+ ["" ] * len (args_after_dce ) if jaxpr_debug_info is None
3166
+ else [a for i , a in enumerate (jaxpr_debug_info .arg_names ) # type: ignore
3167
+ if i in kept_var_idx ]
3168
+ )
3172
3169
errors = []
3173
3170
num_errors = 5
3174
- for arg , xs , xl , name in safe_zip (args , in_xla_shardings , in_xla_layouts ,
3175
- arg_names ):
3171
+ for arg , xs , xl , name in safe_zip (
3172
+ args_after_dce , in_xla_shardings , in_xla_layouts , arg_names ):
3176
3173
if not isinstance (arg , ArrayImpl ):
3177
3174
continue
3178
3175
if is_unspecified_or_auto (xs ):
@@ -3200,7 +3197,6 @@ def check_array_xla_sharding_layout_match(
3200
3197
3201
3198
if (xla_extension_version >= 249 and not db_xs and arg ._committed and
3202
3199
arg .layout .device_local_layout is not None and xl is not None and
3203
- not isinstance (xl , AutoLayout ) and
3204
3200
arg .layout .device_local_layout != xl ):
3205
3201
errors .append (
3206
3202
("Got input layout(s) that compiled object was called with: "
0 commit comments