Skip to content

Commit fe19cf5

Browse files
Vincent Moensapbardtcbegley
authored
[Algorithm] RLHF end-to-end, clean (#1597)
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com> Co-authored-by: Tom Begley <tomcbegley@gmail.com>
1 parent f09b0c8 commit fe19cf5

File tree

26 files changed

+1402
-38
lines changed

26 files changed

+1402
-38
lines changed

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,10 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/sac
282282
train.minibatch_size=100 \
283283
logger.backend=
284284

285-
286285
python .github/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100
287286

287+
## RLHF
288+
# RLHF tests are executed in the dedicated workflow
289+
288290
coverage combine
289291
coverage xml -i

.github/unittest/linux_libs/scripts_rlhf/run_test.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,14 @@ conda deactivate && conda activate ./env
2222
python -c "import transformers, datasets"
2323

2424
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips
25+
26+
python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \
27+
sys.device=cuda:0 sys.ref_device=cuda:0 \
28+
model.name_or_path=gpt2 train.max_epochs=2 \
29+
data.batch_size=2 train.ppo.ppo_batch_size=2 \
30+
train.ppo.ppo_num_epochs=1 reward_model.name_or_path= \
31+
train.ppo.episode_length=8 train.ppo.num_rollouts_per_epoch=4 \
32+
data.block_size=110 io.logger=csv
33+
2534
coverage combine
2635
coverage xml -i

examples/rlhf/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.png
2+
*.bin
3+
*.pt
4+
*.json

examples/rlhf/README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# RLHF example
2+
3+
This example uses RLHF (Reinforcement Learning with Human Feedback) to train a
4+
language model to summarize Reddit posts.
5+
6+
## Getting started
7+
8+
Make sure you have PyTorch>=2.0 installed. You can find installation instructions
9+
[here](https://pytorch.org/get-started/locally/).
10+
11+
From this directory, you can install extra requirements for running these
12+
examples with
13+
14+
```sh
15+
pip install -r requirements.txt
16+
```
17+
18+
## Training the models
19+
### Training the transformer
20+
21+
Once the data has been prepared, you can train the GPT model.
22+
23+
```sh
24+
python train.py
25+
```
26+
27+
Default configuration can be found in `config/train.yaml`, and any option can
28+
be overridden with command-line arguments, for example to run the training
29+
script with a different batch size:
30+
31+
```sh
32+
python train.py --batch_size=128
33+
```
34+
> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps`
35+
> and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback
36+
37+
### Training the reward model
38+
39+
Once you have completed supervised fine-tuning, copy the desired model
40+
checkpoint to `./out` or update the config to point `model.name_or_path` at
41+
the relevant checkpoint in the timestamped working directory created by Hydra.
42+
You can then train the reward model with:
43+
44+
```sh
45+
python train_reward.py
46+
```
47+
48+
### Training the final model with RLHF
49+
50+
Once again, make sure you have either updated the configuration to point
51+
`reward_model.name_or_path` at the relevant timestamped working directory, or
52+
copy the checkpoint to `./out_reward`.
53+
You can then train the final model by running
54+
55+
```sh
56+
python train_rlhf.py
57+
```

examples/rlhf/config/train.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
io:
2+
eval_interval: 200
3+
log_interval: 50
4+
eval_iters: 100
5+
data:
6+
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size
7+
block_size: 550
8+
model:
9+
name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint
10+
out_dir: ./out
11+
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
12+
train:
13+
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0
14+
max_iters: 5000 # total number of training iterations
15+
gradient_accumulation_steps: 2 # used to simulate larger batch sizes
16+
always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir
17+
decay_lr: True # whether to decay the learning rate
18+
optimizer:
19+
# keyword arguments for torch.optim.AdamW
20+
lr: 1.0e-5
21+
weight_decay: 1.0e-1
22+
betas: [0.9, 0.95]
23+
scheduler:
24+
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
25+
T_max: 5000 # maximum number of iterations
26+
eta_min: 1.0e-6 # minimum learning rate
27+
sys:
28+
device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks
29+
dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler
30+
compile: True # use PyTorch 2.0 to compile the model to be faster
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
io:
2+
eval_interval: 200
3+
log_interval: 50
4+
eval_iters: 100
5+
data:
6+
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size
7+
block_size: 550
8+
model:
9+
name_or_path: ./out
10+
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
11+
reward_model:
12+
out_dir: ./out_reward
13+
init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward
14+
train:
15+
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0
16+
max_iters: 20000 # total number of training iterations
17+
gradient_accumulation_steps: 2 # used to simulate larger batch sizes
18+
always_save_checkpoint: False # if True, always save a checkpoint after each eval
19+
decay_lr: False # whether to decay the learning rate
20+
optimizer:
21+
# keyword arguments for torch.optim.AdamW
22+
lr: 1.0e-5
23+
weight_decay: 1.0e-1
24+
betas: [0.9, 0.95]
25+
scheduler:
26+
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
27+
T_max: 20000
28+
eta_min: 1.0e-6
29+
sys:
30+
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
31+
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
32+
compile: True # use PyTorch 2.0 to compile the model to be faster

examples/rlhf/config/train_rlhf.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
io:
2+
eval_interval: 6
3+
log_interval: 1
4+
eval_iters: 10
5+
logger: wandb
6+
data:
7+
batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
8+
block_size: 550
9+
num_workers: 1
10+
model:
11+
name_or_path: ./out
12+
out_dir: ./out_rlhf
13+
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
14+
reward_model:
15+
name_or_path: ./out_reward
16+
train:
17+
grad_clip: 1.0
18+
max_epochs: 1000 # total number of training iterations
19+
always_save_checkpoint: True # if True, always save a checkpoint after each eval
20+
decay_lr: True
21+
optimizer:
22+
# keyword arguments for torch.optim.AdamW
23+
lr: 5.0e-5
24+
weight_decay: 0.0 # 01
25+
betas: [0.9, 0.999]
26+
scheduler:
27+
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
28+
T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size
29+
eta_min: 5.0e-6
30+
ppo:
31+
episode_length: 50
32+
ppo_batch_size: 16
33+
ppo_num_epochs: 3
34+
num_rollouts_per_epoch: 32
35+
sys:
36+
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
37+
ref_device: cuda:1 # device of reference model
38+
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
39+
compile: False # use PyTorch 2.0 to compile the model to be faster

examples/rlhf/data/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr
2+
3+
__all__ = ["get_prompt_dataloader_tldr"]

examples/rlhf/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.

examples/rlhf/models/actor_critic.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator
6+
from torchrl.modules.tensordict_module.common import VmapModule
7+
8+
from .transformer import init_transformer
9+
10+
__all__ = ["init_actor_critic"]
11+
12+
13+
def init_actor_critic(model_cfg, sys_cfg):
14+
15+
transformer_name_or_path = model_cfg.name_or_path
16+
dropout = model_cfg.dropout
17+
18+
device = sys_cfg.device
19+
compile_model = sys_cfg.compile
20+
base_model = init_transformer(
21+
transformer_name_or_path,
22+
dropout,
23+
device,
24+
as_tensordictmodule=False,
25+
compile_model=compile_model,
26+
inference=True,
27+
)
28+
model = LMHeadActorValueOperator(base_model)
29+
model.to(device)
30+
model.eval()
31+
actor = model.get_policy_operator()
32+
critic = model.get_value_operator()
33+
critic_head = model.get_value_head()
34+
35+
return actor, VmapModule(critic), critic_head, base_model

examples/rlhf/models/reward.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import warnings
6+
7+
import torch
8+
from tensordict.nn import TensorDictModule
9+
10+
from torchrl.modules.models.rlhf import GPT2RewardModel
11+
12+
13+
def init_reward_model(
14+
transformer_path=None, reward_model_path=None, device=None, compile_model=False
15+
):
16+
if transformer_path is None and reward_model_path is None:
17+
warnings.warn(
18+
"You did not provide a path to the reward model, a naive reward model will be used instead."
19+
)
20+
model = GPT2RewardModel()
21+
else:
22+
if not ((transformer_path is None) ^ (reward_model_path is None)):
23+
raise ValueError(
24+
"Exactly one of transformer_path or reward_model_path should be specified."
25+
)
26+
if transformer_path is not None:
27+
model = GPT2RewardModel(transformer_path)
28+
else:
29+
model = GPT2RewardModel.from_pretrained(reward_model_path)
30+
31+
model.to(device)
32+
if compile_model:
33+
print("Compiling the reward model...")
34+
model = torch.compile(model)
35+
36+
model = TensorDictModule(
37+
model,
38+
in_keys=["input_ids", "attention_mask"],
39+
out_keys=["rewards", "end_scores"],
40+
)
41+
return model

examples/rlhf/models/transformer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import torch
6+
from tensordict.nn import TensorDictModule
7+
from transformers import GPT2LMHeadModel
8+
9+
10+
def init_transformer(
11+
name_or_path,
12+
dropout,
13+
device,
14+
compile_model,
15+
as_tensordictmodule=True,
16+
inference=False,
17+
):
18+
model_kwargs = {
19+
"resid_pdrop": dropout,
20+
"embd_pdrop": dropout,
21+
"attn_pdrop": dropout,
22+
"summary_first_dropout": dropout,
23+
}
24+
model = GPT2LMHeadModel.from_pretrained(
25+
name_or_path, return_dict=False, **model_kwargs
26+
)
27+
model.to(device)
28+
29+
if compile_model:
30+
# TODO: logging instead of printing?
31+
print("Compiling transformer model...")
32+
model = torch.compile(model)
33+
34+
if as_tensordictmodule:
35+
model = TensorDictModule(
36+
model,
37+
in_keys={
38+
"input_ids": "input_ids",
39+
"attention_mask": "attention_mask",
40+
"labels": "labels",
41+
},
42+
out_keys=["logits"] if inference else ["loss", "logits"],
43+
)
44+
return model

examples/rlhf/requirements.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
datasets
2+
hydra-core
3+
matplotlib
4+
numpy
5+
PyYAML
6+
requests
7+
tiktoken
8+
tqdm
9+
transformers
10+
git+https://github.com/pytorch/rl
11+
git+https://github.com/pytorch-labs/tensordict

0 commit comments

Comments
 (0)