Skip to content

Commit ccff555

Browse files
authored
Initial Llama 3.1 405B support (#108)
1 parent 07f3dfe commit ccff555

File tree

9 files changed

+386
-119
lines changed

9 files changed

+386
-119
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ exclude = ["torchprime.*.tests.*"]
5151

5252
[tool.pytest.ini_options]
5353
minversion = "6.0"
54-
addopts = "--forked" # ensure torchax and torch_xla tests don't conflict
54+
55+
# `--forked` ensures torchax and torch_xla tests don't conflict.
56+
# `--ignore local_transformers` ignores any local Hugging Face transformers checkout
57+
addopts = "--forked --ignore local_transformers"
5558

5659
[tool.ruff]
5760
indent-width = 2

torchprime/rope/rope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010

1111

12-
@dataclass
12+
@dataclass(kw_only=True)
1313
class RopeScaling:
1414
"""
1515
RoPE scaling parameters. The defaults are what was selected in Llama 3.1.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
defaults:
2+
- _self_
3+
- scaling: llama-fsdp-tp
4+
5+
model_class: llama.LlamaForCausalLM
6+
vocab_size: 128256
7+
hidden_size: 16384
8+
intermediate_size: 53248
9+
num_hidden_layers: 126
10+
num_attention_heads: 128
11+
num_key_value_heads: 8
12+
hidden_act: silu
13+
max_position_embeddings: 131072
14+
bos_token_id: 128000
15+
eos_token_id: 128001
16+
tokenizer_name: meta-llama/Meta-Llama-3.1-405B
17+
initializer_range: 0.02
18+
rms_norm_eps: 1.0e-05
19+
attention_dropout: false
20+
attention_bias: false
21+
flash_attention: true
22+
rope_theta: 500000.0
23+
rope_scaling:
24+
factor: 8.0
25+
low_freq_factor: 1.0
26+
high_freq_factor: 4.0
27+
original_context_len: 8192
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# 2D (FSDP + TP) sharding configuration for Llama models.
2+
3+
activation_checkpoint_layers:
4+
- LlamaDecoderLayer
5+
6+
# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info.
7+
optimization_barrier_layers:
8+
- LlamaDecoderLayer
9+
10+
sharding:
11+
# Weights
12+
13+
# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/114): This
14+
# cannot be `[tensor, fsdp]`, or the gradients will sometimes become NaN.
15+
model.embed_tokens.weight: [fsdp, tensor]
16+
17+
model.layers.*.self_attn.q_proj.weight: [tensor, fsdp]
18+
model.layers.*.self_attn.k_proj.weight: [tensor, fsdp]
19+
model.layers.*.self_attn.v_proj.weight: [tensor, fsdp]
20+
model.layers.*.self_attn.o_proj.weight: [fsdp, tensor]
21+
model.layers.*.mlp.gate_proj.weight: [tensor, fsdp]
22+
model.layers.*.mlp.up_proj.weight: [tensor, fsdp]
23+
model.layers.*.mlp.down_proj.weight: [fsdp, tensor]
24+
model.layers.*.input_layernorm.weight: [fsdp]
25+
model.layers.*.post_attention_layernorm.weight: [fsdp]
26+
model.norm.weight: [fsdp]
27+
lm_head.weight: [tensor, fsdp]

torchprime/torch_xla_models/llama/model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from transformers.activations import ACT2FN
2828
from transformers.utils import logging
2929

30+
from torchprime.rope.rope import RopeScaling, llama3_rope_frequencies
3031
from torchprime.torch_xla_models.loss import cross_entropy_loss
3132

3233
logger = logging.get_logger(__name__)
@@ -50,18 +51,16 @@ def forward(self, hidden_states):
5051

5152

5253
class LlamaRotaryEmbedding(nn.Module):
54+
inv_freq: nn.Buffer
55+
5356
def __init__(
54-
self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0
57+
self,
58+
head_dim,
59+
rope_theta,
60+
scaling: RopeScaling | None = None,
5561
):
5662
super().__init__()
57-
self.scaling_factor = scaling_factor
58-
self.dim = dim
59-
self.max_position_embeddings = max_position_embeddings
60-
self.base = base
61-
inv_freq = 1.0 / (
62-
self.base
63-
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)
64-
)
63+
inv_freq = llama3_rope_frequencies(head_dim, theta=rope_theta, scaling=scaling)
6564
self.register_buffer("inv_freq", inv_freq, persistent=False)
6665

6766
@torch.no_grad()
@@ -203,10 +202,11 @@ def __init__(self, config: DictConfig, layer_idx: int | None = None):
203202
self._init_rope()
204203

205204
def _init_rope(self):
205+
scaling = self.config.get("rope_scaling", None)
206+
if scaling is not None:
207+
scaling = RopeScaling(**scaling)
206208
self.rotary_emb = LlamaRotaryEmbedding(
207-
self.head_dim,
208-
max_position_embeddings=self.max_position_embeddings,
209-
base=self.rope_theta,
209+
head_dim=self.head_dim, rope_theta=self.rope_theta, scaling=scaling
210210
)
211211

212212
@xp.trace_me("LlamaAttention")
Lines changed: 162 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import copy
2-
import unittest
2+
from dataclasses import dataclass
33

4+
import pytest
45
import torch
6+
import torch.nn as nn
7+
import torch.test
58
import torch_xla
69
from omegaconf import OmegaConf
710
from transformers import AutoConfig
@@ -10,95 +13,174 @@
1013
from torchprime.torch_xla_models.llama import LlamaForCausalLM
1114

1215

13-
class TestYourModule(unittest.TestCase):
14-
def setUp(self):
15-
super().setUp()
16-
torch.manual_seed(42)
17-
torch_xla.manual_seed(42)
18-
self.vocab_size = 128
19-
config = AutoConfig.from_pretrained(
20-
"meta-llama/Meta-Llama-3-8B",
21-
num_hidden_layers=1,
22-
num_attention_heads=8,
23-
hidden_size=8,
24-
intermediate_size=16,
25-
vocab_size=self.vocab_size,
26-
)
27-
config.flash_attention = False
28-
torchprime_config = OmegaConf.create(
29-
{
30-
"vocab_size": 128,
31-
"hidden_size": 8,
32-
"intermediate_size": 16,
33-
"num_hidden_layers": 1,
34-
"num_attention_heads": 8,
35-
"num_key_value_heads": 8,
36-
"hidden_act": "silu",
37-
"max_position_embeddings": 8192,
38-
"initializer_range": 0.02,
39-
"rms_norm_eps": 1.0e-05,
40-
"attention_dropout": False,
41-
"attention_bias": False,
42-
"flash_attention": False,
43-
"rope_theta": 500000.0,
44-
}
45-
)
46-
# place model on CPU device first
47-
with torch.device("cpu"):
48-
self.hf_model = HfLlamaForCausalLM(config)
49-
self.model = LlamaForCausalLM(torchprime_config)
50-
self.model.load_state_dict(self.hf_model.state_dict())
16+
@dataclass
17+
class LlamaFixture:
18+
vocab_size: int
19+
hf_model: HfLlamaForCausalLM
20+
model: LlamaForCausalLM
5121

52-
def test_forward_our_model_against_hf_model(self):
53-
device = torch_xla.device()
54-
model_xla = copy.deepcopy(self.model).to(device)
55-
hf_model_xla = copy.deepcopy(self.hf_model).to(device)
56-
torch_xla.sync()
57-
input_sizes = [8, 128, 256]
58-
for input_size in input_sizes:
59-
input = torch.randint(128, ((2, input_size // 2))).to(device)
60-
hf_output = hf_model_xla(
61-
input, labels=input, attention_mask=torch.ones_like(input)
62-
)
63-
llama_xla_logits, llama_xla_loss = model_xla(
64-
input, labels=input, attention_mask=torch.ones_like(input)
65-
)
66-
torch_xla.sync()
67-
self.assertTrue(
68-
torch.allclose(hf_output.logits, llama_xla_logits, atol=1e-6),
69-
"logits are not equal",
70-
)
71-
self.assertTrue(
72-
torch.allclose(hf_output.loss, llama_xla_loss, atol=1e-6),
73-
"loss is not equal",
74-
)
7522

76-
def test_forward_torch_xla_against_native(self):
77-
input_size = 8
78-
device = torch.device("cpu")
79-
input = torch.randint(self.vocab_size, ((2, input_size // 2)))
80-
llama_native_logits, llama_native_loss = self.model(
81-
input, labels=input, attention_mask=torch.ones_like(input)
23+
def get_llama_3_8b() -> LlamaFixture:
24+
torch.manual_seed(42)
25+
torch_xla.manual_seed(42)
26+
vocab_size = 128
27+
config = AutoConfig.from_pretrained(
28+
"meta-llama/Meta-Llama-3-8B",
29+
num_hidden_layers=1,
30+
num_attention_heads=8,
31+
hidden_size=64,
32+
intermediate_size=16,
33+
vocab_size=vocab_size,
34+
)
35+
config.flash_attention = False
36+
torchprime_config = OmegaConf.create(
37+
{
38+
"vocab_size": 128,
39+
"hidden_size": 64,
40+
"intermediate_size": 16,
41+
"num_hidden_layers": 1,
42+
"num_attention_heads": 8,
43+
"num_key_value_heads": 8,
44+
"hidden_act": "silu",
45+
"max_position_embeddings": 8192,
46+
"initializer_range": 0.02,
47+
"rms_norm_eps": 1.0e-05,
48+
"attention_dropout": False,
49+
"attention_bias": False,
50+
"flash_attention": False,
51+
"rope_theta": 500000.0,
52+
}
53+
)
54+
# Place model on CPU device first
55+
with torch.device("cpu"):
56+
hf_model = HfLlamaForCausalLM(config)
57+
model = LlamaForCausalLM(torchprime_config)
58+
model.load_state_dict(hf_model.state_dict())
59+
return LlamaFixture(vocab_size, hf_model, model)
60+
61+
62+
def get_llama_3_1_405b() -> LlamaFixture:
63+
torch.manual_seed(42)
64+
torch_xla.manual_seed(42)
65+
vocab_size = 256
66+
config = AutoConfig.from_pretrained(
67+
"meta-llama/Meta-Llama-3.1-405B",
68+
num_hidden_layers=2,
69+
num_attention_heads=8,
70+
hidden_size=64,
71+
intermediate_size=32,
72+
vocab_size=vocab_size,
73+
)
74+
config.flash_attention = False
75+
torchprime_config = OmegaConf.create(
76+
{
77+
"vocab_size": 256,
78+
"hidden_size": 64,
79+
"intermediate_size": 32,
80+
"num_hidden_layers": 2,
81+
"num_attention_heads": 8,
82+
"num_key_value_heads": 8,
83+
"hidden_act": "silu",
84+
"max_position_embeddings": 131072,
85+
"initializer_range": 0.02,
86+
"rms_norm_eps": 1.0e-05,
87+
"attention_dropout": False,
88+
"attention_bias": False,
89+
"flash_attention": False,
90+
"rope_theta": 500000.0,
91+
"rope_scaling": {
92+
"factor": 8.0,
93+
"low_freq_factor": 1.0,
94+
"high_freq_factor": 4.0,
95+
"original_context_len": 8192,
96+
},
97+
}
98+
)
99+
# Place model on CPU device first
100+
with torch.device("cpu"):
101+
hf_model = HfLlamaForCausalLM(config)
102+
model = LlamaForCausalLM(torchprime_config)
103+
# Assert that the `inv_freq` values are the same
104+
assert isinstance(model.model.layers[0].self_attn, nn.Module)
105+
assert isinstance(hf_model.model.layers[0].self_attn, nn.Module)
106+
assert isinstance(model.model.layers[0].self_attn.rotary_emb, nn.Module)
107+
assert isinstance(hf_model.model.layers[0].self_attn.rotary_emb, nn.Module)
108+
torch.testing.assert_close(
109+
model.model.layers[0].self_attn.rotary_emb.inv_freq,
110+
hf_model.model.layers[0].self_attn.rotary_emb.inv_freq,
82111
)
112+
# In this simplified model architecture, hidden_size 64 / num_attention_heads 8 = 8 head dim,
113+
# and the inv_freq size is half of the head dim.
114+
assert model.model.layers[0].self_attn.rotary_emb.inv_freq.shape == (4,)
115+
model.load_state_dict(hf_model.state_dict())
116+
return LlamaFixture(vocab_size, hf_model, model)
83117

84-
device = torch_xla.device()
85-
input = input.to(device)
86-
model_xla = copy.deepcopy(self.model).to(device)
87-
torch_xla.sync()
88118

119+
@pytest.mark.parametrize(
120+
"fixture",
121+
[get_llama_3_8b, get_llama_3_1_405b],
122+
ids=["Llama 3.0 8B", "Llama 3.1 405B"],
123+
)
124+
def test_forward_our_model_against_hf_model(fixture):
125+
fixture = fixture()
126+
device = torch_xla.device()
127+
model_xla = copy.deepcopy(fixture.model).to(device)
128+
hf_model_xla = copy.deepcopy(fixture.hf_model).to(device)
129+
torch_xla.sync()
130+
input_sizes = [8, 128, 256]
131+
for input_size in input_sizes:
132+
input = torch.randint(fixture.vocab_size, ((2, input_size // 2))).to(device)
133+
hf_output = hf_model_xla(input, labels=input, attention_mask=torch.ones_like(input))
89134
llama_xla_logits, llama_xla_loss = model_xla(
90135
input, labels=input, attention_mask=torch.ones_like(input)
91136
)
92137
torch_xla.sync()
93-
self.assertTrue(
94-
torch.allclose(llama_native_logits, llama_xla_logits.to("cpu"), atol=1e-2),
95-
"CPU run and XLA run logits are not equal",
138+
torch.testing.assert_close(
139+
hf_output.logits,
140+
llama_xla_logits,
141+
atol=1e-6,
142+
rtol=1e-9,
143+
msg="logits are not equal",
96144
)
97-
self.assertTrue(
98-
torch.allclose(llama_native_loss, llama_xla_loss.to("cpu"), atol=1e-2),
99-
"CPU run and XLA run loss is not equal",
145+
torch.testing.assert_close(
146+
hf_output.loss, llama_xla_loss, atol=1e-6, rtol=1e-9, msg="loss is not equal"
100147
)
101148

102149

103-
if __name__ == "__main__":
104-
unittest.main()
150+
@pytest.mark.parametrize(
151+
"fixture",
152+
[get_llama_3_8b, get_llama_3_1_405b],
153+
ids=["Llama 3.0 8B", "Llama 3.1 405B"],
154+
)
155+
def test_forward_torch_xla_against_native(fixture):
156+
fixture = fixture()
157+
input_size = 8
158+
device = torch.device("cpu")
159+
input = torch.randint(fixture.vocab_size, ((2, input_size // 2)))
160+
llama_native_logits, llama_native_loss = fixture.model(
161+
input, labels=input, attention_mask=torch.ones_like(input)
162+
)
163+
164+
device = torch_xla.device()
165+
input = input.to(device)
166+
model_xla = copy.deepcopy(fixture.model).to(device)
167+
torch_xla.sync()
168+
169+
llama_xla_logits, llama_xla_loss = model_xla(
170+
input, labels=input, attention_mask=torch.ones_like(input)
171+
)
172+
torch_xla.sync()
173+
torch.testing.assert_close(
174+
llama_native_logits,
175+
llama_xla_logits.to("cpu"),
176+
atol=1e-2,
177+
rtol=1e-6,
178+
msg="CPU run and XLA run logits are not equal",
179+
)
180+
torch.testing.assert_close(
181+
llama_native_loss,
182+
llama_xla_loss.to("cpu"),
183+
atol=1e-2,
184+
rtol=1e-6,
185+
msg="CPU run and XLA run loss is not equal",
186+
)

0 commit comments

Comments
 (0)