60
60
from jax ._src .interpreters import partial_eval as pe
61
61
from jax ._src .interpreters import mlir
62
62
from jax ._src .interpreters import xla
63
- from jax ._src .layout import SpecifiedLayout , AutoLayout
63
+ from jax ._src .layout import XLACompatibleLayout , SpecifiedLayout , LayoutRequest
64
64
from jax ._src .lib import xla_client as xc
65
65
from jax ._src .lib import xla_extension_version
66
66
from jax ._src .lib .mlir import ir
@@ -1985,14 +1985,13 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
1985
1985
return False
1986
1986
return True
1987
1987
1988
- MaybeLayout = Sequence [Union [SpecifiedLayout , AutoLayout , None ]]
1988
+ MaybeLayout = Sequence [Union [XLACompatibleLayout , LayoutRequest , None ]]
1989
1989
1990
1990
1991
1991
class AllArgsInfo (NamedTuple ):
1992
1992
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
1993
1993
in_avals : Sequence [core .ShapedArray ]
1994
1994
in_shardings : Any
1995
- in_layouts : Any
1996
1995
debug_info : core .JaxprDebugInfo | None
1997
1996
1998
1997
@@ -2024,7 +2023,7 @@ def lower_sharding_computation(
2024
2023
auto_spmd_lowering = check_if_any_auto (
2025
2024
it .chain .from_iterable ([in_shardings , out_shardings ])) # type: ignore
2026
2025
2027
- all_args_info = AllArgsInfo (global_in_avals , in_shardings , in_layouts ,
2026
+ all_args_info = AllArgsInfo (global_in_avals , in_shardings ,
2028
2027
closed_jaxpr .jaxpr .debug_info )
2029
2028
2030
2029
(closed_jaxpr , global_in_avals , global_out_avals , donated_invars ,
@@ -2228,6 +2227,8 @@ def lower_mesh_computation(
2228
2227
out_jaxpr_avals = fun_or_jaxpr .out_avals
2229
2228
consts = fun_or_jaxpr .consts
2230
2229
2230
+ all_args_info = AllArgsInfo (global_in_avals , in_shardings , jaxpr .debug_info )
2231
+
2231
2232
assert len (out_shardings ) == len (out_jaxpr_avals )
2232
2233
if spmd_lowering :
2233
2234
global_out_avals = out_jaxpr_avals
@@ -2318,7 +2319,7 @@ def lower_mesh_computation(
2318
2319
in_layouts = (None ,) * len (global_in_avals ),
2319
2320
out_layouts = (None ,) * len (global_out_avals ),
2320
2321
shape_poly_state = lowering_result .shape_poly_state ,
2321
- all_args_info = None )
2322
+ all_args_info = all_args_info )
2322
2323
2323
2324
class MeshComputation (stages .XlaLowering ):
2324
2325
_hlo : ir .Module | None
@@ -2598,7 +2599,7 @@ def _get_layouts_from_executable(
2598
2599
if isinstance (i , SpecifiedLayout ):
2599
2600
if i != x :
2600
2601
raise AssertionError (
2601
- f"Unexpected XLA layout override: (XLA) { x } != { i } (User layout )" )
2602
+ f"Unexpected XLA layout override: (XLA) { x } != { i } (User sharding )" )
2602
2603
new_in_layouts .append (i )
2603
2604
else :
2604
2605
new_in_layouts .append (x )
@@ -2609,7 +2610,7 @@ def _get_layouts_from_executable(
2609
2610
if isinstance (o , SpecifiedLayout ):
2610
2611
if o != x :
2611
2612
raise AssertionError (
2612
- f"Unexpected XLA layout override: (XLA) { x } != { o } (User layout )" )
2613
+ f"Unexpected XLA layout override: (XLA) { x } != { o } (User sharding )" )
2613
2614
new_out_layouts .append (o )
2614
2615
else :
2615
2616
new_out_layouts .append (x )
@@ -3015,24 +3016,19 @@ def call(self, *args):
3015
3016
kept_args = [a for i , a in enumerate (args ) if i in self ._kept_var_idx ]
3016
3017
ref_avals = self .in_avals
3017
3018
in_shardings = self ._in_shardings
3018
- in_layouts = self ._in_layouts
3019
3019
debug_info = None
3020
3020
else :
3021
3021
kept_args = args
3022
3022
ref_avals = self ._all_args_info .in_avals
3023
3023
iter_in_shardings = iter (self ._in_shardings )
3024
3024
in_shardings = [next (iter_in_shardings ) if i in self ._kept_var_idx else s
3025
3025
for i , s in enumerate (self ._all_args_info .in_shardings )]
3026
- iter_in_layouts = iter (self ._in_layouts )
3027
- in_layouts = [next (iter_in_layouts ) if i in self ._kept_var_idx else s
3028
- for i , s in enumerate (self ._all_args_info .in_layouts )]
3029
3026
debug_info = self ._all_args_info .debug_info
3030
3027
3031
3028
arg_avals = map (xla .abstractify , kept_args )
3032
3029
check_arg_avals_for_call (ref_avals , arg_avals , debug_info )
3033
3030
# Check the GDA sharding and the input sharding.
3034
- check_array_xla_sharding_layout_match (kept_args , in_shardings ,
3035
- in_layouts , debug_info )
3031
+ check_gda_or_array_xla_sharding_match (kept_args , in_shardings , debug_info )
3036
3032
return self .unsafe_call (* args ) # pylint: disable=not-callable
3037
3033
3038
3034
def input_shardings (self ) -> Sequence [sharding_impls .XLACompatibleSharding ]:
@@ -3188,17 +3184,15 @@ def check_device_backend_on_shardings(shardings) -> bool:
3188
3184
return False
3189
3185
3190
3186
3191
- def check_array_xla_sharding_layout_match (
3187
+ def check_gda_or_array_xla_sharding_match (
3192
3188
args , in_xla_shardings : Sequence [sharding_impls .XLACompatibleSharding ],
3193
- in_xla_layouts : Sequence [SpecifiedLayout ],
3194
3189
jaxpr_debug_info : core .JaxprDebugInfo | None ) -> None :
3195
3190
from jax ._src .array import ArrayImpl
3196
3191
arg_names = (['' ] * len (args ) if jaxpr_debug_info is None else
3197
3192
jaxpr_debug_info .arg_names )
3198
3193
errors = []
3199
3194
num_errors = 5
3200
- for arg , xs , xl , name in safe_zip (args , in_xla_shardings , in_xla_layouts ,
3201
- arg_names ):
3195
+ for arg , xs , name in safe_zip (args , in_xla_shardings , arg_names ):
3202
3196
if not isinstance (arg , ArrayImpl ):
3203
3197
continue
3204
3198
if is_unspecified_or_auto (xs ):
@@ -3211,47 +3205,27 @@ def check_array_xla_sharding_layout_match(
3211
3205
# Raise memory kind mismatch error even if the arg is uncommitted.
3212
3206
if arg .sharding .memory_kind != xs .memory_kind :
3213
3207
errors .append (
3214
- ( "Got input sharding(s) that compiled object was called with: "
3208
+ "Got input sharding(s) that compiled object was called with: "
3215
3209
f"{ arg .sharding } and sharding(s) the computation was compiled "
3216
- f"with: { xs } for arg { name } with shape: { arg .aval .str_short ()} " ,
3217
- 'sharding' ))
3210
+ f"with: { xs } for arg { name } with shape: { arg .aval .str_short ()} " )
3218
3211
3219
3212
if (not db_xs and arg ._committed and
3220
3213
not op_shardings .are_op_shardings_equal (
3221
3214
arg .sharding ._to_xla_hlo_sharding (arg .ndim ),
3222
3215
xs ._to_xla_hlo_sharding (arg .ndim ))):
3223
3216
errors .append (
3224
- ( "Got input sharding(s) that compiled object was called with: "
3217
+ "Got input sharding(s) that compiled object was called with: "
3225
3218
f"{ arg .sharding } and sharding(s) the computation was compiled "
3226
- f"with: { xs } for arg { name } with shape: { arg .aval .str_short ()} " ,
3227
- 'sharding' ))
3228
-
3229
- # TODO(yashkatariya): Remove `arg.layout is not None` check after pathways
3230
- # supports layout on Array.
3231
- if (xla_extension_version >= 249 and not db_xs and arg ._committed and
3232
- arg .layout is not None and arg .layout != xl ):
3233
- errors .append (
3234
- ("Got input layout(s) that compiled object was called with: "
3235
- f"{ arg .layout } and layout(s) the computation was compiled "
3236
- f"with: { xl } for arg { name } with shape: { arg .aval .str_short ()} " ,
3237
- 'layout' ))
3219
+ f"with: { xs } for arg { name } with shape: { arg .aval .str_short ()} " )
3238
3220
3239
3221
if errors :
3240
- first_errors , error_kinds = unzip2 (errors [:num_errors ])
3241
- str_errors = '\n ' .join (first_errors )
3242
- if all (k == 'sharding' for k in error_kinds ):
3243
- kind_str = r'sharding(s)'
3244
- elif all (k == 'layout' for k in error_kinds ):
3245
- kind_str = 'layout(s)'
3246
- else :
3247
- kind_str = 'sharding(s) and layout(s)'
3222
+ str_errors = '\n ' .join (errors [:num_errors ])
3248
3223
num_mismatch_str = (
3249
3224
f'the { len (errors )} mismatches' if len (errors ) < num_errors else
3250
3225
f"{ num_errors } mismatches out of { len (errors )} " )
3251
3226
raise ValueError (
3252
- f"Compiled object called with input { kind_str } does "
3253
- f"not match the { kind_str } the computation was "
3254
- "compiled with. "
3227
+ "Compiled object called with input sharding(s) does not match the "
3228
+ "sharding(s) the computation was compiled with. "
3255
3229
f"Here are { num_mismatch_str } :\n { str_errors } " )
3256
3230
3257
3231
0 commit comments