@@ -2758,21 +2758,16 @@ def _maybe_get_and_check_in_shardings(
2758
2758
xla_s = aval .dtype ._rules .logical_sharding (aval , xla_s )
2759
2759
new_in_shardings .append (xla_s )
2760
2760
else :
2761
- # TODO(yashkatariya): Remove the if branch for abstract_token once
2762
- # choosing input shardings by XLA is enabled again.
2763
- if aval is core .abstract_token :
2764
- new_in_shardings .append (orig )
2765
- else :
2766
- xla_hlo_s = xla_s ._to_xla_hlo_sharding (aval .ndim ) # type: ignore
2767
- orig_hlo_s = orig ._to_xla_hlo_sharding (aval .ndim ) # type: ignore
2768
- # MANUAL HloSharding comes from other partitioning frameworks.
2769
- if (not dtypes .issubdtype (aval .dtype , dtypes .extended ) and
2770
- not xla_hlo_s .is_manual () and
2771
- (not op_shardings .are_op_shardings_equal (xla_hlo_s , orig_hlo_s ))):
2772
- raise AssertionError (
2773
- f"Unexpected XLA sharding override: (XLA) { xla_s } != { orig } "
2774
- "(User sharding)" )
2775
- new_in_shardings .append (orig )
2761
+ xla_hlo_s = xla_s ._to_xla_hlo_sharding (aval .ndim ) # type: ignore
2762
+ orig_hlo_s = orig ._to_xla_hlo_sharding (aval .ndim ) # type: ignore
2763
+ # MANUAL HloSharding comes from other partitioning frameworks.
2764
+ if (not dtypes .issubdtype (aval .dtype , dtypes .extended ) and
2765
+ not xla_hlo_s .is_manual () and
2766
+ (not op_shardings .are_op_shardings_equal (xla_hlo_s , orig_hlo_s ))):
2767
+ raise AssertionError (
2768
+ f"Unexpected XLA sharding override: (XLA) { xla_s } != { orig } "
2769
+ "(User sharding)" )
2770
+ new_in_shardings .append (orig )
2776
2771
return new_in_shardings
2777
2772
2778
2773
0 commit comments