Skip to content

Commit 4244b21

Browse files
hawkinspjax authors
authored andcommitted
[XLA:Python] Port sharding and device lists to nanobind.
PiperOrigin-RevId: 613933518
1 parent 4b0382e commit 4244b21

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

jax/_src/sharding_impls.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -284,24 +284,15 @@ def __init__(
284284
self.mesh = mesh
285285
self.spec = spec
286286
self._memory_kind = memory_kind
287-
self._parsed_pspec = _parsed_pspec
288287
self._manual_axes = _manual_axes
289-
self._preprocess()
290-
291-
def _preprocess(self):
292-
# This split exists because you can pass `_parsed_pspec` that has been
293-
# modified from the original. For example: Adding extra dimension to
294-
# axis_resources for vmap handlers. In such cases you need to preserve the
295-
# `sync` attribute of parsed pspecs.
296-
# PartitionSpec is inferred from the parsed pspec in this case.
297-
# TODO(yaskatariya): Remove this and replace this with a normalized
298-
# representation of Parsed Pspec
299-
if self._parsed_pspec is None:
300-
self._parsed_pspec, _, _ = prepare_axis_resources(
301-
PartitionSpec() if self.spec is None else self.spec,
302-
"NamedSharding spec", allow_unconstrained_dims=True)
303-
304-
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
288+
self._parsed_pspec = preprocess(self.mesh, self.spec, _parsed_pspec)
289+
290+
# TODO(phawkins): remove this method when jaxlib 0.4.26 or newer is the
291+
# minimum. This method is called by the C++ sharding implementation in earlier
292+
# versions.
293+
if xla_extension_version < 243:
294+
def _preprocess(self):
295+
self._parsed_pspec = preprocess(self.mesh, self.spec, self._parsed_pspec)
305296

306297
def __repr__(self):
307298
mesh_repr = ", ".join(f"'{k}': {v}" for k, v in self.mesh.shape.items())
@@ -1115,6 +1106,23 @@ def __repr__(self):
11151106
f"sync={self.sync})")
11161107

11171108

1109+
def preprocess(mesh, spec, parsed_pspec):
1110+
# This split exists because you can pass `_parsed_pspec` that has been
1111+
# modified from the original. For example: Adding extra dimension to
1112+
# axis_resources for vmap handlers. In such cases you need to preserve the
1113+
# `sync` attribute of parsed pspecs.
1114+
# PartitionSpec is inferred from the parsed pspec in this case.
1115+
# TODO(yaskatariya): Remove this and replace this with a normalized
1116+
# representation of Parsed Pspec
1117+
if parsed_pspec is None:
1118+
parsed_pspec, _, _ = prepare_axis_resources(
1119+
PartitionSpec() if spec is None else spec,
1120+
"NamedSharding spec", allow_unconstrained_dims=True)
1121+
1122+
_check_mesh_resource_axis(mesh, parsed_pspec)
1123+
return parsed_pspec
1124+
1125+
11181126
def prepare_axis_resources(axis_resources,
11191127
arg_name,
11201128
allow_unconstrained_dims=False):

0 commit comments

Comments
 (0)