-
Notifications
You must be signed in to change notification settings - Fork 427
Finetune from pre-trained models #1300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,6 +114,36 @@ Llama 3 8B model locally on 8 GPUs | |
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh | ||
``` | ||
|
||
### Fine-tuning from an existing checkpoint | ||
|
||
You first need to download the Llama checkpoint. Here are the commands: | ||
|
||
```bash | ||
export HF_TOKEN=... # get your HF token from https://huggingface.co/settings/tokens | ||
# Download the tokenizer and model weights | ||
rm -rf tmp | ||
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/tokenizer.model --local-dir tmp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We covered the downloading of tokenizer above, in section "Downloading a tokenizer". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah true, I was gonna ask do you want to replace that with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh I see, maybe let's first put the complete There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! |
||
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/consolidated.00.pth --local-dir tmp | ||
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/params.json --local-dir tmp | ||
# Convert the model weights to the DCP format and move it and the tokenizer to the assets folder | ||
mkdir -p assets/tokenizer && cp tmp/original/tokenizer.model assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model | ||
uv run python -m scripts.convert_llama_to_dcp tmp/original/ assets/models/dcp/llama3.1-8B | ||
Comment on lines
+126
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah forgot to remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. like, instead of |
||
``` | ||
|
||
Then you can fine-tune from the checkpoint: | ||
|
||
```bash | ||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" uv run ./run_train.sh \ | ||
--model.tokenizer_path assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model \ | ||
--training.max_seq_len 131072 \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it's necessary to create this config -- how is it different from specifying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One example use case is when I don't actually have documents up to seq_len 131072, but the pre-trained model has a default seq_len of 131072. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see the problem. The issue is that In that case, we should prob re-use the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My question is why you'd wish them to be different. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am worried about setting However, it appears I need to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
oh I see your worry. I don't think it should be the case. Like I said, the only place the For your finetuning job, the model capability shouldn't be affected by specifying a smaller BTW, you could consider use CP in torchtitan for long sequence finetuning. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I see. It looks like because of how Thanks. I will adjust the PR accordingly. ![]() |
||
--checkpoint.initial_load_path "assets/models/dcp/llama3.1-8B" \ | ||
--profiling.no_enable_profiling \ | ||
--activation_checkpoint.mode full \ | ||
--training.global_batch_size 64 \ | ||
--lr_scheduler.warmup_steps 40 \ | ||
--optimizer.lr 1e-5 | ||
``` | ||
|
||
### Multi-Node Training | ||
For training on ParallelCluster/Slurm type configurations, you can use the `multinode_trainer.slurm` file to submit your sbatch job. | ||
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh we shouldn't just revert the changes -- instead we should investigate the root cause cc @fegin pls take a look at if recent changes break anything |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put this under
docs/finetune.md
instead of main README? We can create a link to the doc around here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course. Will do.