diff --git a/README.md b/README.md index 0b5d65a..dcc6f80 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ ## 📖 Overview -We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 8+ unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants. +We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 11+ unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants. We invite the LLM unlearning community to collaborate by adding new benchmarks, unlearning methods, datasets and evaluation metrics here to expand OpenUnlearning's features, gain feedback from wider usage and drive progress in the field. @@ -33,7 +33,7 @@ We invite the LLM unlearning community to collaborate by adding new benchmarks, 🌟 **Highlights:** - A detailed technical report on OpenUnlearning covering the design, features, and implementation. -- A meta-evaluation framework for benchmarking unlearning evaluations across 450+ open-source models. +- A meta-evaluation framework for benchmarking unlearning evaluations across 450+ models, open-sourced on HuggingFace 🤗: [TOFU Models w & w/o Knowledge](https://huggingface.co/collections/open-unlearning/tofu-models-w-and-w-o-knowledge-6861e4d935eb99ba162e55cd), [TOFU Unlearned Models](https://huggingface.co/collections/open-unlearning/tofu-unlearned-models-6860f6cf3fe35d0223d92e88). - Results benchmarking 8 diverse unlearning methods in one place using 10 evaluation metrics on TOFU.
@@ -77,10 +77,10 @@ We provide several variants for each of the components in the unlearning pipelin | **Component** | **Available Options** | |------------------------|----------------------| | **Benchmarks** | [TOFU](https://arxiv.org/abs/2401.06121), [MUSE](https://muse-bench.github.io/), [WMDP](https://www.wmdp.ai/) | -| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU, UNDIAL, AltPO, SatImp, WGA | +| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU, UNDIAL, AltPO, SatImp, WGA, CE-U | | **Evaluation Metrics** | Verbatim Probability, Verbatim ROUGE, Knowledge QA-ROUGE, Model Utility, Forget Quality, TruthRatio, Extraction Strength, Exact Memorization, 6 MIA attacks, [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) | | **Datasets** | MUSE-News (BBC), MUSE-Books (Harry Potter), TOFU (different splits), WMDP-Bio, WMDP-Cyber | -| **Model Families** | TOFU: LLaMA-3.2, LLaMA-3.1, LLaMA-2; MUSE: LLaMA-2; Additional: Phi-3.5, Phi-1.5, Gemma, Zephyr | +| **Model Families** | TOFU: Llama-3.2, Llama-3.1, Llama-2; MUSE: Llama-2; Additional: Phi-3.5, Phi-1.5, Gemma, Zephyr | --- @@ -124,7 +124,7 @@ python setup_data.py --eval # saves/eval now contains evaluation results of the ### 🔄 Updated TOFU benchmark -We've updated Open-Unlearning's TOFU benchmark target models to use a wider variety of newer architectures with sizes varying from 1B to 8B. These include LLaMA 3.2 1B, LLaMA 3.2 3B, LLaMA 3.1 8B, and the original LLaMA-2 7B (re-created) target models from [the old version of TOFU](github.com/locuslab/tofu). +We've updated Open-Unlearning's TOFU benchmark target models to use a wider variety of newer architectures with sizes varying from 1B to 8B. These include Llama 3.2 1B, Llama 3.2 3B, Llama 3.1 8B, and the original Llama-2 7B (re-created) target models from [the old version of TOFU](github.com/locuslab/tofu). For each architecture, we have finetuned with four different splits of the TOFU datasets: `full`, `retain90`, `retain95`, `retain99`, for a total of 16 finetuned models. The first serves as the target (base model for unlearning) and the rest are retain models used to measure performance against for each forget split. These models are on [HuggingFace](`https://huggingface.co/collections/open-unlearning/tofu-new-models-67bcf636334ea81727573a9f0`) and the paths to these models can be set in the experimental configs or in command-line overrides. diff --git a/community/methods/CEU/run.sh b/community/methods/CEU/run.sh new file mode 100644 index 0000000..b91d504 --- /dev/null +++ b/community/methods/CEU/run.sh @@ -0,0 +1,68 @@ +#!/bin/bash + +export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()") +echo "Master Port: $MASTER_PORT" + +######################################################################################################################## +########################################### Unlearn TOFU models with CE-U ############################################# +######################################################################################################################## + +models=( + "Llama-3.2-1B-Instruct" +) +trainers_experiments=( + "CEU unlearn/tofu/default.yaml" +) +forget_retain_splits=( + "forget10 retain90" + "forget05 retain95" + "forget01 retain99" +) + +per_device_train_batch_size=16 +gradient_accumulation_steps=2 + +lrs=(1e-5) + +for split in "${forget_retain_splits[@]}"; do + forget_split=$(echo $split | cut -d' ' -f1) + retain_split=$(echo $split | cut -d' ' -f2) + for model in "${models[@]}"; do + for trainer_experiment in "${trainers_experiments[@]}"; do + trainer=$(echo $trainer_experiment | cut -d' ' -f1) + experiment=$(echo $trainer_experiment | cut -d' ' -f2) + for lr in "${lrs[@]}"; do + task_name=tofu_${model}_${forget_split}_${trainer}_lr${lr} + model_path=open-unlearning/tofu_${model}_full + echo ${task_name}: Unlearning ${model_path} using ${trainer} + + # Unlearn + CUDA_VISIBLE_DEVICES=0 \ + python src/train.py --config-name=unlearn.yaml \ + experiment=${experiment} \ + trainer=${trainer} \ + task_name=${task_name} \ + model=${model} \ + forget_split=${forget_split} \ + retain_split=${retain_split} \ + model.model_args.pretrained_model_name_or_path=${model_path} \ + retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \ + trainer.args.per_device_train_batch_size=$per_device_train_batch_size \ + trainer.args.gradient_accumulation_steps=$gradient_accumulation_steps \ + trainer.args.eval_strategy=no \ + trainer.args.eval_on_start=False \ + trainer.args.learning_rate=$lr + + # Eval + CUDA_VISIBLE_DEVICES=0 python src/eval.py \ + experiment=eval/tofu/default.yaml \ + forget_split=${forget_split} \ + model=${model} \ + task_name=${task_name} \ + model.model_args.pretrained_model_name_or_path=saves/unlearn/${task_name} \ + paths.output_dir=saves/unlearn/${task_name}/evals \ + retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json + done + done + done +done \ No newline at end of file diff --git a/configs/trainer/CEU.yaml b/configs/trainer/CEU.yaml new file mode 100644 index 0000000..0af0600 --- /dev/null +++ b/configs/trainer/CEU.yaml @@ -0,0 +1,6 @@ +defaults: + - finetune + +handler: CEU +method_args: + ignore_first_n_answer_tokens: 1 diff --git a/docs/links.md b/docs/links.md index 2e1cd84..bbc032c 100644 --- a/docs/links.md +++ b/docs/links.md @@ -36,6 +36,7 @@ Links to research papers and resources corresponding to implemented features in | AltPO | Paper[📄](https://arxiv.org/pdf/2409.13474), Code [🐙](https://github.com/molereddy/Alternate-Preference-Optimization) | | SatImp | Paper[📄](https://arxiv.org/pdf/2505.11953), Code [🐙](https://github.com/Puning97/SatImp-for-LLM-Unlearning) | | WGA (G-effect) | Paper[📄](https://arxiv.org/pdf/2502.19301), Code [🐙](https://github.com/tmlr-group/G-effect) | +| CE-U (Cross-Entropy unlearning) | Paper[📄](https://arxiv.org/pdf/2503.01224) | --- @@ -59,7 +60,7 @@ Links to research papers and resources corresponding to implemented features in | Forget Quality, Truth Ratio, Model Utility | TOFU ([📄](https://arxiv.org/abs/2401.06121)) | | Extraction Strength (ES) | Carlini et al., 2021 ([📄](https://www.usenix.org/conference/usenixsecurity21/presentation/carlini-extracting)), used for unlearning in Wang et al., 2025 ([📄](https://openreview.net/pdf?id=wUtCieKuQU)) | | Exact Memorization (EM) | Tirumala et al., 2022 ([📄](https://proceedings.neurips.cc/paper_files/paper/2022/hash/fa0509f4dab6807e2cb465715bf2d249-Abstract-Conference.html)), used for unlearning in Wang et al., 2025 ([📄](https://openreview.net/pdf?id=wUtCieKuQU)) | -| lm-evaluation-harness | [💻](https://github.com/EleutherAI/lm-evaluation-harness/tree/main) | +| lm-evaluation-harness | Repository: [💻](https://github.com/EleutherAI/lm-evaluation-harness/tree/main) | --- diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py index 97759c1..fa89bf2 100644 --- a/src/trainer/__init__.py +++ b/src/trainer/__init__.py @@ -11,6 +11,7 @@ from trainer.unlearn.simnpo import SimNPO from trainer.unlearn.rmu import RMU from trainer.unlearn.undial import UNDIAL +from trainer.unlearn.ceu import CEU from trainer.unlearn.satimp import SatImp from trainer.unlearn.wga import WGA @@ -93,5 +94,6 @@ def load_trainer( _register_trainer(SimNPO) _register_trainer(RMU) _register_trainer(UNDIAL) +_register_trainer(CEU) _register_trainer(SatImp) _register_trainer(WGA) diff --git a/src/trainer/unlearn/ceu.py b/src/trainer/unlearn/ceu.py new file mode 100644 index 0000000..33da99c --- /dev/null +++ b/src/trainer/unlearn/ceu.py @@ -0,0 +1,96 @@ +from trainer.unlearn.base import UnlearnTrainer + +import torch +import torch.nn.functional as F + + +def cross_entropy_unlearning_loss( + logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, +) -> torch.Tensor: + """ + Implementation of Cross Entropy Unlearning Loss (CE-U). + + This function creates a modified target distribution by setting the logit corresponding to the true label to negative infinity, effectively forcing the model to assign zero probability to the correct answer. The loss then minimizes the KL divergence between this target distribution and the model's output. + + Args: + logits: Model output logits with shape [batch_size, sequence_length, vocabulary_size] + labels: Ground truth token indices with shape [batch_size, sequence_length] + ignore_index: Token indices to ignore in the loss calculation (typically padding) + + Returns: + A scalar tensor representing the mean unlearning loss across valid positions + """ + batch_size, sequence_length, vocabulary_size = logits.shape + # Extract valid logits and labels based on ignore_index. + if ignore_index is not None: + # Shape: [batch_size, sequence_length], boolean mask + valid_mask = labels != ignore_index + # Shape: [num_valid_positions, vocabulary_size] + valid_logits = logits[valid_mask] + # Shape: [num_valid_positions] + valid_labels = labels[valid_mask] + else: + # Shape: [batch_size*sequence_length, vocabulary_size] + valid_logits = logits.view(-1, vocabulary_size) + # Shape: [batch_size*sequence_length] + valid_labels = labels.view(-1) + + # Create a copy of valid_logits to generate the target distribution + # Shape: [num_valid_positions, vocabulary_size] + valid_target_logits = valid_logits.detach().clone() + + # Suppress the logits corresponding to the true token by setting them to -inf. + # This ensures that the probability for the true token is effectively zero after softmax. + valid_target_logits.scatter_( + dim=-1, + index=valid_labels.unsqueeze(-1), # Shape: [num_valid_positions, 1] + value=float("-inf"), + ) # Result shape: [num_valid_positions, vocabulary_size] + + # Apply softmax to generate the target probability distribution + # Shape: [num_valid_positions, vocabulary_size] + valid_target_probabilities = F.softmax(valid_target_logits, dim=-1) + + # Compute the cross entropy loss between input logits and target probabilities + # The loss is averaged over the valid positions and returns a scalar tensor + return F.cross_entropy( + input=valid_logits, + target=valid_target_probabilities, + ) + + +def compute_batch_ceu(model, inputs, ignore_first_n_answer_tokens=1): + outputs = model(**inputs) + logits = outputs.logits + labels = inputs["labels"] + + # Implement the trick to ignore the first n answer tokens mentioned in the footnote in the Training Settings section of arXiv:2503.01224 + valid_mask = labels != -100 + update_mask = ( + valid_mask.cumsum(dim=-1) <= ignore_first_n_answer_tokens + ) & valid_mask + labels_without_first_n_answer_tokens = labels.masked_fill(update_mask, -100) + + shifted_labels = labels_without_first_n_answer_tokens[..., 1:].contiguous() + shifted_logits = logits[..., :-1, :].contiguous() + loss = cross_entropy_unlearning_loss( + shifted_logits, shifted_labels, ignore_index=-100 + ) + return loss, outputs + + +class CEU(UnlearnTrainer): + def __init__(self, ignore_first_n_answer_tokens=1, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ignore_first_n_answer_tokens = ignore_first_n_answer_tokens + + def compute_loss(self, model, inputs, return_outputs=False): + forget_inputs = inputs["forget"] + loss, outputs = compute_batch_ceu( + model, + forget_inputs, + ignore_first_n_answer_tokens=self.ignore_first_n_answer_tokens, + ) + return (loss, outputs) if return_outputs else loss