Skip to content

Commit df03cac

Browse files
authored
[Algorithm] Update TD3 Example (#1523)
1 parent 95b7206 commit df03cac

File tree

7 files changed

+328
-249
lines changed

7 files changed

+328
-249
lines changed

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame
146146
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
147147
collector.total_frames=48 \
148148
collector.init_random_frames=10 \
149-
optimization.batch_size=10 \
149+
optim.batch_size=10 \
150150
collector.frames_per_batch=16 \
151151
collector.num_workers=4 \
152152
collector.env_per_collector=2 \
@@ -247,7 +247,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online
247247
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
248248
collector.total_frames=48 \
249249
collector.init_random_frames=10 \
250-
optimization.batch_size=10 \
250+
optim.batch_size=10 \
251251
collector.frames_per_batch=16 \
252252
collector.num_workers=2 \
253253
collector.env_per_collector=1 \

examples/td3/config.yaml

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,50 @@
1-
# Environment
1+
# task and env
22
env:
33
name: HalfCheetah-v3
44
task: ""
5-
exp_name: "HalfCheetah-TD3"
6-
library: gym
7-
frame_skip: 1
5+
exp_name: ${env.name}_TD3
6+
library: gymnasium
87
seed: 42
8+
max_episode_steps: 1000
99

10-
# Collection
10+
# collector
1111
collector:
1212
total_frames: 1000000
13-
init_random_frames: 10000
13+
init_random_frames: 25_000
1414
init_env_steps: 1000
1515
frames_per_batch: 1000
16-
max_frames_per_traj: 1000
17-
async_collection: 1
16+
reset_at_each_iter: False
1817
collector_device: cpu
1918
env_per_collector: 1
2019
num_workers: 1
2120

22-
# Replay Buffer
21+
# replay buffer
2322
replay_buffer:
2423
prb: 0 # use prioritized experience replay
2524
size: 1000000
25+
scratch_dir: ${env.exp_name}_${env.seed}
2626

27-
# Optimization
28-
optimization:
27+
# optim
28+
optim:
2929
utd_ratio: 1.0
3030
gamma: 0.99
3131
loss_function: l2
32-
lr: 3e-4
33-
weight_decay: 2e-4
32+
lr: 3.0e-4
33+
weight_decay: 0.0
34+
adam_eps: 1e-4
3435
batch_size: 256
3536
target_update_polyak: 0.995
3637
policy_update_delay: 2
38+
policy_noise: 0.2
39+
noise_clip: 0.5
3740

38-
# Network
41+
# network
3942
network:
4043
hidden_sizes: [256, 256]
4144
activation: relu
4245
device: "cuda:0"
4346

44-
# Logging
47+
# logging
4548
logger:
4649
backend: wandb
4750
mode: online

examples/td3/td3.py

Lines changed: 76 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
The helper functions are coded in the utils.py associated with this script.
1212
"""
1313

14-
import hydra
14+
import time
1515

16+
import hydra
1617
import numpy as np
1718
import torch
1819
import torch.cuda
@@ -22,6 +23,7 @@
2223

2324
from torchrl.record.loggers import generate_exp_name, get_logger
2425
from utils import (
26+
log_metrics,
2527
make_collector,
2628
make_environment,
2729
make_loss_module,
@@ -35,6 +37,7 @@
3537
def main(cfg: "DictConfig"): # noqa: F821
3638
device = torch.device(cfg.network.device)
3739

40+
# Create logger
3841
exp_name = generate_exp_name("TD3", cfg.env.exp_name)
3942
logger = None
4043
if cfg.logger.backend:
@@ -45,140 +48,155 @@ def main(cfg: "DictConfig"): # noqa: F821
4548
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
4649
)
4750

51+
# Set seeds
4852
torch.manual_seed(cfg.env.seed)
4953
np.random.seed(cfg.env.seed)
5054

51-
# Create Environments
55+
# Create environments
5256
train_env, eval_env = make_environment(cfg)
5357

54-
# Create Agent
58+
# Create agent
5559
model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device)
5660

5761
# Create TD3 loss
5862
loss_module, target_net_updater = make_loss_module(cfg, model)
5963

60-
# Make Off-Policy Collector
64+
# Create off-policy collector
6165
collector = make_collector(cfg, train_env, exploration_policy)
6266

63-
# Make Replay Buffer
67+
# Create replay buffer
6468
replay_buffer = make_replay_buffer(
65-
batch_size=cfg.optimization.batch_size,
69+
batch_size=cfg.optim.batch_size,
6670
prb=cfg.replay_buffer.prb,
6771
buffer_size=cfg.replay_buffer.size,
72+
buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir,
6873
device=device,
6974
)
7075

71-
# Make Optimizers
76+
# Create optimizers
7277
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
7378

74-
rewards = []
75-
rewards_eval = []
76-
7779
# Main loop
80+
start_time = time.time()
7881
collected_frames = 0
7982
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
80-
r0 = None
81-
q_loss = None
8283

8384
init_random_frames = cfg.collector.init_random_frames
8485
num_updates = int(
8586
cfg.collector.env_per_collector
8687
* cfg.collector.frames_per_batch
87-
* cfg.optimization.utd_ratio
88+
* cfg.optim.utd_ratio
8889
)
89-
delayed_updates = cfg.optimization.policy_update_delay
90+
delayed_updates = cfg.optim.policy_update_delay
9091
prb = cfg.replay_buffer.prb
91-
env_per_collector = cfg.collector.env_per_collector
92-
eval_rollout_steps = cfg.collector.max_frames_per_traj // cfg.env.frame_skip
92+
eval_rollout_steps = cfg.env.max_episode_steps
9393
eval_iter = cfg.logger.eval_iter
94-
frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip
94+
frames_per_batch = cfg.collector.frames_per_batch
95+
update_counter = 0
9596

96-
for i, tensordict in enumerate(collector):
97+
sampling_start = time.time()
98+
for tensordict in collector:
99+
sampling_time = time.time() - sampling_start
97100
exploration_policy.step(tensordict.numel())
98-
# update weights of the inference policy
101+
102+
# Update weights of the inference policy
99103
collector.update_policy_weights_()
100104

101-
if r0 is None:
102-
r0 = tensordict["next", "reward"].sum(-1).mean().item()
103105
pbar.update(tensordict.numel())
104106

105107
tensordict = tensordict.reshape(-1)
106108
current_frames = tensordict.numel()
109+
# Add to replay buffer
107110
replay_buffer.extend(tensordict.cpu())
108111
collected_frames += current_frames
109112

110-
# optimization steps
113+
# Optimization steps
114+
training_start = time.time()
111115
if collected_frames >= init_random_frames:
112116
(
113117
actor_losses,
114118
q_losses,
115119
) = ([], [])
116-
for j in range(num_updates):
117-
# sample from replay buffer
118-
sampled_tensordict = replay_buffer.sample().clone()
120+
for _ in range(num_updates):
121+
122+
# Update actor every delayed_updates
123+
update_counter += 1
124+
update_actor = update_counter % delayed_updates == 0
119125

120-
loss_td = loss_module(sampled_tensordict)
126+
# Sample from replay buffer
127+
sampled_tensordict = replay_buffer.sample().clone()
121128

122-
actor_loss = loss_td["loss_actor"]
123-
q_loss = loss_td["loss_qvalue"]
129+
# Compute loss
130+
q_loss, *_ = loss_module.value_loss(sampled_tensordict)
124131

132+
# Update critic
125133
optimizer_critic.zero_grad()
126-
update_actor = j % delayed_updates == 0
127-
q_loss.backward(retain_graph=update_actor)
134+
q_loss.backward()
128135
optimizer_critic.step()
129136
q_losses.append(q_loss.item())
130137

138+
# Update actor
131139
if update_actor:
140+
actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)
132141
optimizer_actor.zero_grad()
133142
actor_loss.backward()
134143
optimizer_actor.step()
144+
135145
actor_losses.append(actor_loss.item())
136146

137-
# update qnet_target params
147+
# Update target params
138148
target_net_updater.step()
139149

140-
# update priority
150+
# Update priority
141151
if prb:
142152
replay_buffer.update_priority(sampled_tensordict)
143153

144-
rewards.append(
145-
(i, tensordict["next", "reward"].sum().item() / env_per_collector)
154+
training_time = time.time() - training_start
155+
episode_end = (
156+
tensordict["next", "done"]
157+
if tensordict["next", "done"].any()
158+
else tensordict["next", "truncated"]
146159
)
147-
train_log = {
148-
"train_reward": rewards[-1][1],
149-
"collected_frames": collected_frames,
150-
}
151-
if q_loss is not None:
152-
train_log.update(
153-
{
154-
"actor_loss": np.mean(actor_losses),
155-
"q_loss": np.mean(q_losses),
156-
}
160+
episode_rewards = tensordict["next", "episode_reward"][episode_end]
161+
162+
# Logging
163+
metrics_to_log = {}
164+
if len(episode_rewards) > 0:
165+
episode_length = tensordict["next", "step_count"][episode_end]
166+
metrics_to_log["train/reward"] = episode_rewards.mean().item()
167+
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
168+
episode_length
157169
)
158-
if logger is not None:
159-
for key, value in train_log.items():
160-
logger.log_scalar(key, value, step=collected_frames)
161-
if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip:
170+
171+
if collected_frames >= init_random_frames:
172+
metrics_to_log["train/q_loss"] = np.mean(q_losses)
173+
if update_actor:
174+
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
175+
metrics_to_log["train/sampling_time"] = sampling_time
176+
metrics_to_log["train/training_time"] = training_time
177+
178+
# Evaluation
179+
if abs(collected_frames % eval_iter) < frames_per_batch:
162180
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
181+
eval_start = time.time()
163182
eval_rollout = eval_env.rollout(
164183
eval_rollout_steps,
165184
exploration_policy,
166185
auto_cast_to_device=True,
167186
break_when_any_done=True,
168187
)
188+
eval_time = time.time() - eval_start
169189
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
170-
rewards_eval.append((i, eval_reward))
171-
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
172-
if logger is not None:
173-
logger.log_scalar(
174-
"evaluation_reward", rewards_eval[-1][1], step=collected_frames
175-
)
176-
if len(rewards_eval):
177-
pbar.set_description(
178-
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
179-
)
190+
metrics_to_log["eval/reward"] = eval_reward
191+
metrics_to_log["eval/time"] = eval_time
192+
if logger is not None:
193+
log_metrics(logger, metrics_to_log, collected_frames)
194+
sampling_start = time.time()
180195

181196
collector.shutdown()
197+
end_time = time.time()
198+
execution_time = end_time - start_time
199+
print(f"Training took {execution_time:.2f} seconds to finish")
182200

183201

184202
if __name__ == "__main__":

0 commit comments

Comments
 (0)