Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
54d3560
Fix hyperlinks in README (#2)
molereddy Mar 1, 2025
4c36e4f
Fixed DPO command
Dornavineeth Mar 2, 2025
f7a69de
download idk
Dornavineeth Mar 2, 2025
1bd4411
Merge pull request #3 from Dornavineeth/dpo_fix
Dornavineeth Mar 2, 2025
332af36
Revert "Dpo fix"
Dornavineeth Mar 2, 2025
7d6aef3
Merge pull request #4 from Dornavineeth/revert-3-dpo_fix
Dornavineeth Mar 2, 2025
f468efb
download idk data
Dornavineeth Mar 2, 2025
ca8d503
fix dpo experiment config
Dornavineeth Mar 2, 2025
dde60d3
Merge pull request #5 from Dornavineeth/dpo_fix
Dornavineeth Mar 2, 2025
6367fb6
Merge branch 'locuslab:main' into main
molereddy Mar 9, 2025
8b073d6
RMU (#6)
Dornavineeth Mar 9, 2025
855c5f3
Merge branch 'locuslab:main' into main
molereddy Mar 11, 2025
4fb577c
Merge branch 'locuslab:main' into main
Dornavineeth Mar 23, 2025
dccb831
Add structure to contributions, setup leaderboard, update documentati…
Dornavineeth Mar 27, 2025
e3c4709
Merge branch 'locuslab:main' into main
molereddy Mar 27, 2025
c2df505
Merge branch 'main' of https://github.com/Dornavineeth/open-unlearning
molereddy Apr 9, 2025
26aa294
Fix documentation and some miscellaneous things (#13)
molereddy Apr 14, 2025
8ff8d58
bunch of tokenization related bug-fixes (#14)
molereddy Apr 25, 2025
3bad69d
Merge branch 'main' of https://github.com/Dornavineeth/open-unlearning
molereddy Apr 29, 2025
0e8a571
Merge branch 'main' of https://github.com/Dornavineeth/open-unlearning
molereddy Jun 21, 2025
ee70738
Add the OpenUnlearning paper (#20)
molereddy Jun 22, 2025
b41e6cd
Merge remote-tracking branch 'locuslab/main'
molereddy Jun 23, 2025
1ec7b17
Merge branch 'locuslab:main' into main
molereddy Jun 30, 2025
abacb26
feat: add CE-U loss (#21)
molereddy Jun 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

## 📖 Overview

We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 8+ unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants.
We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 11+ unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants.

We invite the LLM unlearning community to collaborate by adding new benchmarks, unlearning methods, datasets and evaluation metrics here to expand OpenUnlearning's features, gain feedback from wider usage and drive progress in the field.

Expand All @@ -33,7 +33,7 @@ We invite the LLM unlearning community to collaborate by adding new benchmarks,

🌟 **Highlights:**
- A detailed technical report on OpenUnlearning covering the design, features, and implementation.
- A meta-evaluation framework for benchmarking unlearning evaluations across 450+ open-source models.
- A meta-evaluation framework for benchmarking unlearning evaluations across 450+ models, open-sourced on HuggingFace 🤗: [TOFU Models w & w/o Knowledge](https://huggingface.co/collections/open-unlearning/tofu-models-w-and-w-o-knowledge-6861e4d935eb99ba162e55cd), [TOFU Unlearned Models](https://huggingface.co/collections/open-unlearning/tofu-unlearned-models-6860f6cf3fe35d0223d92e88).
- Results benchmarking 8 diverse unlearning methods in one place using 10 evaluation metrics on TOFU.

<details>
Expand Down Expand Up @@ -77,10 +77,10 @@ We provide several variants for each of the components in the unlearning pipelin
| **Component** | **Available Options** |
|------------------------|----------------------|
| **Benchmarks** | [TOFU](https://arxiv.org/abs/2401.06121), [MUSE](https://muse-bench.github.io/), [WMDP](https://www.wmdp.ai/) |
| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU, UNDIAL, AltPO, SatImp, WGA |
| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU, UNDIAL, AltPO, SatImp, WGA, CE-U |
| **Evaluation Metrics** | Verbatim Probability, Verbatim ROUGE, Knowledge QA-ROUGE, Model Utility, Forget Quality, TruthRatio, Extraction Strength, Exact Memorization, 6 MIA attacks, [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) |
| **Datasets** | MUSE-News (BBC), MUSE-Books (Harry Potter), TOFU (different splits), WMDP-Bio, WMDP-Cyber |
| **Model Families** | TOFU: LLaMA-3.2, LLaMA-3.1, LLaMA-2; MUSE: LLaMA-2; Additional: Phi-3.5, Phi-1.5, Gemma, Zephyr |
| **Model Families** | TOFU: Llama-3.2, Llama-3.1, Llama-2; MUSE: Llama-2; Additional: Phi-3.5, Phi-1.5, Gemma, Zephyr |

---

Expand Down Expand Up @@ -124,7 +124,7 @@ python setup_data.py --eval # saves/eval now contains evaluation results of the

### 🔄 Updated TOFU benchmark

We've updated Open-Unlearning's TOFU benchmark target models to use a wider variety of newer architectures with sizes varying from 1B to 8B. These include LLaMA 3.2 1B, LLaMA 3.2 3B, LLaMA 3.1 8B, and the original LLaMA-2 7B (re-created) target models from [the old version of TOFU](github.com/locuslab/tofu).
We've updated Open-Unlearning's TOFU benchmark target models to use a wider variety of newer architectures with sizes varying from 1B to 8B. These include Llama 3.2 1B, Llama 3.2 3B, Llama 3.1 8B, and the original Llama-2 7B (re-created) target models from [the old version of TOFU](github.com/locuslab/tofu).

For each architecture, we have finetuned with four different splits of the TOFU datasets: `full`, `retain90`, `retain95`, `retain99`, for a total of 16 finetuned models. The first serves as the target (base model for unlearning) and the rest are retain models used to measure performance against for each forget split. These models are on [HuggingFace](`https://huggingface.co/collections/open-unlearning/tofu-new-models-67bcf636334ea81727573a9f0`) and the paths to these models can be set in the experimental configs or in command-line overrides.

Expand Down
68 changes: 68 additions & 0 deletions community/methods/CEU/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/bin/bash

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

########################################################################################################################
########################################### Unlearn TOFU models with CE-U #############################################
########################################################################################################################

models=(
"Llama-3.2-1B-Instruct"
)
trainers_experiments=(
"CEU unlearn/tofu/default.yaml"
)
forget_retain_splits=(
"forget10 retain90"
"forget05 retain95"
"forget01 retain99"
)

per_device_train_batch_size=16
gradient_accumulation_steps=2

lrs=(1e-5)

for split in "${forget_retain_splits[@]}"; do
forget_split=$(echo $split | cut -d' ' -f1)
retain_split=$(echo $split | cut -d' ' -f2)
for model in "${models[@]}"; do
for trainer_experiment in "${trainers_experiments[@]}"; do
trainer=$(echo $trainer_experiment | cut -d' ' -f1)
experiment=$(echo $trainer_experiment | cut -d' ' -f2)
for lr in "${lrs[@]}"; do
task_name=tofu_${model}_${forget_split}_${trainer}_lr${lr}
model_path=open-unlearning/tofu_${model}_full
echo ${task_name}: Unlearning ${model_path} using ${trainer}

# Unlearn
CUDA_VISIBLE_DEVICES=0 \
python src/train.py --config-name=unlearn.yaml \
experiment=${experiment} \
trainer=${trainer} \
task_name=${task_name} \
model=${model} \
forget_split=${forget_split} \
retain_split=${retain_split} \
model.model_args.pretrained_model_name_or_path=${model_path} \
retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \
trainer.args.per_device_train_batch_size=$per_device_train_batch_size \
trainer.args.gradient_accumulation_steps=$gradient_accumulation_steps \
trainer.args.eval_strategy=no \
trainer.args.eval_on_start=False \
trainer.args.learning_rate=$lr

# Eval
CUDA_VISIBLE_DEVICES=0 python src/eval.py \
experiment=eval/tofu/default.yaml \
forget_split=${forget_split} \
model=${model} \
task_name=${task_name} \
model.model_args.pretrained_model_name_or_path=saves/unlearn/${task_name} \
paths.output_dir=saves/unlearn/${task_name}/evals \
retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json
done
done
done
done
6 changes: 6 additions & 0 deletions configs/trainer/CEU.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults:
- finetune

handler: CEU
method_args:
ignore_first_n_answer_tokens: 1
3 changes: 2 additions & 1 deletion docs/links.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Links to research papers and resources corresponding to implemented features in
| AltPO | Paper[📄](https://arxiv.org/pdf/2409.13474), Code [🐙](https://github.com/molereddy/Alternate-Preference-Optimization) |
| SatImp | Paper[📄](https://arxiv.org/pdf/2505.11953), Code [🐙](https://github.com/Puning97/SatImp-for-LLM-Unlearning) |
| WGA (G-effect) | Paper[📄](https://arxiv.org/pdf/2502.19301), Code [🐙](https://github.com/tmlr-group/G-effect) |
| CE-U (Cross-Entropy unlearning) | Paper[📄](https://arxiv.org/pdf/2503.01224) |

---

Expand All @@ -59,7 +60,7 @@ Links to research papers and resources corresponding to implemented features in
| Forget Quality, Truth Ratio, Model Utility | TOFU ([📄](https://arxiv.org/abs/2401.06121)) |
| Extraction Strength (ES) | Carlini et al., 2021 ([📄](https://www.usenix.org/conference/usenixsecurity21/presentation/carlini-extracting)), used for unlearning in Wang et al., 2025 ([📄](https://openreview.net/pdf?id=wUtCieKuQU)) |
| Exact Memorization (EM) | Tirumala et al., 2022 ([📄](https://proceedings.neurips.cc/paper_files/paper/2022/hash/fa0509f4dab6807e2cb465715bf2d249-Abstract-Conference.html)), used for unlearning in Wang et al., 2025 ([📄](https://openreview.net/pdf?id=wUtCieKuQU)) |
| lm-evaluation-harness | [💻](https://github.com/EleutherAI/lm-evaluation-harness/tree/main) |
| lm-evaluation-harness | Repository: [💻](https://github.com/EleutherAI/lm-evaluation-harness/tree/main) |

---

Expand Down
2 changes: 2 additions & 0 deletions src/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from trainer.unlearn.simnpo import SimNPO
from trainer.unlearn.rmu import RMU
from trainer.unlearn.undial import UNDIAL
from trainer.unlearn.ceu import CEU
from trainer.unlearn.satimp import SatImp
from trainer.unlearn.wga import WGA

Expand Down Expand Up @@ -93,5 +94,6 @@ def load_trainer(
_register_trainer(SimNPO)
_register_trainer(RMU)
_register_trainer(UNDIAL)
_register_trainer(CEU)
_register_trainer(SatImp)
_register_trainer(WGA)
96 changes: 96 additions & 0 deletions src/trainer/unlearn/ceu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from trainer.unlearn.base import UnlearnTrainer

import torch
import torch.nn.functional as F


def cross_entropy_unlearning_loss(
logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
"""
Implementation of Cross Entropy Unlearning Loss (CE-U).

This function creates a modified target distribution by setting the logit corresponding to the true label to negative infinity, effectively forcing the model to assign zero probability to the correct answer. The loss then minimizes the KL divergence between this target distribution and the model's output.

Args:
logits: Model output logits with shape [batch_size, sequence_length, vocabulary_size]
labels: Ground truth token indices with shape [batch_size, sequence_length]
ignore_index: Token indices to ignore in the loss calculation (typically padding)

Returns:
A scalar tensor representing the mean unlearning loss across valid positions
"""
batch_size, sequence_length, vocabulary_size = logits.shape
# Extract valid logits and labels based on ignore_index.
if ignore_index is not None:
# Shape: [batch_size, sequence_length], boolean mask
valid_mask = labels != ignore_index
# Shape: [num_valid_positions, vocabulary_size]
valid_logits = logits[valid_mask]
# Shape: [num_valid_positions]
valid_labels = labels[valid_mask]
else:
# Shape: [batch_size*sequence_length, vocabulary_size]
valid_logits = logits.view(-1, vocabulary_size)
# Shape: [batch_size*sequence_length]
valid_labels = labels.view(-1)

# Create a copy of valid_logits to generate the target distribution
# Shape: [num_valid_positions, vocabulary_size]
valid_target_logits = valid_logits.detach().clone()

# Suppress the logits corresponding to the true token by setting them to -inf.
# This ensures that the probability for the true token is effectively zero after softmax.
valid_target_logits.scatter_(
dim=-1,
index=valid_labels.unsqueeze(-1), # Shape: [num_valid_positions, 1]
value=float("-inf"),
) # Result shape: [num_valid_positions, vocabulary_size]

# Apply softmax to generate the target probability distribution
# Shape: [num_valid_positions, vocabulary_size]
valid_target_probabilities = F.softmax(valid_target_logits, dim=-1)

# Compute the cross entropy loss between input logits and target probabilities
# The loss is averaged over the valid positions and returns a scalar tensor
return F.cross_entropy(
input=valid_logits,
target=valid_target_probabilities,
)


def compute_batch_ceu(model, inputs, ignore_first_n_answer_tokens=1):
outputs = model(**inputs)
logits = outputs.logits
labels = inputs["labels"]

# Implement the trick to ignore the first n answer tokens mentioned in the footnote in the Training Settings section of arXiv:2503.01224
valid_mask = labels != -100
update_mask = (
valid_mask.cumsum(dim=-1) <= ignore_first_n_answer_tokens
) & valid_mask
labels_without_first_n_answer_tokens = labels.masked_fill(update_mask, -100)

shifted_labels = labels_without_first_n_answer_tokens[..., 1:].contiguous()
shifted_logits = logits[..., :-1, :].contiguous()
loss = cross_entropy_unlearning_loss(
shifted_logits, shifted_labels, ignore_index=-100
)
return loss, outputs


class CEU(UnlearnTrainer):
def __init__(self, ignore_first_n_answer_tokens=1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ignore_first_n_answer_tokens = ignore_first_n_answer_tokens

def compute_loss(self, model, inputs, return_outputs=False):
forget_inputs = inputs["forget"]
loss, outputs = compute_batch_ceu(
model,
forget_inputs,
ignore_first_n_answer_tokens=self.ignore_first_n_answer_tokens,
)
return (loss, outputs) if return_outputs else loss