Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_datasets(dataset_cfgs: Union[Dict, DictConfig], **kwargs):
return dataset


def get_data(data_cfg: DictConfig, mode="train", **kwargs):
def get_data(data_cfg: DictConfig, mode="train", seed=0, **kwargs):
data = {}
data_cfg = dict(data_cfg)
anchor = data_cfg.pop("anchor", "forget")
Expand All @@ -56,7 +56,9 @@ def get_data(data_cfg: DictConfig, mode="train", **kwargs):
return data
elif mode == "unlearn":
unlearn_splits = {k: v for k, v in data.items() if k not in ("eval", "test")}
unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor)
unlearn_dataset = ForgetRetainDataset(
**unlearn_splits, anchor=anchor, seed=seed
)
data["train"] = unlearn_dataset
for split in unlearn_splits:
data.pop(split)
Expand Down
16 changes: 13 additions & 3 deletions src/data/unlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

class ForgetRetainDataset(Dataset):
# https://github.com/OPTML-Group/SOUL/blob/main/src/dataset/Base.py
def __init__(self, forget, retain, anchor="forget"):
def __init__(self, forget, retain, anchor="forget", seed=0):
"""Wraps the forget retain dataset into unlearning dataset.

Args:
forget (Dataset): Forget Dataset
retain (Dataset): Retain Dataset
anchor (str, optional): Specifies which dataset to anchor while randomly sampling from the other dataset. Defaults to 'forget'.
seed (int, optional): Random seed for reproducibility. Defaults to 0.
"""
self.forget = forget
self.retain = retain
self.anchor = anchor
self.seed = seed

def __len__(self):
"""Ensures the sampled dataset matches the anchor dataset's length."""
Expand All @@ -33,14 +35,22 @@ def __len__(self):

def __getitem__(self, idx):
item = {}
g = torch.Generator()
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
rank_seed = self.seed + rank + idx
g.manual_seed(rank_seed)
if self.anchor == "forget":
item["forget"] = self.forget[idx]
if self.retain:
retain_idx = torch.randint(0, len(self.retain), (1,)).item()
retain_idx = torch.randint(
0, len(self.retain), (1,), generator=g
).item()
item["retain"] = self.retain[retain_idx]
elif self.anchor == "retain":
item["retain"] = self.retain[idx]
if self.forget:
forget_idx = torch.randint(0, len(self.forget), (1,)).item()
forget_idx = torch.randint(
0, len(self.forget), (1,), generator=g
).item()
item["forget"] = self.forget[forget_idx]
return item
6 changes: 5 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ def main(cfg: DictConfig):
# Load Dataset
data_cfg = cfg.data
data = get_data(
data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args
data_cfg,
mode=mode,
tokenizer=tokenizer,
template_args=template_args,
seed=cfg.trainer.args.seed,
)

# Load collator
Expand Down