60
60
from jax ._src .interpreters import partial_eval as pe
61
61
from jax ._src .interpreters import mlir
62
62
from jax ._src .interpreters import xla
63
- from jax ._src .layout import XLACompatibleLayout , SpecifiedLayout , LayoutRequest
63
+ from jax ._src .layout import SpecifiedLayout , AutoLayout
64
64
from jax ._src .lib import xla_client as xc
65
65
from jax ._src .lib import xla_extension_version
66
66
from jax ._src .lib .mlir import ir
@@ -1985,13 +1985,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
1985
1985
return False
1986
1986
return True
1987
1987
1988
- MaybeLayout = Sequence [Union [XLACompatibleLayout , LayoutRequest , None ]]
1988
+ MaybeLayout = Sequence [Union [SpecifiedLayout , AutoLayout , None ]]
1989
1989
1990
1990
1991
1991
class AllArgsInfo (NamedTuple ):
1992
1992
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
1993
1993
in_avals : Sequence [core .ShapedArray ]
1994
1994
in_shardings : Any
1995
+ in_layouts : Any
1995
1996
debug_info : core .JaxprDebugInfo | None
1996
1997
1997
1998
@@ -2023,7 +2024,7 @@ def lower_sharding_computation(
2023
2024
auto_spmd_lowering = check_if_any_auto (
2024
2025
it .chain .from_iterable ([in_shardings , out_shardings ])) # type: ignore
2025
2026
2026
- all_args_info = AllArgsInfo (global_in_avals , in_shardings ,
2027
+ all_args_info = AllArgsInfo (global_in_avals , in_shardings , in_layouts ,
2027
2028
closed_jaxpr .jaxpr .debug_info )
2028
2029
2029
2030
(closed_jaxpr , global_in_avals , global_out_avals , donated_invars ,
@@ -2227,8 +2228,6 @@ def lower_mesh_computation(
2227
2228
out_jaxpr_avals = fun_or_jaxpr .out_avals
2228
2229
consts = fun_or_jaxpr .consts
2229
2230
2230
- all_args_info = AllArgsInfo (global_in_avals , in_shardings , jaxpr .debug_info )
2231
-
2232
2231
assert len (out_shardings ) == len (out_jaxpr_avals )
2233
2232
if spmd_lowering :
2234
2233
global_out_avals = out_jaxpr_avals
@@ -2319,7 +2318,7 @@ def lower_mesh_computation(
2319
2318
in_layouts = (None ,) * len (global_in_avals ),
2320
2319
out_layouts = (None ,) * len (global_out_avals ),
2321
2320
shape_poly_state = lowering_result .shape_poly_state ,
2322
- all_args_info = all_args_info )
2321
+ all_args_info = None )
2323
2322
2324
2323
class MeshComputation (stages .XlaLowering ):
2325
2324
_hlo : ir .Module | None
@@ -2599,7 +2598,7 @@ def _get_layouts_from_executable(
2599
2598
if isinstance (i , SpecifiedLayout ):
2600
2599
if i != x :
2601
2600
raise AssertionError (
2602
- f"Unexpected XLA layout override: (XLA) { x } != { i } (User sharding )" )
2601
+ f"Unexpected XLA layout override: (XLA) { x } != { i } (User layout )" )
2603
2602
new_in_layouts .append (i )
2604
2603
else :
2605
2604
new_in_layouts .append (x )
@@ -2610,7 +2609,7 @@ def _get_layouts_from_executable(
2610
2609
if isinstance (o , SpecifiedLayout ):
2611
2610
if o != x :
2612
2611
raise AssertionError (
2613
- f"Unexpected XLA layout override: (XLA) { x } != { o } (User sharding )" )
2612
+ f"Unexpected XLA layout override: (XLA) { x } != { o } (User layout )" )
2614
2613
new_out_layouts .append (o )
2615
2614
else :
2616
2615
new_out_layouts .append (x )
@@ -3016,19 +3015,24 @@ def call(self, *args):
3016
3015
kept_args = [a for i , a in enumerate (args ) if i in self ._kept_var_idx ]
3017
3016
ref_avals = self .in_avals
3018
3017
in_shardings = self ._in_shardings
3018
+ in_layouts = self ._in_layouts
3019
3019
debug_info = None
3020
3020
else :
3021
3021
kept_args = args
3022
3022
ref_avals = self ._all_args_info .in_avals
3023
3023
iter_in_shardings = iter (self ._in_shardings )
3024
3024
in_shardings = [next (iter_in_shardings ) if i in self ._kept_var_idx else s
3025
3025
for i , s in enumerate (self ._all_args_info .in_shardings )]
3026
+ iter_in_layouts = iter (self ._in_layouts )
3027
+ in_layouts = [next (iter_in_layouts ) if i in self ._kept_var_idx else s
3028
+ for i , s in enumerate (self ._all_args_info .in_layouts )]
3026
3029
debug_info = self ._all_args_info .debug_info
3027
3030
3028
3031
arg_avals = map (xla .abstractify , kept_args )
3029
3032
check_arg_avals_for_call (ref_avals , arg_avals , debug_info )
3030
3033
# Check the GDA sharding and the input sharding.
3031
- check_gda_or_array_xla_sharding_match (kept_args , in_shardings , debug_info )
3034
+ check_array_xla_sharding_layout_match (kept_args , in_shardings ,
3035
+ in_layouts , debug_info )
3032
3036
return self .unsafe_call (* args ) # pylint: disable=not-callable
3033
3037
3034
3038
def input_shardings (self ) -> Sequence [sharding_impls .XLACompatibleSharding ]:
@@ -3184,15 +3188,17 @@ def check_device_backend_on_shardings(shardings) -> bool:
3184
3188
return False
3185
3189
3186
3190
3187
- def check_gda_or_array_xla_sharding_match (
3191
+ def check_array_xla_sharding_layout_match (
3188
3192
args , in_xla_shardings : Sequence [sharding_impls .XLACompatibleSharding ],
3193
+ in_xla_layouts : Sequence [SpecifiedLayout ],
3189
3194
jaxpr_debug_info : core .JaxprDebugInfo | None ) -> None :
3190
3195
from jax ._src .array import ArrayImpl
3191
3196
arg_names = (['' ] * len (args ) if jaxpr_debug_info is None else
3192
3197
jaxpr_debug_info .arg_names )
3193
3198
errors = []
3194
3199
num_errors = 5
3195
- for arg , xs , name in safe_zip (args , in_xla_shardings , arg_names ):
3200
+ for arg , xs , xl , name in safe_zip (args , in_xla_shardings , in_xla_layouts ,
3201
+ arg_names ):
3196
3202
if not isinstance (arg , ArrayImpl ):
3197
3203
continue
3198
3204
if is_unspecified_or_auto (xs ):
@@ -3205,27 +3211,47 @@ def check_gda_or_array_xla_sharding_match(
3205
3211
# Raise memory kind mismatch error even if the arg is uncommitted.
3206
3212
if arg .sharding .memory_kind != xs .memory_kind :
3207
3213
errors .append (
3208
- "Got input sharding(s) that compiled object was called with: "
3214
+ ( "Got input sharding(s) that compiled object was called with: "
3209
3215
f"{ arg .sharding } and sharding(s) the computation was compiled "
3210
- f"with: { xs } for arg { name } with shape: { arg .aval .str_short ()} " )
3216
+ f"with: { xs } for arg { name } with shape: { arg .aval .str_short ()} " ,
3217
+ 'sharding' ))
3211
3218
3212
3219
if (not db_xs and arg ._committed and
3213
3220
not op_shardings .are_op_shardings_equal (
3214
3221
arg .sharding ._to_xla_hlo_sharding (arg .ndim ),
3215
3222
xs ._to_xla_hlo_sharding (arg .ndim ))):
3216
3223
errors .append (
3217
- "Got input sharding(s) that compiled object was called with: "
3224
+ ( "Got input sharding(s) that compiled object was called with: "
3218
3225
f"{ arg .sharding } and sharding(s) the computation was compiled "
3219
- f"with: { xs } for arg { name } with shape: { arg .aval .str_short ()} " )
3226
+ f"with: { xs } for arg { name } with shape: { arg .aval .str_short ()} " ,
3227
+ 'sharding' ))
3228
+
3229
+ # TODO(yashkatariya): Remove `arg.layout is not None` check after pathways
3230
+ # supports layout on Array.
3231
+ if (xla_extension_version >= 249 and not db_xs and arg ._committed and
3232
+ arg .layout is not None and arg .layout != xl ):
3233
+ errors .append (
3234
+ ("Got input layout(s) that compiled object was called with: "
3235
+ f"{ arg .layout } and layout(s) the computation was compiled "
3236
+ f"with: { xl } for arg { name } with shape: { arg .aval .str_short ()} " ,
3237
+ 'layout' ))
3220
3238
3221
3239
if errors :
3222
- str_errors = '\n ' .join (errors [:num_errors ])
3240
+ first_errors , error_kinds = unzip2 (errors [:num_errors ])
3241
+ str_errors = '\n ' .join (first_errors )
3242
+ if all (k == 'sharding' for k in error_kinds ):
3243
+ kind_str = r'sharding(s)'
3244
+ elif all (k == 'layout' for k in error_kinds ):
3245
+ kind_str = 'layout(s)'
3246
+ else :
3247
+ kind_str = 'sharding(s) and layout(s)'
3223
3248
num_mismatch_str = (
3224
3249
f'the { len (errors )} mismatches' if len (errors ) < num_errors else
3225
3250
f"{ num_errors } mismatches out of { len (errors )} " )
3226
3251
raise ValueError (
3227
- "Compiled object called with input sharding(s) does not match the "
3228
- "sharding(s) the computation was compiled with. "
3252
+ f"Compiled object called with input { kind_str } does "
3253
+ f"not match the { kind_str } the computation was "
3254
+ "compiled with. "
3229
3255
f"Here are { num_mismatch_str } :\n { str_errors } " )
3230
3256
3231
3257
0 commit comments