From b4c23584d0306273738841300d040537dec41eac Mon Sep 17 00:00:00 2001 From: Bo Yang Date: Mon, 17 Mar 2025 00:46:09 +0000 Subject: [PATCH 1/2] feat: add CE-U loss --- README.md | 4 +- configs/trainer/CEU.yaml | 6 +++ scripts/tofu_unlearn.sh | 1 + src/trainer/__init__.py | 2 + src/trainer/unlearn/ceu.py | 11 ++++++ src/trainer/utils.py | 77 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 configs/trainer/CEU.yaml create mode 100644 src/trainer/unlearn/ceu.py diff --git a/README.md b/README.md index 403754b..ec88872 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ ## 📖 Overview -We provide efficient and streamlined implementations of the TOFU, MUSE unlearning benchmarks while supporting 6 unlearning methods, 3+ datasets, 6+ evaluation metrics, and 7+ LLMs. Each of these can be easily extended to incorporate more variants. +We provide efficient and streamlined implementations of the TOFU, MUSE unlearning benchmarks while supporting 7 unlearning methods, 3+ datasets, 6+ evaluation metrics, and 7+ LLMs. Each of these can be easily extended to incorporate more variants. We invite the LLM unlearning community to collaborate by adding new benchmarks, unlearning methods, datasets and evaluation metrics here to expand OpenUnlearning's features, gain feedback from wider usage and drive progress in the field. @@ -35,7 +35,7 @@ We provide several variants for each of the components in the unlearning pipelin | **Component** | **Available Options** | |------------------------|----------------------| | **Benchmarks** | [TOFU](https://arxiv.org/abs/2401.06121), [MUSE](https://muse-bench.github.io/) | -| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU | +| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU, CE-U | | **Evaluation Metrics** | Verbatim Probability, Verbatim ROUGE, QA-ROUGE, MIA Attacks, TruthRatio, Model Utility | | **Datasets** | MUSE-News (BBC), MUSE-Books (Harry Potter), TOFU (different splits) | | **Model Families** | TOFU: LLaMA-3.2, LLaMA-3.1, LLaMA-2; MUSE: LLaMA-2, ICLM; Additional: Phi-3.5, Phi-1.5, Gemma | diff --git a/configs/trainer/CEU.yaml b/configs/trainer/CEU.yaml new file mode 100644 index 0000000..0af0600 --- /dev/null +++ b/configs/trainer/CEU.yaml @@ -0,0 +1,6 @@ +defaults: + - finetune + +handler: CEU +method_args: + ignore_first_n_answer_tokens: 1 diff --git a/scripts/tofu_unlearn.sh b/scripts/tofu_unlearn.sh index ae33189..d02d452 100644 --- a/scripts/tofu_unlearn.sh +++ b/scripts/tofu_unlearn.sh @@ -15,6 +15,7 @@ trainers_experiments=( "NPO unlearn/tofu/default.yaml" "DPO unlearn/tofu/idk.yaml" "RMU unlearn/tofu/default.yaml" + "CEU unlearn/tofu/default.yaml" ) forget_retain_splits=( "forget01 retain99" diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py index 7e195fa..d30b173 100644 --- a/src/trainer/__init__.py +++ b/src/trainer/__init__.py @@ -10,6 +10,7 @@ from trainer.unlearn.dpo import DPO from trainer.unlearn.simnpo import SimNPO from trainer.unlearn.rmu import RMU +from trainer.unlearn.ceu import CEU TRAINER_REGISTRY: Dict[str, Any] = {} @@ -81,3 +82,4 @@ def load_trainer( _register_trainer(DPO) _register_trainer(SimNPO) _register_trainer(RMU) +_register_trainer(CEU) diff --git a/src/trainer/unlearn/ceu.py b/src/trainer/unlearn/ceu.py new file mode 100644 index 0000000..417617f --- /dev/null +++ b/src/trainer/unlearn/ceu.py @@ -0,0 +1,11 @@ +from trainer.unlearn.base import UnlearnTrainer +from trainer.utils import compute_batch_ceu +class CEU(UnlearnTrainer): + def __init__(self, ignore_first_n_answer_tokens=1, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ignore_first_n_answer_tokens = ignore_first_n_answer_tokens + + def compute_loss(self, model, inputs, return_outputs=False): + forget_inputs = inputs["forget"] + loss, outputs = compute_batch_ceu(model, forget_inputs, ignore_first_n_answer_tokens=self.ignore_first_n_answer_tokens) + return (loss, outputs) if return_outputs else loss diff --git a/src/trainer/utils.py b/src/trainer/utils.py index c5125b7..b783029 100644 --- a/src/trainer/utils.py +++ b/src/trainer/utils.py @@ -55,3 +55,80 @@ def compute_dpo_loss(model, ref_model, win_inputs=None, lose_inputs=None, beta=1 loss = -2 / beta * F.logsigmoid(beta * (win_log_ratio - lose_log_ratio)).mean() return loss, (win_outputs, lose_outputs) + + +def cross_entropy_unlearning_loss( + logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, +) -> torch.Tensor: + """ + Implementation of Cross Entropy Unlearning Loss (CE-U). + + This function creates a modified target distribution by setting the logit corresponding to the true label to negative infinity, effectively forcing the model to assign zero probability to the correct answer. The loss then minimizes the KL divergence between this target distribution and the model's output. + + Args: + logits: Model output logits with shape [batch_size, sequence_length, vocabulary_size] + labels: Ground truth token indices with shape [batch_size, sequence_length] + ignore_index: Token indices to ignore in the loss calculation (typically padding) + + Returns: + A scalar tensor representing the mean unlearning loss across valid positions + """ + batch_size, sequence_length, vocabulary_size = logits.shape + # Extract valid logits and labels based on ignore_index. + if ignore_index is not None: + # Shape: [batch_size, sequence_length], boolean mask + valid_mask = labels != ignore_index + # Shape: [num_valid_positions, vocabulary_size] + valid_logits = logits[valid_mask] + # Shape: [num_valid_positions] + valid_labels = labels[valid_mask] + else: + # Shape: [batch_size*sequence_length, vocabulary_size] + valid_logits = logits.view(-1, vocabulary_size) + # Shape: [batch_size*sequence_length] + valid_labels = labels.view(-1) + + # Create a copy of valid_logits to generate the target distribution + # Shape: [num_valid_positions, vocabulary_size] + valid_target_logits = valid_logits.detach().clone() + + # Suppress the logits corresponding to the true token by setting them to -inf. + # This ensures that the probability for the true token is effectively zero after softmax. + valid_target_logits.scatter_( + dim=-1, + index=valid_labels.unsqueeze(-1), # Shape: [num_valid_positions, 1] + value=float("-inf"), + ) # Result shape: [num_valid_positions, vocabulary_size] + + # Apply softmax to generate the target probability distribution + # Shape: [num_valid_positions, vocabulary_size] + valid_target_probabilities = F.softmax(valid_target_logits, dim=-1) + + # Compute the cross entropy loss between input logits and target probabilities + # The loss is averaged over the valid positions and returns a scalar tensor + return F.cross_entropy( + input=valid_logits, + target=valid_target_probabilities, + ) + + +def compute_batch_ceu(model, inputs, ignore_first_n_answer_tokens=1): + outputs = model(**inputs) + logits = outputs.logits + labels = inputs["labels"] + + # Implement the trick to ignore the first n answer tokens mentioned in the footnote in the Training Settings section of arXiv:2503.01224 + valid_mask = labels != -100 + update_mask = ( + valid_mask.cumsum(dim=-1) <= ignore_first_n_answer_tokens + ) & valid_mask + labels_without_first_n_answer_tokens = labels.masked_fill(update_mask, -100) + + shifted_labels = labels_without_first_n_answer_tokens[..., 1:].contiguous() + shifted_logits = logits[..., :-1, :].contiguous() + loss = cross_entropy_unlearning_loss( + shifted_logits, shifted_labels, ignore_index=-100 + ) + return loss, outputs From c32e5f42a4b4434289f4078adcb9f218a5a27eeb Mon Sep 17 00:00:00 2001 From: Bo Yang Date: Mon, 17 Mar 2025 22:57:30 +0000 Subject: [PATCH 2/2] chore: move CE-U related functions to `ceu.py` --- src/trainer/unlearn/ceu.py | 82 +++++++++++++++++++++++++++++++++++++- src/trainer/utils.py | 77 ----------------------------------- 2 files changed, 81 insertions(+), 78 deletions(-) diff --git a/src/trainer/unlearn/ceu.py b/src/trainer/unlearn/ceu.py index 417617f..70e5bba 100644 --- a/src/trainer/unlearn/ceu.py +++ b/src/trainer/unlearn/ceu.py @@ -1,5 +1,85 @@ from trainer.unlearn.base import UnlearnTrainer -from trainer.utils import compute_batch_ceu + +import torch +import torch.nn.functional as F + + +def cross_entropy_unlearning_loss( + logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, +) -> torch.Tensor: + """ + Implementation of Cross Entropy Unlearning Loss (CE-U). + + This function creates a modified target distribution by setting the logit corresponding to the true label to negative infinity, effectively forcing the model to assign zero probability to the correct answer. The loss then minimizes the KL divergence between this target distribution and the model's output. + + Args: + logits: Model output logits with shape [batch_size, sequence_length, vocabulary_size] + labels: Ground truth token indices with shape [batch_size, sequence_length] + ignore_index: Token indices to ignore in the loss calculation (typically padding) + + Returns: + A scalar tensor representing the mean unlearning loss across valid positions + """ + batch_size, sequence_length, vocabulary_size = logits.shape + # Extract valid logits and labels based on ignore_index. + if ignore_index is not None: + # Shape: [batch_size, sequence_length], boolean mask + valid_mask = labels != ignore_index + # Shape: [num_valid_positions, vocabulary_size] + valid_logits = logits[valid_mask] + # Shape: [num_valid_positions] + valid_labels = labels[valid_mask] + else: + # Shape: [batch_size*sequence_length, vocabulary_size] + valid_logits = logits.view(-1, vocabulary_size) + # Shape: [batch_size*sequence_length] + valid_labels = labels.view(-1) + + # Create a copy of valid_logits to generate the target distribution + # Shape: [num_valid_positions, vocabulary_size] + valid_target_logits = valid_logits.detach().clone() + + # Suppress the logits corresponding to the true token by setting them to -inf. + # This ensures that the probability for the true token is effectively zero after softmax. + valid_target_logits.scatter_( + dim=-1, + index=valid_labels.unsqueeze(-1), # Shape: [num_valid_positions, 1] + value=float("-inf"), + ) # Result shape: [num_valid_positions, vocabulary_size] + + # Apply softmax to generate the target probability distribution + # Shape: [num_valid_positions, vocabulary_size] + valid_target_probabilities = F.softmax(valid_target_logits, dim=-1) + + # Compute the cross entropy loss between input logits and target probabilities + # The loss is averaged over the valid positions and returns a scalar tensor + return F.cross_entropy( + input=valid_logits, + target=valid_target_probabilities, + ) + + +def compute_batch_ceu(model, inputs, ignore_first_n_answer_tokens=1): + outputs = model(**inputs) + logits = outputs.logits + labels = inputs["labels"] + + # Implement the trick to ignore the first n answer tokens mentioned in the footnote in the Training Settings section of arXiv:2503.01224 + valid_mask = labels != -100 + update_mask = ( + valid_mask.cumsum(dim=-1) <= ignore_first_n_answer_tokens + ) & valid_mask + labels_without_first_n_answer_tokens = labels.masked_fill(update_mask, -100) + + shifted_labels = labels_without_first_n_answer_tokens[..., 1:].contiguous() + shifted_logits = logits[..., :-1, :].contiguous() + loss = cross_entropy_unlearning_loss( + shifted_logits, shifted_labels, ignore_index=-100 + ) + return loss, outputs + class CEU(UnlearnTrainer): def __init__(self, ignore_first_n_answer_tokens=1, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/trainer/utils.py b/src/trainer/utils.py index b783029..c5125b7 100644 --- a/src/trainer/utils.py +++ b/src/trainer/utils.py @@ -55,80 +55,3 @@ def compute_dpo_loss(model, ref_model, win_inputs=None, lose_inputs=None, beta=1 loss = -2 / beta * F.logsigmoid(beta * (win_log_ratio - lose_log_ratio)).mean() return loss, (win_outputs, lose_outputs) - - -def cross_entropy_unlearning_loss( - logits: torch.Tensor, - labels: torch.Tensor, - ignore_index: int = -100, -) -> torch.Tensor: - """ - Implementation of Cross Entropy Unlearning Loss (CE-U). - - This function creates a modified target distribution by setting the logit corresponding to the true label to negative infinity, effectively forcing the model to assign zero probability to the correct answer. The loss then minimizes the KL divergence between this target distribution and the model's output. - - Args: - logits: Model output logits with shape [batch_size, sequence_length, vocabulary_size] - labels: Ground truth token indices with shape [batch_size, sequence_length] - ignore_index: Token indices to ignore in the loss calculation (typically padding) - - Returns: - A scalar tensor representing the mean unlearning loss across valid positions - """ - batch_size, sequence_length, vocabulary_size = logits.shape - # Extract valid logits and labels based on ignore_index. - if ignore_index is not None: - # Shape: [batch_size, sequence_length], boolean mask - valid_mask = labels != ignore_index - # Shape: [num_valid_positions, vocabulary_size] - valid_logits = logits[valid_mask] - # Shape: [num_valid_positions] - valid_labels = labels[valid_mask] - else: - # Shape: [batch_size*sequence_length, vocabulary_size] - valid_logits = logits.view(-1, vocabulary_size) - # Shape: [batch_size*sequence_length] - valid_labels = labels.view(-1) - - # Create a copy of valid_logits to generate the target distribution - # Shape: [num_valid_positions, vocabulary_size] - valid_target_logits = valid_logits.detach().clone() - - # Suppress the logits corresponding to the true token by setting them to -inf. - # This ensures that the probability for the true token is effectively zero after softmax. - valid_target_logits.scatter_( - dim=-1, - index=valid_labels.unsqueeze(-1), # Shape: [num_valid_positions, 1] - value=float("-inf"), - ) # Result shape: [num_valid_positions, vocabulary_size] - - # Apply softmax to generate the target probability distribution - # Shape: [num_valid_positions, vocabulary_size] - valid_target_probabilities = F.softmax(valid_target_logits, dim=-1) - - # Compute the cross entropy loss between input logits and target probabilities - # The loss is averaged over the valid positions and returns a scalar tensor - return F.cross_entropy( - input=valid_logits, - target=valid_target_probabilities, - ) - - -def compute_batch_ceu(model, inputs, ignore_first_n_answer_tokens=1): - outputs = model(**inputs) - logits = outputs.logits - labels = inputs["labels"] - - # Implement the trick to ignore the first n answer tokens mentioned in the footnote in the Training Settings section of arXiv:2503.01224 - valid_mask = labels != -100 - update_mask = ( - valid_mask.cumsum(dim=-1) <= ignore_first_n_answer_tokens - ) & valid_mask - labels_without_first_n_answer_tokens = labels.masked_fill(update_mask, -100) - - shifted_labels = labels_without_first_n_answer_tokens[..., 1:].contiguous() - shifted_logits = logits[..., :-1, :].contiguous() - loss = cross_entropy_unlearning_loss( - shifted_logits, shifted_labels, ignore_index=-100 - ) - return loss, outputs