Skip to content

Conversation

pbontrager
Copy link
Contributor

@pbontrager pbontrager commented Dec 11, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

decoder_trainable flag was not hooked up right. This PR makes it so that setting it to lora or frozen toggles the decoder between frozen weights and lora correctly. This also disables mixed lora and full finetuning training. Setting x_trainable=full for a lora model will now throw an error. This will be re-enabled once checkpointing can properly support this.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

tune run lora_finetune_single_device --config llama3_2_vision/11B_lora_single_device metric_logger=torchtune.training.metric_logging.WandBLogger model.decoder_trainable=frozen

Screenshot 2024-12-11 at 11 08 03 AM

tune run lora_finetune_single_device --config llama3_2_vision/11B_lora_single_device metric_logger=torchtune.training.metric_logging.WandBLogger model.decoder_trainable=lora

Screenshot 2024-12-11 at 11 21 35 AM

For further testing I inspected the model parameters with the two options and manually verified that the correct decoder parameters existed or not and were trainable or not.

Copy link

pytorch-bot bot commented Dec 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2150

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3584823 with merge base 5370e0d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 11, 2024
decoder_type,
decoder_type,
fusion_type,
], "We've temporarily removed support for mixed LoRA + Full Finetuning yet. Please don't use the 'full' option and use llama3_2_vision_11b if you need full finetuning"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just say "X is not currently supported" instead of "We've temporarily removed support for X" (and same comment on the error message further down)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worded it this way since the option is still there, but we don't let you use it. Also in case someone is already using it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've temporarily removed support for mixed LoRA + Full Finetuning yet.

😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this could be worded in a kinder manner. Feels aggressive. Ask ChatGPT to reword or smthing.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For further testing I inspected the model parameters with the two options and manually verified that the correct decoder parameters existed or not and were trainable or not.

Sounds worthy of a unit test to me 👀

@pbontrager
Copy link
Contributor Author

Unit test added

@felipemello1
Copy link
Contributor

Setting x_trainable=full for a lora model will now throw an error. This will be re-enabled once checkpointing can properly support this.

Just to confirm, it does support it, we just don't support resuming from ckpt. Is that correct?

@felipemello1
Copy link
Contributor

felipemello1 commented Dec 12, 2024

Did you have a chance to compare the memory between these 2 wandb runs?

felipemello1
felipemello1 previously approved these changes Dec 12, 2024
Comment on lines +552 to +581
if decoder_lora:
self_attn = lora_llama3_attention(
lora_modules=lora_attn_modules,
pos_embeddings=rope,
head_dim=head_dim,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_seq_len=max_seq_len,
attn_dropout=0.0,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
else:
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
pos_embeddings=rope,
max_seq_len=max_seq_len,
attn_dropout=0.0,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean that before we always added lora to the decoder?

@felipemello1 felipemello1 dismissed their stale review December 12, 2024 00:47

wait until changes are made to the error message

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the unit test! I still would update the error message (it reads like a mishmash of two competing sentences). Otherwise this looks good to me

decoder_type,
decoder_type,
fusion_type,
], "We've temporarily removed support for mixed LoRA + Full Finetuning yet. Please don't use the 'full' option and use llama3_2_vision_11b if you need full finetuning"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this could be worded in a kinder manner. Feels aggressive. Ask ChatGPT to reword or smthing.

@felipemello1
Copy link
Contributor

i will merge it so i can cut the release. @pbontrager , please write a new PR with a new error msg, and i can cherry pick it

@felipemello1 felipemello1 merged commit cdf5ea2 into meta-pytorch:main Dec 12, 2024
17 checks passed
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 18, 2024
* Llama 3.3 70B (meta-pytorch#2124)

* Llama 3.3 readme updates (meta-pytorch#2125)

* update configs (meta-pytorch#2107)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* Reduce logging output for distributed KD (meta-pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (meta-pytorch#1076)

Co-authored-by: ebsmothers <ebs@meta.com>

* Update checkpointing directory (meta-pytorch#2074)

Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>

* pass correct arg (meta-pytorch#2127)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* update configs (meta-pytorch#2128)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* fix qat_lora_test (meta-pytorch#2131)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* guard ckpt imports (meta-pytorch#2133)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* [bug fix] add parents=True (meta-pytorch#2136)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* [bug fix] re-add model (meta-pytorch#2135)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* Update save sizes into GiB (meta-pytorch#2143)

* [bug fix] remove config download when source is kaggle (meta-pytorch#2144)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* [fix] remove "with_suffix" (meta-pytorch#2146)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* DoRA fixes (meta-pytorch#2139)



Co-authored-by: Mircea Mironenco <5738815+mirceamironenco@users.noreply.github.com>

* [Fix] Llama 3.2 Vision decoder_trainable flag fixed (meta-pytorch#2150)

* Small readme, config updates (meta-pytorch#2157)

* Using `FormattedCheckpointFiles` in configs (meta-pytorch#2147)

* Move ``get_world_size_and_rank`` to utils (meta-pytorch#2155)

* Faster intermediate checkpoints with DCP async save in TorchTune (meta-pytorch#2006)

Co-authored-by: Saurabh Mishra <msaurabh@fb.com>

* torchdata integration - multi-dataset and streaming support (meta-pytorch#1929)

* Allow higher version of lm-eval (meta-pytorch#2165)

* Using `FormattedCheckpointFiles` in configs... round 2 (meta-pytorch#2167)

* [EZ] Fix set_torch_num_threads in multi-node. (meta-pytorch#2164)

---------

Co-authored-by: Philip Bontrager <pbontrager@gmail.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>
Co-authored-by: Mircea Mironenco <5738815+mirceamironenco@users.noreply.github.com>
Co-authored-by: salman <salman.mohammadi@outlook.com>
Co-authored-by: Saurabh Mishra <msaurabh@meta.com>
Co-authored-by: Saurabh Mishra <msaurabh@fb.com>
Co-authored-by: Andrew Ho <andrew.kenneth.ho@gmail.com>
Co-authored-by: Eugen Hotaj <eugen_hotaj_91@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants