Skip to content

Commit 9a9ea25

Browse files
authored
Add support for Llama 3.2. 3B in eval/bench pipelines (#1271)
* feat: adds support for llama 3.2 3b in benchmarks * chore: add model to prepare.sh
1 parent 39f16f4 commit 9a9ea25

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

scripts/convert_hf_checkpoint.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def permute(w, n_head):
8585
else:
8686
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
8787
merged_result.update(state_dict)
88+
89+
if config.tie_word_embeddings:
90+
merged_result["lm_head.weight"] = merged_result["model.embed_tokens.weight"].clone()
91+
8892
final_result = {}
8993
for key, value in merged_result.items():
9094
if "layers" in key:
@@ -112,7 +116,7 @@ def permute(w, n_head):
112116
del final_result[key.replace("wq", "wv")]
113117
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
114118
torch.save(final_result, checkpoint_dir / "model.pth")
115-
if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower():
119+
if any([x in model_name.lower() for x in ["llama-3-", "llama-3.1-", "llama-3.2-"]]):
116120
if 'llama-3.1-405b' in model_name.lower():
117121
original_dir = checkpoint_dir / "original" / "mp16"
118122
else:

scripts/prepare.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
22
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
33
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
4+
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B
45
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
56
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
67
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
8+
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B

torchao/_models/llama/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class ModelArgs:
3535
rope_base: float = 10000
3636
norm_eps: float = 1e-5
3737
use_scaled_rope: bool = False
38+
tie_word_embeddings: bool = False
3839

3940
def __post_init__(self):
4041
if self.n_local_heads == -1:
@@ -79,6 +80,9 @@ def from_name(cls, name: str):
7980
"Llama-3.1-405B": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
8081
use_scaled_rope=True
8182
),
83+
"Llama-3.2-3B": dict(block_size=131072, n_layer=28, n_head=24, n_local_heads=8, dim=3072, intermediate_size=8192, vocab_size=128256, rope_base=500000,
84+
use_scaled_rope=True, tie_word_embeddings=True
85+
),
8286
}
8387

8488
# this is a model specific variable that controls whether index_put is used for the kv_cache update,

0 commit comments

Comments
 (0)