Skip to content

Commit 69d5343

Browse files
BY571vmoens
andauthored
[Algorithm] Update DDPG Example (#1525)
Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent df03cac commit 69d5343

File tree

4 files changed

+164
-110
lines changed

4 files changed

+164
-110
lines changed

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.
6666
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
6767
collector.total_frames=48 \
6868
collector.init_random_frames=10 \
69-
optimization.batch_size=10 \
69+
optim.batch_size=10 \
7070
collector.frames_per_batch=16 \
71-
collector.num_workers=4 \
7271
collector.env_per_collector=2 \
7372
collector.collector_device=cuda:0 \
7473
network.device=cuda:0 \
75-
optimization.utd_ratio=1 \
74+
optim.utd_ratio=1 \
7675
replay_buffer.size=120 \
7776
env.name=Pendulum-v1 \
7877
logger.backend=
@@ -183,13 +182,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame
183182
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
184183
collector.total_frames=48 \
185184
collector.init_random_frames=10 \
186-
optimization.batch_size=10 \
185+
optim.batch_size=10 \
187186
collector.frames_per_batch=16 \
188-
collector.num_workers=2 \
189187
collector.env_per_collector=1 \
190188
collector.collector_device=cuda:0 \
191189
network.device=cuda:0 \
192-
optimization.utd_ratio=1 \
190+
optim.utd_ratio=1 \
193191
replay_buffer.size=120 \
194192
env.name=Pendulum-v1 \
195193
logger.backend=

examples/ddpg/config.yaml

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,47 @@
1-
# Environment
1+
# environment and task
22
env:
33
name: HalfCheetah-v3
44
task: ""
5-
exp_name: "HalfCheetah-DDPG"
6-
library: gym
7-
frame_skip: 1
8-
seed: 1
5+
exp_name: ${env.name}_DDPG
6+
library: gymnasium
7+
max_episode_steps: 1000
8+
seed: 42
99

10-
# Collection
10+
# collector
1111
collector:
12-
total_frames: 1000000
13-
init_random_frames: 10000
12+
total_frames: 1_000_000
13+
init_random_frames: 25_000
1414
frames_per_batch: 1000
15-
max_frames_per_traj: 1000
1615
init_env_steps: 1000
17-
async_collection: 1
16+
reset_at_each_iter: False
1817
collector_device: cpu
1918
env_per_collector: 1
20-
num_workers: 1
2119

22-
# Replay Buffer
20+
21+
# replay buffer
2322
replay_buffer:
2423
size: 1000000
2524
prb: 0 # use prioritized experience replay
25+
scratch_dir: ${env.exp_name}_${env.seed}
2626

27-
# Optimization
28-
optimization:
27+
# optimization
28+
optim:
2929
utd_ratio: 1.0
3030
gamma: 0.99
31-
loss_function: smooth_l1
32-
lr: 3e-4
33-
weight_decay: 2e-4
31+
loss_function: l2
32+
lr: 3.0e-4
33+
weight_decay: 1e-4
3434
batch_size: 256
3535
target_update_polyak: 0.995
3636

37+
# network
3738
network:
3839
hidden_sizes: [256, 256]
3940
activation: relu
4041
device: "cuda:0"
42+
noise_type: "ou" # ou or gaussian
4143

42-
# Logging
44+
# logging
4345
logger:
4446
backend: wandb
4547
mode: online

examples/ddpg/ddpg.py

Lines changed: 72 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@
1111
The helper functions are coded in the utils.py associated with this script.
1212
"""
1313

14+
import time
15+
1416
import hydra
1517

1618
import numpy as np
1719
import torch
1820
import torch.cuda
1921
import tqdm
22+
2023
from torchrl.envs.utils import ExplorationType, set_exploration_type
2124
from torchrl.record.loggers import generate_exp_name, get_logger
2225
from utils import (
26+
log_metrics,
2327
make_collector,
2428
make_ddpg_agent,
2529
make_environment,
@@ -33,6 +37,7 @@
3337
def main(cfg: "DictConfig"): # noqa: F821
3438
device = torch.device(cfg.network.device)
3539

40+
# Create logger
3641
exp_name = generate_exp_name("DDPG", cfg.env.exp_name)
3742
logger = None
3843
if cfg.logger.backend:
@@ -43,137 +48,149 @@ def main(cfg: "DictConfig"): # noqa: F821
4348
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
4449
)
4550

51+
# Set seeds
4652
torch.manual_seed(cfg.env.seed)
4753
np.random.seed(cfg.env.seed)
4854

49-
# Create Environments
55+
# Create environments
5056
train_env, eval_env = make_environment(cfg)
5157

52-
# Create Agent
58+
# Create agent
5359
model, exploration_policy = make_ddpg_agent(cfg, train_env, eval_env, device)
5460

55-
# Create Loss Module and Target Updater
61+
# Create DDPG loss
5662
loss_module, target_net_updater = make_loss_module(cfg, model)
5763

58-
# Make Off-Policy Collector
64+
# Create off-policy collector
5965
collector = make_collector(cfg, train_env, exploration_policy)
6066

61-
# Make Replay Buffer
67+
# Create replay buffer
6268
replay_buffer = make_replay_buffer(
63-
batch_size=cfg.optimization.batch_size,
69+
batch_size=cfg.optim.batch_size,
6470
prb=cfg.replay_buffer.prb,
6571
buffer_size=cfg.replay_buffer.size,
72+
buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir,
6673
device=device,
6774
)
6875

69-
# Make Optimizers
76+
# Create optimizers
7077
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
7178

72-
rewards = []
73-
rewards_eval = []
74-
7579
# Main loop
80+
start_time = time.time()
7681
collected_frames = 0
7782
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
78-
r0 = None
79-
q_loss = None
8083

8184
init_random_frames = cfg.collector.init_random_frames
8285
num_updates = int(
8386
cfg.collector.env_per_collector
8487
* cfg.collector.frames_per_batch
85-
* cfg.optimization.utd_ratio
88+
* cfg.optim.utd_ratio
8689
)
8790
prb = cfg.replay_buffer.prb
88-
env_per_collector = cfg.collector.env_per_collector
89-
frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip
91+
frames_per_batch = cfg.collector.frames_per_batch
9092
eval_iter = cfg.logger.eval_iter
91-
eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip
93+
eval_rollout_steps = cfg.env.max_episode_steps
9294

93-
for i, tensordict in enumerate(collector):
95+
sampling_start = time.time()
96+
for _, tensordict in enumerate(collector):
97+
sampling_time = time.time() - sampling_start
98+
# Update exploration policy
9499
exploration_policy.step(tensordict.numel())
95-
# update weights of the inference policy
100+
101+
# Update weights of the inference policy
96102
collector.update_policy_weights_()
97103

98-
if r0 is None:
99-
r0 = tensordict["next", "reward"].sum(-1).mean().item()
100104
pbar.update(tensordict.numel())
101105

102106
tensordict = tensordict.reshape(-1)
103107
current_frames = tensordict.numel()
108+
# Add to replay buffer
104109
replay_buffer.extend(tensordict.cpu())
105110
collected_frames += current_frames
106111

107-
# optimization steps
112+
# Optimization steps
113+
training_start = time.time()
108114
if collected_frames >= init_random_frames:
109115
(
110116
actor_losses,
111117
q_losses,
112118
) = ([], [])
113119
for _ in range(num_updates):
114-
# sample from replay buffer
120+
# Sample from replay buffer
115121
sampled_tensordict = replay_buffer.sample().clone()
116122

123+
# Compute loss
117124
loss_td = loss_module(sampled_tensordict)
118125

119-
optimizer_critic.zero_grad()
120-
optimizer_actor.zero_grad()
121-
122126
actor_loss = loss_td["loss_actor"]
123127
q_loss = loss_td["loss_value"]
124-
(actor_loss + q_loss).backward()
125128

129+
# Update critic
130+
optimizer_critic.zero_grad()
131+
q_loss.backward()
126132
optimizer_critic.step()
127-
q_losses.append(q_loss.item())
128133

134+
# Update actor
135+
optimizer_actor.zero_grad()
136+
actor_loss.backward()
129137
optimizer_actor.step()
138+
139+
q_losses.append(q_loss.item())
130140
actor_losses.append(actor_loss.item())
131141

132-
# update qnet_target params
142+
# Update qnet_target params
133143
target_net_updater.step()
134144

135-
# update priority
145+
# Update priority
136146
if prb:
137147
replay_buffer.update_priority(sampled_tensordict)
138148

139-
rewards.append(
140-
(i, tensordict["next", "reward"].sum().item() / env_per_collector)
149+
training_time = time.time() - training_start
150+
episode_end = (
151+
tensordict["next", "done"]
152+
if tensordict["next", "done"].any()
153+
else tensordict["next", "truncated"]
141154
)
142-
train_log = {
143-
"train_reward": rewards[-1][1],
144-
"collected_frames": collected_frames,
145-
}
146-
if q_loss is not None:
147-
train_log.update(
148-
{
149-
"actor_loss": np.mean(actor_losses),
150-
"q_loss": np.mean(q_losses),
151-
}
155+
episode_rewards = tensordict["next", "episode_reward"][episode_end]
156+
157+
# Logging
158+
metrics_to_log = {}
159+
if len(episode_rewards) > 0:
160+
episode_length = tensordict["next", "step_count"][episode_end]
161+
metrics_to_log["train/reward"] = episode_rewards.mean().item()
162+
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
163+
episode_length
152164
)
153-
if logger is not None:
154-
for key, value in train_log.items():
155-
logger.log_scalar(key, value, step=collected_frames)
156-
if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip:
165+
166+
if collected_frames >= init_random_frames:
167+
metrics_to_log["train/q_loss"] = np.mean(q_losses)
168+
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
169+
metrics_to_log["train/sampling_time"] = sampling_time
170+
metrics_to_log["train/training_time"] = training_time
171+
172+
# Evaluation
173+
if abs(collected_frames % eval_iter) < frames_per_batch:
157174
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
175+
eval_start = time.time()
158176
eval_rollout = eval_env.rollout(
159177
eval_rollout_steps,
160178
exploration_policy,
161179
auto_cast_to_device=True,
162180
break_when_any_done=True,
163181
)
182+
eval_time = time.time() - eval_start
164183
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
165-
rewards_eval.append((i, eval_reward))
166-
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
167-
if logger is not None:
168-
logger.log_scalar(
169-
"evaluation_reward", rewards_eval[-1][1], step=collected_frames
170-
)
171-
if len(rewards_eval):
172-
pbar.set_description(
173-
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
174-
)
184+
metrics_to_log["eval/reward"] = eval_reward
185+
metrics_to_log["eval/time"] = eval_time
186+
if logger is not None:
187+
log_metrics(logger, metrics_to_log, collected_frames)
188+
sampling_start = time.time()
175189

176190
collector.shutdown()
191+
end_time = time.time()
192+
execution_time = end_time - start_time
193+
print(f"Training took {execution_time:.2f} seconds to finish")
177194

178195

179196
if __name__ == "__main__":

0 commit comments

Comments
 (0)