@@ -284,24 +284,15 @@ def __init__(
284
284
self .mesh = mesh
285
285
self .spec = spec
286
286
self ._memory_kind = memory_kind
287
- self ._parsed_pspec = _parsed_pspec
288
287
self ._manual_axes = _manual_axes
289
- self ._preprocess ()
290
-
291
- def _preprocess (self ):
292
- # This split exists because you can pass `_parsed_pspec` that has been
293
- # modified from the original. For example: Adding extra dimension to
294
- # axis_resources for vmap handlers. In such cases you need to preserve the
295
- # `sync` attribute of parsed pspecs.
296
- # PartitionSpec is inferred from the parsed pspec in this case.
297
- # TODO(yaskatariya): Remove this and replace this with a normalized
298
- # representation of Parsed Pspec
299
- if self ._parsed_pspec is None :
300
- self ._parsed_pspec , _ , _ = prepare_axis_resources (
301
- PartitionSpec () if self .spec is None else self .spec ,
302
- "NamedSharding spec" , allow_unconstrained_dims = True )
303
-
304
- _check_mesh_resource_axis (self .mesh , self ._parsed_pspec )
288
+ self ._parsed_pspec = preprocess (self .mesh , self .spec , _parsed_pspec )
289
+
290
+ # TODO(phawkins): remove this method when jaxlib 0.4.26 or newer is the
291
+ # minimum. This method is called by the C++ sharding implementation in earlier
292
+ # versions.
293
+ if xla_extension_version < 243 :
294
+ def _preprocess (self ):
295
+ self ._parsed_pspec = preprocess (self .mesh , self .spec , self ._parsed_pspec )
305
296
306
297
def __repr__ (self ):
307
298
mesh_repr = ", " .join (f"'{ k } ': { v } " for k , v in self .mesh .shape .items ())
@@ -1115,6 +1106,23 @@ def __repr__(self):
1115
1106
f"sync={ self .sync } )" )
1116
1107
1117
1108
1109
+ def preprocess (mesh , spec , parsed_pspec ):
1110
+ # This split exists because you can pass `_parsed_pspec` that has been
1111
+ # modified from the original. For example: Adding extra dimension to
1112
+ # axis_resources for vmap handlers. In such cases you need to preserve the
1113
+ # `sync` attribute of parsed pspecs.
1114
+ # PartitionSpec is inferred from the parsed pspec in this case.
1115
+ # TODO(yaskatariya): Remove this and replace this with a normalized
1116
+ # representation of Parsed Pspec
1117
+ if parsed_pspec is None :
1118
+ parsed_pspec , _ , _ = prepare_axis_resources (
1119
+ PartitionSpec () if spec is None else spec ,
1120
+ "NamedSharding spec" , allow_unconstrained_dims = True )
1121
+
1122
+ _check_mesh_resource_axis (mesh , parsed_pspec )
1123
+ return parsed_pspec
1124
+
1125
+
1118
1126
def prepare_axis_resources (axis_resources ,
1119
1127
arg_name ,
1120
1128
allow_unconstrained_dims = False ):
0 commit comments