Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
54d3560
Fix hyperlinks in README (#2)
molereddy Mar 1, 2025
4c36e4f
Fixed DPO command
Dornavineeth Mar 2, 2025
f7a69de
download idk
Dornavineeth Mar 2, 2025
1bd4411
Merge pull request #3 from Dornavineeth/dpo_fix
Dornavineeth Mar 2, 2025
332af36
Revert "Dpo fix"
Dornavineeth Mar 2, 2025
7d6aef3
Merge pull request #4 from Dornavineeth/revert-3-dpo_fix
Dornavineeth Mar 2, 2025
f468efb
download idk data
Dornavineeth Mar 2, 2025
ca8d503
fix dpo experiment config
Dornavineeth Mar 2, 2025
dde60d3
Merge pull request #5 from Dornavineeth/dpo_fix
Dornavineeth Mar 2, 2025
6367fb6
Merge branch 'locuslab:main' into main
molereddy Mar 9, 2025
8b073d6
RMU (#6)
Dornavineeth Mar 9, 2025
855c5f3
Merge branch 'locuslab:main' into main
molereddy Mar 11, 2025
4fb577c
Merge branch 'locuslab:main' into main
Dornavineeth Mar 23, 2025
dccb831
Add structure to contributions, setup leaderboard, update documentati…
Dornavineeth Mar 27, 2025
e3c4709
Merge branch 'locuslab:main' into main
molereddy Mar 27, 2025
5a7dfb4
UNDIAL
dong-river Apr 2, 2025
c4c8000
UNDIAL2
dong-river Apr 2, 2025
4aec929
UNDIAL3
dong-river Apr 2, 2025
c2df505
Merge branch 'main' of https://github.com/Dornavineeth/open-unlearning
molereddy Apr 9, 2025
be84613
Merge remote-tracking branch 'origin/main' into UNDIAL
molereddy Apr 13, 2025
a08a26b
Ruff quality formatting changes
molereddy Apr 13, 2025
37fb738
Merge remote-tracking branch 'upstream/main' into UNDIAL
Dornavineeth May 4, 2025
2d7c66c
fix config
Dornavineeth May 11, 2025
716d7e9
fix docs and script
Dornavineeth May 22, 2025
bb48240
Merge remote-tracking branch 'upstream/main' into UNDIAL
Dornavineeth May 22, 2025
0687b35
Update readme
Dornavineeth May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

## 📖 Overview

We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 6 unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants.
We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 7 unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants.

We invite the LLM unlearning community to collaborate by adding new benchmarks, unlearning methods, datasets and evaluation metrics here to expand OpenUnlearning's features, gain feedback from wider usage and drive progress in the field.

Expand Down Expand Up @@ -62,7 +62,7 @@ We provide several variants for each of the components in the unlearning pipelin
| **Component** | **Available Options** |
|------------------------|----------------------|
| **Benchmarks** | [TOFU](https://arxiv.org/abs/2401.06121), [MUSE](https://muse-bench.github.io/), [WMDP](https://www.wmdp.ai/) |
| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU |
| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU, UNDIAL |
| **Evaluation Metrics** | Verbatim Probability, Verbatim ROUGE, Knowledge QA-ROUGE, Model Utility, Forget Quality, TruthRatio, Extraction Strength, Exact Memorization, 6 MIA attacks, [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) |
| **Datasets** | MUSE-News (BBC), MUSE-Books (Harry Potter), TOFU (different splits), WMDP-Bio, WMDP-Cyber |
| **Model Families** | TOFU: LLaMA-3.2, LLaMA-3.1, LLaMA-2; MUSE: LLaMA-2; Additional: Phi-3.5, Phi-1.5, Gemma, Zephyr |
Expand Down
24 changes: 24 additions & 0 deletions community/methods/UNDIAL/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# UNDIAL: Self-Distillation with Adjusted Logits for Robust Unlearning in Large Language Models (NAACL 2025)

- Authors: Yijiang River Dong, Hongzhou Lin, Mikhail Belkin, Ramón Huerta, Ivan Vulić
- Link​: https://arxiv.org/pdf/2402.10052

# Setup
- Hyperparameters: The original paper uses Llama-2 7B with LoRA to tune the model (rank=8, alpha=16) and learning rate of 1e-4. It's suggested to search the learning rate over [1e-5, 3e-4, 1e-4], and use an effective batch size of 32 (batch_size * gradient_accumulation). The other important hyperparemeter is beta, the strength of penalty, which typically takes a number between [3,10,30]. If we change to other models, adjusting learning rate accordingly.

- Computation Setup: All experiments are run on one A100.
- Other Details: The original paper does not use the retain set and aims to retain knowledge in all domains, not just on the retain set. So alpha is set to 0. Practionioners could search over the alpha or gamma to better retain the performance on the retain set.

# Results
Run `run.sh` script.

# Citation
@misc{dong2024undial,
title={UNDIAL: Self-Distillation with Adjusted Logits for Robust Unlearning in Large Language Models},
author={Yijiang River Dong and Hongzhou Lin and Mikhail Belkin and Ramon Huerta and Ivan Vulić},
year={2024},
eprint={2402.10052},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2402.10052},
}
78 changes: 78 additions & 0 deletions community/methods/UNDIAL/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/bin/bash

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

########################################################################################################################
########################################### Unlearn TOFU models ########################################################
########################################################################################################################

models=(
"Llama-3.2-1B-Instruct"
)
trainers_experiments=(
"UNDIAL unlearn/tofu/default.yaml"
)
forget_retain_splits=(
"forget10 retain90"
"forget05 retain95"
"forget01 retain99"
)

per_device_train_batch_size=16
gradient_accumulation_steps=2


lrs=(1e-5 1e-4 3e-4)
alphas=(1 2 5)
betas=(3 10 30)


for split in "${forget_retain_splits[@]}"; do
forget_split=$(echo $split | cut -d' ' -f1)
retain_split=$(echo $split | cut -d' ' -f2)
for model in "${models[@]}"; do
for trainer_experiment in "${trainers_experiments[@]}"; do
trainer=$(echo $trainer_experiment | cut -d' ' -f1)
experiment=$(echo $trainer_experiment | cut -d' ' -f2)
for lr in "${lrs[@]}"; do
for beta in "${betas[@]}"; do
for alpha in "${alphas[@]}"; do
task_name=tofu_${model}_${forget_split}_${trainer}_lr${lr}_beta${beta}_alpha${alpha}
model_path=open-unlearning/tofu_${model}_full
echo ${task_name}: Unlearning ${model_path} using ${trainer}

# Unlearn
CUDA_VISIBLE_DEVICES=0 \
python src/train.py --config-name=unlearn.yaml \
experiment=${experiment} \
trainer=${trainer} \
task_name=${task_name} \
model=${model} \
forget_split=${forget_split} \
retain_split=${retain_split} \
model.model_args.pretrained_model_name_or_path=${model_path} \
retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \
trainer.args.per_device_train_batch_size=$per_device_train_batch_size \
trainer.args.gradient_accumulation_steps=$gradient_accumulation_steps \
trainer.args.eval_strategy=no \
trainer.args.eval_on_start=False \
trainer.args.learning_rate=$lr \
trainer.method_args.beta=$beta \
trainer.method_args.alpha=$alpha

# Eval
CUDA_VISIBLE_DEVICES=0 python src/eval.py \
experiment=eval/tofu/default.yaml \
forget_split=${forget_split} \
model=${model} \
task_name=${task_name} \
model.model_args.pretrained_model_name_or_path=saves/unlearn/${task_name} \
paths.output_dir=saves/unlearn/${task_name}/evals \
retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json
done
done
done
done
done
done
12 changes: 12 additions & 0 deletions configs/trainer/UNDIAL.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defaults:
- finetune

handler: UNDIAL # corresponds to the class defined in src/trainer/unlearn/grad_diff.py
args: # HuggingFace TrainingArguments
learning_rate: 1e-4
num_train_epochs: 10
method_args: # Your own method-specific arguments
gamma: 1.0
alpha: 0.0
beta: 10.0 # the strength of penalty for memorized tokens
retain_loss_type: NLL
1 change: 1 addition & 0 deletions docs/links.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Links to research papers and resources corresponding to implemented features in
| SimNPO | Paper [📄](https://arxiv.org/abs/2410.07163), Code [🐙](https://github.com/OPTML-Group/Unlearn-Simple) |
| IdkDPO | TOFU ([📄](https://arxiv.org/abs/2401.06121)) |
| RMU | WMDP paper ([🐙](https://github.com/centerforaisafety/wmdp/tree/main/rmu), [🌐](https://www.wmdp.ai/)), later used in G-effect ([🐙](https://github.com/tmlr-group/G-effect/blob/main/dataloader.py)) |
| UNDIAL | Paper [📄](https://arxiv.org/pdf/2402.10052), Code [🐙](https://github.com/dong-river/LLM_unlearning/tree/main) |

---

Expand Down
2 changes: 2 additions & 0 deletions src/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from trainer.unlearn.dpo import DPO
from trainer.unlearn.simnpo import SimNPO
from trainer.unlearn.rmu import RMU
from trainer.unlearn.undial import UNDIAL

import logging

Expand Down Expand Up @@ -88,3 +89,4 @@ def load_trainer(
_register_trainer(DPO)
_register_trainer(SimNPO)
_register_trainer(RMU)
_register_trainer(UNDIAL)
27 changes: 27 additions & 0 deletions src/trainer/unlearn/undial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from trainer.utils import compute_undial_loss
from trainer.unlearn.grad_diff import GradDiff


class UNDIAL(GradDiff):
def __init__(self, beta=1.0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.beta = beta
if self.ref_model is None:
self.ref_model = self._prepare_ref_model(self.model)

def compute_loss(self, model, inputs, return_outputs=False):
forget_inputs = inputs["forget"]
forget_loss, forget_outputs = compute_undial_loss(
model, self.ref_model, forget_inputs, self.beta
)

retain_inputs = inputs["retain"]
retain_inputs = {
"input_ids": retain_inputs["input_ids"],
"attention_mask": retain_inputs["attention_mask"],
"labels": retain_inputs["labels"],
}
retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)

loss = self.gamma * forget_loss + self.alpha * retain_loss
return (loss, forget_outputs) if return_outputs else loss
32 changes: 32 additions & 0 deletions src/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,35 @@ def compute_dpo_loss(model, ref_model, win_inputs=None, lose_inputs=None, beta=1

loss = -2 / beta * F.logsigmoid(beta * (win_log_ratio - lose_log_ratio)).mean()
return loss, (win_outputs, lose_outputs)


def compute_undial_loss(model, ref_model, inputs, beta):
# Forward pass on the student (trainable) model
outputs = model(**inputs)
logits = outputs.logits
labels = inputs["labels"]

shift_labels = labels[..., 1:].contiguous()
shift_logits = logits[..., :-1, :].contiguous()

# Forward pass on the teacher model (no grad)
with torch.no_grad():
teacher_logits = ref_model(**inputs).logits
shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()

# Build the mask that identifies the tokens need to be unlearned
mask = torch.zeros_like(shift_teacher_logits)
batch_idx = torch.arange(mask.shape[0]).view(-1, 1, 1)
seq_idx = torch.arange(mask.shape[1]).view(1, -1, 1)
mask[batch_idx, seq_idx, shift_labels.unsqueeze(-1)] = 1.0

# Adjust teacher logits: subtract di_strength on the correct token
pre_softmax = shift_teacher_logits - mask * beta
soft_label = F.softmax(pre_softmax, dim=-1)

loss_fct = nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
soft_label.view(-1, soft_label.size(-1)),
)
return loss.mean(), outputs