@@ -2574,23 +2574,23 @@ def _get_out_sharding_from_orig_sharding(
2574
2574
out .append (o )
2575
2575
return out
2576
2576
2577
- def maybe_get_orig_out_sharding (
2578
- in_shardings , out_shardings , in_avals , out_avals ):
2579
- if all (not isinstance (o , sharding_impls .GSPMDSharding ) for o in out_shardings ):
2580
- return out_shardings
2577
+ def maybe_recover_user_shardings (
2578
+ old_shardings , new_shardings , old_avals , new_avals ):
2579
+ if all (not isinstance (o , sharding_impls .GSPMDSharding ) for o in new_shardings ):
2580
+ return new_shardings
2581
2581
2582
2582
orig_in_s = None
2583
2583
orig_aval = None
2584
- for oi , aval in safe_zip (in_shardings , in_avals ):
2584
+ for oi , aval in safe_zip (old_shardings , old_avals ):
2585
2585
if type (oi ) in _orig_out_sharding_handlers :
2586
2586
orig_in_s = oi
2587
2587
orig_aval = aval
2588
2588
break
2589
2589
if orig_in_s is not None :
2590
2590
return _get_out_sharding_from_orig_sharding (
2591
- out_shardings , out_avals , orig_in_s , orig_aval )
2591
+ new_shardings , new_avals , orig_in_s , orig_aval )
2592
2592
2593
- return out_shardings
2593
+ return new_shardings
2594
2594
2595
2595
2596
2596
def _get_layouts_from_executable (
@@ -2744,6 +2744,10 @@ def _maybe_get_and_check_in_shardings(
2744
2744
f"Unexpected XLA sharding override: (XLA) { xla_s } != { orig } "
2745
2745
"(User sharding)" )
2746
2746
new_in_shardings .append (orig )
2747
+
2748
+ new_in_shardings = maybe_recover_user_shardings (
2749
+ in_shardings , new_in_shardings , global_in_avals , global_in_avals )
2750
+
2747
2751
return new_in_shardings
2748
2752
2749
2753
@@ -2921,7 +2925,7 @@ def from_hlo(name: str,
2921
2925
assert all (i is None for i in in_layouts )
2922
2926
assert all (o is None for o in out_layouts )
2923
2927
2924
- out_shardings = maybe_get_orig_out_sharding (
2928
+ out_shardings = maybe_recover_user_shardings (
2925
2929
in_shardings , out_shardings , global_in_avals , global_out_avals )
2926
2930
2927
2931
out_shardings = finalize_out_shardings (out_shardings , da )
0 commit comments