@@ -211,14 +211,31 @@ def shard_torch_xla_model_from_config(
211
211
If `mesh` is not given, there must be a registered global mesh.
212
212
"""
213
213
import torch_xla .distributed .spmd as xs
214
+ from torch_xla .distributed .spmd .xla_sharding import MarkShardingFunction
214
215
215
- def shard_fn (tensor , spec : tuple [str , ...]):
216
+ def shard_activation (tensor , spec : tuple [str , ...]):
217
+ the_mesh = mesh if mesh is not None else xs .get_global_mesh ()
218
+ assert the_mesh is not None , "No mesh found"
219
+ # TODO(https://github.com/pytorch/xla/issues/8678): Replace with the simpler
220
+ # `mark_sharding_and_gradients`.
221
+ out = MarkShardingFunction .apply (tensor , the_mesh , spec )
222
+ assert isinstance (out , torch .Tensor )
223
+ return out
224
+
225
+ # TODO(https://github.com/pytorch/xla/issues/8809): If we shard parameters with
226
+ # `MarkShardingFunction.apply`, that causes Mixtral to OOM. Gradient HLO arrays end up
227
+ # living much longer than needed.
228
+ def shard_param (tensor , spec : tuple [str , ...]):
216
229
the_mesh = mesh if mesh is not None else xs .get_global_mesh ()
217
230
assert the_mesh is not None , "No mesh found"
218
- # TODO(https://github.com/pytorch/xla/issues/8678): Shard the gradient too.
219
231
return xs .mark_sharding (tensor , the_mesh , spec ).global_tensor
220
232
221
- return shard_model_from_config (model , config , shard_fn )
233
+ return shard_model_from_config (
234
+ model ,
235
+ config ,
236
+ shard_activation ,
237
+ shard_param ,
238
+ )
222
239
223
240
224
241
def _process_tail_index_syntax (
0 commit comments