Skip to content

Conversation

calvinpelletier
Copy link
Contributor

@calvinpelletier calvinpelletier commented Jun 24, 2024

Context

Adding support for DoRA: https://arxiv.org/abs/2402.09353

Also refactoring LoRALinear. The adapter logic is now encapsulated in a submodule LowRankAdapter while the base layer logic remains in LoRALinear. This lets us just wrap the LowRankAdapter module for FSDP, instead of separately wrapping the lora_a/lora_b projections and the new magnitude parameter.

Addresses #1100, #893, #936

Changelog

  • Refactor LoRALinear/LowRankAdapter
  • Support DoRA in LoRALinear (enabled via use_dora in constructor)
  • Support DoRA in get_merged_lora_ckpt
  • Add use_dora to all model/component builders
  • Example configs for DoRA and QDoRA
  • Unit test for DoRA
  • Update LoRA recipes so that they notify the model when the base parameters have been loaded (for DoRA initialization)
  • Fix FSDP unit tests
  • Update recipes/dev/lora_finetune_fsdp2.py
  • Support DoRA models in state dict conversion PEFT <-> torchtune (*.lora_magnitude_vector.weight <-> *.lora.magnitude)

Test plan

pytest tests/torchtune/modules/peft/test_lora.py::TestLoRALinear::test_dora

  • checks parity against a reference implementation based on huggingface's
  • checks that DoRA initializes correctly
  • checks that the merged linear module is approximately the same as the unmerged DoRA linear module

Compared finetunes for LoRA variants (config based on llama3/8B_lora(_single_device) for 2 epochs):

LoRA 1xH100

DoRA 1xH100

QLoRA 1xH100

QDoRA 1xH100

DoRA 2xH100


Attempted to replicate section 3.2 ("Weight Decomposition Analysis") of the DoRA paper.

The code is here.

Results (for comparison with figure 2 of the paper):
Each point is a single query projection (avg direction delta on the X axis, avg magnitude delta on the Y axis).
Full finetune
full
LoRA finetune
lora
DoRA fintune
dora

The difference with the paper is probably due to a different experimental setup (Llama3 finetuned on alpaca instead of VL-BART finetuned on four image-text tasks).

Copy link

pytorch-bot bot commented Jun 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1115

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5f243af with merge base 0bdd308 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2024
@ebsmothers ebsmothers mentioned this pull request Jun 25, 2024
11 tasks
@ebsmothers ebsmothers requested a review from SLR722 June 25, 2024 01:24
@SLR722
Copy link
Member

SLR722 commented Jul 1, 2024

Thanks for introducing DoRA to torchtune!

  • could you add an example distributed yaml config to showcase DoRA?
  • In DoRA paper, DoRA outperformances LoRA on several tasks while they didn't mention the loss different. I saw in your comparison, the loss is pretty similar between LoRA and DoRA, could you also eval DoRA and LoRA on some eval benchmarks to see the difference?

_TO_PEFT_KEYS = {
"lora_a": "lora_A",
"lora_b": "lora_B",
"lora.a": "lora_A",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we rename the LoRA keys name?

lora_out = self._dora_forward(x, base_weight, lora_out)
return base_out + lora_out

def _dora_forward(self, x: Tensor, base_weight: Tensor, lora_out: Tensor) -> Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DoRA logic looks good to me based on HF's implementation



class LoRALinear(nn.Module, AdapterModule):
class LoRALinear(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this LoRA refactor part, defer to @ebsmothers to check

@ebsmothers
Copy link
Contributor

Update: I commandeered this PR and pushed some changes with help from @weifengpy. Instead of the previous version where we split out LoRALinear into two classes to work with FSDP, I've kept our existing LoRA design and simply added a new DoRALinear class that's pretty similar to LoRALinear but with the addition of the magnitude parameter.

A lot of the reason for the design of the original version was the usage of FSDP1 APIs, which were a bit too restrictive in terms of sharding params with different values of requires_grad in the same nn.Module. But with FSDP2 APIs, this is not a problem. However, loading the magnitude param on meta device was a bit tricky.. I wound up using the workaround suggested by Wei in pytorch/pytorch#132721, calling load_state_dict() on the magnitude params to force the call to reset_sharded_param after we initialize them.

There is still some more work to do here though:

(1) Checkpoint save does not work (I think this should be relatively straightforward though)
(2) A lot of cleanup, I'm sure I broke a bunch of tests with my refactor

I think QDoRA will also need more extensive testing too. cc @SLR722

@SalmanMohammadi SalmanMohammadi mentioned this pull request Aug 7, 2024
10 tasks
raise AssertionError("Unexpected key loading adapter")


def load_dora_magnitudes(model: nn.Module) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fyi after https://github.com/pytorch/pytorch/pull/132954/files landed we can now remove this hack by following the example there

@SLR722 SLR722 force-pushed the dora2 branch 4 times, most recently from 0d487b2 to 4455169 Compare August 16, 2024 03:46
@SLR722 SLR722 marked this pull request as draft August 16, 2024 18:04
@SLR722 SLR722 force-pushed the dora2 branch 2 times, most recently from 29998ad to 6744221 Compare August 21, 2024 00:01
@RdoubleA RdoubleA mentioned this pull request Aug 21, 2024
for m in reversed(list(model.modules())):
if isinstance(m, nn.Linear) and m.weight.requires_grad:
fully_shard(m, **fsdp_kwargs)
if isinstance(m, DoRALinear):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we treat DoRALinear separately here? why don't we treat it same as LoRALinear?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's necessary for proper initialization of the magnitude from the sharded LoRA A and B weights. Since otherwise they will be in different FSDP param groups so we would probably need to manually gather and then reshard the weights. You can doublecheck by commenting out L319-320 though to do the usual LoRA sharding, my guess is you will see some kind of DTensor error though.

adapter_params = get_adapter_params(model)
set_trainable_params(model, adapter_params)
num_lora_ab, num_transformer_layers = _get_n_lora_and_tformer_layers(model)
num_lora, num_transformer_layers = _get_n_lora_and_tformer_layers(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw multiple changes to test_lora_fsdp_wrap test but the unit test doesn't pass on those changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry can you elaborate? Do you mean changes when you merge with latest main? Or changes from Calvin's initial version?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several changes on test_distributed.py compared with master and I think it's either done by Calvin or your FSDP refactor work. Those tests are broken and I don't have context on why we need these changes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK yeah I can take a look. The only changes I made were these ones

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the test_distributed.py issue by reverting some of Calvin's change on _distributed.py, which are unnecessary after the DoRA refactor

@codecov-commenter
Copy link

codecov-commenter commented Aug 22, 2024

Codecov Report

Attention: Patch coverage is 86.95652% with 36 lines in your changes missing coverage. Please review.

Project coverage is 70.12%. Comparing base (0bdd308) to head (fb8266c).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/models/mistral/_component_builders.py 41.17% 10 Missing ⚠️
torchtune/models/llama3/_component_builders.py 14.28% 6 Missing ⚠️
torchtune/models/llama3_1/_component_builders.py 14.28% 6 Missing ⚠️
torchtune/models/gemma/_component_builders.py 16.66% 5 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 3 Missing ⚠️
torchtune/modules/peft/_utils.py 85.00% 3 Missing ⚠️
torchtune/models/llama2/_component_builders.py 89.47% 2 Missing ⚠️
torchtune/modules/peft/dora.py 98.48% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1115      +/-   ##
==========================================
+ Coverage   69.80%   70.12%   +0.32%     
==========================================
  Files         272      274       +2     
  Lines       13053    13271     +218     
==========================================
+ Hits         9111     9306     +195     
- Misses       3942     3965      +23     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SLR722
Copy link
Member

SLR722 commented Aug 22, 2024

I made several updates to this PR to push it closer to the finish line.

  • fix and clean up several unit tests to let majority of the unit tests pass (there are some remaining issues will list below)
  • split test_dora.py and test_lora and improve DoRA logic to make it match with reference implementation
  • fix the checkpointer logic so that dora_distributed, dora_single_device and qdora_single_device can run successfully E2E

There are some remaining issues may need help from @ebsmothers or other people

  • test_dora.py numerical doesn't match ref implementation when dropout p = 0.1 (maybe it's expected because of the non-deterministic of dropout?)
  • some tests in test_distributed.py doesn't pass (see the discussions here DoRA #1115 (comment))
  • 8B_dora_fsdp2.yaml takes long time in get_full_model_state_dict() with >2 ranks which causes NCCL timeout

cc: @ebsmothers

@SLR722
Copy link
Member

SLR722 commented Aug 27, 2024

To further verification the correctness of this DoRA work

  • compare the loss between LoRA and DoRA

  • compare the correlation between changes in direction and magnitude

  • run eval on mmlu

    • LoRA (0.6187)
Screenshot 2024-08-26 at 8 57 48 PM - DoRA (0.6214) Screenshot 2024-08-26 at 9 03 30 PM - DoRA has slightly better eval results

@SLR722 SLR722 marked this pull request as ready for review August 28, 2024 23:34

# Training
epochs: 1
max_steps_per_epoch: 50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug code sneaking in 😃

"""
# NOTE: this function has to be updated if the names of "lora_a" and "lora_b"
# in this module change.
# TODO: need to add back magnitude, but causing initial check to error out cause it's not defined yet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this can be removed now

Comment on lines 35 to 36
"_lora_a_init_params",
"_lora_b_init_params",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want these to be public

Comment on lines 173 to 270
class _DoraReference(nn.Module):
"""
DoRA linear layer reference.
Paper: https://arxiv.org/abs/2402.09353
Based on the code from:
https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py
https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/dora.py
For more info, see the discussion here:
https://github.com/huggingface/peft/pull/1474
"""

def __init__(
self,
dtype: torch.dtype,
in_dim: int,
out_dim: int,
rank: int,
alpha: float,
dropout: float = 0.0,
use_bias: bool = False,
quantize_base: bool = False,
use_dora: bool = True,
):
super().__init__()
self.use_bias = use_bias
self.quantize_base = quantize_base
self.use_dora = use_dora

linear = nn.Linear(
in_features=in_dim, out_features=out_dim, bias=use_bias, dtype=dtype
)
weight = linear.weight if not quantize_base else to_nf4(linear.weight)
bias = None
if use_bias:
if quantize_base:
raise NotImplementedError()
bias = linear.bias
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)

self.lora_a = nn.Linear(in_dim, rank, bias=False, dtype=dtype)
self.lora_b = nn.Linear(rank, out_dim, bias=False, dtype=dtype)
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_b.weight)
self.scaling = alpha / rank
if use_dora:
self.lora_magnitude = nn.Parameter(torch.randn(out_dim, dtype=dtype))
self.dropout = nn.Dropout(p=dropout)

def initialize_dora(self):
weight = self.weight.to(self.lora_a.weight.dtype)
lora_weight = self.lora_b.weight @ self.lora_a.weight
weight_norm = self._get_weight_norm(weight, lora_weight)
self.lora_magnitude = nn.Parameter(weight_norm, requires_grad=True)

def forward(self, x):
result = self._base_forward(x)
torch_result_dtype = result.dtype
x = x.to(self.lora_a.weight.dtype)
if not self.use_dora:
result = result + self.lora_b(self.lora_a(self.dropout(x))) * self.scaling
else:
x = self.dropout(x)
result = result + self._dora_forward(x)
result = result.to(torch_result_dtype)
return result

def _base_forward(self, x):
if self.quantize_base:
return linear_nf4(input=x, weight=self.weight)
return F.linear(x, self.weight, self.bias)

def _dora_forward(self, x):
lora_result = self.lora_b(self.lora_a(x))
x_eye = torch.eye(
self.lora_a.weight.shape[1], device=self.lora_a.weight.device, dtype=x.dtype
)
lora_weight = self.lora_b(self.lora_a(x_eye)).T

magnitude = self.lora_magnitude
weight = self.weight.to(x.dtype)
weight_norm = self._get_weight_norm(weight, lora_weight.detach())
weight_norm = weight_norm.detach()
mag_norm_scale = (magnitude / weight_norm).view(1, -1)
result_dora = (mag_norm_scale - 1) * (
F.linear(x, weight)
) + mag_norm_scale * lora_result * self.scaling
return result_dora

def _get_weight_norm(self, weight, lora_weight):
weight = weight + self.scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should refactor this test. In general we try not to put the reference implementation directly in the unit test, but instead use it to determine the expected values, hardcode those in the test, and point to where we got them

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address this comment by only compare the numerical in unit test and move the reference implementation to tests/torchtune/models/llama2/scripts/compare_dora.py

torch.manual_seed(0)
qdora_linear_out = qdora_linear(inputs)
torch.testing.assert_close(
dora_linear_out, qdora_linear_out, rtol=1e-01, atol=1e-01
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to make this assert very loose since I found to_nf4 will change the original value significantly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is a bit suspicious to me. For comparison in the corresponding LoRA test we do not have to do this. Can we make sure nothing unexpected is happening with the magnitude vector here?

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small things, otherwise want to make sure there are no correctness around quantization with the magnitude. After that this looks good to go!

rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
use_dora=use_dora,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't have to pass use_dora to adapter_cls or DoRALinear (seems there are a bunch of such instances but just in this file)

path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @RdoubleA is working on landing some changes to change this path from .utils. -> .training.. Let's coordinate since these are both big sets of changes

quantize_base=False,
)

# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this is commented out?

torch.manual_seed(0)
qdora_linear_out = qdora_linear(inputs)
torch.testing.assert_close(
dora_linear_out, qdora_linear_out, rtol=1e-01, atol=1e-01
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is a bit suspicious to me. For comparison in the corresponding LoRA test we do not have to do this. Can we make sure nothing unexpected is happening with the magnitude vector here?

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is finally ready to go! Huge thanks to both @calvinpelletier for all the initial work and to @SLR722 for taking this one the last mile (really quite a few miles). Really excited to be able to support DoRA!

@ebsmothers ebsmothers merged commit e55a41c into meta-pytorch:main Aug 30, 2024
20 checks passed
@calvinpelletier calvinpelletier deleted the dora2 branch December 8, 2024 18:19
FlamingoPg pushed a commit to FlamingoPg/sgl-tune-eagle that referenced this pull request May 26, 2025
`norm.py` and `norm_type` in `JobConfig` and llama `ModelArgs` were
introduced when `nn.RMSNorm` was not available. Now that we don't have
such need, let's remove them, following meta-pytorch#1111.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants