Skip to content

[llama3] add configurations for Llama 3 1B and 3B models #1376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@
use_flex_attn=True,
attn_mask_type="block_causal",
),
"1B": TransformerModelArgs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the weight-tying, I think it is fine, as this is just a more relaxed configuration with higher expressivity, so if it used for inference or fine-tuning, the starting point would be identical to Llama 1B and 3B.

I'm OK with this, as long as you could help add some comments here for 1B and 3B

Suggested change
"1B": TransformerModelArgs(
# NOTE: The original model checkpoints of Llama 3.2 1B and 3B are provided
# with weight-tying between the embedding layer and the output layer,
# which is not supported in torchtitan.
"1B": TransformerModelArgs(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we plan to provide ways to load HF checkpoints into torchtitan for training.
A note for us that the mapping designed for Llama 3.1 may not work for Llama 3.2 due to reasons like https://www.reddit.com/r/LocalLLaMA/comments/1fzn8c9/where_did_llama_32s_language_modeling_head_go/

dim=2048,
n_layers=16,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this give the same results compared with when it is 1.4?
Asking because I've verified it works with 1.4
https://github.com/pytorch/torchtitan/pull/1040/files#r2043164137

Copy link
Contributor Author

@idoh idoh Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the 1.4 results in the same configuration, but 1.5 is more accurate in this case. The ffn_dim_multiplier to intermidiate_size conversion is a bit tricky, but it comes down to these lines of code:
https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model/model.py#L272
https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model/model.py#L226
https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model/model.py#L230

So basically you take the dim size, multiply it by 4, then by 2/3, then by ffn_dim_multiplier and finally round it to the nearest multiple_of. In the case of 1.4 you get 7645.86 which thanks to the rounding, it is rounded to 8192, which is the correct answer. For 1.5 it is exactly 8192, which IMO is safer incase someone decides to change multiple_of.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, thanks for explaining

multiple_of=1024,
rope_theta=500000,
),
"3B": TransformerModelArgs(
dim=3072,
n_layers=28,
n_heads=24,
n_kv_heads=8,
ffn_dim_multiplier=1.0,
multiple_of=1024,
rope_theta=500000,
),
"8B": TransformerModelArgs(
dim=4096,
n_layers=32,
Expand Down
62 changes: 62 additions & 0 deletions torchtitan/models/llama3/train_configs/llama3_1b.toml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these files are particularly interesting, and also there's no evidence such setting (e.g. batch size, learning rate) would provide stable training and good throughputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true that the batch size and learning rate might need to change for optimal settings, however, the existing ones should be a sufficient starting point.

As for the weight-tying, I think it is fine, as this is just a more relaxed configuration with higher expressivity, so if it used for inference or fine-tuning, the starting point would be identical to Llama 1B and 3B.

Finally, I think there is a lot of value in adding these configurations as labs even if they have the funds, have at most one chance to train a large 70B model, whereas most ablations and testing are conducted on smaller sizes such as 1B or 3B.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally, I think there is a lot of value in adding these configurations as labs even if they have the funds, have at most one chance to train a large 70B model, whereas most ablations and testing are conducted on smaller sizes such as 1B or 3B.

I'm still not 100% comfortable adding .toml configs that are not verified.
Given that the file is almost identical to llama3_8b.toml, users can just run

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.flavor 1B

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice, I didn't think of that option. It depends if we want to optimize the hyper-parameters for a H100 setting. As with a 3B model we can fit a much larger batch size per GPU. I'm comfortable either way.

Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# torchtitan Config.toml
# NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 1B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "1B"
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 200 # lr scheduler warm up

[training]
local_batch_size = 1
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 1000
compile = false
dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 500
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]
62 changes: 62 additions & 0 deletions torchtitan/models/llama3/train_configs/llama3_3b.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# torchtitan Config.toml
# NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 3B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "3B"
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 200 # lr scheduler warm up

[training]
local_batch_size = 1
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 1000
compile = false
dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 500
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]