diff --git a/README.md b/README.md index 1d3cdfb9..ea272743 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,7 @@ Finally, each model may also provide a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to show case what changes we have done to make it performant on TPU. The original version is not expected to be run. + ## Contributing Contributions are welcome! Please feel free to submit a pull request. @@ -176,70 +177,50 @@ Contributions are welcome! Please feel free to submit a pull request. When developing, use `pip install -e '.[dev]'` to install dev dependencies such as linter and formatter. -### How to run tests: +### How to run tests ```sh pytest ``` -### How to run some of the tests, and re-run them whenever you change a file: +### How to run some of the tests, and re-run them whenever you change a file ```sh tp -i test ... # replace with path to tests/directories ``` +### How to format -### How to run HuggingFace transformer models -Torchprime supports run with huggingface models by taking advantage of `tp run`. -To use huggingface models, you can clone -[huggingface/transformers](https://github.com/huggingface/transformers) under -torchprime and name it as `local_transformers`. This allows you to pick any -branch or make code modifications in transformers for experiment: -``` -git clone https://github.com/huggingface/transformers.git local_transformers -``` -If huggingface transformer doesn't exist, torchprime will automatically clone -the repo and build the docker for experiment. To switch to huggingface models, -add flag `--use-hf` to `tp run` command: +```sh +ruff format ``` -tp run --use-hf torchprime/hf_models/train.py + +### How to lint + +```sh +ruff check [--fix] ``` +You can install a Ruff VSCode plugin to check errors and format files from +the editor. + ### How to run inside the docker container locally + You can also run locally without XPK with docker. When running inside the docker container, it will use the same dependencies and build process as used in the XPK approach, improving the hermeticity and reliability. -``` + +```sh tp docker-run torchprime/torch_xla_models/train.py ``` -This will run the TorchPrime docker image locally. You can also add `--use-hf` -to run HuggingFace model locally. -``` -tp docker-run --use-hf torchprime/hf_models/train.py -``` -### How to run locally without XPK: -``` -tp dbrun torchprime/torch_xla_models/train.py -``` This will run the TorchPrime docker image locally. You can also add `--use-hf` to run HuggingFace model locally. -### How to format: - ```sh -ruff format -``` - -### How to lint: - -```sh -ruff check [--fix] +tp docker-run --use-hf torchprime/hf_models/train.py ``` -You can install a Ruff VSCode plugin to check errors and format files from -the editor. - ## Run distributed training with local torch/torch_xla wheel Torchprime supports running with user specified torch and torch_xla wheels placed diff --git a/torchprime/layers/sequential.py b/torchprime/layers/sequential.py new file mode 100644 index 00000000..44e51be7 --- /dev/null +++ b/torchprime/layers/sequential.py @@ -0,0 +1,41 @@ +from typing import Any + +import torch.nn as nn + +PyTree = Any + + +class HomogeneousSequential(nn.Sequential): + """ + HomogenousSequential is a sequential container that requires all child modules + to be of the same type and have matching input/output shapes. In turn, it may be + compiled with the `scan` higher order operator to save compile time. + """ + + repeated_layer: type + """The type of the layer being looped over.""" + + def __init__(self, *args: nn.Module) -> None: + super().__init__(*args) + types = set(type(module) for module in args) + assert len(types) == 1, f"All modules must be of the same type. Got {types}" + self.repeated_layer = types.pop() + + def forward(self, *input, **broadcasted_inputs: PyTree): + """ + Much like `torch.nn.Sequential`, this takes `input` and forwards it to the + first module it contains. It then "chains" outputs to inputs sequentially for + each subsequent module, finally returning the output of the last module. + Different from `torch.nn.Sequential`, you may specify `broadcasted_inputs` via + keyword arguments. The same keyword arguments will be passed to every layer + without changes (i.e. "broadcasted"). + """ + for module in self: + input = module(*splat(input), **broadcasted_inputs) + return input + + +def splat(input): + if not isinstance(input, list | tuple): + input = (input,) + return input diff --git a/torchprime/torch_xla_models/llama/model.py b/torchprime/torch_xla_models/llama/model.py index e406182d..c3ace32c 100644 --- a/torchprime/torch_xla_models/llama/model.py +++ b/torchprime/torch_xla_models/llama/model.py @@ -28,7 +28,9 @@ from transformers.activations import ACT2FN from transformers.utils import logging +from torchprime.layers.sequential import HomogeneousSequential from torchprime.rope.rope import RopeScaling, llama3_rope_frequencies +from torchprime.torch_xla_models import offloading from torchprime.torch_xla_models.loss import cross_entropy_loss logger = logging.get_logger(__name__) @@ -328,8 +330,8 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - ) -> torch.FloatTensor: + position_ids: torch.Tensor | None = None, + ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -337,6 +339,11 @@ def forward( attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. """ + # This gives the `hidden_states` tensor a name so that we can layer specify + # to offload this tensor to host RAM to save memory. This is not a standard + # torch API because there is no such feature in PyTorch. Instead, the name + # becomes node metadata during FX graph capture. + hidden_states = offloading.offload_name(hidden_states, "decoder_input") residual = hidden_states @@ -370,10 +377,12 @@ class LlamaModel(nn.Module): def __init__(self, config: DictConfig): super().__init__() self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList( - [ + + # `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with + # `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html. + self.layers = HomogeneousSequential( + *[ LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] @@ -385,15 +394,19 @@ def forward( self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor | None = None, - ) -> torch.FloatTensor: + ) -> torch.Tensor: + # convert input ids to embeddings inputs_embeds = self.embed_tokens(input_ids) - position_ids = torch.arange( - inputs_embeds.shape[1], device=inputs_embeds.device - ).unsqueeze(0) - - # Create a causal mask without calling the current method seq_length = inputs_embeds.size(1) + + # TODO(https://github.com/pytorch/xla/issues/8783): Pass position_ids as `long()` + # when `scan` can take non-differentiable inputs. + position_ids = ( + torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0).float() + ) + + # Create a causal attention mask causal_mask = torch.triu( torch.full((seq_length, seq_length), float("-inf"), device=inputs_embeds.device), diagonal=1, @@ -403,14 +416,10 @@ def forward( if attention_mask is not None: causal_mask = causal_mask * attention_mask[:, None, None, :] - # embed positions - hidden_states = inputs_embeds - # decoder layers - for decoder_layer in self.layers: - hidden_states = decoder_layer( - hidden_states, attention_mask=causal_mask, position_ids=position_ids - ) + hidden_states = self.layers( + inputs_embeds, attention_mask=causal_mask, position_ids=position_ids + ) hidden_states = self.norm(hidden_states) return hidden_states diff --git a/torchprime/torch_xla_models/mixtral/model.py b/torchprime/torch_xla_models/mixtral/model.py index ce3ad03a..7be65e30 100644 --- a/torchprime/torch_xla_models/mixtral/model.py +++ b/torchprime/torch_xla_models/mixtral/model.py @@ -31,6 +31,7 @@ from torch import nn from torch.nn import init +from torchprime.layers.sequential import HomogeneousSequential from torchprime.torch_xla_models.loss import cross_entropy_loss from torchprime.torch_xla_models.topology import get_num_slices @@ -129,8 +130,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos[position_ids.long()].unsqueeze(unsqueeze_dim) + sin = sin[position_ids.long()].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -204,7 +205,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, + position_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: bsz, q_len, _ = hidden_states.size() @@ -816,7 +817,9 @@ def load_balance_loss(self, top_k_indices, logits): return loss @xp.trace_me("MixtralMoeBlock") - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) @@ -851,6 +854,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: case "dropping": hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) mesh = xs.get_global_mesh() + assert mesh is not None selected_experts = selected_experts.view( batch_size, sequence_length, self.top_k ) @@ -895,7 +899,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim ) - return final_hidden_states, router_logits, loss + case _: + raise NotImplementedError( + f"Unsupported moe implementation {self.moe_implementation}" + ) + return final_hidden_states, router_logits, torch.tensor(loss) class MixtralDecoderLayer(nn.Module): @@ -915,9 +923,10 @@ def __init__(self, config: DictConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + cumulative_loss: torch.Tensor, attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + position_ids: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -936,7 +945,7 @@ def forward( hidden_states, router_logits, loss = self.block_sparse_moe(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states, loss) + outputs = (hidden_states, cumulative_loss + loss) return outputs @@ -954,8 +963,11 @@ def __init__(self, config: DictConfig): self.config = config self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList( - [ + + # `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with + # `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html. + self.layers = HomogeneousSequential( + *[ MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] @@ -978,12 +990,14 @@ def _init_weights(self, module): @xp.trace_me("MixtralModel") def forward( - self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor | None = None + self, input_ids: torch.LongTensor, attention_mask: torch.Tensor | None = None ) -> tuple: batch_size, seq_length = input_ids.shape device = input_ids.device - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + # TODO(https://github.com/pytorch/xla/issues/8783): Pass position_ids as `long()` + # when `scan` can take non-differentiable inputs. + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).float() position_ids = position_ids.unsqueeze(0).view(-1, seq_length) inputs_embeds = self.embed_tokens(input_ids) @@ -999,22 +1013,18 @@ def forward( hidden_states = inputs_embeds - total_loss = 0.0 - for decoder_layer in self.layers: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - ) - - hidden_states = layer_outputs[0] - load_balance_loss = layer_outputs[-1] - total_loss += load_balance_loss + load_balance_loss = torch.tensor(0.0, device=device) + hidden_states, load_balance_loss = self.layers( + hidden_states, + load_balance_loss, + attention_mask=causal_mask, + position_ids=position_ids, + ) - total_loss = total_loss / len(self.layers) + load_balance_loss = load_balance_loss / len(self.layers) hidden_states = self.norm(hidden_states) - return (hidden_states, total_loss) + return (hidden_states, load_balance_loss) class MixtralForCausalLM(nn.Module): diff --git a/torchprime/torch_xla_models/scan_layers.py b/torchprime/torch_xla_models/scan_layers.py new file mode 100644 index 00000000..e15fa669 --- /dev/null +++ b/torchprime/torch_xla_models/scan_layers.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +from functorch.compile import default_partition +from torch_xla.experimental.scan_layers import scan_layers + +from torchprime.layers.sequential import HomogeneousSequential, PyTree, splat + + +class HomogeneousSequentialScan(HomogeneousSequential): + def __init__(self, *args, partition_fn=default_partition): + super().__init__(*args) + self.partition_fn = partition_fn + + def forward(self, *input, **broadcasted_inputs: PyTree): + # `self.children()` returns an iterator over the immediate submodules, i.e. + # the layers we want to scan over. In the `BroadcastArguments` we extend each + # layer's return value to also output the broadcasted inputs + # (position IDs in case of LLMs, etc). This plumbs those values across scan + # iterations so the same values are available to all layers. + layers = [BroadcastArguments(m) for m in self.children()] + if len(input) == 1: + # Handle single argument case: we don't need to call the module with a tuple. + input = input[0] + out, _broadcasted_inputs_back = scan_layers( + layers, (input, broadcasted_inputs), partition_fn=self.partition_fn + ) + return out + + +class BroadcastArguments(torch.nn.Module): + def __init__(self, mod: nn.Module): + super().__init__() + self.mod = mod + + def forward(self, orig_input, broadcasted_inputs): + out = self.mod(*splat(orig_input), **broadcasted_inputs) + return (out, broadcasted_inputs) + + +def compile_one_stack( + mod: HomogeneousSequential, partition_fn=default_partition +) -> HomogeneousSequential: + # Replace base class with our optimized subclass. + if isinstance(mod, HomogeneousSequentialScan): + raise NotImplementedError("Cannot compile HomogeneousSequential twice") + new_mod = HomogeneousSequentialScan(*mod.children(), partition_fn=partition_fn) + return new_mod + + +def compile( + mod: nn.Module, sequential_to_scan: str, partition_fn=default_partition +) -> nn.Module: + seq = mod.get_submodule(sequential_to_scan) + if not isinstance(seq, HomogeneousSequential): + raise ValueError(f"compile only supports HomogeneousSequential, got {type(seq)}") + # Replace the submodule + mod.set_submodule( + sequential_to_scan, compile_one_stack(seq, partition_fn=partition_fn) + ) + return mod diff --git a/torchprime/torch_xla_models/tests/test_llama.py b/torchprime/torch_xla_models/tests/test_llama.py index 1eb00ca4..ed13f27a 100644 --- a/torchprime/torch_xla_models/tests/test_llama.py +++ b/torchprime/torch_xla_models/tests/test_llama.py @@ -32,7 +32,7 @@ def get_llama_3_8b() -> LlamaFixture: intermediate_size=16, vocab_size=vocab_size, ) - config.attention_kernel = None + config.flash_attention = False torchprime_config = OmegaConf.create( { "vocab_size": 128, @@ -71,7 +71,7 @@ def get_llama_3_1_405b() -> LlamaFixture: intermediate_size=32, vocab_size=vocab_size, ) - config.attention_kernel = None + config.flash_attention = False torchprime_config = OmegaConf.create( { "vocab_size": 256, @@ -116,15 +116,27 @@ def get_llama_3_1_405b() -> LlamaFixture: return LlamaFixture(vocab_size, hf_model, model) +def noop(mod): + return mod + + +def scan_decoders(mod): + import torchprime.torch_xla_models.scan_layers + + return torchprime.torch_xla_models.scan_layers.compile(mod, "model.layers") + + @pytest.mark.parametrize( "fixture", [get_llama_3_8b, get_llama_3_1_405b], ids=["Llama 3.0 8B", "Llama 3.1 405B"], ) -def test_forward_our_model_against_hf_model(fixture): +@pytest.mark.parametrize("transform", [noop, scan_decoders]) +def test_forward_our_model_against_hf_model(fixture, transform): fixture = fixture() device = torch_xla.device() model_xla = copy.deepcopy(fixture.model).to(device) + model_xla = transform(model_xla) hf_model_xla = copy.deepcopy(fixture.hf_model).to(device) torch_xla.sync() input_sizes = [8, 128, 256] diff --git a/torchprime/torch_xla_models/tests/test_mixtral.py b/torchprime/torch_xla_models/tests/test_mixtral.py index 7b463412..527d18c3 100644 --- a/torchprime/torch_xla_models/tests/test_mixtral.py +++ b/torchprime/torch_xla_models/tests/test_mixtral.py @@ -1,7 +1,9 @@ import copy -import unittest +from dataclasses import dataclass +import pytest import torch +import torch.test import torch_xla from omegaconf import OmegaConf from transformers import AutoConfig @@ -10,98 +12,133 @@ from torchprime.torch_xla_models.mixtral import MixtralForCausalLM -class TestYourModule(unittest.TestCase): - def setUp(self): - super().setUp() - torch.manual_seed(42) - torch_xla.manual_seed(42) - self.vocab_size = 128 - config = AutoConfig.from_pretrained( - "mistralai/Mixtral-8x7B-v0.1", - num_hidden_layers=1, - num_attention_heads=8, - hidden_size=8, - intermediate_size=16, - vocab_size=self.vocab_size, - ) - config.attention_kernel = None - torchprime_config = OmegaConf.create( - { - "vocab_size": 128, - "hidden_size": 8, - "intermediate_size": 16, - "num_hidden_layers": 1, - "num_attention_heads": 8, - "num_key_value_heads": 8, - "max_position_embeddings": 32768, - "initializer_range": 0.02, - "rms_norm_eps": 1.0e-05, - "num_experts_per_tok": 2, - "num_local_experts": 8, - "rope_theta": 1000000.0, - "router_aux_loss_coef": 0.02, - "attention_dropout": 0.0, - "attention_bias": False, - "attention_kernel": None, - "moe_implementation": "static", - } - ) - # place model on CPU device first - with torch.device("cpu"): - self.hf_model = HfMixtralForCausalLM(config) - self.model = MixtralForCausalLM(torchprime_config) - self.model.load_state_dict(self.hf_model.state_dict()) - - def test_forward_our_model_against_hf_model(self): - device = torch_xla.device() - model_xla = copy.deepcopy(self.model).to(device) - hf_model_xla = copy.deepcopy(self.hf_model).to(device) - torch_xla.sync() - input_sizes = [8, 128, 256] - for input_size in input_sizes: - input = torch.randint(128, ((2, input_size // 2))).to(device) - hf_output = hf_model_xla( - input, labels=input, attention_mask=torch.ones_like(input) - ) - mixtral_xla_logits, mixtral_xla_loss = model_xla( - input, labels=input, attention_mask=torch.ones_like(input) - ) - torch_xla.sync() - self.assertTrue( - torch.allclose(hf_output.logits, mixtral_xla_logits, atol=1e-6), - "logits are not equal", - ) - self.assertTrue( - torch.allclose(hf_output.loss, mixtral_xla_loss, atol=1e-6), - "loss is not equal", - ) - - def test_forward_torch_xla_against_native(self): - input_size = 8 - device = torch.device("cpu") - input = torch.randint(self.vocab_size, ((2, input_size // 2))) - mixtral_native_logits, mixtral_native_loss = self.model( - input, labels=input, attention_mask=torch.ones_like(input) - ) +@dataclass +class MixtralFixture: + vocab_size: int + hf_model: HfMixtralForCausalLM + model: MixtralForCausalLM - device = torch_xla.device() - input = input.to(device) - model_xla = copy.deepcopy(self.model).to(device) - torch_xla.sync() +def get_mixtral_8x7b() -> MixtralFixture: + torch.manual_seed(42) + torch_xla.manual_seed(42) + vocab_size = 128 + config = AutoConfig.from_pretrained( + "mistralai/Mixtral-8x7B-v0.1", + num_hidden_layers=1, + num_attention_heads=8, + hidden_size=8, + intermediate_size=16, + vocab_size=vocab_size, + ) + config.flash_attention = False + torchprime_config = OmegaConf.create( + { + "vocab_size": vocab_size, + "hidden_size": 8, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 8, + "num_key_value_heads": 8, + "max_position_embeddings": 32768, + "initializer_range": 0.02, + "rms_norm_eps": 1.0e-05, + "num_experts_per_tok": 2, + "num_local_experts": 8, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.02, + "attention_dropout": 0.0, + "attention_bias": False, + "attention_kernel": None, + "moe_implementation": "static", + } + ) + # place model on CPU device first + with torch.device("cpu"): + hf_model = HfMixtralForCausalLM(config) + model = MixtralForCausalLM(torchprime_config) + model.load_state_dict(hf_model.state_dict()) + + return MixtralFixture(vocab_size, hf_model, model) + + +def noop(mod): + return mod + + +def scan_decoders(mod): + import torchprime.torch_xla_models.scan_layers + + return torchprime.torch_xla_models.scan_layers.compile(mod, "model.layers") + + +@pytest.mark.parametrize( + "fixture", + [get_mixtral_8x7b], + ids=["Mixtral 8x7B"], +) +@pytest.mark.parametrize("transform", [noop, scan_decoders]) +def test_forward_our_model_against_hf_model(fixture, transform): + fixture = fixture() + device = torch_xla.device() + model_xla = copy.deepcopy(fixture.model).to(device) + model_xla = transform(model_xla) + hf_model_xla = copy.deepcopy(fixture.hf_model).to(device) + torch_xla.sync() + input_sizes = [8, 128, 256] + for input_size in input_sizes: + input = torch.randint(128, ((2, input_size // 2))).to(device) + hf_output = hf_model_xla(input, labels=input, attention_mask=torch.ones_like(input)) mixtral_xla_logits, mixtral_xla_loss = model_xla( input, labels=input, attention_mask=torch.ones_like(input) ) torch_xla.sync() - self.assertTrue( - torch.allclose(mixtral_native_logits, mixtral_xla_logits.to("cpu"), atol=1e-2), - "CPU run and XLA run logits are not equal", + torch.testing.assert_close( + hf_output.logits, + mixtral_xla_logits, + atol=1e-6, + rtol=1e-9, + msg="logits are not equal", ) - self.assertTrue( - torch.allclose(mixtral_native_loss, mixtral_native_loss.to("cpu"), atol=1e-2), - "CPU run and XLA run loss is not equal", + torch.testing.assert_close( + hf_output.loss, mixtral_xla_loss, atol=1e-6, rtol=1e-9, msg="loss is not equal" ) -if __name__ == "__main__": - unittest.main() +@pytest.mark.parametrize( + "fixture", + [get_mixtral_8x7b], + ids=["Mixtral 8x7B"], +) +def test_forward_torch_xla_against_native(fixture): + fixture = fixture() + input_size = 8 + device = torch.device("cpu") + input = torch.randint(fixture.vocab_size, ((2, input_size // 2))) + mixtral_native_logits, mixtral_native_loss = fixture.model( + input, labels=input, attention_mask=torch.ones_like(input) + ) + + device = torch_xla.device() + input = input.to(device) + model_xla = copy.deepcopy(fixture.model).to(device) + torch_xla.sync() + + mixtral_xla_logits, mixtral_xla_loss = model_xla( + input, labels=input, attention_mask=torch.ones_like(input) + ) + torch_xla.sync() + torch.testing.assert_close( + mixtral_native_logits, + mixtral_xla_logits.to("cpu"), + atol=1e-2, + rtol=1e-4, + msg="CPU run and XLA run logits are not equal", + ) + torch.testing.assert_close( + mixtral_native_loss, + mixtral_native_loss.to("cpu"), + atol=1e-2, + rtol=1e-4, + msg="CPU run and XLA run loss is not equal", + )