Skip to content

Commit fc9794d

Browse files
albertbou92vmoens
andauthored
[Algorithm] Update PPO examples (#1495)
Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 99dcae8 commit fc9794d

16 files changed

+932
-778
lines changed

.github/unittest/linux_examples/scripts/environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@ dependencies:
2727
- coverage
2828
- vmas
2929
- transformers
30+
- gym[atari]
31+
- gym[accept-rom-license]

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,21 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
4848
# ==================================================================================== #
4949
# ================================ Gymnasium ========================================= #
5050

51-
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
52-
env.num_envs=1 \
53-
env.device=cuda:0 \
54-
collector.total_frames=48 \
55-
collector.frames_per_batch=16 \
56-
collector.collector_device=cuda:0 \
57-
optim.device=cuda:0 \
51+
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \
52+
env.env_name=HalfCheetah-v4 \
53+
collector.total_frames=40 \
54+
collector.frames_per_batch=20 \
5855
loss.mini_batch_size=10 \
5956
loss.ppo_epochs=1 \
6057
logger.backend= \
61-
logger.log_interval=4 \
62-
optim.lr_scheduler=False
58+
logger.test_interval=40
59+
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.py \
60+
collector.total_frames=80 \
61+
collector.frames_per_batch=20 \
62+
loss.mini_batch_size=20 \
63+
loss.ppo_epochs=1 \
64+
logger.backend= \
65+
logger.test_interval=40
6366
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
6467
collector.total_frames=48 \
6568
collector.init_random_frames=10 \
@@ -208,18 +211,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
208211
record_video=True \
209212
record_frames=4 \
210213
buffer_size=120
211-
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
212-
env.num_envs=1 \
213-
env.device=cuda:0 \
214-
collector.total_frames=48 \
215-
collector.frames_per_batch=16 \
216-
collector.collector_device=cuda:0 \
217-
optim.device=cuda:0 \
218-
loss.mini_batch_size=10 \
219-
loss.ppo_epochs=1 \
220-
logger.backend= \
221-
logger.log_interval=4 \
222-
optim.lr_scheduler=False
223214
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
224215
total_frames=48 \
225216
init_random_frames=10 \

examples/ppo/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
## Reproducing Proximal Policy Optimization (PPO) Algorithm Results
2+
3+
This repository contains scripts that enable training agents using the Proximal Policy Optimization (PPO) Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Schulman et al. (2017) to implement the PPO algorithm but introduce the improvement of computing the Generalised Advantage Estimator (GAE) at every epoch.
4+
5+
6+
## Examples Structure
7+
8+
Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files:
9+
10+
1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. ppo_atari.py).
11+
12+
2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py).
13+
14+
3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml).
15+
16+
17+
## Running the Examples
18+
19+
You can execute the PPO algorithm on Atari environments by running the following command:
20+
21+
```bash
22+
python ppo_atari.py
23+
```
24+
25+
You can execute the PPO algorithm on MuJoCo environments by running the following command:
26+
27+
```bash
28+
python ppo_mujoco.py
29+
```

examples/ppo/config.yaml

Lines changed: 0 additions & 46 deletions
This file was deleted.

examples/ppo/config_atari.yaml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Environment
2+
env:
3+
env_name: PongNoFrameskip-v4
4+
num_envs: 8
5+
6+
# collector
7+
collector:
8+
frames_per_batch: 4096
9+
total_frames: 40_000_000
10+
11+
# logger
12+
logger:
13+
backend: wandb
14+
exp_name: Atari_Schulman17
15+
test_interval: 40_000_000
16+
num_test_episodes: 3
17+
18+
# Optim
19+
optim:
20+
lr: 2.5e-4
21+
eps: 1.0e-6
22+
weight_decay: 0.0
23+
max_grad_norm: 0.5
24+
anneal_lr: True
25+
26+
# loss
27+
loss:
28+
gamma: 0.99
29+
mini_batch_size: 1024
30+
ppo_epochs: 3
31+
gae_lambda: 0.95
32+
clip_epsilon: 0.1
33+
anneal_clip_epsilon: True
34+
critic_coef: 1.0
35+
entropy_coef: 0.01
36+
loss_critic_type: l2

examples/ppo/config_example2.yaml

Lines changed: 0 additions & 43 deletions
This file was deleted.

examples/ppo/config_mujoco.yaml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# task and env
2+
env:
3+
env_name: HalfCheetah-v3
4+
5+
# collector
6+
collector:
7+
frames_per_batch: 2048
8+
total_frames: 1_000_000
9+
10+
# logger
11+
logger:
12+
backend: wandb
13+
exp_name: Mujoco_Schulman17
14+
test_interval: 1_000_000
15+
num_test_episodes: 5
16+
17+
# Optim
18+
optim:
19+
lr: 3e-4
20+
weight_decay: 0.0
21+
anneal_lr: False
22+
23+
# loss
24+
loss:
25+
gamma: 0.99
26+
mini_batch_size: 64
27+
ppo_epochs: 10
28+
gae_lambda: 0.95
29+
clip_epsilon: 0.2
30+
anneal_clip_epsilon: False
31+
critic_coef: 0.25
32+
entropy_coef: 0.0
33+
loss_critic_type: l2

0 commit comments

Comments
 (0)