Skip to content

Commit 6c3b297

Browse files
authored
Replace nn.Linear with einsum in the model (#147)
Fixes #139. Previously, the batch and sequence dimensions are always squished together by nn.Linear, losing some sharding annotations. This is now no longer the case. As another proof that the sharding propagation got more sensible, I tried sharding the q/k/v proj activations explicitly: model.layers.*.self_attn.q_proj: [fsdp, null, tensor] model.layers.*.self_attn.k_proj: [fsdp, null, tensor] model.layers.*.self_attn.v_proj: [fsdp, null, tensor] model.layers.*.self_attn.o_proj: [fsdp, null, tensor] Without replacing nn.Linear with einsum, this actually _decreases_ MFU. Now this keeps MFU unchanged which is the expected behavior.
1 parent 7d19eb5 commit 6c3b297

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

torchprime/torch_xla_models/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ tp run torchprime/torch_xla_models/train.py \
8080
### Llama 3.1 405B on v6e-256
8181

8282
Recipe for global batch size 64, sequence length 8192.
83-
Expected step duration: 27.349s. MFU: 21.48%.
83+
Expected step duration: 19.642s. MFU: 29.91%.
8484

8585
```sh
8686
export LIBTPU_INIT_ARGS='--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=98304'

torchprime/torch_xla_models/configs/model/scaling/llama-fsdp-tp.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ sharding:
2727
lm_head.weight: [tensor, fsdp]
2828

2929
# Activations
30-
model.layers.*.self_attn: [fsdp, null, tensor]
30+
model.layers.*.self_attn.q_proj: [fsdp, null, tensor]
31+
model.layers.*.self_attn.k_proj: [fsdp, null, tensor]
32+
model.layers.*.self_attn.v_proj: [fsdp, null, tensor]
33+
model.layers.*.self_attn.o_proj: [fsdp, null, tensor]
3134
model.layers.*.input_layernorm: [fsdp, null, tensor]
3235
model.layers.*.post_attention_layernorm: [fsdp, null, tensor]
3336
model.layers.*.mlp: [fsdp, null, tensor]

torchprime/torch_xla_models/train.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch import nn
2525
from torch.utils.data import DataLoader, Dataset, IterableDataset
2626
from torch_xla.distributed.fsdp import checkpoint_module
27+
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear
2728

2829
# Transformers imports
2930
from transformers import (
@@ -81,13 +82,23 @@ def __init__(
8182
self.input_sharding_spec = xs.ShardingSpec(
8283
mesh, (("data", "fsdp"), None), minibatch=minibatch
8384
)
85+
86+
# Recursively replace `nn.Linear` layers with einsum operations in the model.
87+
# Without this patch, an `nn.Linear` module will flatten non-contracting dimensions
88+
# (e.g. batch and sequence), thus destroying the sharding constraints on those dimensions.
89+
model = apply_xla_patch_to_nn_linear(model)
90+
91+
# Annotate model weights and activations with sharding constraints to distribute
92+
# the training across devices following the SPMD paradigm.
8493
sharding_config = OmegaConf.to_container(
8594
self.config.model.scaling.sharding, resolve=True
8695
)
8796
assert isinstance(
8897
sharding_config, dict
8998
), f"Sharding config {sharding_config} must be a dict"
9099
model = shard_torch_xla_model_from_config(model, config=sharding_config)
100+
101+
# Rematerialize forward computation during the backward pass if requested.
91102
model = self._checkpoint_model(model)
92103
model = self._add_optimization_barrier_model(model)
93104
self.model = model

0 commit comments

Comments
 (0)