-
Notifications
You must be signed in to change notification settings - Fork 4
PPO in MARL
In the context of Proximal Policy Optimization (PPO), batch processing refers to the way PPO collects and processes data for training the policy and value networks.
batch size = n_steps × n_envs × n_agents
n_steps:
In the context of Stable-Baselines3's PPO, n_steps refers to the number of timesteps collected by each environment during a single rollout before performing a policy update. It is a key hyperparameter that determines how much data PPO collects from the environment in each batch.
Difference between Total Rollout Size and Batch Size
Feature | Total Rollout Size | Batch Size |
---|---|---|
Definition | Total timesteps collected in a single rollout. | Amount of data used for each gradient update. |
Computation | n_steps * n_envs * n_agents |
User-defined hyperparameter. |
Scope | Covers all the data gathered during a rollout. | A subset of the rollout used in one update. |
Role in PPO | Defines the total data collected before optimization. | Defines the data processed in each gradient step. |
Hyperparameter? | Determined by n_steps , n_envs , and (if applicable) n_agents . |
Set by the user, independent of rollout size. |
Default batch_size in Stable-Baselines3 PPO is equal to the total rollout size (n_steps * n_envs). While this is suitable for many scenarios, explicitly specifying batch_size can provide finer control over mini-batch updates and computational efficiency, especially for larger rollouts.
-
Divisibility:
- The
batch_size
should divide evenly into the total rollout size to avoid leftover data. - Good choices for
batch_size
are divisors of1,638,400
.
- The
-
Memory Considerations:
- Larger
batch_size
values require more memory for each gradient update. - Smaller
batch_size
values reduce memory usage but can make gradient updates noisier.
- Larger
-
Gradient Update Efficiency:
- Smaller
batch_size
values allow more mini-batches per epoch, increasing the number of gradient steps.
- Smaller
-
Typical Values:
- Commonly used
batch_size
values are between16,000
and128,000
, depending on hardware capabilities.
- Commonly used
Batch Size | Explanation |
---|---|
16,384 | Small, more mini-batches, less memory usage. |
32,768 | A balance between computation and memory. |
64,000 | Efficient for most modern GPUs. |
128,000 | Larger batch, stable gradients, higher memory. |
- A
batch_size
of 64,000 is a good choice for this setup:- Divides evenly into
1,638,400
:1,638,400 ÷ 64,000 = 25.6 (≈ 25 mini-batches per epoch)
- Requires moderate memory.
- Provides stable gradient updates without excessive noise.
- Divides evenly into
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from supersuit import pettingzoo_env_to_vec_env_v1
# Example: Multi-agent custom environment with PettingZoo
env = make_vec_env("CustomMultiAgentEnv", n_envs=8)
# Initialize PPO with recommended batch_size
model = PPO(
"MlpPolicy",
env,
n_steps=2048,
batch_size=64_000, # Recommended batch size
n_epochs=10, # Default number of epochs
verbose=1
)
# Train the model
model.learn(total_timesteps=10_000_000)
-
Chosen
batch_size = 64,000
: - (1,638,400 / 64,000 = 25.6), resulting in:
- 25 full mini-batches.
- 1 truncated mini-batch of size 38,400.
To avoid truncated mini-batches, choose a batch_size
that is a factor of the Rollout Buffer Size (1,638,400).
The factors of 1,638,400
include:
Valid Batch Sizes | Explanation |
---|---|
2048 | Small batch size, stable. |
8192 | Medium batch size. |
32,768 | Balanced batch size. |
65,536 | Efficient for large rollouts. |
1,638,400 | Full rollout as one batch. |
- Recommended Value: 32,768
- Balances memory usage and training efficiency.
- Produces:
- (1,638,400 / 32,768 = 50) mini-batches per epoch.
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
# Create the environment
env = make_vec_env("CartPole-v1", n_envs=800)
# Initialize PPO with a batch_size that is a factor of the rollout size
model = PPO(
"MlpPolicy",
env,
n_steps=2048,
batch_size=32_768, # A factor of 1,638,400
n_epochs=10,
verbose=1
)
# Train the model
model.learn(total_timesteps=10_000_000)
- For large rollouts:
- Use a
batch_size
that is just below the median factor of the rollout size. - For smaller rollouts or memory constraints:
- Drop down to 1/4 or 1/8 of the rollout size but avoid excessive fragmentation.
Not necessarily. The optimal number of mini-batches depends on the problem, hardware, and rollout size:
-
Larger Rollout Sizes:
- More mini-batches (e.g., 100–200) might be appropriate for very large rollouts to process data more thoroughly.
-
Smaller Rollout Sizes:
- Fewer mini-batches (e.g., 10–20) might suffice for smaller rollouts.