Skip to content

Make models amenable to scan #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 18 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,77 +169,58 @@ Finally, each model may also provide a GPU "original" version that illustrates
and attributes where this model code came from, if any. This also helps to
show case what changes we have done to make it performant on TPU. The original
version is not expected to be run.

## Contributing

Contributions are welcome! Please feel free to submit a pull request.

When developing, use `pip install -e '.[dev]'` to install dev dependencies such
as linter and formatter.

### How to run tests:
### How to run tests

```sh
pytest
```

### How to run some of the tests, and re-run them whenever you change a file:
### How to run some of the tests, and re-run them whenever you change a file

```sh
tp -i test ... # replace with path to tests/directories
```

### How to format

### How to run HuggingFace transformer models
Torchprime supports run with huggingface models by taking advantage of `tp run`.
To use huggingface models, you can clone
[huggingface/transformers](https://github.com/huggingface/transformers) under
torchprime and name it as `local_transformers`. This allows you to pick any
branch or make code modifications in transformers for experiment:
```
git clone https://github.com/huggingface/transformers.git local_transformers
```
If huggingface transformer doesn't exist, torchprime will automatically clone
the repo and build the docker for experiment. To switch to huggingface models,
add flag `--use-hf` to `tp run` command:
```sh
ruff format
```
tp run --use-hf torchprime/hf_models/train.py

### How to lint

```sh
ruff check [--fix]
```

You can install a Ruff VSCode plugin to check errors and format files from
the editor.

### How to run inside the docker container locally

You can also run locally without XPK with docker. When running inside the docker
container, it will use the same dependencies and build process as used in the
XPK approach, improving the hermeticity and reliability.
```

```sh
tp docker-run torchprime/torch_xla_models/train.py
```
This will run the TorchPrime docker image locally. You can also add `--use-hf`
to run HuggingFace model locally.
```
tp docker-run --use-hf torchprime/hf_models/train.py
```

### How to run locally without XPK:
```
tp dbrun torchprime/torch_xla_models/train.py
```
This will run the TorchPrime docker image locally. You can also add `--use-hf`
to run HuggingFace model locally.

### How to format:

```sh
ruff format
```

### How to lint:

```sh
ruff check [--fix]
tp docker-run --use-hf torchprime/hf_models/train.py
```

You can install a Ruff VSCode plugin to check errors and format files from
the editor.

## Run distributed training with local torch/torch_xla wheel

Torchprime supports running with user specified torch and torch_xla wheels placed
Expand Down
41 changes: 41 additions & 0 deletions torchprime/layers/sequential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any

import torch.nn as nn

PyTree = Any


class HomogeneousSequential(nn.Sequential):
"""
HomogenousSequential is a sequential container that requires all child modules
to be of the same type and have matching input/output shapes. In turn, it may be
compiled with the `scan` higher order operator to save compile time.
"""

repeated_layer: type
"""The type of the layer being looped over."""

def __init__(self, *args: nn.Module) -> None:
super().__init__(*args)
types = set(type(module) for module in args)
assert len(types) == 1, f"All modules must be of the same type. Got {types}"
self.repeated_layer = types.pop()

def forward(self, *input, **broadcasted_inputs: PyTree):
"""
Much like `torch.nn.Sequential`, this takes `input` and forwards it to the
first module it contains. It then "chains" outputs to inputs sequentially for
each subsequent module, finally returning the output of the last module.
Different from `torch.nn.Sequential`, you may specify `broadcasted_inputs` via
keyword arguments. The same keyword arguments will be passed to every layer
without changes (i.e. "broadcasted").
"""
for module in self:
input = module(*splat(input), **broadcasted_inputs)
return input


def splat(input):
if not isinstance(input, list | tuple):
input = (input,)
return input
45 changes: 27 additions & 18 deletions torchprime/torch_xla_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from transformers.activations import ACT2FN
from transformers.utils import logging

from torchprime.layers.sequential import HomogeneousSequential
from torchprime.rope.rope import RopeScaling, llama3_rope_frequencies
from torchprime.torch_xla_models import offloading
from torchprime.torch_xla_models.loss import cross_entropy_loss

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -328,15 +330,20 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
) -> torch.FloatTensor:
position_ids: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
"""
# This gives the `hidden_states` tensor a name so that we can layer specify
# to offload this tensor to host RAM to save memory. This is not a standard
# torch API because there is no such feature in PyTorch. Instead, the name
# becomes node metadata during FX graph capture.
hidden_states = offloading.offload_name(hidden_states, "decoder_input")

residual = hidden_states

Expand Down Expand Up @@ -370,10 +377,12 @@ class LlamaModel(nn.Module):
def __init__(self, config: DictConfig):
super().__init__()
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[

# `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with
# `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html.
self.layers = HomogeneousSequential(
*[
LlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
Expand All @@ -385,15 +394,19 @@ def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.FloatTensor | None = None,
) -> torch.FloatTensor:
) -> torch.Tensor:
# convert input ids to embeddings
inputs_embeds = self.embed_tokens(input_ids)

position_ids = torch.arange(
inputs_embeds.shape[1], device=inputs_embeds.device
).unsqueeze(0)

# Create a causal mask without calling the current method
seq_length = inputs_embeds.size(1)

# TODO(https://github.com/pytorch/xla/issues/8783): Pass position_ids as `long()`
# when `scan` can take non-differentiable inputs.
position_ids = (
torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0).float()
)

# Create a causal attention mask
causal_mask = torch.triu(
torch.full((seq_length, seq_length), float("-inf"), device=inputs_embeds.device),
diagonal=1,
Expand All @@ -403,14 +416,10 @@ def forward(
if attention_mask is not None:
causal_mask = causal_mask * attention_mask[:, None, None, :]

# embed positions
hidden_states = inputs_embeds

# decoder layers
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states, attention_mask=causal_mask, position_ids=position_ids
)
hidden_states = self.layers(
inputs_embeds, attention_mask=causal_mask, position_ids=position_ids
)

hidden_states = self.norm(hidden_states)
return hidden_states
Expand Down
60 changes: 35 additions & 25 deletions torchprime/torch_xla_models/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch import nn
from torch.nn import init

from torchprime.layers.sequential import HomogeneousSequential
from torchprime.torch_xla_models.loss import cross_entropy_loss
from torchprime.torch_xla_models.topology import get_num_slices

Expand Down Expand Up @@ -129,8 +130,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
cos = cos[position_ids.long()].unsqueeze(unsqueeze_dim)
sin = sin[position_ids.long()].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -204,7 +205,7 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
position_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -816,7 +817,9 @@ def load_balance_loss(self, top_k_indices, logits):
return loss

@xp.trace_me("MixtralMoeBlock")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
Expand Down Expand Up @@ -851,6 +854,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
case "dropping":
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
mesh = xs.get_global_mesh()
assert mesh is not None
selected_experts = selected_experts.view(
batch_size, sequence_length, self.top_k
)
Expand Down Expand Up @@ -895,7 +899,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states, router_logits, loss
case _:
raise NotImplementedError(
f"Unsupported moe implementation {self.moe_implementation}"
)
return final_hidden_states, router_logits, torch.tensor(loss)


class MixtralDecoderLayer(nn.Module):
Expand All @@ -915,9 +923,10 @@ def __init__(self, config: DictConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
cumulative_loss: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
position_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand All @@ -936,7 +945,7 @@ def forward(
hidden_states, router_logits, loss = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states, loss)
outputs = (hidden_states, cumulative_loss + loss)
return outputs


Expand All @@ -954,8 +963,11 @@ def __init__(self, config: DictConfig):
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[

# `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with
# `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html.
self.layers = HomogeneousSequential(
*[
MixtralDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
Expand All @@ -978,12 +990,14 @@ def _init_weights(self, module):

@xp.trace_me("MixtralModel")
def forward(
self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor | None = None
self, input_ids: torch.LongTensor, attention_mask: torch.Tensor | None = None
) -> tuple:
batch_size, seq_length = input_ids.shape

device = input_ids.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
# TODO(https://github.com/pytorch/xla/issues/8783): Pass position_ids as `long()`
# when `scan` can take non-differentiable inputs.
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).float()
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)

inputs_embeds = self.embed_tokens(input_ids)
Expand All @@ -999,22 +1013,18 @@ def forward(

hidden_states = inputs_embeds

total_loss = 0.0
for decoder_layer in self.layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
)

hidden_states = layer_outputs[0]
load_balance_loss = layer_outputs[-1]
total_loss += load_balance_loss
load_balance_loss = torch.tensor(0.0, device=device)
hidden_states, load_balance_loss = self.layers(
hidden_states,
load_balance_loss,
attention_mask=causal_mask,
position_ids=position_ids,
)

total_loss = total_loss / len(self.layers)
load_balance_loss = load_balance_loss / len(self.layers)

hidden_states = self.norm(hidden_states)
return (hidden_states, total_loss)
return (hidden_states, load_balance_loss)


class MixtralForCausalLM(nn.Module):
Expand Down
Loading