Skip to content

Finetune from pre-trained models #1300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,36 @@ Llama 3 8B model locally on 8 GPUs
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
```

### Fine-tuning from an existing checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this under docs/finetune.md instead of main README? We can create a link to the doc around here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course. Will do.


You first need to download the Llama checkpoint. Here are the commands:

```bash
export HF_TOKEN=... # get your HF token from https://huggingface.co/settings/tokens
# Download the tokenizer and model weights
rm -rf tmp
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/tokenizer.model --local-dir tmp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We covered the downloading of tokenizer above, in section "Downloading a tokenizer".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah true, I was gonna ask do you want to replace that with huggingface-cli commands? We could use it for both downloading tokenizer and the actual models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, maybe let's first put the complete huggingface-cli flow inside finetune.md. If people get used to it, we can change the version in main README later.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

uv run huggingface-cli download meta-llama/Llama-3.1-8B original/consolidated.00.pth --local-dir tmp
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/params.json --local-dir tmp
# Convert the model weights to the DCP format and move it and the tokenizer to the assets folder
mkdir -p assets/tokenizer && cp tmp/original/tokenizer.model assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model
uv run python -m scripts.convert_llama_to_dcp tmp/original/ assets/models/dcp/llama3.1-8B
Comment on lines +126 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using uv is fine but as general instruction we shouldn't assume users have to use uv

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of tmp and assets/models/dcp which looks arbitrarily chose, let's try to use generic placeholders.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah forgot to remove the uv part. What do you mean by generic placeholders?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like, instead of tmp, use [original_model_dir], dcp_model_dir, [tokenizer_dir] so that people know what to replace

```

Then you can fine-tune from the checkpoint:

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" uv run ./run_train.sh \
--model.tokenizer_path assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model \
--training.max_seq_len 131072 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's necessary to create this config -- how is it different from specifying --training.seq_len 131072?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One example use case is when I don't actually have documents up to seq_len 131072, but the pre-trained model has a default seq_len of 131072.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the seq_len field is only used when generating freqs_cis https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model/model.py#L397
This should be input agnostic, so I feel you can just specify --training.seq_len to be however long you need (as long as it doesn't exceed model capability).
Let me know if it's not the case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see the problem. The issue is that HuggingFaceDataset uses --training.seq_len, so the packed dataset also has the same length.

In that case, we should prob re-use the same seq_len, but allowing the HuggingFaceDataset to use a separate packed_len. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question is why you'd wish them to be different.
The requirement is that: if HF dataset uses seq_len_hf, then we need to have seq_len_transformer >= seq_len_hf to make sure the freqs_cis is init with enough length.
But we don't need seq_len_transformer > seq_len_hf (or do we?), so it can just be seq_len_transformer = seq_len_hf = training.seq_len.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am worried about setting seq_len_transformer=131072 would make training OOM vs seq_len_transformer =8192.

However, it appears I need to set seq_len_transformer=131072 if I am trying to load a pretrained model such as llama 3.1 8B. Is this correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it appears I need to set seq_len_transformer=131072 if I am trying to load a pretrained model such as llama 3.1 8B. Is this correct?

oh I see your worry.

I don't think it should be the case. Like I said, the only place the seq_len matters in transformers is for freqs_cis which is a non-persistent buffer and shouldn't be included the model checkpoint.
(Previously in torchtitan it could be, but after https://github.com/pytorch/torchtitan/pull/1236/files#diff-27a108fa6d4885d9c66306785cb36029c0b4f5a1542e63ae24e84eb7e9a273d1R87 it shouldn't.)

For your finetuning job, the model capability shouldn't be affected by specifying a smaller max_seq_len.

BTW, you could consider use CP in torchtitan for long sequence finetuning.

Copy link
Author

@vwxyzjn vwxyzjn Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your finetuning job, the model capability shouldn't be affected by specifying a smaller max_seq_len.

I guess an important question is this: if we have a pretrained model with seq_len=131072, should we always compute freqs_cis using seq_len=131072?

If the answer is yes, it would make sense to set up an arg called seq_len_transformer (in place of my current max_seq_len, and set it to 131072 when loading llama 3.1 8B.

I see. It looks like because of how freqs_cis is used in reshape_for_broadcast, it's fine if we calculate it without the full 131072. Then it doesn't make sense to save / load from it.

Thanks. I will adjust the PR accordingly.

image

--checkpoint.initial_load_path "assets/models/dcp/llama3.1-8B" \
--profiling.no_enable_profiling \
--activation_checkpoint.mode full \
--training.global_batch_size 64 \
--lr_scheduler.warmup_steps 40 \
--optimizer.lr 1e-5
```

### Multi-Node Training
For training on ParallelCluster/Slurm type configurations, you can use the `multinode_trainer.slurm` file to submit your sbatch job.

Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_llama_to_dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
import torch.distributed.checkpoint as DCP
from torchtitan.models.llama.model import precompute_freqs_cis
from torchtitan.models.llama3.model import precompute_freqs_cis
from torchtitan.tools.logging import init_logger, logger


Expand Down
76 changes: 40 additions & 36 deletions torchtitan/components/checkpoint.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh we shouldn't just revert the changes -- instead we should investigate the root cause

cc @fegin pls take a look at if recent changes break anything

Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,6 @@ def load_state_dict(state_dict):
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
) or self.ft_manager

if not self.enable_checkpoint and self.ft_manager is None:
return

self.states = states
self.states.update(
{
Expand All @@ -265,18 +262,21 @@ def load_state_dict(state_dict):
}
)
self.ft_states = {DATALOADER: dataloader}
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.exclude_from_loading = ckpt_config.exclude_from_loading
self.initial_load_path = ckpt_config.initial_load_path
self.initial_load_model_weights_only = (
ckpt_config.initial_load_model_weights_only
)
if not self.enable_checkpoint and self.ft_manager is None:
return

self.staging = False
self.sending_to_checkpoint_mp = False
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.initial_load_path = ckpt_config.initial_load_path
self.initial_load_model_weights_only = (
ckpt_config.initial_load_model_weights_only
)
self.interval = ckpt_config.interval
async_mode = ckpt_config.async_mode.lower()
if async_mode == AsyncMode.ASYNC or self.ft_manager:
Expand All @@ -299,7 +299,6 @@ def load_state_dict(state_dict):

self.last_save_model_weights_only = ckpt_config.last_save_model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
self.exclude_from_loading = ckpt_config.exclude_from_loading

self.mp = None
self.async_future = None
Expand Down Expand Up @@ -418,21 +417,11 @@ def load(self, step: int = -1) -> bool:
if self.ft_manager:
self._ft_load()

if not self.enable_checkpoint:
return False

model_only = False
if not os.path.exists(self.folder):
if self.initial_load_path:
checkpoint_id = self.initial_load_path
if not os.path.isdir(checkpoint_id):
raise ValueError(
"initial_load_full_checkpoint is specified but the path is not valid."
)
model_only = self.initial_load_model_weights_only
else:
# 1. Try load from checkpoint folder first if it exists
if os.path.exists(self.folder):
if not self.enable_checkpoint:
return False
else:
if self.initial_load_path:
logger.info(
"`initial_load_path` is provided but the checkpoint folder exists. "
Expand All @@ -446,11 +435,33 @@ def load(self, step: int = -1) -> bool:

if not os.path.isdir(checkpoint_id):
return False
return self.load_from_path(checkpoint_id, model_only)

# Then try `initial_load_path`
if self.initial_load_path:
checkpoint_id = self.initial_load_path
if not os.path.isdir(checkpoint_id):
raise ValueError(
"initial_load_full_checkpoint is specified but the path is not valid."
)
model_only = self.initial_load_model_weights_only
return self.load_from_path(checkpoint_id, model_only)

return False

def load_from_path(self, path: str, model_only: bool = False) -> bool:
"""Load the checkpoint from the given path.

logger.info(f"Loading the checkpoint from {checkpoint_id}.")
This function will load the checkpoint from the given path. If ``model_only`` is
True, it will load the model only.
"""
if not os.path.isdir(path):
raise ValueError(f"The path {path} is not a valid checkpoint folder.")

logger.info(f"Loading the checkpoint from {path}.")
begin = time.monotonic()
states = self._states_to_load(model_only)
dcp.load(states, checkpoint_id=checkpoint_id)
dcp.load(states, checkpoint_id=path)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -561,20 +572,13 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
Dict[str, Any]: The states to load for the given step.
"""
# For the first step, we will only load the model weights.
if model_only:
sd = self.states[MODEL].state_dict()
for k in excluded_parameters_for_model_only:
sd.pop(k, None)
return sd

for exclude_key in self.exclude_from_loading:
if exclude_key not in self.states:
raise ValueError(f"{exclude_key} not found in state_dict.")

states = {MODEL: self.states[MODEL]} if model_only else self.states
states_to_load = {
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
k: v for k, v in states.items() if k not in self.exclude_from_loading
}

for exclude_key in self.exclude_from_loading:
if exclude_key not in states:
raise ValueError(f"{exclude_key} not found in state_dict.")
if self.ft_manager:
states_to_load.pop(DATALOADER)

Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ class Training:
seq_len: int = 2048
"""Sequence length"""

max_seq_len: int | None = None
"""
Maximum sequence length (defaults to the `seq_len`, but could be
used for loading pre-trained model. E.g., 131072 for llama3.1-8B)
"""

max_norm: float | int = 1.0
"""Max norm for gradient clipping"""

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TransformerModelArgs(BaseModelArgs):

def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
self.vocab_size = tokenizer.n_words
self.max_seq_len = job_config.training.seq_len
self.max_seq_len = job_config.training.max_seq_len or self.max_seq_len
self.eos_id = tokenizer.eos_id

if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn:
Expand Down
Loading