Skip to content

Commit 7d19eb5

Browse files
authored
Shard both forward and backwards (#144)
We replace `xs.mark_sharding` with `MarkShardingFunction` which shards gradient too. This should make GSPMD much more robust.
1 parent 39a73b6 commit 7d19eb5

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

torchprime/sharding/shard_model.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,31 @@ def shard_torch_xla_model_from_config(
211211
If `mesh` is not given, there must be a registered global mesh.
212212
"""
213213
import torch_xla.distributed.spmd as xs
214+
from torch_xla.distributed.spmd.xla_sharding import MarkShardingFunction
214215

215-
def shard_fn(tensor, spec: tuple[str, ...]):
216+
def shard_activation(tensor, spec: tuple[str, ...]):
217+
the_mesh = mesh if mesh is not None else xs.get_global_mesh()
218+
assert the_mesh is not None, "No mesh found"
219+
# TODO(https://github.com/pytorch/xla/issues/8678): Replace with the simpler
220+
# `mark_sharding_and_gradients`.
221+
out = MarkShardingFunction.apply(tensor, the_mesh, spec)
222+
assert isinstance(out, torch.Tensor)
223+
return out
224+
225+
# TODO(https://github.com/pytorch/xla/issues/8809): If we shard parameters with
226+
# `MarkShardingFunction.apply`, that causes Mixtral to OOM. Gradient HLO arrays end up
227+
# living much longer than needed.
228+
def shard_param(tensor, spec: tuple[str, ...]):
216229
the_mesh = mesh if mesh is not None else xs.get_global_mesh()
217230
assert the_mesh is not None, "No mesh found"
218-
# TODO(https://github.com/pytorch/xla/issues/8678): Shard the gradient too.
219231
return xs.mark_sharding(tensor, the_mesh, spec).global_tensor
220232

221-
return shard_model_from_config(model, config, shard_fn)
233+
return shard_model_from_config(
234+
model,
235+
config,
236+
shard_activation,
237+
shard_param,
238+
)
222239

223240

224241
def _process_tail_index_syntax(

0 commit comments

Comments
 (0)