Skip to content

Commit 169fe1f

Browse files
authored
[Algorithm] unify grpo sync/async implementations (#3006)
1 parent fac4b3e commit 169fe1f

File tree

25 files changed

+958
-423
lines changed

25 files changed

+958
-423
lines changed

sota-implementations/grpo/README.md

Lines changed: 103 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,13 @@ export VLLM_USE_V1=0 # Required for vLLM compatibility
3737

3838
### Device Management
3939

40-
There are two ways to specify device allocation:
40+
The number of devices for each model component is specified using `num_devices`:
4141

42-
1. Using `num_devices` (Recommended):
4342
```bash
4443
train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
4544
```
46-
This approach automatically manages device allocation based on the training mode (sync/async) and prevents device conflicts.
4745

48-
2. Using `devices` (Manual):
49-
```bash
50-
train_model.devices=[0,1] ref_model.devices=[2,3] inference_model.devices=[4,5]
51-
```
52-
This approach requires manual device management and is more error-prone.
53-
54-
The `num_devices` approach is recommended as it:
46+
This approach:
5547
- Automatically handles device allocation
5648
- Works correctly in both sync and async modes
5749
- Prevents device conflicts between model components
@@ -71,47 +63,139 @@ There are two training modes available:
7163

7264
#### Synchronous Mode (Default)
7365
```bash
74-
VLLM_USE_V1=0 python sota-implementations/grpo/grpo.py train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
66+
VLLM_USE_V1=0 python sota-implementations/grpo/grpo-sync.py mode=sync train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
7567
```
7668

7769
#### Asynchronous Mode (Recommended)
7870
```bash
79-
VLLM_USE_V1=0 python sota-implementations/grpo/grpo-async.py train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
71+
VLLM_USE_V1=0 python sota-implementations/grpo/grpo-async.py mode=async train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
72+
```
73+
74+
The key difference between sync and async modes is how data collection and optimization are handled:
75+
76+
**Synchronous Mode (grpo-sync.py)**:
77+
```python
78+
# Three nested loops:
79+
for data in collector: # Data collection loop
80+
for epoch in range(epochs): # Epoch loop
81+
for batch in replay_buffer: # Buffer consumption loop
82+
# Optimize on batch
83+
loss = loss_fn(batch)
84+
loss.backward()
85+
optimizer.step()
86+
# Weight updte
87+
weight_updater.push_weights(policy_training)
88+
```
89+
90+
**Asynchronous Mode (grpo-async.py)**:
91+
```python
92+
# Start data collection in background
93+
collector.start()
94+
95+
# Single optimization loop
96+
for step in range(total_steps):
97+
# Sample and optimize
98+
batch = replay_buffer.sample()
99+
loss = loss_fn(batch)
100+
loss.backward()
101+
optimizer.step()
102+
# Update weights once in a while
103+
if cond():
104+
weight_updater.push_weights(policy_training)
105+
80106
```
81107

108+
Key differences:
109+
1. **Data Collection**:
110+
- Sync: Data collection and optimization happen sequentially
111+
- Async: Data collection runs in background while optimization happens
112+
113+
2. **Buffer Size**:
114+
- Sync: Buffer size must equal the batch size returned by collector (`buffer_size = steps_per_batch`)
115+
- Async: Buffer can be larger than the batch size, allowing for more diverse sampling
116+
117+
3. **Data Processing**:
118+
- Sync: Processes the same data multiple times (epochs)
119+
- Async: Each piece of data is processed a non-deterministic number of times.
120+
121+
4. **Weight updates**:
122+
- Sync: Weights are updated befor every collection of data
123+
- Async: Weights are updated at a given interval (in gradient steps)
124+
82125
The async mode offers better performance by:
83126
- Running data collection and optimization concurrently
84127
- More efficient GPU utilization
85128
- Reduced memory overhead
86129
- Better throughput
130+
- More flexible buffer management
131+
132+
### KL Divergences in PPO: Reference vs Inference
133+
134+
KL divergence is a key regularization term in policy optimization algorithms like PPO and in LLM post-training. It measures how much the updated policy diverges from a baseline or reference policy, helping to prevent the new policy from drifting too far and ensuring stable learning.
135+
136+
There are two main types of KL divergences commonly used:
137+
138+
#### 1. KL to Reference Policy (KL[ref || policy])
139+
- **Definition:** Measures how much the new (learned) policy diverges from a fixed reference policy (often the original, pre-trained model).
140+
- **Implementation:** In GRPO, this is computed as `(ref_log_prob - cur_log_prob).expm1() - (ref_log_prob - cur_log_prob)`, which is a numerically stable way to compute KL for log probabilities.
141+
- **Usage:**
142+
- **LLM Post-Training:** This is the canonical choice in LLM post-training (e.g., RLHF, DPO, GRPO). The reference is usually the original language model before any RL fine-tuning. Penalizing KL[ref || policy] ensures the fine-tuned model stays close to the original, preserving language quality and preventing over-optimization.
143+
- **Effect:** Encourages the new policy to not deviate too much from the reference, maintaining fluency and generalization.
144+
145+
#### 2. KL to Inference Policy (KL[policy || inference])
146+
- **Definition:** Measures how much the current policy diverges from the policy used to generate the data (the inference policy, sometimes called the behavior policy).
147+
- **Implementation:** In GRPO, this is approximated as `prev_log_prob - cur_log_prob`, where `prev_log_prob` is from the inference policy that generated the data.
148+
- **Usage:**
149+
- **Canonical PPO:** In standard PPO (especially in RL for control), this is the canonical KL: KL[policy || inference]. The inference policy is the one that generated the trajectories in the replay buffer. Penalizing this KL ensures that the updated policy does not move too far from the data distribution, stabilizing importance sampling and learning.
150+
- **Effect:** Prevents the policy from making large, unstable updates relative to the data it was trained on.
151+
152+
#### Summary Table
153+
| Setting | Canonical KL Term | Purpose |
154+
|--------------------|--------------------------|---------------------------------------------|
155+
| PPO (RL control) | KL[policy || inference] | Stabilize updates, match data distribution |
156+
| LLM Post-Training | KL[ref || policy] | Stay close to pre-trained model |
157+
158+
In GRPO, both types of KL can be used and controlled via configuration. Typically, for LLM post-training, the KL to reference is the most important for preserving model quality, while the KL to inference is more about stabilizing the optimization process.
159+
160+
The KL contributions to the loss can be controlled via the `train.kl_to_ref_coeff` and `train.kl_to_inference_coeff`, respectively.
161+
162+
Additionally, the KL to ref loss contribution can be either added to the reward during the grading of the LLM response, or added directly to the loss given by the `train.kl_coef_in_loss` config option.
163+
164+
In the original GRPO paper, the KL to reference (KL[ref || policy]) is added **directly to the loss function**, not to the reward. This means that the KL penalty acts as a regularizer during optimization, discouraging the policy from drifting too far from the reference model at every update step. This is in contrast to some RLHF-style approaches, where the KL penalty is added to the reward signal during data collection (i.e., the environment's reward is modified).
165+
166+
**Why does this matter?**
167+
- **KL in the loss (as in GRPO):** The optimization explicitly balances the policy objective and the KL penalty at each gradient step, making the trade-off more direct and stable. This is the canonical approach in GRPO and is controlled by setting `train.kl_coef_in_loss=True` in the config.
168+
- **KL in the reward:** The KL penalty is treated as part of the environment's reward, so the policy is trained to maximize this modified reward. This can sometimes make the effect of the KL less direct, as it is mixed with the task reward during data collection.
169+
170+
In summary, GRPO's approach of adding the KL to reference directly to the loss provides more explicit and stable regularization, and is the recommended setting for most LLM post-training scenarios.
87171

88172
### Run with IFEval Config
89173

90174
```bash
91-
python grpo.py --config-name grpo_ifeval
175+
python grpo-sync.py mode=sync --config-name grpo_ifeval
92176
```
93177

94178
### Override Config Values
95179

96180
```bash
97181
# Change dataset
98-
python grpo.py env.dataset=ifeval
182+
python grpo-sync.py mode=sync env.dataset=ifeval
99183

100184
# Modify training parameters
101-
python grpo.py optimizer.lr=2e-5 optimizer.weight_decay=0.01
185+
python grpo-sync.py mode=sync optimizer.lr=2e-5 optimizer.weight_decay=0.01
102186

103187
# Change model
104-
python grpo.py model.name=meta-llama/Llama-2-7b-hf
188+
python grpo-sync.py mode=sync model.name=meta-llama/Llama-2-7b-hf
105189
```
106190

107191
### Hyperparameter Sweeps
108192

109193
```bash
110194
# Learning rate sweep
111-
python grpo.py --multirun optimizer.lr=1e-4,1e-5,1e-6
195+
python grpo-sync.py mode=sync --multirun optimizer.lr=1e-4,1e-5,1e-6
112196

113197
# Multiple parameters
114-
python grpo.py --multirun \
198+
python grpo-sync.py mode=sync --multirun \
115199
optimizer.lr=1e-4,1e-5 \
116200
policy.kl_coef=0.01,0.1
117201
```
@@ -153,7 +237,7 @@ sota-implementations/grpo/
153237
├── config/
154238
│ └── grpo_gsm8k.yaml # Main configuration file
155239
│ └── grpo_ifeval.yaml # config file for IFEval task
156-
├── grpo.py # Synchronous training script
240+
├── grpo-sync.py # Synchronous training script
157241
├── grpo-async.py # Asynchronous training script
158242
├── grpo_utils.py # Utility functions
159243
└── README.md # This file

sota-implementations/grpo/config/grpo_gsm8k.yaml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# @package _global_
22
defaults:
3-
- mode: async # Default to async mode, will be overridden by grpo.py
3+
- mode: ${mode:async} # Default to async mode, can be overridden by scripts
44
- _self_
55
- override hydra/hydra_logging: disabled
66
- override hydra/job_logging: disabled
@@ -35,15 +35,19 @@ train:
3535

3636
# Fields used by both scripts but with different semantics
3737
checkpoint_frequency: 100 # Save checkpoint every N steps/batches
38+
39+
# KL coefficients for the KL divergence to the reference and inference policies
40+
kl_to_ref_coeff: 1e-2
41+
kl_to_inference_coeff: 0.0
42+
entropy_coeff: 0.01
3843

39-
# Fields used only by grpo-async.py
40-
weight_update_frequency: 10 # Update policy weights every N steps
44+
# Fields used only by grpo-async.py / grpo-sync.py
4145
logging_frequency: 10 # Log metrics every N steps
46+
4247
# Training model configuration
4348
train_model:
4449
gradient_checkpointing: true # Enabled for memory efficiency
4550
num_devices: 1 # Number of devices to use
46-
devices: null # Will be computed by compute_device_allocation
4751
lora:
4852
enabled: true # Using LoRA for memory efficiency
4953
r: 8 # LoRA rank - controls capacity of adaptations
@@ -57,7 +61,6 @@ train_model:
5761
# Inference model configuration
5862
inference_model:
5963
num_devices: 1 # Number of devices to use
60-
devices: null # Will be computed by compute_device_allocation
6164
quantization:
6265
enabled: false # Enable 4-bit quantization for base model
6366
attn_implementation: sdpa # Using flash attention for memory efficiency
@@ -72,7 +75,6 @@ inference_model:
7275
ref_model:
7376
gradient_checkpointing: false # Always false, no backprop
7477
num_devices: 1 # Number of devices to use
75-
devices: null # Will be computed by compute_device_allocation
7678
lora:
7779
enabled: true # Using LoRA for memory efficiency
7880
r: 8 # LoRA rank - controls capacity of adaptations
@@ -83,16 +85,13 @@ ref_model:
8385
attn_implementation: sdpa # Using flash attention for memory efficiency
8486
torch_dtype: bfloat16
8587

86-
# Policy configuration
87-
policy:
88-
kl_coef: 1e-2
89-
9088
# Optimizer configuration
9189
optimizer:
9290
name: AdamW
9391
lr: 1e-5
9492
clip_grad_norm: 1.0
9593
weight_decay: 0.0
94+
9695
# Ray configuration
9796
ray:
9897
init_config:
@@ -113,6 +112,7 @@ ray:
113112
replay_buffer_config:
114113
num_cpus: 24 # CPUs for replay buffer
115114
num_gpus: 0.0 # No GPU needed for replay buffer
115+
116116
# Logging configuration
117117
logging:
118118
experiment_name: null # Will be auto-generated if not provided

sota-implementations/grpo/config/grpo_ifeval.yaml

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# @package _global_
22
defaults:
3-
- mode: async # Default to async mode, will be overridden by grpo.py
3+
- mode: ${mode:async} # Default to async mode, can be overridden by scripts
44
- _self_
55
- override hydra/hydra_logging: disabled
66
- override hydra/job_logging: disabled
@@ -35,16 +35,19 @@ train:
3535

3636
# Fields used by both scripts but with different semantics
3737
checkpoint_frequency: 100 # Save checkpoint every N steps/batches
38+
39+
# KL coefficients for the KL divergence to the reference and inference policies
40+
kl_to_ref_coeff: 1e-2
41+
kl_to_inference_coeff: 0.0
42+
entropy_coeff: 0.01
3843

39-
# Fields used only by grpo-async.py
40-
weight_update_frequency: 10 # Update policy weights every N steps
44+
# Fields used only by grpo-async.py / grpo-sync.py
4145
logging_frequency: 10 # Log metrics every N steps
4246

4347
# Training model configuration
4448
train_model:
4549
gradient_checkpointing: true # Enabled for memory efficiency
4650
num_devices: 1 # Number of devices to use
47-
devices: null # Will be computed by compute_device_allocation
4851
lora:
4952
enabled: true # Using LoRA for memory efficiency
5053
r: 8 # LoRA rank - controls capacity of adaptations
@@ -58,7 +61,6 @@ train_model:
5861
# Inference model configuration
5962
inference_model:
6063
num_devices: 1 # Number of devices to use
61-
devices: null # Will be computed by compute_device_allocation
6264
quantization:
6365
enabled: false # Enable 4-bit quantization for base model
6466
attn_implementation: sdpa # Using flash attention for memory efficiency
@@ -73,7 +75,6 @@ inference_model:
7375
ref_model:
7476
gradient_checkpointing: false # Always false, no backprop
7577
num_devices: 1 # Number of devices to use
76-
devices: null # Will be computed by compute_device_allocation
7778
lora:
7879
enabled: true # Using LoRA for memory efficiency
7980
r: 8 # LoRA rank - controls capacity of adaptations
@@ -84,10 +85,6 @@ ref_model:
8485
attn_implementation: sdpa # Using flash attention for memory efficiency
8586
torch_dtype: bfloat16
8687

87-
# Policy configuration
88-
policy:
89-
kl_coef: 1e-2
90-
9188
# Optimizer configuration
9289
optimizer:
9390
name: AdamW

sota-implementations/grpo/config/mode/async.yaml

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,22 @@ train:
44
sync: false # Force asynchronous mode
55

66
# Shared training settings
7+
# Whether to use mixed precision training.
78
mixed_precision: true
9+
# Number of epochs to train for, every time a batch is collected. Per se, not directly used in async - aside from computing the total number of steps.
810
epochs: 1
11+
# Number of steps in each batch. Higher values will cause the inference step to be slower, but won't use more GPU memory.
912
steps_per_batch: 16
10-
buffer_size: 32
13+
# Leave buffer_size empty to use steps_per_batch in async mode
14+
buffer_size:
15+
# Total number of dialog turns to collect during training.
1116
total_dialog_turns: 100_000
12-
optim_batch_size: 4
17+
# Batch size for optimization. Higher values will use more GPU memory.
18+
optim_batch_size: 1
19+
# Number of gradient accumulation steps. Higher values will use less GPU memory (comparing with bigger batches and lower gradient_accumulation_steps),
20+
# but will make the optimization step slower.
1321
gradient_accumulation_steps: 4
14-
kl_coef_in_loss: false
22+
# Whether to include the KL coefficient in the loss function. Alternatively, the KL ref-to-train will be added to the reward.
23+
kl_coef_in_loss: true
24+
# Update policy weights every N steps - can be set to any positive integer in async mode
25+
weight_update_frequency: 10

sota-implementations/grpo/config/mode/sync.yaml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,22 @@ train:
44
sync: true # Force synchronous mode
55

66
# Shared training settings
7+
# Whether to use mixed precision training.
78
mixed_precision: true
9+
# Number of epochs to train for, every time a batch is collected.
810
epochs: 1
9-
steps_per_batch: 32
11+
# Number of steps in each batch. Higher values will cause the inference step to be slower, but won't use more GPU memory.
12+
steps_per_batch: 64
1013
# Leave buffer_size empty to use steps_per_batch in sync mode
1114
buffer_size:
15+
# Total number of dialog turns to collect during training.
1216
total_dialog_turns: 100_000
17+
# Batch size for optimization. Higher values will use more GPU memory.
1318
optim_batch_size: 1
19+
# Number of gradient accumulation steps. Higher values will use less GPU memory (comparing with bigger batches and lower gradient_accumulation_steps),
20+
# but will make the optimization step slower.
1421
gradient_accumulation_steps: 1
22+
# Whether to include the KL coefficient in the loss function. Alternatively, the KL ref-to-train will be added to the reward.
1523
kl_coef_in_loss: true
24+
# Update policy weights every N steps - must be left empty in sync mode
25+
weight_update_frequency:

0 commit comments

Comments
 (0)