Skip to content

GKD trainer + chunked JSD loss + FSDP #615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
benyaminjami opened this issue Mar 17, 2025 · 1 comment
Open

GKD trainer + chunked JSD loss + FSDP #615

benyaminjami opened this issue Mar 17, 2025 · 1 comment

Comments

@benyaminjami
Copy link

Hello Liger Kernel team,

First of all, thank you for making this project available! I’ve been exploring your codebase and tried to implement GKDTrainer using the chunked_jsd_loss similarly to how ORPOTrainer handles it. I’m now aiming to use Fully Sharded Data Parallel (FSDP) for both the teacher and student models but am unsure of the best way to integrate it.

I would greatly appreciate any guidance you could provide on:

Implementing the chunked JSD loss function for FSDP-enabled training – Are there recommended patterns or helper functions within the codebase that can simplify this process?
Key code structures or APIs in the GKDTrainer – Which parts of GKDTrainer might need modification or extension to properly handle chunked JSD loss under FSDP?
Best practices or potential pitfalls – Have you encountered any common issues or gotchas when combining chunked losses with FSDP that I should be aware of?
Code snippets or references – If you have any example snippets, documentation references, or design patterns that illustrate how to properly handle teacher and student models together under FSDP, that would be incredibly helpful.
Thank you in advance for your time and assistance! Any insights, tips, or examples you can share will help me get up and running much more quickly.

Additional Context:

I’m currently referencing the ORPOTrainer sample but see that it doesn’t fully address the GKD use case.

@shivam15s
Copy link
Collaborator

Implementing the chunked JSD loss function for FSDP-enabled training

The approach is similar to LigerORPOTrainer where we pass the models' lm_head weights and last hidden states to Liger{ORPO/JSD}Loss which should then return the expected loss.
The caveat with FSDP however is that you need to unshard the FSDP root params before doing the forward pass of the Liger{ORPO/JSD}Loss (assuming you're using FSDP 1) as otherwise lm_head weight would be sharded among the gpus (ref:

class _FSDPForwardRedirection:
"""
Modified based on
https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
post-forward can be properly executed around the method call.
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
"""
def __call__(
self,
wrapper_module: FullyShardedDataParallel,
method: Callable,
*args: Any,
**kwargs: Any,
):
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
Args:
wrapper_module: The module that has `original_module` wrapped.
original_module: The module that was wrapped inside `wrapper_module`.
method_name: The name of the method that should be called on the `original_module` after inputs get
redirected through the `wrapper_module`'s `forward` method.
*args: The positional arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
"""
assert isinstance(wrapper_module, FullyShardedDataParallel)
original_module = wrapper_module._fsdp_wrapped_module
original_forward = original_module.forward
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
original_module.forward = original_forward # type: ignore[method-assign]
# Call the actual method e.g. `.training_step(...)`
out = method(*_args, **_kwargs)
return out
# Patch the original_module's forward so we can redirect the arguments back to the real method
original_module.forward = wrapped_forward # type: ignore[method-assign]
wrapper_output = wrapper_module(*args, **kwargs)
return wrapper_output
).

Key code structures or APIs in the GKDTrainer – Which parts of GKDTrainer might need modification or extension to properly handle chunked JSD loss under FSDP?

Took a quick look through the GKDTrainer in trl -- I'd say you need to patch the compute_loss function to get the last hidden states and then do the unsharding as discussed above to finally get the loss through chunked JSD.

Best practices or potential pitfalls – Have you encountered any common issues or gotchas when combining chunked losses with FSDP that I should be aware of?

torch.compile gave us some issues when we were using mixed_precision training. A workaround was to force the inputs to LigerJSDLoss to be float32.

Code snippets or references – If you have any example snippets, documentation references, or design patterns that illustrate how to properly handle teacher and student models together under FSDP, that would be incredibly helpful.

Don't have the exact code snippet but can point you to two references:

  1. LigerORPOTrainer: Has the fsdp redirection to unshard_weights.
  2. [WIP] [Liger] liger JSD support: This has some wip patching code for doing what you need but this PR alone wont' work for FSDP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants