diff --git a/README.md b/README.md index fc423e8..135af95 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ ## πŸ“– Overview -We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 6 unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants. +We provide efficient and streamlined implementations of the TOFU, MUSE and WMDP unlearning benchmarks while supporting 7 unlearning methods, 5+ datasets, 10+ evaluation metrics, and 7+ LLM architectures. Each of these can be easily extended to incorporate more variants. We invite the LLM unlearning community to collaborate by adding new benchmarks, unlearning methods, datasets and evaluation metrics here to expand OpenUnlearning's features, gain feedback from wider usage and drive progress in the field. @@ -62,7 +62,7 @@ We provide several variants for each of the components in the unlearning pipelin | **Component** | **Available Options** | |------------------------|----------------------| | **Benchmarks** | [TOFU](https://arxiv.org/abs/2401.06121), [MUSE](https://muse-bench.github.io/), [WMDP](https://www.wmdp.ai/) | -| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU | +| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU, UNDIAL | | **Evaluation Metrics** | Verbatim Probability, Verbatim ROUGE, Knowledge QA-ROUGE, Model Utility, Forget Quality, TruthRatio, Extraction Strength, Exact Memorization, 6 MIA attacks, [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) | | **Datasets** | MUSE-News (BBC), MUSE-Books (Harry Potter), TOFU (different splits), WMDP-Bio, WMDP-Cyber | | **Model Families** | TOFU: LLaMA-3.2, LLaMA-3.1, LLaMA-2; MUSE: LLaMA-2; Additional: Phi-3.5, Phi-1.5, Gemma, Zephyr | diff --git a/community/methods/UNDIAL/README.md b/community/methods/UNDIAL/README.md new file mode 100644 index 0000000..28c309d --- /dev/null +++ b/community/methods/UNDIAL/README.md @@ -0,0 +1,24 @@ +# UNDIAL: Self-Distillation with Adjusted Logits for Robust Unlearning in Large Language Models (NAACL 2025) + +- Authors: Yijiang River Dong, Hongzhou Lin, Mikhail Belkin, RamΓ³n Huerta, Ivan VuliΔ‡ +- Link​: https://arxiv.org/pdf/2402.10052 + +# Setup +- Hyperparameters: The original paper uses Llama-2 7B with LoRA to tune the model (rank=8, alpha=16) and learning rate of 1e-4. It's suggested to search the learning rate over [1e-5, 3e-4, 1e-4], and use an effective batch size of 32 (batch_size * gradient_accumulation). The other important hyperparemeter is beta, the strength of penalty, which typically takes a number between [3,10,30]. If we change to other models, adjusting learning rate accordingly. + +- Computation Setup: All experiments are run on one A100. +- Other Details: The original paper does not use the retain set and aims to retain knowledge in all domains, not just on the retain set. So alpha is set to 0. Practionioners could search over the alpha or gamma to better retain the performance on the retain set. + +# Results +Run `run.sh` script. + +# Citation +@misc{dong2024undial, + title={UNDIAL: Self-Distillation with Adjusted Logits for Robust Unlearning in Large Language Models}, + author={Yijiang River Dong and Hongzhou Lin and Mikhail Belkin and Ramon Huerta and Ivan VuliΔ‡}, + year={2024}, + eprint={2402.10052}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2402.10052}, +} \ No newline at end of file diff --git a/community/methods/UNDIAL/run.sh b/community/methods/UNDIAL/run.sh new file mode 100644 index 0000000..06caef2 --- /dev/null +++ b/community/methods/UNDIAL/run.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()") +echo "Master Port: $MASTER_PORT" + +######################################################################################################################## +########################################### Unlearn TOFU models ######################################################## +######################################################################################################################## + +models=( + "Llama-3.2-1B-Instruct" +) +trainers_experiments=( + "UNDIAL unlearn/tofu/default.yaml" +) +forget_retain_splits=( + "forget10 retain90" + "forget05 retain95" + "forget01 retain99" +) + +per_device_train_batch_size=16 +gradient_accumulation_steps=2 + + +lrs=(1e-5 1e-4 3e-4) +alphas=(1 2 5) +betas=(3 10 30) + + +for split in "${forget_retain_splits[@]}"; do + forget_split=$(echo $split | cut -d' ' -f1) + retain_split=$(echo $split | cut -d' ' -f2) + for model in "${models[@]}"; do + for trainer_experiment in "${trainers_experiments[@]}"; do + trainer=$(echo $trainer_experiment | cut -d' ' -f1) + experiment=$(echo $trainer_experiment | cut -d' ' -f2) + for lr in "${lrs[@]}"; do + for beta in "${betas[@]}"; do + for alpha in "${alphas[@]}"; do + task_name=tofu_${model}_${forget_split}_${trainer}_lr${lr}_beta${beta}_alpha${alpha} + model_path=open-unlearning/tofu_${model}_full + echo ${task_name}: Unlearning ${model_path} using ${trainer} + + # Unlearn + CUDA_VISIBLE_DEVICES=0 \ + python src/train.py --config-name=unlearn.yaml \ + experiment=${experiment} \ + trainer=${trainer} \ + task_name=${task_name} \ + model=${model} \ + forget_split=${forget_split} \ + retain_split=${retain_split} \ + model.model_args.pretrained_model_name_or_path=${model_path} \ + retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \ + trainer.args.per_device_train_batch_size=$per_device_train_batch_size \ + trainer.args.gradient_accumulation_steps=$gradient_accumulation_steps \ + trainer.args.eval_strategy=no \ + trainer.args.eval_on_start=False \ + trainer.args.learning_rate=$lr \ + trainer.method_args.beta=$beta \ + trainer.method_args.alpha=$alpha + + # Eval + CUDA_VISIBLE_DEVICES=0 python src/eval.py \ + experiment=eval/tofu/default.yaml \ + forget_split=${forget_split} \ + model=${model} \ + task_name=${task_name} \ + model.model_args.pretrained_model_name_or_path=saves/unlearn/${task_name} \ + paths.output_dir=saves/unlearn/${task_name}/evals \ + retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json + done + done + done + done + done +done diff --git a/configs/trainer/UNDIAL.yaml b/configs/trainer/UNDIAL.yaml new file mode 100644 index 0000000..5b2b0d3 --- /dev/null +++ b/configs/trainer/UNDIAL.yaml @@ -0,0 +1,12 @@ +defaults: + - finetune + +handler: UNDIAL # corresponds to the class defined in src/trainer/unlearn/grad_diff.py +args: # HuggingFace TrainingArguments + learning_rate: 1e-4 + num_train_epochs: 10 +method_args: # Your own method-specific arguments + gamma: 1.0 + alpha: 0.0 + beta: 10.0 # the strength of penalty for memorized tokens + retain_loss_type: NLL \ No newline at end of file diff --git a/docs/links.md b/docs/links.md index c41f325..b84977c 100644 --- a/docs/links.md +++ b/docs/links.md @@ -25,6 +25,7 @@ Links to research papers and resources corresponding to implemented features in | SimNPO | Paper [πŸ“„](https://arxiv.org/abs/2410.07163), Code [πŸ™](https://github.com/OPTML-Group/Unlearn-Simple) | | IdkDPO | TOFU ([πŸ“„](https://arxiv.org/abs/2401.06121)) | | RMU | WMDP paper ([πŸ™](https://github.com/centerforaisafety/wmdp/tree/main/rmu), [🌐](https://www.wmdp.ai/)), later used in G-effect ([πŸ™](https://github.com/tmlr-group/G-effect/blob/main/dataloader.py)) | +| UNDIAL | Paper [πŸ“„](https://arxiv.org/pdf/2402.10052), Code [πŸ™](https://github.com/dong-river/LLM_unlearning/tree/main) | --- diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py index 27f44ee..a696e6c 100644 --- a/src/trainer/__init__.py +++ b/src/trainer/__init__.py @@ -10,6 +10,7 @@ from trainer.unlearn.dpo import DPO from trainer.unlearn.simnpo import SimNPO from trainer.unlearn.rmu import RMU +from trainer.unlearn.undial import UNDIAL import logging @@ -88,3 +89,4 @@ def load_trainer( _register_trainer(DPO) _register_trainer(SimNPO) _register_trainer(RMU) +_register_trainer(UNDIAL) diff --git a/src/trainer/unlearn/undial.py b/src/trainer/unlearn/undial.py new file mode 100644 index 0000000..e32147b --- /dev/null +++ b/src/trainer/unlearn/undial.py @@ -0,0 +1,27 @@ +from trainer.utils import compute_undial_loss +from trainer.unlearn.grad_diff import GradDiff + + +class UNDIAL(GradDiff): + def __init__(self, beta=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.beta = beta + if self.ref_model is None: + self.ref_model = self._prepare_ref_model(self.model) + + def compute_loss(self, model, inputs, return_outputs=False): + forget_inputs = inputs["forget"] + forget_loss, forget_outputs = compute_undial_loss( + model, self.ref_model, forget_inputs, self.beta + ) + + retain_inputs = inputs["retain"] + retain_inputs = { + "input_ids": retain_inputs["input_ids"], + "attention_mask": retain_inputs["attention_mask"], + "labels": retain_inputs["labels"], + } + retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs) + + loss = self.gamma * forget_loss + self.alpha * retain_loss + return (loss, forget_outputs) if return_outputs else loss diff --git a/src/trainer/utils.py b/src/trainer/utils.py index dfb6876..e1e8e86 100644 --- a/src/trainer/utils.py +++ b/src/trainer/utils.py @@ -66,3 +66,35 @@ def compute_dpo_loss(model, ref_model, win_inputs=None, lose_inputs=None, beta=1 loss = -2 / beta * F.logsigmoid(beta * (win_log_ratio - lose_log_ratio)).mean() return loss, (win_outputs, lose_outputs) + + +def compute_undial_loss(model, ref_model, inputs, beta): + # Forward pass on the student (trainable) model + outputs = model(**inputs) + logits = outputs.logits + labels = inputs["labels"] + + shift_labels = labels[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + + # Forward pass on the teacher model (no grad) + with torch.no_grad(): + teacher_logits = ref_model(**inputs).logits + shift_teacher_logits = teacher_logits[..., :-1, :].contiguous() + + # Build the mask that identifies the tokens need to be unlearned + mask = torch.zeros_like(shift_teacher_logits) + batch_idx = torch.arange(mask.shape[0]).view(-1, 1, 1) + seq_idx = torch.arange(mask.shape[1]).view(1, -1, 1) + mask[batch_idx, seq_idx, shift_labels.unsqueeze(-1)] = 1.0 + + # Adjust teacher logits: subtract di_strength on the correct token + pre_softmax = shift_teacher_logits - mask * beta + soft_label = F.softmax(pre_softmax, dim=-1) + + loss_fct = nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + soft_label.view(-1, soft_label.size(-1)), + ) + return loss.mean(), outputs