Skip to content

Commit d66b7ba

Browse files
committed
Initial commit.
1 parent ec3c30f commit d66b7ba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+9397
-1
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,7 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
.idea/
132+
.DS_Store
133+

README.md

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,164 @@
1-
# private-transformers
1+
# private-transformers
2+
3+
This codebase facilitates fast experimentation of differentially private training
4+
of [Hugging Face transformers](https://huggingface.co/transformers/).
5+
6+
---
7+
<p align="center">
8+
<img width="950" height="450" src="./assets/fig1.png">
9+
</p>
10+
11+
## What is this? Why an extra codebase?
12+
13+
- This codebase provides a privacy engine that builds off [Opacus](https://github.com/pytorch/opacus), but works way
14+
more smoothly with [Hugging Face's transformers library](https://github.com/huggingface/transformers).
15+
- Additionally, we support the *ghost clipping* technique (see Section 4 of [this](https://arxiv.org/pdf/2110.05679.pdf)
16+
preprint on how it works) which allows privately training large transformers with considerably reduced memory cost --
17+
in many cases, almost as light as non-private training -- at a modest run-time overhead.
18+
- **With this codebase, we have fine-tuned very large pretrained models, yielding some of the best performing
19+
differentially private NLP models to date. Some of these models have performance matching strong non-private baseline
20+
approaches. We see strong empirical evidence that highly performant DP NLP models could be built on modest datasets.**
21+
22+
## Installation
23+
24+
Make sure you have python>=3.8; run the following command:
25+
26+
```bash
27+
pip install git+ssh://git@github.com/lxuechen/private-transformers.git
28+
```
29+
30+
## Usage
31+
32+
### Basic usage
33+
34+
Privately training Hugging Face transformers with our codebase simply consists of 4 steps:
35+
36+
1. Create your favourite transformer model and optimizer; attach this optimizer to a `PrivacyEngine`
37+
2. Compute a per-example loss (1-D tensor) for a mini-batch of data
38+
3. Pass the loss to `optimizer.step` or `optimizer.virtual_step` as a keyword argument
39+
4. Repeat from step 2
40+
41+
Below is a quick example:
42+
43+
```python
44+
import transformers, torch
45+
from private_transformers import PrivacyEngine
46+
import torch.nn.functional as F
47+
48+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49+
model = transformers.GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)
50+
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)
51+
privacy_engine = PrivacyEngine(
52+
model,
53+
batch_size=10,
54+
sample_size=50000,
55+
epochs=3,
56+
max_grad_norm=0.1,
57+
target_epsilon=3,
58+
)
59+
privacy_engine.attach(optimizer)
60+
61+
batch_size, seq_len = 10, 20
62+
# Inputs are batch-first format, i.e., the first dimension of tensors must be batch dimension.
63+
input_ids = torch.randint(size=[batch_size, seq_len], low=0, high=100, device=device)
64+
# Calling `.train()` is very important; otherwise underlying forward and backward hooks don't run.
65+
model.train()
66+
outputs = model(input_ids=input_ids, return_dict=True)
67+
labels = input_ids[:, 1:, ]
68+
logits = outputs.logits[:, :-1, :].permute(0, 2, 1)
69+
# `loss` is a 1-D tensor of shape (batch_size,).
70+
loss = F.cross_entropy(logits, labels, reduction="none").mean(dim=1)
71+
# This step is different from existing workflows:
72+
# Don't call `loss.backward`; leave it to `optimizer.step` to handle backward.
73+
optimizer.step(loss=loss)
74+
```
75+
76+
The biggest differences compared to Opacus are:
77+
78+
- We require the per-example loss (a 1-D tensor) be passed into `optimizer.step` (or `optimizer.virtual_step`)
79+
- The per-example loss must be passed in as a *keyword argument*.
80+
- `loss.backward()` shouldn't be called on the user end; it's called internally in `optimizer.step` (
81+
or `optimizer.virtual_step`).
82+
- Inputs should be in batch-first format; there isn't a toggle to switch between different formats in the engine.
83+
84+
### Ghost clipping: memory saving differentially private learning
85+
86+
Turning on ghost clipping requires changing only 1 line. You should notice a drastic reduction in peak GPU memory usage
87+
once this is turned on, at a potential cost of slower training speed. One might find this especially useful when
88+
constrained to only use older GPUs with small VRAMs or fitting super large models.
89+
90+
```python
91+
import transformers, torch
92+
from private_transformers import PrivacyEngine
93+
94+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
95+
model = transformers.GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)
96+
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)
97+
privacy_engine = PrivacyEngine(
98+
model,
99+
batch_size=10,
100+
sample_size=50000,
101+
epochs=3,
102+
max_grad_norm=0.1,
103+
target_epsilon=3,
104+
ghost_clipping=True, # The only change you need to make!
105+
)
106+
privacy_engine.attach(optimizer)
107+
```
108+
109+
We ran stringent numerical tests to ensure the double-backward implementation is correct. Check out files in the `tests`
110+
folder for more on this.
111+
112+
### Examples
113+
114+
Code in the `examples` folder roughly reproduces our results for the table-to-text and classification tasks. There may
115+
be some minor discrepancies, since hyperparameters there aren't exactly what's used in the paper. Nevertheless, it
116+
should be sufficient to get things started. Detailed instructions are in the readme file of each subfolder.
117+
118+
### Currently supported [Hugging Face models](https://huggingface.co/transformers/pretrained_models.html)
119+
120+
- [OpenAIGPTLMHeadModel](https://huggingface.co/transformers/_modules/transformers/models/openai/modeling_openai.html#OpenAIGPTLMHeadModel)
121+
- [OpenAIGPTDoubleHeadsModel](https://huggingface.co/transformers/_modules/transformers/models/openai/modeling_openai.html#OpenAIGPTDoubleHeadsModel)
122+
- [GPT2LMHead](https://huggingface.co/transformers/_modules/transformers/models/gpt2/modeling_gpt2.html#GPT2LMHeadModel)
123+
- [GPT2DoubleLMHead](https://huggingface.co/transformers/_modules/transformers/models/gpt2/modeling_gpt2.html#GPT2DoubleHeadsModel)
124+
- [BertForSequenceClassification](https://huggingface.co/transformers/_modules/transformers/models/bert/modeling_bert.html#BertForSequenceClassification)
125+
- [RobertaForSequenceClassification](https://huggingface.co/transformers/model_doc/roberta.html#robertaforsequenceclassification)
126+
- [AlbertForSequenceClassification](https://huggingface.co/transformers/_modules/transformers/models/albert/modeling_albert.html#AlbertForSequenceClassification)
127+
128+
Not all models in the Hugging Face library are supported. The main additional work here is to
129+
130+
1. support per-example gradients for bespoke modules (e.g., [T5LayerNorm](https://huggingface.co/transformers/_modules/transformers/modeling_t5.html)), and
131+
2. ensure `position_ids` are repeated.
132+
133+
We plan to support more models in the future if there's such a need. Feel free to open an issue if you may want to try
134+
out specific models that aren't in the current list.
135+
136+
## Acknowledgements
137+
138+
It would have been impossible to develop this codebase without cool past works and existing codebases. We roughly follow
139+
the `PrivacyEngine` design in `Opacus==0.13.0`. We directly use
140+
an [off-the-shelf package](https://github.com/microsoft/prv_accountant) for tightly tracking tradeoff functions while
141+
composing multiple private mechanisms.
142+
143+
## Disclaimer
144+
145+
- This codebase is not yet production-grade, e.g., cryptographically secure PRNGs are required for sampling noise -- our
146+
codebase currently does not use these strong PRNGs.
147+
- This codebase is born out of the need to experiment with various things for differentially private NLP in rapidly
148+
succession. I've tried my best to write clean code, though parts of this codebase may be less tidy than I had hoped
149+
given the extremely tight timeline.
150+
151+
## Citation
152+
153+
If you found this codebase useful in your research, please consider citing:
154+
155+
```
156+
@misc{li2021large,
157+
title={Large Language Models Can Be Strong Differentially Private Learners},
158+
author={Xuechen Li and Florian Tramèr and Percy Liang and Tatsunori Hashimoto},
159+
year={2021},
160+
eprint={2110.05679},
161+
archivePrefix={arXiv},
162+
primaryClass={cs.LG}
163+
}
164+
```

WIP.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
## Work-in-progress
2+
3+
While this codebase is born out of a specific research need, I plan to build off it to make it generally useful. This
4+
doc details what's planned and work in progress.
5+
6+
- [ ] Support mixed-precision for ghost clipping
7+
- [ ] Support additional HF models
8+
- [ ] BART
9+
- [ ] T5
10+
- [ ] Release code for dialog generation experiments.

assets/fig1.png

512 KB
Loading

examples/__init__.py

Whitespace-only changes.

examples/classification/README.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
## Reproducing results for sentence classification
2+
3+
### Requirements
4+
5+
In addition to requirements of the `private-transformers` package, install additional requirements by running the
6+
following from the `examples` folder of this repo:
7+
8+
```bash
9+
pip install -r classification/requirements.txt
10+
```
11+
12+
This code is tested against `transformers==4.11.3`, but should also work for slightly earlier versions.
13+
14+
### Getting the data
15+
16+
This part of the codebase is adapted from the excellent work
17+
by [[Gao et al., 2021](https://arxiv.org/pdf/2012.15723.pdf)]. We reuse their data pipeline. To obtain the data, run the
18+
following:
19+
20+
```bash
21+
cd data
22+
bash download_dataset.sh
23+
```
24+
25+
This should produce a `data/original` subfolder that contains all the data that we need.
26+
27+
### Running
28+
29+
Use the `run_wrapper.py` script in the folder. This Python script produces a text string for the command and runs it.
30+
31+
Supply at least 2 arguments:
32+
33+
- `--output_dir`: path to a folder where results will be written
34+
- `--task_name`: name of task; one of `sst-2`, `qnli`, `qqp`, `qnli`
35+
36+
For instance, run the following under the `examples/` folder:
37+
38+
```bash
39+
python -m classification.run_wrapper --output_dir <output_dir> --task_name <task_name>
40+
```
41+
42+
The script by default uses ghost clipping, and the micro batch size is tweaked so that things should run smoothly even
43+
on a Titan Xp with 12Gigs of VRAM. For SST-2, the run-time of this script on an RTX 3090 is roughly less than one and a
44+
half hours. Larger datasets take longer to train.
45+
46+
Additional arguments:
47+
48+
- `--target_epsilon`: Target privacy spending
49+
- `model_name_or_path`: The pretrained model; one of `distilbert-base-uncased`, `bert-base-uncased`
50+
, `bert-large-uncased`, `distilroberta-base`, `roberta-base`, `roberta-large`
51+
- `--few_shot_type`: Whether to use the generic prompt formatter described in Section 3.2 of our paper. `prompt` is to
52+
use, `finetune` is to not use.
53+
- `ghost_clipping`: Whether to use ghost clipping for memory saving; one of `yes`, `no`
54+
Note keeping other training hyperparameter (e.g., number of training epochs, clipping norm, learning rate) the same,
55+
things should still work
56+
- `data_dir`: Path to where data is stored; if data is obtained via the procedure described above, just stick to the
57+
defaults.
58+
59+
Training on the larger datasets for even more epochs should bring further performance gains.

examples/classification/__init__.py

Whitespace-only changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
wget https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar
2+
tar xvf datasets.tar

0 commit comments

Comments
 (0)