From 5d159a722b2ef663b666a458e4e403cb225ad91c Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Mon, 17 Mar 2025 00:49:39 -0700 Subject: [PATCH 1/2] Make models amenable to scan We replace the `for` loop in both Llama and Mixtral with an equivalent `HomogenousSequential` layer, which can be either run a for loop or use `torch_xla`'s scan operator. This is a clean-ish way to turn scan on/off without cluttering the modeling code. I also adjusted Mixtral slightly so that we can even run `scan` in Mixtral with its static MoE implementation. Scanning over GMM on the other hand won't work until GMM forward/backward is wrapped in a custom op similar to https://github.com/pytorch/xla/pull/8654. Test: added unit test. Next PR will change the trainer to apply scan. --- README.md | 55 ++--- torchprime/layers/sequential.py | 41 ++++ torchprime/torch_xla_models/llama/model.py | 42 ++-- torchprime/torch_xla_models/mixtral/model.py | 58 ++--- torchprime/torch_xla_models/scan_layers.py | 55 +++++ .../torch_xla_models/tests/test_llama.py | 18 +- .../torch_xla_models/tests/test_mixtral.py | 209 +++++++++++------- 7 files changed, 309 insertions(+), 169 deletions(-) create mode 100644 torchprime/layers/sequential.py create mode 100644 torchprime/torch_xla_models/scan_layers.py diff --git a/README.md b/README.md index 1d3cdfb9..ea272743 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,7 @@ Finally, each model may also provide a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to show case what changes we have done to make it performant on TPU. The original version is not expected to be run. + ## Contributing Contributions are welcome! Please feel free to submit a pull request. @@ -176,70 +177,50 @@ Contributions are welcome! Please feel free to submit a pull request. When developing, use `pip install -e '.[dev]'` to install dev dependencies such as linter and formatter. -### How to run tests: +### How to run tests ```sh pytest ``` -### How to run some of the tests, and re-run them whenever you change a file: +### How to run some of the tests, and re-run them whenever you change a file ```sh tp -i test ... # replace with path to tests/directories ``` +### How to format -### How to run HuggingFace transformer models -Torchprime supports run with huggingface models by taking advantage of `tp run`. -To use huggingface models, you can clone -[huggingface/transformers](https://github.com/huggingface/transformers) under -torchprime and name it as `local_transformers`. This allows you to pick any -branch or make code modifications in transformers for experiment: -``` -git clone https://github.com/huggingface/transformers.git local_transformers -``` -If huggingface transformer doesn't exist, torchprime will automatically clone -the repo and build the docker for experiment. To switch to huggingface models, -add flag `--use-hf` to `tp run` command: +```sh +ruff format ``` -tp run --use-hf torchprime/hf_models/train.py + +### How to lint + +```sh +ruff check [--fix] ``` +You can install a Ruff VSCode plugin to check errors and format files from +the editor. + ### How to run inside the docker container locally + You can also run locally without XPK with docker. When running inside the docker container, it will use the same dependencies and build process as used in the XPK approach, improving the hermeticity and reliability. -``` + +```sh tp docker-run torchprime/torch_xla_models/train.py ``` -This will run the TorchPrime docker image locally. You can also add `--use-hf` -to run HuggingFace model locally. -``` -tp docker-run --use-hf torchprime/hf_models/train.py -``` -### How to run locally without XPK: -``` -tp dbrun torchprime/torch_xla_models/train.py -``` This will run the TorchPrime docker image locally. You can also add `--use-hf` to run HuggingFace model locally. -### How to format: - ```sh -ruff format -``` - -### How to lint: - -```sh -ruff check [--fix] +tp docker-run --use-hf torchprime/hf_models/train.py ``` -You can install a Ruff VSCode plugin to check errors and format files from -the editor. - ## Run distributed training with local torch/torch_xla wheel Torchprime supports running with user specified torch and torch_xla wheels placed diff --git a/torchprime/layers/sequential.py b/torchprime/layers/sequential.py new file mode 100644 index 00000000..44e51be7 --- /dev/null +++ b/torchprime/layers/sequential.py @@ -0,0 +1,41 @@ +from typing import Any + +import torch.nn as nn + +PyTree = Any + + +class HomogeneousSequential(nn.Sequential): + """ + HomogenousSequential is a sequential container that requires all child modules + to be of the same type and have matching input/output shapes. In turn, it may be + compiled with the `scan` higher order operator to save compile time. + """ + + repeated_layer: type + """The type of the layer being looped over.""" + + def __init__(self, *args: nn.Module) -> None: + super().__init__(*args) + types = set(type(module) for module in args) + assert len(types) == 1, f"All modules must be of the same type. Got {types}" + self.repeated_layer = types.pop() + + def forward(self, *input, **broadcasted_inputs: PyTree): + """ + Much like `torch.nn.Sequential`, this takes `input` and forwards it to the + first module it contains. It then "chains" outputs to inputs sequentially for + each subsequent module, finally returning the output of the last module. + Different from `torch.nn.Sequential`, you may specify `broadcasted_inputs` via + keyword arguments. The same keyword arguments will be passed to every layer + without changes (i.e. "broadcasted"). + """ + for module in self: + input = module(*splat(input), **broadcasted_inputs) + return input + + +def splat(input): + if not isinstance(input, list | tuple): + input = (input,) + return input diff --git a/torchprime/torch_xla_models/llama/model.py b/torchprime/torch_xla_models/llama/model.py index e406182d..ff5ca681 100644 --- a/torchprime/torch_xla_models/llama/model.py +++ b/torchprime/torch_xla_models/llama/model.py @@ -28,7 +28,9 @@ from transformers.activations import ACT2FN from transformers.utils import logging +from torchprime.layers.sequential import HomogeneousSequential from torchprime.rope.rope import RopeScaling, llama3_rope_frequencies +from torchprime.torch_xla_models import offloading from torchprime.torch_xla_models.loss import cross_entropy_loss logger = logging.get_logger(__name__) @@ -328,8 +330,8 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - ) -> torch.FloatTensor: + position_ids: torch.Tensor | None = None, + ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -337,6 +339,11 @@ def forward( attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. """ + # This gives the `hidden_states` tensor a name so that we can layer specify + # to offload this tensor to host RAM to save memory. This is not a standard + # torch API because there is no such feature in PyTorch. Instead, the name + # becomes node metadata during FX graph capture. + hidden_states = offloading.offload_name(hidden_states, "decoder_input") residual = hidden_states @@ -370,10 +377,12 @@ class LlamaModel(nn.Module): def __init__(self, config: DictConfig): super().__init__() self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList( - [ + + # `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with + # `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html. + self.layers = HomogeneousSequential( + *[ LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] @@ -385,15 +394,16 @@ def forward( self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor | None = None, - ) -> torch.FloatTensor: + ) -> torch.Tensor: + # convert input ids to embeddings inputs_embeds = self.embed_tokens(input_ids) - position_ids = torch.arange( - inputs_embeds.shape[1], device=inputs_embeds.device - ).unsqueeze(0) - - # Create a causal mask without calling the current method seq_length = inputs_embeds.size(1) + position_ids = ( + torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0).float() + ) + + # Create a causal attention mask causal_mask = torch.triu( torch.full((seq_length, seq_length), float("-inf"), device=inputs_embeds.device), diagonal=1, @@ -403,14 +413,10 @@ def forward( if attention_mask is not None: causal_mask = causal_mask * attention_mask[:, None, None, :] - # embed positions - hidden_states = inputs_embeds - # decoder layers - for decoder_layer in self.layers: - hidden_states = decoder_layer( - hidden_states, attention_mask=causal_mask, position_ids=position_ids - ) + hidden_states = self.layers( + inputs_embeds, attention_mask=causal_mask, position_ids=position_ids + ) hidden_states = self.norm(hidden_states) return hidden_states diff --git a/torchprime/torch_xla_models/mixtral/model.py b/torchprime/torch_xla_models/mixtral/model.py index ce3ad03a..1c80b479 100644 --- a/torchprime/torch_xla_models/mixtral/model.py +++ b/torchprime/torch_xla_models/mixtral/model.py @@ -32,6 +32,7 @@ from torch.nn import init from torchprime.torch_xla_models.loss import cross_entropy_loss +from torchprime.torch_xla_models.scan_layers import HomogeneousSequential from torchprime.torch_xla_models.topology import get_num_slices @@ -129,8 +130,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos[position_ids.long()].unsqueeze(unsqueeze_dim) + sin = sin[position_ids.long()].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -204,7 +205,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, + position_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: bsz, q_len, _ = hidden_states.size() @@ -816,7 +817,9 @@ def load_balance_loss(self, top_k_indices, logits): return loss @xp.trace_me("MixtralMoeBlock") - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) @@ -851,6 +854,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: case "dropping": hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) mesh = xs.get_global_mesh() + assert mesh is not None selected_experts = selected_experts.view( batch_size, sequence_length, self.top_k ) @@ -895,7 +899,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim ) - return final_hidden_states, router_logits, loss + case _: + raise NotImplementedError( + f"Unsupported moe implementation {self.moe_implementation}" + ) + return final_hidden_states, router_logits, torch.tensor(loss) class MixtralDecoderLayer(nn.Module): @@ -915,9 +923,10 @@ def __init__(self, config: DictConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + cumulative_loss: torch.Tensor, attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + position_ids: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -936,7 +945,7 @@ def forward( hidden_states, router_logits, loss = self.block_sparse_moe(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states, loss) + outputs = (hidden_states, cumulative_loss + loss) return outputs @@ -954,8 +963,11 @@ def __init__(self, config: DictConfig): self.config = config self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList( - [ + + # `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with + # `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html. + self.layers = HomogeneousSequential( + *[ MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] @@ -978,12 +990,12 @@ def _init_weights(self, module): @xp.trace_me("MixtralModel") def forward( - self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor | None = None + self, input_ids: torch.LongTensor, attention_mask: torch.Tensor | None = None ) -> tuple: batch_size, seq_length = input_ids.shape device = input_ids.device - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).float() position_ids = position_ids.unsqueeze(0).view(-1, seq_length) inputs_embeds = self.embed_tokens(input_ids) @@ -999,22 +1011,18 @@ def forward( hidden_states = inputs_embeds - total_loss = 0.0 - for decoder_layer in self.layers: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - ) - - hidden_states = layer_outputs[0] - load_balance_loss = layer_outputs[-1] - total_loss += load_balance_loss + load_balance_loss = torch.tensor(0.0, device=device) + hidden_states, load_balance_loss = self.layers( + hidden_states, + load_balance_loss, + attention_mask=causal_mask, + position_ids=position_ids, + ) - total_loss = total_loss / len(self.layers) + load_balance_loss = load_balance_loss / len(self.layers) hidden_states = self.norm(hidden_states) - return (hidden_states, total_loss) + return (hidden_states, load_balance_loss) class MixtralForCausalLM(nn.Module): diff --git a/torchprime/torch_xla_models/scan_layers.py b/torchprime/torch_xla_models/scan_layers.py new file mode 100644 index 00000000..e30f24c0 --- /dev/null +++ b/torchprime/torch_xla_models/scan_layers.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +from functorch.compile import default_partition +from torch_xla.experimental.scan_layers import scan_layers + +from torchprime.layers.sequential import HomogeneousSequential, PyTree, splat + + +class HomogeneousSequentialScan(HomogeneousSequential): + def __init__(self, *args, partition_fn=default_partition): + super().__init__(*args) + self.partition_fn = partition_fn + + def forward(self, *input, **broadcasted_inputs: PyTree): + layers = [BroadcastArguments(m) for m in self.children()] + if len(input) == 1: + # Handle single argument case: we don't need to call the module with a tuple. + input = input[0] + out, _broadcasted_inputs_back = scan_layers( + layers, (input, broadcasted_inputs), partition_fn=self.partition_fn + ) + return out + + +class BroadcastArguments(torch.nn.Module): + def __init__(self, mod: nn.Module): + super().__init__() + self.mod = mod + + def forward(self, orig_input, broadcasted_inputs): + out = self.mod(*splat(orig_input), **broadcasted_inputs) + return (out, broadcasted_inputs) + + +def compile_one_stack( + mod: HomogeneousSequential, partition_fn=default_partition +) -> HomogeneousSequential: + # Replace base class with our optimized subclass. + if isinstance(mod, HomogeneousSequentialScan): + raise NotImplementedError("Cannot compile HomogeneousSequential twice") + new_mod = HomogeneousSequentialScan(*mod.children(), partition_fn=partition_fn) + return new_mod + + +def compile( + mod: nn.Module, sequential_to_scan: str, partition_fn=default_partition +) -> nn.Module: + seq = mod.get_submodule(sequential_to_scan) + if not isinstance(seq, HomogeneousSequential): + raise ValueError(f"compile only supports HomogeneousSequential, got {type(seq)}") + # Replace the submodule + mod.set_submodule( + sequential_to_scan, compile_one_stack(seq, partition_fn=partition_fn) + ) + return mod diff --git a/torchprime/torch_xla_models/tests/test_llama.py b/torchprime/torch_xla_models/tests/test_llama.py index 1eb00ca4..ed13f27a 100644 --- a/torchprime/torch_xla_models/tests/test_llama.py +++ b/torchprime/torch_xla_models/tests/test_llama.py @@ -32,7 +32,7 @@ def get_llama_3_8b() -> LlamaFixture: intermediate_size=16, vocab_size=vocab_size, ) - config.attention_kernel = None + config.flash_attention = False torchprime_config = OmegaConf.create( { "vocab_size": 128, @@ -71,7 +71,7 @@ def get_llama_3_1_405b() -> LlamaFixture: intermediate_size=32, vocab_size=vocab_size, ) - config.attention_kernel = None + config.flash_attention = False torchprime_config = OmegaConf.create( { "vocab_size": 256, @@ -116,15 +116,27 @@ def get_llama_3_1_405b() -> LlamaFixture: return LlamaFixture(vocab_size, hf_model, model) +def noop(mod): + return mod + + +def scan_decoders(mod): + import torchprime.torch_xla_models.scan_layers + + return torchprime.torch_xla_models.scan_layers.compile(mod, "model.layers") + + @pytest.mark.parametrize( "fixture", [get_llama_3_8b, get_llama_3_1_405b], ids=["Llama 3.0 8B", "Llama 3.1 405B"], ) -def test_forward_our_model_against_hf_model(fixture): +@pytest.mark.parametrize("transform", [noop, scan_decoders]) +def test_forward_our_model_against_hf_model(fixture, transform): fixture = fixture() device = torch_xla.device() model_xla = copy.deepcopy(fixture.model).to(device) + model_xla = transform(model_xla) hf_model_xla = copy.deepcopy(fixture.hf_model).to(device) torch_xla.sync() input_sizes = [8, 128, 256] diff --git a/torchprime/torch_xla_models/tests/test_mixtral.py b/torchprime/torch_xla_models/tests/test_mixtral.py index 7b463412..527d18c3 100644 --- a/torchprime/torch_xla_models/tests/test_mixtral.py +++ b/torchprime/torch_xla_models/tests/test_mixtral.py @@ -1,7 +1,9 @@ import copy -import unittest +from dataclasses import dataclass +import pytest import torch +import torch.test import torch_xla from omegaconf import OmegaConf from transformers import AutoConfig @@ -10,98 +12,133 @@ from torchprime.torch_xla_models.mixtral import MixtralForCausalLM -class TestYourModule(unittest.TestCase): - def setUp(self): - super().setUp() - torch.manual_seed(42) - torch_xla.manual_seed(42) - self.vocab_size = 128 - config = AutoConfig.from_pretrained( - "mistralai/Mixtral-8x7B-v0.1", - num_hidden_layers=1, - num_attention_heads=8, - hidden_size=8, - intermediate_size=16, - vocab_size=self.vocab_size, - ) - config.attention_kernel = None - torchprime_config = OmegaConf.create( - { - "vocab_size": 128, - "hidden_size": 8, - "intermediate_size": 16, - "num_hidden_layers": 1, - "num_attention_heads": 8, - "num_key_value_heads": 8, - "max_position_embeddings": 32768, - "initializer_range": 0.02, - "rms_norm_eps": 1.0e-05, - "num_experts_per_tok": 2, - "num_local_experts": 8, - "rope_theta": 1000000.0, - "router_aux_loss_coef": 0.02, - "attention_dropout": 0.0, - "attention_bias": False, - "attention_kernel": None, - "moe_implementation": "static", - } - ) - # place model on CPU device first - with torch.device("cpu"): - self.hf_model = HfMixtralForCausalLM(config) - self.model = MixtralForCausalLM(torchprime_config) - self.model.load_state_dict(self.hf_model.state_dict()) - - def test_forward_our_model_against_hf_model(self): - device = torch_xla.device() - model_xla = copy.deepcopy(self.model).to(device) - hf_model_xla = copy.deepcopy(self.hf_model).to(device) - torch_xla.sync() - input_sizes = [8, 128, 256] - for input_size in input_sizes: - input = torch.randint(128, ((2, input_size // 2))).to(device) - hf_output = hf_model_xla( - input, labels=input, attention_mask=torch.ones_like(input) - ) - mixtral_xla_logits, mixtral_xla_loss = model_xla( - input, labels=input, attention_mask=torch.ones_like(input) - ) - torch_xla.sync() - self.assertTrue( - torch.allclose(hf_output.logits, mixtral_xla_logits, atol=1e-6), - "logits are not equal", - ) - self.assertTrue( - torch.allclose(hf_output.loss, mixtral_xla_loss, atol=1e-6), - "loss is not equal", - ) - - def test_forward_torch_xla_against_native(self): - input_size = 8 - device = torch.device("cpu") - input = torch.randint(self.vocab_size, ((2, input_size // 2))) - mixtral_native_logits, mixtral_native_loss = self.model( - input, labels=input, attention_mask=torch.ones_like(input) - ) +@dataclass +class MixtralFixture: + vocab_size: int + hf_model: HfMixtralForCausalLM + model: MixtralForCausalLM - device = torch_xla.device() - input = input.to(device) - model_xla = copy.deepcopy(self.model).to(device) - torch_xla.sync() +def get_mixtral_8x7b() -> MixtralFixture: + torch.manual_seed(42) + torch_xla.manual_seed(42) + vocab_size = 128 + config = AutoConfig.from_pretrained( + "mistralai/Mixtral-8x7B-v0.1", + num_hidden_layers=1, + num_attention_heads=8, + hidden_size=8, + intermediate_size=16, + vocab_size=vocab_size, + ) + config.flash_attention = False + torchprime_config = OmegaConf.create( + { + "vocab_size": vocab_size, + "hidden_size": 8, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 8, + "num_key_value_heads": 8, + "max_position_embeddings": 32768, + "initializer_range": 0.02, + "rms_norm_eps": 1.0e-05, + "num_experts_per_tok": 2, + "num_local_experts": 8, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.02, + "attention_dropout": 0.0, + "attention_bias": False, + "attention_kernel": None, + "moe_implementation": "static", + } + ) + # place model on CPU device first + with torch.device("cpu"): + hf_model = HfMixtralForCausalLM(config) + model = MixtralForCausalLM(torchprime_config) + model.load_state_dict(hf_model.state_dict()) + + return MixtralFixture(vocab_size, hf_model, model) + + +def noop(mod): + return mod + + +def scan_decoders(mod): + import torchprime.torch_xla_models.scan_layers + + return torchprime.torch_xla_models.scan_layers.compile(mod, "model.layers") + + +@pytest.mark.parametrize( + "fixture", + [get_mixtral_8x7b], + ids=["Mixtral 8x7B"], +) +@pytest.mark.parametrize("transform", [noop, scan_decoders]) +def test_forward_our_model_against_hf_model(fixture, transform): + fixture = fixture() + device = torch_xla.device() + model_xla = copy.deepcopy(fixture.model).to(device) + model_xla = transform(model_xla) + hf_model_xla = copy.deepcopy(fixture.hf_model).to(device) + torch_xla.sync() + input_sizes = [8, 128, 256] + for input_size in input_sizes: + input = torch.randint(128, ((2, input_size // 2))).to(device) + hf_output = hf_model_xla(input, labels=input, attention_mask=torch.ones_like(input)) mixtral_xla_logits, mixtral_xla_loss = model_xla( input, labels=input, attention_mask=torch.ones_like(input) ) torch_xla.sync() - self.assertTrue( - torch.allclose(mixtral_native_logits, mixtral_xla_logits.to("cpu"), atol=1e-2), - "CPU run and XLA run logits are not equal", + torch.testing.assert_close( + hf_output.logits, + mixtral_xla_logits, + atol=1e-6, + rtol=1e-9, + msg="logits are not equal", ) - self.assertTrue( - torch.allclose(mixtral_native_loss, mixtral_native_loss.to("cpu"), atol=1e-2), - "CPU run and XLA run loss is not equal", + torch.testing.assert_close( + hf_output.loss, mixtral_xla_loss, atol=1e-6, rtol=1e-9, msg="loss is not equal" ) -if __name__ == "__main__": - unittest.main() +@pytest.mark.parametrize( + "fixture", + [get_mixtral_8x7b], + ids=["Mixtral 8x7B"], +) +def test_forward_torch_xla_against_native(fixture): + fixture = fixture() + input_size = 8 + device = torch.device("cpu") + input = torch.randint(fixture.vocab_size, ((2, input_size // 2))) + mixtral_native_logits, mixtral_native_loss = fixture.model( + input, labels=input, attention_mask=torch.ones_like(input) + ) + + device = torch_xla.device() + input = input.to(device) + model_xla = copy.deepcopy(fixture.model).to(device) + torch_xla.sync() + + mixtral_xla_logits, mixtral_xla_loss = model_xla( + input, labels=input, attention_mask=torch.ones_like(input) + ) + torch_xla.sync() + torch.testing.assert_close( + mixtral_native_logits, + mixtral_xla_logits.to("cpu"), + atol=1e-2, + rtol=1e-4, + msg="CPU run and XLA run logits are not equal", + ) + torch.testing.assert_close( + mixtral_native_loss, + mixtral_native_loss.to("cpu"), + atol=1e-2, + rtol=1e-4, + msg="CPU run and XLA run loss is not equal", + ) From 51d4568620b20e804224c91fcb3ad2c09efbff4e Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Mon, 17 Mar 2025 15:40:04 -0700 Subject: [PATCH 2/2] Address comments --- torchprime/torch_xla_models/llama/model.py | 3 +++ torchprime/torch_xla_models/mixtral/model.py | 4 +++- torchprime/torch_xla_models/scan_layers.py | 5 +++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/torchprime/torch_xla_models/llama/model.py b/torchprime/torch_xla_models/llama/model.py index ff5ca681..c3ace32c 100644 --- a/torchprime/torch_xla_models/llama/model.py +++ b/torchprime/torch_xla_models/llama/model.py @@ -399,6 +399,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) seq_length = inputs_embeds.size(1) + + # TODO(https://github.com/pytorch/xla/issues/8783): Pass position_ids as `long()` + # when `scan` can take non-differentiable inputs. position_ids = ( torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0).float() ) diff --git a/torchprime/torch_xla_models/mixtral/model.py b/torchprime/torch_xla_models/mixtral/model.py index 1c80b479..7be65e30 100644 --- a/torchprime/torch_xla_models/mixtral/model.py +++ b/torchprime/torch_xla_models/mixtral/model.py @@ -31,8 +31,8 @@ from torch import nn from torch.nn import init +from torchprime.layers.sequential import HomogeneousSequential from torchprime.torch_xla_models.loss import cross_entropy_loss -from torchprime.torch_xla_models.scan_layers import HomogeneousSequential from torchprime.torch_xla_models.topology import get_num_slices @@ -995,6 +995,8 @@ def forward( batch_size, seq_length = input_ids.shape device = input_ids.device + # TODO(https://github.com/pytorch/xla/issues/8783): Pass position_ids as `long()` + # when `scan` can take non-differentiable inputs. position_ids = torch.arange(seq_length, dtype=torch.long, device=device).float() position_ids = position_ids.unsqueeze(0).view(-1, seq_length) diff --git a/torchprime/torch_xla_models/scan_layers.py b/torchprime/torch_xla_models/scan_layers.py index e30f24c0..e15fa669 100644 --- a/torchprime/torch_xla_models/scan_layers.py +++ b/torchprime/torch_xla_models/scan_layers.py @@ -12,6 +12,11 @@ def __init__(self, *args, partition_fn=default_partition): self.partition_fn = partition_fn def forward(self, *input, **broadcasted_inputs: PyTree): + # `self.children()` returns an iterator over the immediate submodules, i.e. + # the layers we want to scan over. In the `BroadcastArguments` we extend each + # layer's return value to also output the broadcasted inputs + # (position IDs in case of LLMs, etc). This plumbs those values across scan + # iterations so the same values are available to all layers. layers = [BroadcastArguments(m) for m in self.children()] if len(input) == 1: # Handle single argument case: we don't need to call the module with a tuple.