Skip to content

[Liger] liger DPO support #2568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open

[Liger] liger DPO support #2568

wants to merge 25 commits into from

Conversation

kashif
Copy link
Collaborator

@kashif kashif commented Jan 14, 2025

What does this PR do?

Add support for Liger-kernel losses for the DPO Kernel

Needs: linkedin/Liger-Kernel#521

Peft support: #3065

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

3. Loss values are reasonable and finite
4. Training works with both default and custom beta values
"""
beta_values = [0.1, 0.5] # Test multiple beta values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use @parameterized.expand instead?

@qgallouedec
Copy link
Member

liger loss isn't compatible with ref precomputing right? If so we could add a warning or an error.

Comment on lines 87 to 105

## Liger for reducing peak memory usage

[To complete]

<hfoptions id="liger">
<hfoption id="DPO">

To use Liger for reducing peak memory usage, use the following code snippet:

```python
from trl import DPOConfig

training_args = DPOConfig(..., use_liger_loss=True)
```

</hfoption>
</hfoptions>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif I've added this section in the new guide for reducing memory usage, if you've words to fill it

@VProv
Copy link

VProv commented Mar 26, 2025

@VProv VProv mentioned this pull request Mar 26, 2025
5 tasks
@kashif
Copy link
Collaborator Author

kashif commented Mar 26, 2025

@VProv, at the moment, I was having issues getting the same outputs/metrics with and without liger in the trainer.

@VProv
Copy link

VProv commented Mar 26, 2025

@VProv, at the moment, I was having issues getting the same outputs/metrics with and without liger in the trainer.

What setup are you using?

@vaibhavjindal
Copy link
Contributor

Hi, I am working on fixing the output/metrics issue.
Added a PR in liger-kernel: linkedin/Liger-Kernel#676

@vaibhavjindal
Copy link
Contributor

@kashif @qgallouedec can you please review the following PR which fixes the output/metrics issue? Thanks :)
#3346

@kashif
Copy link
Collaborator Author

kashif commented Apr 23, 2025

thanks @vaibhavjindal done, i'll fix the merge conflict and then review this PR

@hanbyul-kim
Copy link

Hi, thanks for sharing your work! Can I use your code with DeepSpeed Zero 3? I tried running it with that setup, but it doesn't seem to be working. I think it's related to parameter partitioning based on my analysis of the error log.

[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/dpo_loss.py", line 94, in forward
[rank5]:     return super().forward(
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 241, in forward
[rank5]:     accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 159, in accumulate_chunk
[rank5]:     ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 120, in fused_fwd_bwd
[rank5]:     return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
[rank5]:   File "/root/.dpo_trainer_venv/lib/python3.10/site-packages/torch/_functorch/apis.py", line 440, in wrapper
[rank5]:     return eager_transforms.grad_and_value_impl(
[rank5]:   File "/root/.dpo_trainer_venv/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 48, in fn
[rank5]:     return f(*args, **kwargs)
[rank5]:   File "/root/.dpo_trainer_venv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 1409, in grad_and_value_impl
[rank5]:     output = func(*args, **kwargs)
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 377, in _compute_loss
[rank5]:     ) = LigerFusedLinearPreferenceBase.chunk_forward(
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 289, in chunk_forward
[rank5]:     logits_chunk = input_chunk @ weight.t()
[rank5]: RuntimeError: size mismatch, got input (322), mat (322x4096), vec (0)

@hanbyul-kim
Copy link

Continuing my analysis, I can confirm that it's definitely connected to DeepSpeed zero 3. When I switched to stage 2, it ran smoothly without any issues.

@kashif
Copy link
Collaborator Author

kashif commented May 5, 2025

thanks @hanbyul-kim for the report

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants