Skip to content

Commit d6f2452

Browse files
authored
Make models amenable to scan (#157)
* Make models amenable to scan We replace the `for` loop in both Llama and Mixtral with an equivalent `HomogenousSequential` layer, which can be either run a for loop or use `torch_xla`'s scan operator. This is a clean-ish way to turn scan on/off without cluttering the modeling code. I also adjusted Mixtral slightly so that we can even run `scan` in Mixtral with its static MoE implementation. Scanning over GMM on the other hand won't work until GMM forward/backward is wrapped in a custom op similar to pytorch/xla#8654. Test: added unit test. Next PR will change the trainer to apply scan. * Address comments
1 parent 6f26df9 commit d6f2452

File tree

7 files changed

+319
-169
lines changed

7 files changed

+319
-169
lines changed

README.md

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -169,77 +169,58 @@ Finally, each model may also provide a GPU "original" version that illustrates
169169
and attributes where this model code came from, if any. This also helps to
170170
show case what changes we have done to make it performant on TPU. The original
171171
version is not expected to be run.
172+
172173
## Contributing
173174

174175
Contributions are welcome! Please feel free to submit a pull request.
175176

176177
When developing, use `pip install -e '.[dev]'` to install dev dependencies such
177178
as linter and formatter.
178179

179-
### How to run tests:
180+
### How to run tests
180181

181182
```sh
182183
pytest
183184
```
184185

185-
### How to run some of the tests, and re-run them whenever you change a file:
186+
### How to run some of the tests, and re-run them whenever you change a file
186187

187188
```sh
188189
tp -i test ... # replace with path to tests/directories
189190
```
190191

192+
### How to format
191193

192-
### How to run HuggingFace transformer models
193-
Torchprime supports run with huggingface models by taking advantage of `tp run`.
194-
To use huggingface models, you can clone
195-
[huggingface/transformers](https://github.com/huggingface/transformers) under
196-
torchprime and name it as `local_transformers`. This allows you to pick any
197-
branch or make code modifications in transformers for experiment:
198-
```
199-
git clone https://github.com/huggingface/transformers.git local_transformers
200-
```
201-
If huggingface transformer doesn't exist, torchprime will automatically clone
202-
the repo and build the docker for experiment. To switch to huggingface models,
203-
add flag `--use-hf` to `tp run` command:
194+
```sh
195+
ruff format
204196
```
205-
tp run --use-hf torchprime/hf_models/train.py
197+
198+
### How to lint
199+
200+
```sh
201+
ruff check [--fix]
206202
```
207203

204+
You can install a Ruff VSCode plugin to check errors and format files from
205+
the editor.
206+
208207
### How to run inside the docker container locally
208+
209209
You can also run locally without XPK with docker. When running inside the docker
210210
container, it will use the same dependencies and build process as used in the
211211
XPK approach, improving the hermeticity and reliability.
212-
```
212+
213+
```sh
213214
tp docker-run torchprime/torch_xla_models/train.py
214215
```
215-
This will run the TorchPrime docker image locally. You can also add `--use-hf`
216-
to run HuggingFace model locally.
217-
```
218-
tp docker-run --use-hf torchprime/hf_models/train.py
219-
```
220216

221-
### How to run locally without XPK:
222-
```
223-
tp dbrun torchprime/torch_xla_models/train.py
224-
```
225217
This will run the TorchPrime docker image locally. You can also add `--use-hf`
226218
to run HuggingFace model locally.
227219

228-
### How to format:
229-
230220
```sh
231-
ruff format
232-
```
233-
234-
### How to lint:
235-
236-
```sh
237-
ruff check [--fix]
221+
tp docker-run --use-hf torchprime/hf_models/train.py
238222
```
239223

240-
You can install a Ruff VSCode plugin to check errors and format files from
241-
the editor.
242-
243224
## Run distributed training with local torch/torch_xla wheel
244225

245226
Torchprime supports running with user specified torch and torch_xla wheels placed

torchprime/layers/sequential.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Any
2+
3+
import torch.nn as nn
4+
5+
PyTree = Any
6+
7+
8+
class HomogeneousSequential(nn.Sequential):
9+
"""
10+
HomogenousSequential is a sequential container that requires all child modules
11+
to be of the same type and have matching input/output shapes. In turn, it may be
12+
compiled with the `scan` higher order operator to save compile time.
13+
"""
14+
15+
repeated_layer: type
16+
"""The type of the layer being looped over."""
17+
18+
def __init__(self, *args: nn.Module) -> None:
19+
super().__init__(*args)
20+
types = set(type(module) for module in args)
21+
assert len(types) == 1, f"All modules must be of the same type. Got {types}"
22+
self.repeated_layer = types.pop()
23+
24+
def forward(self, *input, **broadcasted_inputs: PyTree):
25+
"""
26+
Much like `torch.nn.Sequential`, this takes `input` and forwards it to the
27+
first module it contains. It then "chains" outputs to inputs sequentially for
28+
each subsequent module, finally returning the output of the last module.
29+
Different from `torch.nn.Sequential`, you may specify `broadcasted_inputs` via
30+
keyword arguments. The same keyword arguments will be passed to every layer
31+
without changes (i.e. "broadcasted").
32+
"""
33+
for module in self:
34+
input = module(*splat(input), **broadcasted_inputs)
35+
return input
36+
37+
38+
def splat(input):
39+
if not isinstance(input, list | tuple):
40+
input = (input,)
41+
return input

torchprime/torch_xla_models/llama/model.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from transformers.activations import ACT2FN
2929
from transformers.utils import logging
3030

31+
from torchprime.layers.sequential import HomogeneousSequential
3132
from torchprime.rope.rope import RopeScaling, llama3_rope_frequencies
33+
from torchprime.torch_xla_models import offloading
3234
from torchprime.torch_xla_models.loss import cross_entropy_loss
3335

3436
logger = logging.get_logger(__name__)
@@ -328,15 +330,20 @@ def forward(
328330
self,
329331
hidden_states: torch.Tensor,
330332
attention_mask: torch.Tensor | None = None,
331-
position_ids: torch.LongTensor | None = None,
332-
) -> torch.FloatTensor:
333+
position_ids: torch.Tensor | None = None,
334+
) -> torch.Tensor:
333335
"""
334336
Args:
335337
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
336338
attention_mask (`torch.FloatTensor`, *optional*):
337339
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
338340
query_sequence_length, key_sequence_length)` if default attention is used.
339341
"""
342+
# This gives the `hidden_states` tensor a name so that we can layer specify
343+
# to offload this tensor to host RAM to save memory. This is not a standard
344+
# torch API because there is no such feature in PyTorch. Instead, the name
345+
# becomes node metadata during FX graph capture.
346+
hidden_states = offloading.offload_name(hidden_states, "decoder_input")
340347

341348
residual = hidden_states
342349

@@ -370,10 +377,12 @@ class LlamaModel(nn.Module):
370377
def __init__(self, config: DictConfig):
371378
super().__init__()
372379
self.vocab_size = config.vocab_size
373-
374380
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
375-
self.layers = nn.ModuleList(
376-
[
381+
382+
# `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with
383+
# `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html.
384+
self.layers = HomogeneousSequential(
385+
*[
377386
LlamaDecoderLayer(config, layer_idx)
378387
for layer_idx in range(config.num_hidden_layers)
379388
]
@@ -385,15 +394,19 @@ def forward(
385394
self,
386395
input_ids: torch.LongTensor,
387396
attention_mask: torch.FloatTensor | None = None,
388-
) -> torch.FloatTensor:
397+
) -> torch.Tensor:
398+
# convert input ids to embeddings
389399
inputs_embeds = self.embed_tokens(input_ids)
390400

391-
position_ids = torch.arange(
392-
inputs_embeds.shape[1], device=inputs_embeds.device
393-
).unsqueeze(0)
394-
395-
# Create a causal mask without calling the current method
396401
seq_length = inputs_embeds.size(1)
402+
403+
# TODO(https://github.com/pytorch/xla/issues/8783): Pass position_ids as `long()`
404+
# when `scan` can take non-differentiable inputs.
405+
position_ids = (
406+
torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0).float()
407+
)
408+
409+
# Create a causal attention mask
397410
causal_mask = torch.triu(
398411
torch.full((seq_length, seq_length), float("-inf"), device=inputs_embeds.device),
399412
diagonal=1,
@@ -403,14 +416,10 @@ def forward(
403416
if attention_mask is not None:
404417
causal_mask = causal_mask * attention_mask[:, None, None, :]
405418

406-
# embed positions
407-
hidden_states = inputs_embeds
408-
409419
# decoder layers
410-
for decoder_layer in self.layers:
411-
hidden_states = decoder_layer(
412-
hidden_states, attention_mask=causal_mask, position_ids=position_ids
413-
)
420+
hidden_states = self.layers(
421+
inputs_embeds, attention_mask=causal_mask, position_ids=position_ids
422+
)
414423

415424
hidden_states = self.norm(hidden_states)
416425
return hidden_states

torchprime/torch_xla_models/mixtral/model.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch import nn
3232
from torch.nn import init
3333

34+
from torchprime.layers.sequential import HomogeneousSequential
3435
from torchprime.torch_xla_models.loss import cross_entropy_loss
3536
from torchprime.torch_xla_models.topology import get_num_slices
3637

@@ -129,8 +130,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
129130
Returns:
130131
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
131132
"""
132-
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
133-
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
133+
cos = cos[position_ids.long()].unsqueeze(unsqueeze_dim)
134+
sin = sin[position_ids.long()].unsqueeze(unsqueeze_dim)
134135
q_embed = (q * cos) + (rotate_half(q) * sin)
135136
k_embed = (k * cos) + (rotate_half(k) * sin)
136137
return q_embed, k_embed
@@ -204,7 +205,7 @@ def forward(
204205
self,
205206
hidden_states: torch.Tensor,
206207
attention_mask: torch.Tensor | None = None,
207-
position_ids: torch.LongTensor | None = None,
208+
position_ids: torch.Tensor | None = None,
208209
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
209210
bsz, q_len, _ = hidden_states.size()
210211

@@ -816,7 +817,9 @@ def load_balance_loss(self, top_k_indices, logits):
816817
return loss
817818

818819
@xp.trace_me("MixtralMoeBlock")
819-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
820+
def forward(
821+
self, hidden_states: torch.Tensor
822+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
820823
batch_size, sequence_length, hidden_dim = hidden_states.shape
821824
hidden_states = hidden_states.view(-1, hidden_dim)
822825
# router_logits: (batch * sequence_length, n_experts)
@@ -851,6 +854,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
851854
case "dropping":
852855
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
853856
mesh = xs.get_global_mesh()
857+
assert mesh is not None
854858
selected_experts = selected_experts.view(
855859
batch_size, sequence_length, self.top_k
856860
)
@@ -895,7 +899,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
895899
final_hidden_states = final_hidden_states.reshape(
896900
batch_size, sequence_length, hidden_dim
897901
)
898-
return final_hidden_states, router_logits, loss
902+
case _:
903+
raise NotImplementedError(
904+
f"Unsupported moe implementation {self.moe_implementation}"
905+
)
906+
return final_hidden_states, router_logits, torch.tensor(loss)
899907

900908

901909
class MixtralDecoderLayer(nn.Module):
@@ -915,9 +923,10 @@ def __init__(self, config: DictConfig, layer_idx: int):
915923
def forward(
916924
self,
917925
hidden_states: torch.Tensor,
926+
cumulative_loss: torch.Tensor,
918927
attention_mask: torch.Tensor | None = None,
919-
position_ids: torch.LongTensor | None = None,
920-
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
928+
position_ids: torch.Tensor | None = None,
929+
) -> tuple[torch.Tensor, torch.Tensor]:
921930
residual = hidden_states
922931

923932
hidden_states = self.input_layernorm(hidden_states)
@@ -936,7 +945,7 @@ def forward(
936945
hidden_states, router_logits, loss = self.block_sparse_moe(hidden_states)
937946
hidden_states = residual + hidden_states
938947

939-
outputs = (hidden_states, loss)
948+
outputs = (hidden_states, cumulative_loss + loss)
940949
return outputs
941950

942951

@@ -954,8 +963,11 @@ def __init__(self, config: DictConfig):
954963
self.config = config
955964
self.vocab_size = config.vocab_size
956965
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
957-
self.layers = nn.ModuleList(
958-
[
966+
967+
# `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with
968+
# `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html.
969+
self.layers = HomogeneousSequential(
970+
*[
959971
MixtralDecoderLayer(config, layer_idx)
960972
for layer_idx in range(config.num_hidden_layers)
961973
]
@@ -978,12 +990,14 @@ def _init_weights(self, module):
978990

979991
@xp.trace_me("MixtralModel")
980992
def forward(
981-
self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor | None = None
993+
self, input_ids: torch.LongTensor, attention_mask: torch.Tensor | None = None
982994
) -> tuple:
983995
batch_size, seq_length = input_ids.shape
984996

985997
device = input_ids.device
986-
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
998+
# TODO(https://github.com/pytorch/xla/issues/8783): Pass position_ids as `long()`
999+
# when `scan` can take non-differentiable inputs.
1000+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).float()
9871001
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
9881002

9891003
inputs_embeds = self.embed_tokens(input_ids)
@@ -999,22 +1013,18 @@ def forward(
9991013

10001014
hidden_states = inputs_embeds
10011015

1002-
total_loss = 0.0
1003-
for decoder_layer in self.layers:
1004-
layer_outputs = decoder_layer(
1005-
hidden_states,
1006-
attention_mask=causal_mask,
1007-
position_ids=position_ids,
1008-
)
1009-
1010-
hidden_states = layer_outputs[0]
1011-
load_balance_loss = layer_outputs[-1]
1012-
total_loss += load_balance_loss
1016+
load_balance_loss = torch.tensor(0.0, device=device)
1017+
hidden_states, load_balance_loss = self.layers(
1018+
hidden_states,
1019+
load_balance_loss,
1020+
attention_mask=causal_mask,
1021+
position_ids=position_ids,
1022+
)
10131023

1014-
total_loss = total_loss / len(self.layers)
1024+
load_balance_loss = load_balance_loss / len(self.layers)
10151025

10161026
hidden_states = self.norm(hidden_states)
1017-
return (hidden_states, total_loss)
1027+
return (hidden_states, load_balance_loss)
10181028

10191029

10201030
class MixtralForCausalLM(nn.Module):

0 commit comments

Comments
 (0)