From 04561bdf22feb4534e26f4c4b6adce1575cb7f17 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 23 Jun 2025 07:02:45 +0100 Subject: [PATCH] [Algorithm] DPO --- docs/source/reference/llms.rst | 23 ++ test/llm/test_objectives.py | 314 ++++++++++++++++++- torchrl/data/llm/acceptance.py | 240 +++++++++++++++ torchrl/objectives/llm/__init__.py | 16 +- torchrl/objectives/llm/dpo.py | 471 +++++++++++++++++++++++++++++ torchrl/objectives/ppo.py | 20 +- 6 files changed, 1071 insertions(+), 13 deletions(-) create mode 100644 torchrl/data/llm/acceptance.py create mode 100644 torchrl/objectives/llm/dpo.py diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst index 402e18ffa97..f3379f2ac38 100644 --- a/docs/source/reference/llms.rst +++ b/docs/source/reference/llms.rst @@ -282,6 +282,8 @@ SFT SFTLoss SFTLossOutput + sft_loss + minor_sft_loss .. currentmodule:: torchrl.data.llm @@ -290,3 +292,24 @@ SFT :template: rl_template.rst TopKRewardSelector + +DPO +~~~ + +.. currentmodule:: torchrl.objectives.llm + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + DPOLoss + DPOLossOutput + dpo_loss + +.. currentmodule:: torchrl.data.llm + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + AcceptanceRewardSelector diff --git a/test/llm/test_objectives.py b/test/llm/test_objectives.py index baf301e5f33..e72a0187966 100644 --- a/test/llm/test_objectives.py +++ b/test/llm/test_objectives.py @@ -11,14 +11,19 @@ import pytest import torch from mocking_classes_llm import DummyStrDataLoader - -from tensordict import lazy_stack, set_capture_non_tensor_stack, TensorDict +from tensordict import ( + lazy_stack, + NonTensorStack, + set_capture_non_tensor_stack, + TensorDict, +) from torchrl.data import History, LazyStackStorage, ReplayBuffer, Unbounded from torchrl.envs import Transform from torchrl.envs.llm import LLMEnv from torchrl.envs.llm.transforms.kl import RetrieveLogProb from torchrl.modules.llm import TransformersWrapper from torchrl.objectives import ClipPPOLoss +from torchrl.objectives.llm.dpo import DPOLoss, DPOLossOutput from torchrl.objectives.llm.grpo import GRPOLoss, GRPOLossOutput, MCAdvantage from torchrl.objectives.llm.sft import SFTLoss @@ -249,8 +254,6 @@ def test_sft( data, policy_train, ): - pass - policy_train, tokenizer = policy_train loss = SFTLoss( actor_network=policy_train, @@ -338,6 +341,309 @@ def test_sft_assistant_only(self, data): loss(td) +class TestDPO: + @pytest.fixture(scope="class") + def preference_data(self): + from transformers import AutoTokenizer + + # Create preference data with chosen/rejected pairs + chats = [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, # chosen + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's 2+2?"}, + {"role": "assistant", "content": "I don't know."}, # rejected + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Explain quantum physics."}, + { + "role": "assistant", + "content": "Quantum physics is complex.", + }, # chosen + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Explain quantum physics."}, + {"role": "assistant", "content": "2+2 equals 4."}, # chosen + ], + ] + # with LLMs, rewards have 2 singleton dimensions + rewards = torch.tensor([1.0, -1.0, 1.0, -1.0]).unsqueeze(-1) + history = History.from_chats(chats) + assert history.shape == (4, 3) # 2 conversations, 4 messages each + + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + tokenizer.pad_token = tokenizer.eos_token + + # Create preference labels (True for chosen, False for rejected) + is_chosen = torch.tensor([True, False, True, False]) + + # Prepare text for each response + text = history[:, :-2].apply_chat_template( + tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True + ) + text_chosen = history[:, -2:-1].apply_chat_template( + tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False + ) + text_rejected = history[:, -1:].apply_chat_template( + tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False + ) + + # Create tensordict with preference data + # We have 4 trajectories of 1 step each + td = TensorDict( + history=history, + done=torch.zeros(4, dtype=torch.bool), + next=TensorDict( + is_chosen=is_chosen, + done=torch.ones(4, dtype=torch.bool), + reward=rewards, + history=history, + ), + batch_size=(4,), + ).unsqueeze( + 1 + ) # unsqueeze time dim - there's a single step + yield lazy_stack(list(td.unbind(0))) + + @pytest.fixture(scope="class") + def policy_train(self): + from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM + + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + tokenizer.pad_token = tokenizer.eos_token + + model = OPTForCausalLM(OPTConfig()).eval() + policy_train = TransformersWrapper( + model, + tokenizer=tokenizer, + generate=False, + from_text=True, + chat_template_name="qwen", + ) + + return policy_train, tokenizer + + @pytest.mark.skipif( + not _has_transformers, reason="transformers lib required to test DPO" + ) + @pytest.mark.parametrize("beta", [0.1, 0.5, 1.0]) + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + @pytest.mark.parametrize("normalize_by_seq_length", [True, False]) + @pytest.mark.parametrize("kl_to_ref_coeff", [None, 0.1]) + def test_dpo( + self, + beta, + reduction, + normalize_by_seq_length, + kl_to_ref_coeff, + preference_data, + policy_train, + ): + policy_train, tokenizer = policy_train + + loss = DPOLoss( + actor_network=policy_train, + tokenizer=tokenizer, + beta=beta, + reduction=reduction, + normalize_by_seq_length=normalize_by_seq_length, + kl_to_ref_coeff=kl_to_ref_coeff, + tokenizer_kwargs={"chat_template_name": "qwen"}, + ) + + td = preference_data + + # Add reference log probabilities if needed + if kl_to_ref_coeff is not None: + policy_ref = TransformersWrapper( + policy_train.model, + tokenizer=tokenizer, + generate=False, + from_text=True, + return_log_probs=True, + chat_template_name="qwen", + ) + transform = RetrieveLogProb( + policy_ref, + assistant_only=True, + tokenizer_kwargs={"chat_template_name": "qwen"}, + tokenizer=tokenizer, + ) + with torch.no_grad(): + # Compute ref log-probs + transform(td) + + loss_vals = loss(td) + + # Check output structure + assert isinstance(loss_vals, DPOLossOutput) + assert loss_vals.loss_dpo.requires_grad + assert loss_vals.chosen_rewards is not None + assert loss_vals.rejected_rewards is not None + assert loss_vals.accuracy is not None + + # Check shapes based on reduction + if reduction == "mean": + assert loss_vals.loss_dpo.shape == () + elif reduction == "sum": + assert loss_vals.loss_dpo.shape == () + elif reduction == "none": + # Should have shape matching the number of preference pairs + assert loss_vals.loss_dpo.shape == (2,) + + # Check KL loss if enabled + if kl_to_ref_coeff is not None: + assert loss_vals.loss_kl_to_ref is not None + assert loss_vals.kl_to_ref is not None + assert loss_vals.loss_kl_to_ref.shape == () + assert loss_vals.kl_to_ref.shape == () + else: + assert loss_vals.loss_kl_to_ref is None + assert loss_vals.kl_to_ref is None + + # Check that total loss can be computed + total_loss = loss_vals.sum(reduce=True) + assert total_loss.shape == () + assert total_loss.requires_grad + + # Check accuracy is reasonable (should be between 0 and 1) + assert 0.0 <= loss_vals.accuracy.item() <= 1.0 + + @pytest.mark.skipif( + not _has_transformers, reason="transformers lib required to test DPO" + ) + def test_dpo_no_preference_pairs(self, policy_train): + """Test that DPO raises an error when no preference pairs are present.""" + policy_train, tokenizer = policy_train + + # Create data with only chosen responses (no rejected) + chats = [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello?"}, + {"role": "assistant", "content": "Hi there!"}, + ], + ] + history = History.from_chats(chats) + + # All responses marked as chosen (no rejected) + is_chosen = torch.tensor([True]) + + td = TensorDict( + history=history, + next=TensorDict( + is_chosen=is_chosen, + done=torch.zeros(1, dtype=torch.bool), + history=history, + ), + batch_size=(1,), + ) + + loss = DPOLoss( + actor_network=policy_train, + tokenizer=tokenizer, + beta=0.1, + tokenizer_kwargs={"chat_template_name": "qwen"}, + ) + + with pytest.raises( + ValueError, match="Both chosen and rejected responses must be present" + ): + loss(td) + + def test_dpo_loss_function(self, preference_data): + """Test the standalone dpo_loss function.""" + from torchrl.objectives.llm.dpo import dpo_loss + + # Create some dummy log probabilities + policy_chosen_logprob = torch.tensor([1.0, 2.0]).requires_grad_(True) + policy_rejected_logprob = torch.tensor([0.5, 1.0]).requires_grad_(True) + reference_chosen_logprob = torch.tensor([0.8, 1.5]).requires_grad_(False) + reference_rejected_logprob = torch.tensor([0.3, 0.8]).requires_grad_(False) + beta = 0.1 + + # Test different reductions + for reduction in ["mean", "sum", "none"]: + loss = dpo_loss( + policy_chosen_logprob, + policy_rejected_logprob, + reference_chosen_logprob, + reference_rejected_logprob, + beta, + reduction, + ) + + assert loss.requires_grad + if reduction == "mean": + assert loss.shape == () + elif reduction == "sum": + assert loss.shape == () + elif reduction == "none": + assert loss.shape == (2,) + + assert (loss > 0).all() + + @pytest.mark.skipif( + not _has_transformers, reason="transformers lib required to test DPO" + ) + @pytest.mark.parametrize("reward_threshold", [0.0, "mean", "median"]) + def test_dpo_acceptance_reward_selector( + self, preference_data, reward_threshold, policy_train + ): + from torchrl.data import LazyStackStorage, ReplayBuffer + from torchrl.data.llm.acceptance import ( + AcceptanceRewardSampler, + AcceptanceRewardSelector, + ) + + policy_train, tokenizer = policy_train + rb = ReplayBuffer( + storage=LazyStackStorage(4), + transform=AcceptanceRewardSelector( + reward_threshold=reward_threshold, total_dialog_turns=2 + ), + sampler=AcceptanceRewardSampler(total_dialog_turns=2), + ) + + td = preference_data.copy() + del td["next", "is_chosen"] + td["text"] = NonTensorStack( + *[ + h.apply_chat_template( + tokenizer=tokenizer, + chat_template_name="qwen", + add_generation_prompt=True, + ) + for h in td["history"][..., 0].unbind(0) + ] + ).unsqueeze(-1) + + assert len(td["text"]) == 4 + assert td["text"][0] == td["text"][1] + assert td["text"][2] == td["text"][3] + assert td.shape == (4, 1) + rb.extend(td) + assert len(rb) == 2 + data = rb.sample(10) + assert data["next", "is_chosen"].shape == (2, 10, 1, 1) + assert data["next", "is_chosen"][0].all() + assert not data["next", "is_chosen"][1].any() + + data = rb[:] + assert ( + data["next", "is_chosen"].squeeze() + == torch.tensor([True, False, True, False]).view(2, 2) + ).all() + + # Test loss execution + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/llm/acceptance.py b/torchrl/data/llm/acceptance.py new file mode 100644 index 00000000000..f77f9d76242 --- /dev/null +++ b/torchrl/data/llm/acceptance.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from collections import defaultdict, deque +from typing import Any + +import torch +from tensordict import NestedKey, TensorDictBase, lazy_stack +from torchrl._utils import logger as torchrl_logger +from torchrl.data.replay_buffers.samplers import Sampler +from torchrl.envs.transforms import Transform +from torchrl.data.replay_buffers.storages import Storage +from torchrl.data.replay_buffers.writers import RoundRobinWriter + +from typing import Literal + +class AcceptanceRewardSelector(Transform): + """A replay-buffer transform that marks items as accepted or rejected, based on a reward threshold. + + Args: + reward_threshold (float | Literal["mean", "median"]): Threshold for the reward to be considered accepted. + Can be a `float` value or `"mean"` or `"median"`, in which case the acceptance is based on the mean or median of the rewards + over cumulated batches (`total = total_dialog_turns`). + + Keyword Args: + total_dialog_turns (int): Number of dialog turns to keep in memory for the acceptance selection. + reward_key (NestedKey): Key to the reward in the tensordict. Defaults to ("next", "reward"). + done_key (NestedKey): Key to the done state in the tensordict. Defaults to ("next", "done"). + accept_key (NestedKey): Key to the accept state in the tensordict. Defaults to ("next", "is_chosen"). + verbose (bool): Whether to print verbose information. Defaults to `False`. + + """ + + def __init__( + self, + reward_threshold: float | Literal["mean", "median"], + *, + total_dialog_turns: int, + reward_key: NestedKey = ("next", "reward"), + done_key: NestedKey = ("next", "done"), + accept_key: NestedKey = ("next", "is_chosen"), + prompt_key: NestedKey = "text", + verbose: bool = False, + ): + super().__init__() + self.reward_threshold = reward_threshold + self.total_dialog_turns = total_dialog_turns + self.queues = defaultdict(deque) + self._cumul = isinstance(reward_threshold, str) + + self.reward_key = reward_key + self.done_key = done_key + self.accept_key = accept_key + self.prompt_key = prompt_key + self.verbose = verbose + + def forward(self, tensordict: TensorDictBase) -> Any: + return tensordict + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + # flip batch size and accept/reject dim to have a TD of shape [2, batch_size] + tensordict = tensordict.transpose(1, 0) + if tensordict.shape[0] != 2: + raise ValueError(f"Expected a TD of shape [2, batch_size], got {tensordict.shape=}") + return tensordict + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + # This transform expects trajectories, either in batches or a single (cat of) trajectories + if tensordict.ndim == 1: + # Check how many done states we have + num_done = tensordict[self.done_key].sum() + if num_done > 1: + done_idx = tensordict[self.done_key].nonzero(as_tuple=True)[0] + 1 + splits = torch.cat([done_idx.new_zeros((1,)), done_idx], dim=0).diff() + tensordicts = tensordict.split(splits) + tensordicts = [self._inv_call(td) for td in tensordicts] + tensordicts = [td for td in tensordicts if td is not None] + return torch.cat(tensordicts, 0) if tensordicts else None + # Then we have a single trajectory. Check if it's done + if not tensordict[-1][self.done_key].all(): + raise RuntimeError("Expected the trajectory to be done.") + # Now we have a single, done trajectory. Get the prompt, add it to the corresponding queue + prompt = tensordict[0][self.prompt_key] + if not isinstance(prompt, str): + raise TypeError(f"Expected a string as prompt, got {type(prompt)=}") + self.queues[prompt].append(tensordict) + # If the queue is full, we can process it and pass it to the buffer + + if len(self.queues[prompt]) == self.total_dialog_turns: + if self.verbose: + torchrl_logger.info(f"Getting top-k rewards for {prompt=}") + # lazy_stack of the trajectories + tds = lazy_stack(list(self.queues.pop(prompt)), 0) + # Collect rewards: they will have shape (total_dialog_turns, traj_len, *reward_shape) + reward = tds.get(self.reward_key, as_nested_tensor=True) + print(f"{reward=}") + reward = self._aggregate_rewards(reward) + # Check if all rewards are equal + if (reward == reward[0]).all(): + # If all rewards are equal, we can't select top-k - discard the trajectories + if self.verbose: + torchrl_logger.warning( + f"All rewards are equal ({reward.unique()=})" + ) + return + # Filter out rewards below median / target value + if self.reward_threshold == "median": + reward_threshold = reward.median(dim=-1, keepdim=True)[0] + elif self.reward_threshold == "mean": + reward_threshold = reward.mean(dim=-1, keepdim=True)[0] + else: + reward_threshold = self.reward_threshold + mask = reward > reward_threshold + try: + tds.set(self.accept_key, mask.view(tds.shape)) + except Exception as e: + raise RuntimeError(f"Failed setting the accept key with shape {mask.shape} for {tds.shape=}. It is expected that the number of elements of the accept key is the same as the number of elements in the tensordict.") from e + accepted_tds = tds[mask.nonzero(as_tuple=True)] + rejected_tds = tds[(~mask).nonzero(as_tuple=True)] + # Make a lazy stack of accepted rejected. This stack will have shape + # (1, 2, total_dialog_turns // 2, traj_len) + tds = lazy_stack([accepted_tds, rejected_tds]).unsqueeze(0) # 0 is accepted, 1 is rejected + return tds + return + elif tensordict.ndim > 2: + # keep the time dim at the end + tensordict = tensordict.flatten(0, -2) + trajs = tensordict.unbind(0) + # Iterate over the trajectories + result = [] + for traj in trajs: + td_out = self._inv_call(traj) + if td_out is None: + continue + result.append(td_out) + if result: + return torch.cat(result, 0) + return + + def _aggregate_rewards(self, reward: torch.Tensor) -> torch.Tensor: + """Aggregate the rewards across the dialog turns. + + `reward` is expected to be a nested tensor. + + The default implementation is to take the mean of the rewards across the dialog turns. + """ + # reward = reward.to_padded_tensor(padding=0.0) + if reward.ndim < 2 or reward.ndim > 3: + raise ValueError( + f"Expected reward to be a 2D or 3D tensor, got {reward.ndim}D tensor" + ) + return reward.mean(dim=-2).squeeze(-1) + +class AcceptanceRewardSampler(Sampler): + """A sampler for acceptance/rejection sampling.""" + def __init__(self, total_dialog_turns: int) -> None: + super().__init__() + self.total_dialog_turns = total_dialog_turns + self._num_accepted_samples = defaultdict(lambda: 0) + self._num_rejected_samples = defaultdict(lambda: 0) + + def sample(self, storage: Storage, batch_size: int) -> tuple[Any, dict]: + # samples an index corresponding to a prompt + prompt_idx = torch.randint(0, len(storage), (batch_size,)) + + # Within that prompt, sample an index corresponding to a dialog turn - independently for accepted and rejected + higher_accept = torch.tensor([self._num_accepted_samples[prompt_idx.item()] for prompt_idx in prompt_idx]) + higher_rej = torch.tensor([self._num_rejected_samples[prompt_idx.item()] for prompt_idx in prompt_idx]) + + # equiv to randint with variable upper bound + accepted_idx = (torch.rand(higher_accept.shape) * higher_accept).floor().long() + rejected_idx = (torch.rand(higher_rej.shape) * higher_rej).floor().long() + + # Compound the indices + accepted_idx = torch.stack([prompt_idx, torch.zeros_like(prompt_idx), accepted_idx]) + rejected_idx = torch.stack([prompt_idx, torch.ones_like(prompt_idx), rejected_idx]) + return torch.cat([accepted_idx, rejected_idx]), {} + + def extend(self, index: int | torch.Tensor) -> None: + print(f'index: {index}') + if isinstance(index, torch.Tensor): + index = index.tolist() + if isinstance(index, list): + for i in index: + self.extend(i) + return + if not isinstance(index, int): + raise ValueError(f"Expected an int, got {type(index)=}") + # Keep track of the accepted and rejected indices + self._num_accepted_samples[index] = self._num_accepted_samples[index] + self.total_dialog_turns + self._num_rejected_samples[index] = self._num_rejected_samples[index] + self.total_dialog_turns + + def state_dict(self) -> dict[str, Any]: + return { + "num_accepted_samples": self._num_accepted_samples, + "num_rejected_samples": self._num_rejected_samples, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._num_accepted_samples = state_dict["num_accepted_samples"] + self._num_rejected_samples = state_dict["num_rejected_samples"] + def dumps(self) -> str: + raise NotImplementedError("Not implemented") + def loads(self, state_dict: str) -> None: + raise NotImplementedError("Not implemented") + def _empty(self): + self.__init__() + +class AcceptanceRewardWriter(RoundRobinWriter): + def __init__(self, total_dialog_turns: int) -> None: + super().__init__() + self.total_dialog_turns = total_dialog_turns + self._num_accepted_samples = defaultdict(lambda: 0) + self._num_rejected_samples = defaultdict(lambda: 0) + + def add(self, data: Any) -> torch.Tensor | int: + pass + + def extend(self, data: Any) -> torch.Tensor: + pass + + def state_dict(self) -> dict[str, Any]: + return { + "num_accepted_samples": self._num_accepted_samples, + "num_rejected_samples": self._num_rejected_samples, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._num_accepted_samples = state_dict["num_accepted_samples"] + self._num_rejected_samples = state_dict["num_rejected_samples"] + def dumps(self) -> str: + raise NotImplementedError("Not implemented") + def loads(self, state_dict: str) -> None: + raise NotImplementedError("Not implemented") + def _empty(self): + self.__init__() + \ No newline at end of file diff --git a/torchrl/objectives/llm/__init__.py b/torchrl/objectives/llm/__init__.py index eb3920845d5..5435d5e41d8 100644 --- a/torchrl/objectives/llm/__init__.py +++ b/torchrl/objectives/llm/__init__.py @@ -4,7 +4,19 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from .dpo import dpo_loss, DPOLoss, DPOLossOutput from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage -from .sft import SFTLoss, SFTLossOutput +from .sft import minor_sft_loss, sft_loss, SFTLoss, SFTLossOutput -__all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage", "SFTLoss", "SFTLossOutput"] +__all__ = [ + "DPOLoss", + "DPOLossOutput", + "GRPOLoss", + "GRPOLossOutput", + "MCAdvantage", + "SFTLoss", + "SFTLossOutput", + "dpo_loss", + "sft_loss", + "minor_sft_loss", +] diff --git a/torchrl/objectives/llm/dpo.py b/torchrl/objectives/llm/dpo.py new file mode 100644 index 00000000000..37a9d502d49 --- /dev/null +++ b/torchrl/objectives/llm/dpo.py @@ -0,0 +1,471 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import contextlib +import warnings + +from dataclasses import dataclass +from typing import Literal + +import torch +from tensordict import NestedKey, TensorClass, TensorDictBase +from tensordict.nn import TensorDictModule +from tensordict.utils import _zip_strict +from torchrl.data import History +from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper +from torchrl.objectives.common import LossModule + + +def dpo_loss( + policy_chosen_logprob: torch.Tensor, + policy_rejected_logprob: torch.Tensor, + reference_chosen_logprob: torch.Tensor, + reference_rejected_logprob: torch.Tensor, + beta: float, + reduction: Literal["mean", "sum", "none"], +) -> torch.Tensor: + """Compute the DPO loss. + + Args: + policy_chosen_logps (torch.Tensor): Log probabilities of chosen responses from the policy model. + policy_rejected_logps (torch.Tensor): Log probabilities of rejected responses from the policy model. + reference_chosen_logps (torch.Tensor): Log probabilities of chosen responses from the reference model. + reference_rejected_logps (torch.Tensor): Log probabilities of rejected responses from the reference model. + beta (float): The beta parameter controlling the strength of the preference optimization. + reduction (str): The reduction to apply to the loss. + + Returns: + torch.Tensor: The DPO loss. + + References: + - Rafailov, R., Sharma, A., Mitchell, E., Ermon, S., Manning, C. D., & Finn, C. (2023). + `"Direct Preference Optimization: Your Language Model is Secretly a Reward Model" `_ + """ + chosen_rewards = beta * (policy_chosen_logprob - reference_chosen_logprob) + rejected_rewards = beta * (policy_rejected_logprob - reference_rejected_logprob) + + losses = -torch.nn.functional.logsigmoid(chosen_rewards - rejected_rewards) + + if reduction == "mean": + return losses.mean() + elif reduction == "sum": + return losses.sum() + elif reduction == "none": + return losses + else: + raise ValueError(f"Invalid reduction: {reduction}") + + +class DPOLossOutput(TensorClass["nocast"]): + """DPO Loss Output. + + Attributes: + loss_dpo (torch.Tensor): The loss for the DPO objective. + loss_kl_to_ref (torch.Tensor | None): The loss for the KL divergence to the reference model. + kl_to_ref (torch.Tensor | None): The KL divergence to the reference model. + chosen_rewards (torch.Tensor): The rewards for chosen responses. + rejected_rewards (torch.Tensor): The rewards for rejected responses. + accuracy (torch.Tensor): The accuracy of preference prediction. + + .. note:: + The loss components are kept separate to allow for logging and visualization. + Before backpropagation, the loss components are to be summed together. Since non-loss components are not differentiable + when the loss is constructed via :class:`~torchrl.objectives.llm.dpo.DPOLoss`, summing + the :class:`~torchrl.objectives.llm.dpo.DPOLossOutput` directly is a proper way of obtaining the total loss. + + >>> loss_fn = DPOLoss(...) + >>> loss_output = loss_fn(td) + >>> loss = loss_output.loss_dpo + loss_output.loss_kl_to_ref + >>> loss.backward() + >>> # or equivalently + >>> loss = loss_fn(td) + >>> loss.sum(reduce=True).backward() + """ + + loss_dpo: torch.Tensor + loss_kl_to_ref: torch.Tensor | None = None + kl_to_ref: torch.Tensor | None = None + chosen_rewards: torch.Tensor | None = None + rejected_rewards: torch.Tensor | None = None + accuracy: torch.Tensor | None = None + + +class DPOLoss(LossModule): + r"""Direct Preference Optimization loss. + + Args: + actor_network (TensorDictModule): the actor network. Usually a :class:`~torchrl.modules.llm.TransformersWrapper` instance, + with `return_log_prob=True` and `from_text=True`. + tokenizer (`Tokenizer`): the tokenizer to be used to tokenize the input and compute the assistant mask. If not provided, the tokenizer will be inferred from the `actor_network`. + tokenizer_kwargs (dict, optional): keyword arguments to pass to the tokenizer during :meth:`~torchrl.data.llm.chat.History.apply_chat_template`. + This can be used to override arguments such as the `chat_template` or `chat_template_name`. + beta (float): The beta parameter controlling the strength of the preference optimization. Higher values make the optimization more aggressive. + reduction (Literal["mean", "sum", "none"], optional): the reduction to apply to the loss. Defaults to `"mean"`. + normalize_by_seq_length (bool, optional): whether to normalize the loss by the sequence length. Defaults to `True`. + kl_to_ref_coeff (float | None, optional): coefficient for KL divergence to reference model. Defaults to `None`. + device (torch.device | None, optional): the device to use for the loss, when tokenizing the input. Defaults to `None`. + + .. note:: + The input tensordict is expected to contain the following keys by default: + - ``("next", "history")``: The chat history + - ``("next", "is_chosen")``: Boolean tensor indicating which response is chosen (True) vs rejected (False) + - ``("next", "ref_log_prob")`` (optional): Reference model log probabilities, required if kl_to_ref_coeff is set + + These keys can be customized using the ``set_keys()`` method. + + .. seealso:: :class:`~torchrl.envs.llm.transforms.RetrieveLogProb` for the KL divergence computation. + + References: + - Rafailov, R., Sharma, A., Mitchell, E., Ermon, S., Manning, C. D., & Finn, C. (2023). + `"Direct Preference Optimization: Your Language Model is Secretly a Reward Model" `_ + + Examples: + >>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES + >>> from torchrl.modules.llm import TransformersWrapper + >>> from torchrl.objectives.llm.dpo import DPOLoss + >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM + >>> from tensordict import TensorDict, lazy_stack + >>> import torch + >>> + >>> # Create preference data + >>> chats = [ + ... [ + ... {"role": "system", "content": "You are a helpful assistant."}, + ... {"role": "user", "content": "What's 2+2?"}, + ... {"role": "assistant", "content": "2+2 equals 4."}, # chosen + ... {"role": "assistant", "content": "I don't know."}, # rejected + ... ], + ... [ + ... {"role": "system", "content": "You are a helpful assistant."}, + ... {"role": "user", "content": "Explain quantum physics."}, + ... {"role": "assistant", "content": "Quantum physics is complex."}, # chosen + ... {"role": "assistant", "content": "It's magic."}, # rejected + ... ], + ... ] + >>> history = History.from_chats(chats) + >>> + >>> # Setup tokenizer and model + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + >>> tokenizer.pad_token = tokenizer.eos_token + >>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"] + >>> model = OPTForCausalLM(OPTConfig()).eval() + >>> + >>> # Create training and reference policies + >>> policy_train = TransformersWrapper( + ... model, + ... tokenizer=tokenizer, + ... generate=False, + ... from_text=True, + ... chat_template_name="qwen", + ... ) + >>> policy_ref = TransformersWrapper( + ... model, + ... tokenizer=tokenizer, + ... generate=False, + ... from_text=True, + ... return_log_probs=True, + ... chat_template_name="qwen", + ... ) + >>> + >>> # Create the RetrieveLogProb transform + >>> transform = RetrieveLogProb( + ... policy_ref, + ... assistant_only=True, + ... tokenizer_kwargs={"chat_template_name": "qwen"}, + ... tokenizer=tokenizer, + ... ) + >>> + >>> # Prepare data with preference labels + >>> text = history[:, :-2].apply_chat_template( + ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True + ... ) + >>> text_chosen = history[:, -2:-1].apply_chat_template( + ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False + ... ) + >>> text_rejected = history[:, -1:].apply_chat_template( + ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False + ... ) + >>> + >>> # Create preference labels (True for chosen, False for rejected) + >>> is_chosen = torch.tensor([True, False, True, False]).reshape(2, 2) + >>> + >>> td = TensorDict( + ... text=text, + ... text_chosen=text_chosen, + ... text_rejected=text_rejected, + ... history=history, + ... next=TensorDict( + ... is_chosen=is_chosen, + ... done=torch.zeros(2, dtype=torch.bool), + ... history=history, + ... ), + ... batch_size=(2,), + ... ) + >>> data = lazy_stack(list(td.unbind(0))) + >>> + >>> # Apply the transform to get reference log probabilities + >>> data = transform(data) + >>> assert "ref_log_prob" in data["next"].keys() + >>> + >>> # Use with DPOLoss + >>> loss = DPOLoss( + ... actor_network=policy_train, + ... tokenizer=tokenizer, + ... beta=0.1, + ... reduction="mean", + ... normalize_by_seq_length=True, + ... kl_to_ref_coeff=0.1, + ... tokenizer_kwargs={"chat_template_name": "qwen"}, + ... ) + >>> loss_vals = loss(data) + >>> print(f"DPO Loss: {loss_vals.loss_dpo.item():.4f}") + >>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}") + >>> print(f"Accuracy: {loss_vals.accuracy.item():.4f}") + + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + history (NestedKey): The input tensordict key where the chat history is expected. + Defaults to ``("next", "history")``. + is_chosen (NestedKey): The input tensordict key where the preference labels are expected. + Defaults to ``("next", "is_chosen")``. + ref_log_prob (NestedKey): The input tensordict key where the reference model log probabilities are expected. + Only used when kl_to_ref_coeff is set. Defaults to ``("next", "ref_log_prob")``. + log_probs (NestedKey): The output tensordict key where the model's log probabilities will be written. + Defaults to ``"log_probs"``. + """ + + history: NestedKey = ("next", "history") + is_chosen: NestedKey = ("next", "is_chosen") + ref_log_prob: NestedKey = ("next", "ref_log_prob") + log_probs: NestedKey = "log_probs" + + default_keys = _AcceptedKeys + tensor_keys: _AcceptedKeys + + def __init__( + self, + actor_network: TensorDictModule | TransformersWrapper, + tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821 + tokenizer_kwargs: dict | None = None, + beta: float = 0.1, + reduction: Literal["mean", "sum", "none"] = "mean", + normalize_by_seq_length: bool = True, + kl_to_ref_coeff: float | None = None, + device: torch.device | None = None, + ): + super().__init__() + self.in_keys = [] + self.actor_network = actor_network + if tokenizer is None: + tokenizer = actor_network.tokenizer + self.tokenizer = tokenizer + if tokenizer_kwargs is None: + tokenizer_kwargs = {} + if tokenizer is None: + raise ValueError("Tokenizer must be provided.") + tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True) + tokenizer_kwargs.setdefault("tokenize", True) + tokenizer_kwargs.setdefault("return_tensors", "pt") + tokenizer_kwargs.setdefault("padding", False) + tokenizer_kwargs.setdefault("add_generation_prompt", False) + self.tokenizer_kwargs = tokenizer_kwargs + self.beta = beta + self.reduction = reduction + self.normalize_by_seq_length = normalize_by_seq_length + self.kl_to_ref_coeff = kl_to_ref_coeff + self._set_in_keys() + self.device = device + + def _set_in_keys(self) -> None: + """Sets the input keys for the loss module.""" + in_keys = [self.tensor_keys.history, self.tensor_keys.is_chosen] + if self.kl_to_ref_coeff is not None: + in_keys.append(self.tensor_keys.ref_log_prob) + self.in_keys = in_keys + self.out_keys = [] # Loss modules typically don't have out_keys + + def _kl_to_ref( + self, + cur_log_prob: list[torch.Tensor], + ref_log_prob: list[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute KL divergence to reference model. + + Args: + cur_log_prob (List[torch.Tensor]): Log probabilities from current model. Must have shape [T] where T is the number of tokens in the assistant response. + ref_log_prob (List[torch.Tensor]): Log probabilities from reference model. Must have shape [T] where T is the number of tokens in the assistant response. + + Returns: + tuple[torch.Tensor, torch.Tensor]: (KL loss term, KL penalty for logging) + """ + # Apply mask + ref_log_prob = torch.cat(ref_log_prob) + cur_log_prob = torch.cat(cur_log_prob) + if cur_log_prob.shape != ref_log_prob.shape: + raise ValueError( + f"Current log probabilities and reference log probabilities have different shapes: {cur_log_prob.shape=} vs {ref_log_prob.shape=}." + ) + # Compute KL using same approximation as GRPO + diff = ref_log_prob - cur_log_prob + + kl_penalty = (diff.expm1() - diff).mean() + return self.kl_to_ref_coeff * kl_penalty, kl_penalty + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + # Gather history and preference labels + history: History = tensordict[self.tensor_keys.history] + is_chosen: torch.Tensor = tensordict[self.tensor_keys.is_chosen] + + # Apply tokenizer to history and gather mask + with torch.device( + self.device + ) if self.device is not None else contextlib.nullcontext(): + token_struct = history.apply_chat_template( + tokenizer=self.tokenizer, **self.tokenizer_kwargs + ) + if "assistant_masks" not in token_struct: + raise ValueError( + f"Assistant masks are not present in the token structure: {token_struct=}." + ) + assistant_masks = token_struct.get( + "assistant_masks", + as_list=True, + ) + assistant_masks = [mask.bool() for mask in assistant_masks] + attention_mask = token_struct.get("attention_mask", as_list=True) + attention_mask = [mask.bool() for mask in attention_mask] + assistant_masks = [ + mask & a_mask for mask, a_mask in zip(assistant_masks, attention_mask) + ] + + if not any(mask.any(-1).all() for mask in assistant_masks): + raise ValueError("Some inputs have no valid assistant masks.") + + input_loss = tensordict.select(self.tensor_keys.history) + if ( + isinstance(self.tensor_keys.history, tuple) + and self.tensor_keys.history[0] == "next" + ): + input_loss = input_loss["next"] + + with torch.device( + self.device + ) if self.device is not None else contextlib.nullcontext(): + output_loss = self.actor_network(input_loss) + + # get log-probs + log_probs = output_loss.get( + self.tensor_keys.log_probs, + as_list=True, + ) + # apply mask + if not all( + mask.shape == lp.shape + for mask, lp in _zip_strict(assistant_masks, log_probs) + ): + raise ValueError( + f"Assistant masks and log_probs have different shapes: {[mask.shape for mask in assistant_masks]} vs {[lp.shape for lp in log_probs]}. Tokens from current template: {[inp.shape for inp in token_struct.get('input_ids', as_padded_tensor=True)]}" + ) + + log_probs_masked = [ + lp.masked_fill(~mask, 0.0) + for lp, mask in _zip_strict(log_probs, assistant_masks) + ] + + # Sum log probs, optionally normalize by sequence length + summed_log_probs = torch.stack( + [lp.sum(tensordict.ndim - 1) for lp in log_probs_masked] + ) + seq_lengths = torch.stack( + [mask.sum(tensordict.ndim - 1) for mask in assistant_masks] + ) + if self.normalize_by_seq_length: + # Compute sequence lengths for normalization (number of assistant tokens) + summed_log_probs = summed_log_probs / seq_lengths.clamp(min=1) + + # Split log probs into chosen and rejected based on preference labels + chosen_mask = is_chosen.bool() + rejected_mask = ~is_chosen.bool() + + if not chosen_mask.any() or not rejected_mask.any(): + raise ValueError("Both chosen and rejected responses must be present in the batch.") + + policy_chosen_logps = summed_log_probs[chosen_mask] + policy_rejected_logps = summed_log_probs[rejected_mask] + + # Get reference log probabilities if available + if self.kl_to_ref_coeff is not None: + ref_log_probs = tensordict.get( + self.tensor_keys.ref_log_prob, + default=None, + as_list=True, + ) + if ref_log_probs is None: + raise ValueError( + "Reference log probs not found in tensordict but kl_to_ref_coeff was set" + ) + + # Sum reference log probs similarly to policy log probs + summed_ref_log_probs = torch.stack([lp.sum() for lp in ref_log_probs]).to( + summed_log_probs.device + ) + if self.normalize_by_seq_length: + summed_ref_log_probs = summed_ref_log_probs / seq_lengths.clamp(min=1) + + reference_chosen_logps = summed_ref_log_probs[chosen_mask] + reference_rejected_logps = summed_ref_log_probs[rejected_mask] + else: + # If no reference model, use zeros (equivalent to no reference model in DPO) + reference_chosen_logps = torch.zeros_like(policy_chosen_logps) + reference_rejected_logps = torch.zeros_like(policy_rejected_logps) + + # Compute DPO loss + loss = dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + self.beta, + self.reduction, + ) + + # Compute additional metrics for logging + with torch.no_grad(): + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps) + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps) + accuracy = (chosen_rewards > rejected_rewards).float().mean() + + # Add KL divergence loss if reference model is provided + if self.kl_to_ref_coeff is not None: + loss_kl, kl_penalty = self._kl_to_ref( + [lp[mask] for lp, mask in _zip_strict(log_probs, assistant_masks)], + ref_log_probs, + ) + output = DPOLossOutput( + loss_dpo=loss, + loss_kl_to_ref=loss_kl, + kl_to_ref=kl_penalty.detach(), + chosen_rewards=chosen_rewards.detach(), + rejected_rewards=rejected_rewards.detach(), + accuracy=accuracy, + ) + else: + output = DPOLossOutput( + loss_dpo=loss, + chosen_rewards=chosen_rewards.detach(), + rejected_rewards=rejected_rewards.detach(), + accuracy=accuracy, + ) + + return output \ No newline at end of file diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 1e8f33268ed..272093eb9b1 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -752,10 +752,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: explained_variance = None if self.log_explained_variance: - with torch.no_grad(): # <‑‑ break grad‐flow - tgt = target_return.detach() - pred = state_value.detach() - eps = torch.finfo(tgt.dtype).eps + with torch.no_grad(): # <‑‑ break grad‐flow + tgt = target_return.detach() + pred = state_value.detach() + eps = torch.finfo(tgt.dtype).eps resid = torch.var(tgt - pred, unbiased=False, dim=0) total = torch.var(tgt, unbiased=False, dim=0) explained_variance = 1.0 - resid / (total + eps) @@ -819,7 +819,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", self._weighted_loss_entropy(entropy)) if self._has_critic: - loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict) + loss_critic, value_clip_fraction, explained_variance = self.loss_critic( + tensordict + ) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: td_out.set("value_clip_fraction", value_clip_fraction) @@ -1189,7 +1191,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", self._weighted_loss_entropy(entropy)) if self._has_critic: - loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict) + loss_critic, value_clip_fraction, explained_variance = self.loss_critic( + tensordict + ) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: td_out.set("value_clip_fraction", value_clip_fraction) @@ -1537,7 +1541,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", self._weighted_loss_entropy(entropy)) if self._has_critic: - loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict_copy) + loss_critic, value_clip_fraction, explained_variance = self.loss_critic( + tensordict_copy + ) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: td_out.set("value_clip_fraction", value_clip_fraction)