Skip to content
Closed

RMU #67

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ We provide several variants for each of the components in the unlearning pipelin

## 📌 Table of Contents
- 📖 [Overview](#-overview)
- 🗃️ [Available Components](#-available-components)
- 🗃️ [Available Components](#%EF%B8%8F-available-components)
-[Quickstart](#-quickstart)
- 🛠️ [Environment Setup](#-environment-setup)
- 💾 [Data Setup](#-data-setup)
Expand All @@ -56,7 +56,7 @@ We provide several variants for each of the components in the unlearning pipelin
-[How to Add New Components](#-how-to-add-new-components)
- 📚 [Further Documentation](#-further-documentation)
- 🔗 [Support & Contributors](#-support--contributors)
- 📝 [Citing this work](#-citating-this-work)
- 📝 [Citing this work](#-citing-this-work)
- 🤝 [Acknowledgements](#-acknowledgements)
- 📄 [License](#-license)

Expand Down Expand Up @@ -198,7 +198,7 @@ If you use OpenUnlearning in your research, please cite:

---

### 🤝 Acknowledgments
### 🤝 Acknowledgements

- This repo is inspired from [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory).
- The [TOFU](https://github.com/locuslab/tofu) and [MUSE](https://github.com/jaechan-repo/muse_bench) benchmarks served as the foundation for our re-implementation.
Expand Down
1 change: 1 addition & 0 deletions configs/experiment/unlearn/muse/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ eval:
muse:
data_split: ${data_split}
retain_logs_path: ${retain_logs_path}
overwrite: true

trainer:
args:
Expand Down
1 change: 1 addition & 0 deletions configs/experiment/unlearn/muse/scalability.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ eval:
muse:
data_split: ${data_split}
retain_logs_path: ${retain_logs_path}
overwrite: true

trainer:
args:
Expand Down
1 change: 1 addition & 0 deletions configs/experiment/unlearn/muse/sustainabilty.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ eval:
muse:
data_split: ${data_split}
retain_logs_path: ${retain_logs_path}
overwrite: true

trainer:
args:
Expand Down
1 change: 1 addition & 0 deletions configs/experiment/unlearn/tofu/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ eval:
tofu:
forget_split: ${forget_split}
retain_logs_path: ${retain_logs_path}
overwrite: true

data:
anchor: forget
Expand Down
1 change: 1 addition & 0 deletions configs/experiment/unlearn/tofu/idk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ eval:
tofu:
forget_split: ${forget_split}
retain_logs_path: ${retain_logs_path}
overwrite: true

data:
anchor: forget
Expand Down
13 changes: 13 additions & 0 deletions configs/trainer/RMU.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
defaults:
- GradDiff

handler: RMU
method_args:
# The params here are more dependent on model and dataset. Tune them carefully to work
gamma: 1.0
alpha: 1000
steering_coeff: 300
retain_loss_type: null
module_regex: model\.layers\.7
trainable_params_regex:
- model\.layers\.(5|6|7)\.mlp\.down_proj\.weight
36 changes: 35 additions & 1 deletion docs/results.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ For all the experiments below, we used the following setup
| **Hyperparameters** | Learning Rate (lr) = 1e-5 <br> α = 1, γ = 1, β = 0.1 (where applicable) <br> Number of Epochs = 10 <br> Optimizer: [paged_adamw_32bit](https://huggingface.co/docs/bitsandbytes/main/en/reference/optim/adamw#bitsandbytes.optim.PagedAdamW) |

__Note:__
1. Results may vary even with the same effective hyperparameters when trained with modifications to the distributed training setup, including when training on a single GPU. For example: methods such as SimNPO, can be significantly improved with careful tuning. **Please use these numbers only for reproducibility purposes**.
1. Results may vary even with the same effective hyperparameters when trained with modifications to the distributed training setup, including when training on a single GPU. For example: methods such as SimNPO & RMU can be significantly improved with careful tuning. **Please use these numbers only for reproducibility purposes**.
2. NPO in MUSE: for NPO, the MUSE implementation is inconsistent with the [original paper](https://github.com/licong-lin/negative-preference-optimization) as discussed [here]( https://github.com/jaechan-repo/muse_bench/issues/2). This inconsistency is carried over into implementations like [SimNPO](https://github.com/OPTML-Group/Unlearn-Simple/issues/5). Here, we use the original NPO implementation with the same loss function expression across datasets.


Expand Down Expand Up @@ -140,6 +140,17 @@ __Note:__
<td>0.6</td>
<td>3.17e-04</td>
</tr>
<tr>
<th>RMU</th>
<td>6.76e-03</td>
<td>7.18e-04</td>
<td>0.84</td>
<td>1.21e-10</td>
<td>0</td>
<td>0.81</td>
<td>1.18e-17</td>
<td>0</td>
<td>0.8</td>
</tbody>
</table>
</div>
Expand Down Expand Up @@ -257,6 +268,18 @@ __Note:__
<td>0.54</td>
<td>1.07e-05</td>
</tr>
<tr>
<th>RMU</th>
<td>6.76e-03</td>
<td>0.60</td>
<td>0.47</td>
<td>2.89e-11</td>
<td>0.6</td>
<td>0.47</td>
<td>0.32</td>
<td>0.59</td>
<td>0.64</td>
</tr>
</tbody>
</table>
</div>
Expand Down Expand Up @@ -354,6 +377,17 @@ __Note:__
<td>-54.26</td>
<td>0.54</td>
</tr>
<tr>
<th>RMU</th>
<td>0.67</td>
<td>0.57</td>
<td>-99.81</td>
<td>0.56</td>
<td>0.47</td>
<td>1.0</td>
<td>-57.35</td>
<td>0.67</td>
</tr>
</tbody>
</table>
</div>
2 changes: 1 addition & 1 deletion scripts/tofu_unlearn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainers_experiments=(
"GradAscent unlearn/tofu/default.yaml"
"GradDiff unlearn/tofu/default.yaml"
"NPO unlearn/tofu/default.yaml"
"DPO unlearn/tofu/default.yaml"
"DPO unlearn/tofu/idk.yaml"
)
forget_retain_splits=(
"forget01 retain99"
Expand Down
9 changes: 9 additions & 0 deletions setup_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from huggingface_hub import snapshot_download

# Setup retain model metrics
snapshot_download(
repo_id="open-unlearning/eval",
allow_patterns="*.json",
repo_type="dataset",
local_dir="saves/eval",
)

# Setup data
snapshot_download(
repo_id="open-unlearning/idk",
allow_patterns="*.jsonl",
repo_type="dataset",
local_dir="data",
)
2 changes: 2 additions & 0 deletions src/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from trainer.unlearn.npo import NPO
from trainer.unlearn.dpo import DPO
from trainer.unlearn.simnpo import SimNPO
from trainer.unlearn.rmu import RMU

TRAINER_REGISTRY: Dict[str, Any] = {}

Expand Down Expand Up @@ -79,3 +80,4 @@ def load_trainer(
_register_trainer(NPO)
_register_trainer(DPO)
_register_trainer(SimNPO)
_register_trainer(RMU)
2 changes: 1 addition & 1 deletion src/trainer/unlearn/grad_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, gamma=1.0, alpha=1.0, retain_loss_type="NLL", *args, **kwargs
self.ref_model = self._prepare_ref_model(self.model)

def _prepare_ref_model(self, model):
ref_model = copy.deepcopy(model).to("cuda")
ref_model = copy.deepcopy(model).to(self.accelerator.device)
ref_model.eval()
if self.is_deepspeed_enabled:
ref_model = self._prepare_deepspeed(ref_model)
Expand Down
138 changes: 138 additions & 0 deletions src/trainer/unlearn/rmu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import re
import torch
import deepspeed
from torch import nn
from trainer.unlearn.grad_diff import GradDiff


class RMU(GradDiff):
def __init__(self,
module_regex="model\.layers\.7",
trainable_params_regex=["model\.layers\.(5|6|7)\.mlp\.down_proj\.weight"],
steering_coeff=20,
*args, **kwargs):
"""
RMU Trainer that fine-tunes only specific layers and parameters using regex-based filtering.

Args:
module_path (str): Regex pattern to match module names.
trainable_param_paths (list of str): List of regex patterns for trainable parameters.
"""
super().__init__(*args, **kwargs)

# Create reference model if not already set
if self.ref_model is None:
self.ref_model = self._prepare_ref_model(self.model)

# Unfreeze only the selected parameters
self.trainable_params_regex = trainable_params_regex # Regex for selecting params

# Get actual module references
self.module_regex = module_regex # Regex for selecting modules
self.model_module = self._get_matching_module(self.model, self.module_regex)
self.ref_module = self._get_matching_module(self.ref_model, self.module_regex)
self.steering_coeff = steering_coeff
self.control_vec = None


def create_optimizer(self):
self._set_all_params(self.model, False)
# This makes the optimizer to select only trainable params
self._set_trainable_params(self.model, self.trainable_params_regex, True)
super().create_optimizer()
self._set_all_params(self.model, True)


def _get_matching_module(self, model, module_regex):
"""Returns a single module matching the given regex from a DeepSpeed/DDP-wrapped model."""
# Handle DeepSpeed and DDP-wrapped models by accessing the underlying module
if isinstance(model, deepspeed.DeepSpeedEngine):
model = model.module # Extract the actual PyTorch model inside

matched_modules = {name: module for name, module in model.named_modules() if re.fullmatch(module_regex, name)}

if len(matched_modules) > 1:
raise ValueError(f"More than one module matched with {module_regex}: {list(matched_modules.keys())}")
elif not matched_modules:
raise ValueError(f"No module matched with {module_regex}")

return next(iter(matched_modules.values())) # Return the single matched module

def _set_all_params(self, model, requires_grad=True):
"""Freeze all parameters in the model initially."""
for param in model.parameters():
param.requires_grad = requires_grad

def _set_trainable_params(self, model, trainable_params_regex, requires_grad=True):
"""Unfreeze specific parameters that match the regex patterns."""
for name, param in model.named_parameters():
if any(re.fullmatch(pattern, name) for pattern in trainable_params_regex):
param.requires_grad = requires_grad
print(f"{name}:requires_grad\t{requires_grad}")

def forward_with_cache(self, model, inputs, module, no_grad=True):
cache = []
def hook(module, input, output):
if isinstance(output, tuple):
cache.append(output[0])
else:
cache.append(output)
return None

hook_handle = module.register_forward_hook(hook)
if no_grad:
with torch.no_grad():
_ = model(**inputs)
else:
_ = model(**inputs)
hook_handle.remove()
return cache[0]

def get_control_vector(self, dim):
if self.control_vec is None:
random_vector = torch.rand(1,1, dim)
self.control_vec = random_vector / torch.norm(random_vector) * self.steering_coeff
return self.control_vec


def compute_activation_loss(self, activation1, avtivation2, mask):
squared_diff = torch.nn.functional.mse_loss(activation1, avtivation2, reduction="none") # Shape (b, s, d)
expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff) # Shape: [b, s, d]
# squared_diff_sum = (squared_diff * expanded_mask).sum(dim=(1, 2)) # Sum over seq_len and feature dim
squared_diff_sum = (squared_diff * expanded_mask).mean(dim=2).sum(dim=(1)) # Sum over seq_len and feature dim
num_tokens = mask.sum(dim=-1, keepdim=True) # Sum over seq_len, Shape: [b, 1]
return (squared_diff_sum / num_tokens).mean()

def compute_retain_loss(self, model, retain_inputs):
model_retain_activations = self.forward_with_cache(model, retain_inputs, module=self.model_module, no_grad=False)
ref_retain_activations = self.forward_with_cache(self.ref_model, retain_inputs, module=self.ref_module, no_grad=True).to(model_retain_activations.device)
mask = (retain_inputs['labels'] != -100) # Shape: [b, s]
retain_loss = self.compute_activation_loss(model_retain_activations, ref_retain_activations, mask)
return retain_loss

def compute_loss(self, model, inputs, return_outputs=False):
forget_inputs = inputs["forget"]
forget_inputs = {
"input_ids": forget_inputs["input_ids"],
"attention_mask": forget_inputs["attention_mask"],
"labels": forget_inputs["labels"],
}

model_forget_activations = self.forward_with_cache(model, forget_inputs, self.model_module, no_grad=False)
control_vec = forget_inputs.get("control_vec", self.get_control_vector(model_forget_activations.shape[-1]))
control_vec = control_vec.to(dtype=model_forget_activations.dtype, device=model_forget_activations.device)
control_vec = control_vec.expand_as(model_forget_activations)
mask = (forget_inputs['labels'] != -100) # Shape: [b, s]
forget_loss = self.compute_activation_loss(model_forget_activations, control_vec, mask)

retain_inputs = inputs["retain"]
retain_inputs = {
"input_ids": retain_inputs["input_ids"],
"attention_mask": retain_inputs["attention_mask"],
"labels": retain_inputs["labels"],
}
retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)

loss = self.gamma * forget_loss + self.alpha * retain_loss

return (loss, model_forget_activations) if return_outputs else loss
Loading