@@ -2006,8 +2006,15 @@ def lower_sharding_computation(
2006
2006
any (not is_unspecified (js ) for js , _ in jaxpr_sharding ) or
2007
2007
any (not is_unspecified (o ) for o in out_shardings ))
2008
2008
2009
- gs = sharding_impls .GSPMDSharding .get_replicated (device_assignment )
2010
- in_shardings = tuple (gs if is_unspecified (i ) else i for i in in_shardings )
2009
+ gs = GSPMDSharding .get_replicated (device_assignment )
2010
+ if xla_extension_version < 240 or hasattr (backend , "compile_replicated" ):
2011
+ in_shardings = tuple (gs if is_unspecified (i ) else i for i in in_shardings )
2012
+
2013
+ # TODO(yashkatariya): Allow prng sharding inference by XLA. Enable this after
2014
+ # output sharding of XLA is partially constrained on the trailing dimensions.
2015
+ in_shardings = tuple (
2016
+ gs if a is not core .abstract_token and dtypes .issubdtype (a .dtype , dtypes .extended )
2017
+ else i for i , a in safe_zip (in_shardings , global_in_avals ))
2011
2018
2012
2019
da_object = _create_da_object (tuple (device_assignment ))
2013
2020
@@ -2318,7 +2325,7 @@ def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
2318
2325
return input_indices
2319
2326
2320
2327
2321
- def get_gspmd_shardings_from_executable (
2328
+ def get_out_shardings_from_executable (
2322
2329
xla_executable ,
2323
2330
device_assignment : Sequence [xc .Device ],
2324
2331
num_out_avals : int ,
@@ -2374,6 +2381,32 @@ def get_gspmd_shardings_from_executable(
2374
2381
for os , mk in safe_zip (out_op_shardings , omk )]
2375
2382
2376
2383
2384
+ def _get_in_shardings_from_xla (
2385
+ xla_executable , device_assignment : Sequence [xc .Device ], num_in_avals : int ,
2386
+ num_ordered_effects : int
2387
+ ) -> Sequence [sharding_impls .XLACompatibleSharding ] | None :
2388
+ """Returns input shardings from XLA."""
2389
+ from jax ._src import pjit
2390
+
2391
+ # When the device assignment only has 1 device, SPMD partitioner will not run.
2392
+ # Hence the op shardings will not be set on the `hlo_module`.
2393
+ if len (device_assignment ) == 1 :
2394
+ return [sharding_impls .SingleDeviceSharding (device_assignment [0 ])] * num_in_avals
2395
+
2396
+ in_op_shardings , _ = pjit .get_op_sharding_from_executable (xla_executable )
2397
+ if not in_op_shardings :
2398
+ return None
2399
+
2400
+ if num_ordered_effects > 0 :
2401
+ in_op_shardings = in_op_shardings [num_ordered_effects :]
2402
+
2403
+ assert len (in_op_shardings ) == num_in_avals , (
2404
+ len (in_op_shardings ), num_in_avals )
2405
+
2406
+ return [sharding_impls .GSPMDSharding (device_assignment , os )
2407
+ for os in in_op_shardings ]
2408
+
2409
+
2377
2410
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
2378
2411
# without mesh.
2379
2412
def _get_mesh_pspec_shardings_from_executable (
@@ -2526,8 +2559,8 @@ def get_logical_mesh_ids(mesh_shape):
2526
2559
2527
2560
@weakref_lru_cache
2528
2561
def _cached_compilation (computation , name , mesh , spmd_lowering ,
2529
- tuple_args , auto_spmd_lowering ,
2530
- _allow_propagation_to_outputs , host_callbacks , backend ,
2562
+ tuple_args , auto_spmd_lowering , allow_prop_to_inputs ,
2563
+ allow_prop_to_outputs , host_callbacks , backend ,
2531
2564
da , pmap_nreps , compiler_options_keys ,
2532
2565
compiler_options_values ):
2533
2566
# TODO(phawkins): One would normally just write:
@@ -2580,7 +2613,9 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
2580
2613
get_logical_mesh_ids (list (mesh .shape .values ()))
2581
2614
.reshape (- 1 ))
2582
2615
compile_options .parameter_is_tupled_arguments = tuple_args
2583
- opts .allow_spmd_sharding_propagation_to_output = list (_allow_propagation_to_outputs )
2616
+ if xla_extension_version >= 240 :
2617
+ opts .allow_spmd_sharding_propagation_to_parameters = list (allow_prop_to_inputs )
2618
+ opts .allow_spmd_sharding_propagation_to_output = list (allow_prop_to_outputs )
2584
2619
2585
2620
if hasattr (backend , "compile_replicated" ):
2586
2621
return None , compile_options
@@ -2593,22 +2628,59 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
2593
2628
return xla_executable , compile_options
2594
2629
2595
2630
2596
- def _get_shardings_from_executable (
2631
+ def _maybe_get_and_check_in_shardings (
2632
+ xla_executable , in_shardings , device_assignment ,
2633
+ global_in_avals , num_ordered_effects ):
2634
+ """Returns in_shardings extracted from XLA or checks and returns original
2635
+ shardings.
2636
+
2637
+ If in_shardings exist on `jit` or on `jax.Array`, then this function will
2638
+ check that sharding against what XLA returns as in_shardings. If they don't
2639
+ match, an error is raised.
2640
+
2641
+ If in_sharding is unspecified, then the sharding returned by XLA is returned.
2642
+ """
2643
+ in_shardings_xla = _get_in_shardings_from_xla ( # type: ignore
2644
+ xla_executable , device_assignment , len (global_in_avals ),
2645
+ num_ordered_effects ) # type: ignore
2646
+ if in_shardings_xla is None :
2647
+ return in_shardings
2648
+
2649
+ new_in_shardings = []
2650
+ for xla_s , orig , aval in safe_zip (in_shardings_xla , in_shardings ,
2651
+ global_in_avals ):
2652
+ if is_unspecified (orig ):
2653
+ new_in_shardings .append (xla_s )
2654
+ else :
2655
+ xla_hlo_s = xla_s ._to_xla_hlo_sharding (aval .ndim ) # type: ignore
2656
+ orig_hlo_s = orig ._to_xla_hlo_sharding (aval .ndim ) # type: ignore
2657
+ # MANUAL HloSharding comes from other partitioning frameworks.
2658
+ if (not dtypes .issubdtype (aval .dtype , dtypes .extended ) and
2659
+ not xla_hlo_s .is_manual () and
2660
+ (not op_shardings .are_op_shardings_equal (xla_hlo_s , orig_hlo_s ) or
2661
+ xla_s .memory_kind != orig .memory_kind )): # type: ignore
2662
+ raise AssertionError (
2663
+ f"Unexpected XLA sharding override: (XLA) { xla_s } != { orig } "
2664
+ "(User sharding)" )
2665
+ new_in_shardings .append (orig )
2666
+ return new_in_shardings
2667
+
2668
+
2669
+ def _get_out_shardings_from_executable (
2597
2670
xla_executable , out_shardings , device_assignment , global_out_avals ,
2598
2671
num_ordered_effects , all_default_mem_kind
2599
2672
):
2600
- out_shardings_xla = get_gspmd_shardings_from_executable ( # type: ignore
2673
+ out_shardings_xla = get_out_shardings_from_executable ( # type: ignore
2601
2674
xla_executable , device_assignment , len (global_out_avals ),
2602
2675
num_ordered_effects , all_default_mem_kind ) # type: ignore
2603
2676
if out_shardings_xla is None :
2604
2677
return out_shardings , (False ,) * len (global_out_avals )
2605
2678
2606
- orig_out_shardings = out_shardings
2607
- out_shardings , are_out_shardings_from_xla = [], [] # type: ignore
2608
- for xla_s , orig , aval in safe_zip (out_shardings_xla , orig_out_shardings ,
2679
+ new_out_shardings , are_out_shardings_from_xla = [], [] # type: ignore
2680
+ for xla_s , orig , aval in safe_zip (out_shardings_xla , out_shardings ,
2609
2681
global_out_avals ):
2610
2682
if is_unspecified (orig ):
2611
- out_shardings .append (xla_s )
2683
+ new_out_shardings .append (xla_s )
2612
2684
are_out_shardings_from_xla .append (True )
2613
2685
else :
2614
2686
xla_hlo_s = xla_s ._to_xla_hlo_sharding (aval .ndim ) # type: ignore
@@ -2621,9 +2693,9 @@ def _get_shardings_from_executable(
2621
2693
raise AssertionError (
2622
2694
f"Unexpected XLA sharding override: (XLA) { xla_s } != { orig } "
2623
2695
"(User sharding)" )
2624
- out_shardings .append (orig )
2696
+ new_out_shardings .append (orig )
2625
2697
are_out_shardings_from_xla .append (False )
2626
- return out_shardings , are_out_shardings_from_xla
2698
+ return new_out_shardings , are_out_shardings_from_xla
2627
2699
2628
2700
2629
2701
def finalize_out_shardings (out_shardings , are_out_shardings_from_xla ,
@@ -2722,6 +2794,8 @@ def from_hlo(name: str,
2722
2794
else :
2723
2795
da = _create_da_object (tuple (device_assignment ))
2724
2796
del device_assignment
2797
+
2798
+ allow_prop_to_inputs = tuple (is_unspecified (i ) for i in in_shardings )
2725
2799
allow_prop_to_outputs = tuple (is_unspecified (o ) for o in out_shardings )
2726
2800
2727
2801
mesh = None
@@ -2733,8 +2807,8 @@ def from_hlo(name: str,
2733
2807
2734
2808
xla_executable , compile_options = _cached_compilation (
2735
2809
hlo , name , mesh , spmd_lowering ,
2736
- tuple_args , auto_spmd_lowering , allow_prop_to_outputs ,
2737
- tuple (host_callbacks ), backend , da , pmap_nreps ,
2810
+ tuple_args , auto_spmd_lowering , allow_prop_to_inputs ,
2811
+ allow_prop_to_outputs , tuple (host_callbacks ), backend , da , pmap_nreps ,
2738
2812
compiler_options_keys , compiler_options_values )
2739
2813
2740
2814
if hasattr (backend , "compile_replicated" ):
@@ -2761,9 +2835,11 @@ def from_hlo(name: str,
2761
2835
else :
2762
2836
if pmap_nreps == 1 :
2763
2837
assert mesh is None
2764
- # TODO(yashkatariya): Make da directly usable in the downstream code
2765
- # without tuple conversion.
2766
- out_shardings , are_out_shardings_from_xla = _get_shardings_from_executable (
2838
+ if xla_extension_version >= 240 :
2839
+ in_shardings = _maybe_get_and_check_in_shardings (
2840
+ xla_executable , in_shardings , tuple (da ), global_in_avals ,
2841
+ len (ordered_effects ))
2842
+ out_shardings , are_out_shardings_from_xla = _get_out_shardings_from_executable (
2767
2843
xla_executable , out_shardings , tuple (da ), global_out_avals ,
2768
2844
len (ordered_effects ), all_default_mem_kind )
2769
2845
else :
0 commit comments