Skip to content

Commit 146af04

Browse files
BY571vmoens
andauthored
[Algorithm] Update SAC Example (#1524)
Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 69d5343 commit 146af04

File tree

5 files changed

+197
-124
lines changed

5 files changed

+197
-124
lines changed

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,10 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
118118
collector.total_frames=48 \
119119
collector.init_random_frames=10 \
120120
collector.frames_per_batch=16 \
121-
collector.num_workers=4 \
122121
collector.env_per_collector=2 \
123122
collector.collector_device=cuda:0 \
124-
optimization.batch_size=10 \
125-
optimization.utd_ratio=1 \
123+
optim.batch_size=10 \
124+
optim.utd_ratio=1 \
126125
replay_buffer.size=120 \
127126
env.name=Pendulum-v1 \
128127
network.device=cuda:0 \
@@ -221,17 +220,16 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
221220
collector.total_frames=48 \
222221
collector.init_random_frames=10 \
223222
collector.frames_per_batch=16 \
224-
collector.num_workers=2 \
225223
collector.env_per_collector=1 \
226224
collector.collector_device=cuda:0 \
225+
optim.batch_size=10 \
226+
optim.utd_ratio=1 \
227227
network.device=cuda:0 \
228-
optimization.batch_size=10 \
229-
optimization.utd_ratio=1 \
228+
optim.batch_size=10 \
229+
optim.utd_ratio=1 \
230230
replay_buffer.size=120 \
231231
env.name=Pendulum-v1 \
232232
logger.backend=
233-
# record_video=True \
234-
# record_frames=4 \
235233
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
236234
total_frames=48 \
237235
batch_size=10 \

examples/sac/config.yaml

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,49 @@
1-
# Environment
1+
# environment and task
22
env:
33
name: HalfCheetah-v3
44
task: ""
5-
exp_name: "HalfCheetah-SAC"
6-
library: gym
7-
frame_skip: 1
8-
seed: 1
5+
exp_name: ${env.name}_SAC
6+
library: gymnasium
7+
max_episode_steps: 1000
8+
seed: 42
99

10-
# Collection
10+
# collector
1111
collector:
12-
total_frames: 1000000
13-
init_random_frames: 10000
12+
total_frames: 1_000_000
13+
init_random_frames: 25000
1414
frames_per_batch: 1000
15-
max_frames_per_traj: 1000
1615
init_env_steps: 1000
17-
async_collection: 1
1816
collector_device: cpu
1917
env_per_collector: 1
20-
num_workers: 1
18+
reset_at_each_iter: False
2119

22-
# Replay Buffer
20+
# replay buffer
2321
replay_buffer:
2422
size: 1000000
2523
prb: 0 # use prioritized experience replay
24+
scratch_dir: ${env.exp_name}_${env.seed}
2625

27-
# Optimization
28-
optimization:
26+
# optim
27+
optim:
2928
utd_ratio: 1.0
3029
gamma: 0.99
31-
loss_function: smooth_l1
32-
lr: 3e-4
33-
weight_decay: 2e-4
34-
lr_scheduler: ""
30+
loss_function: l2
31+
lr: 3.0e-4
32+
weight_decay: 0.0
3533
batch_size: 256
3634
target_update_polyak: 0.995
35+
alpha_init: 1.0
36+
adam_eps: 1.0e-8
3737

38-
# Algorithm
38+
# network
3939
network:
4040
hidden_sizes: [256, 256]
4141
activation: relu
4242
default_policy_scale: 1.0
4343
scale_lb: 0.1
4444
device: "cuda:0"
4545

46-
# Logging
46+
# logging
4747
logger:
4848
backend: wandb
4949
mode: online

examples/sac/sac.py

Lines changed: 94 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@
1111
The helper functions are coded in the utils.py associated with this script.
1212
"""
1313

14+
import time
15+
1416
import hydra
1517

1618
import numpy as np
1719
import torch
1820
import torch.cuda
1921
import tqdm
20-
22+
from tensordict import TensorDict
2123
from torchrl.envs.utils import ExplorationType, set_exploration_type
2224

2325
from torchrl.record.loggers import generate_exp_name, get_logger
2426
from utils import (
27+
log_metrics,
2528
make_collector,
2629
make_environment,
2730
make_loss_module,
@@ -35,6 +38,7 @@
3538
def main(cfg: "DictConfig"): # noqa: F821
3639
device = torch.device(cfg.network.device)
3740

41+
# Create logger
3842
exp_name = generate_exp_name("SAC", cfg.env.exp_name)
3943
logger = None
4044
if cfg.logger.backend:
@@ -48,132 +52,158 @@ def main(cfg: "DictConfig"): # noqa: F821
4852
torch.manual_seed(cfg.env.seed)
4953
np.random.seed(cfg.env.seed)
5054

51-
# Create Environments
55+
# Create environments
5256
train_env, eval_env = make_environment(cfg)
53-
# Create Agent
57+
58+
# Create agent
5459
model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device)
5560

56-
# Create TD3 loss
61+
# Create SAC loss
5762
loss_module, target_net_updater = make_loss_module(cfg, model)
5863

59-
# Make Off-Policy Collector
64+
# Create off-policy collector
6065
collector = make_collector(cfg, train_env, exploration_policy)
6166

62-
# Make Replay Buffer
67+
# Create replay buffer
6368
replay_buffer = make_replay_buffer(
64-
batch_size=cfg.optimization.batch_size,
69+
batch_size=cfg.optim.batch_size,
6570
prb=cfg.replay_buffer.prb,
6671
buffer_size=cfg.replay_buffer.size,
72+
buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir,
6773
device=device,
6874
)
6975

70-
# Make Optimizers
71-
optimizer = make_sac_optimizer(cfg, loss_module)
72-
73-
rewards = []
74-
rewards_eval = []
76+
# Create optimizers
77+
(
78+
optimizer_actor,
79+
optimizer_critic,
80+
optimizer_alpha,
81+
) = make_sac_optimizer(cfg, loss_module)
7582

7683
# Main loop
84+
start_time = time.time()
7785
collected_frames = 0
7886
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
79-
r0 = None
80-
q_loss = None
8187

8288
init_random_frames = cfg.collector.init_random_frames
8389
num_updates = int(
8490
cfg.collector.env_per_collector
8591
* cfg.collector.frames_per_batch
86-
* cfg.optimization.utd_ratio
92+
* cfg.optim.utd_ratio
8793
)
8894
prb = cfg.replay_buffer.prb
89-
env_per_collector = cfg.collector.env_per_collector
9095
eval_iter = cfg.logger.eval_iter
91-
frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip
92-
eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip
96+
frames_per_batch = cfg.collector.frames_per_batch
97+
eval_rollout_steps = cfg.env.max_episode_steps
9398

99+
sampling_start = time.time()
94100
for i, tensordict in enumerate(collector):
95-
# update weights of the inference policy
101+
sampling_time = time.time() - sampling_start
102+
103+
# Update weights of the inference policy
96104
collector.update_policy_weights_()
97105

98-
if r0 is None:
99-
r0 = tensordict["next", "reward"].sum(-1).mean().item()
100106
pbar.update(tensordict.numel())
101107

102-
tensordict = tensordict.view(-1)
108+
tensordict = tensordict.reshape(-1)
103109
current_frames = tensordict.numel()
110+
# Add to replay buffer
104111
replay_buffer.extend(tensordict.cpu())
105112
collected_frames += current_frames
106113

107-
# optimization steps
114+
# Optimization steps
115+
training_start = time.time()
108116
if collected_frames >= init_random_frames:
109-
(actor_losses, q_losses, alpha_losses) = ([], [], [])
110-
for _ in range(num_updates):
111-
# sample from replay buffer
117+
losses = TensorDict(
118+
{},
119+
batch_size=[
120+
num_updates,
121+
],
122+
)
123+
for i in range(num_updates):
124+
# Sample from replay buffer
112125
sampled_tensordict = replay_buffer.sample().clone()
113126

127+
# Compute loss
114128
loss_td = loss_module(sampled_tensordict)
115129

116130
actor_loss = loss_td["loss_actor"]
117131
q_loss = loss_td["loss_qvalue"]
118132
alpha_loss = loss_td["loss_alpha"]
119-
loss = actor_loss + q_loss + alpha_loss
120133

121-
optimizer.zero_grad()
122-
loss.backward()
123-
optimizer.step()
134+
# Update actor
135+
optimizer_actor.zero_grad()
136+
actor_loss.backward()
137+
optimizer_actor.step()
124138

125-
q_losses.append(q_loss.item())
126-
actor_losses.append(actor_loss.item())
127-
alpha_losses.append(alpha_loss.item())
139+
# Update critic
140+
optimizer_critic.zero_grad()
141+
q_loss.backward()
142+
optimizer_critic.step()
128143

129-
# update qnet_target params
144+
# Update alpha
145+
optimizer_alpha.zero_grad()
146+
alpha_loss.backward()
147+
optimizer_alpha.step()
148+
149+
losses[i] = loss_td.select(
150+
"loss_actor", "loss_qvalue", "loss_alpha"
151+
).detach()
152+
153+
# Update qnet_target params
130154
target_net_updater.step()
131155

132-
# update priority
156+
# Update priority
133157
if prb:
134158
replay_buffer.update_priority(sampled_tensordict)
135159

136-
rewards.append(
137-
(i, tensordict["next", "reward"].sum().item() / env_per_collector)
160+
training_time = time.time() - training_start
161+
episode_end = (
162+
tensordict["next", "done"]
163+
if tensordict["next", "done"].any()
164+
else tensordict["next", "truncated"]
138165
)
139-
train_log = {
140-
"train_reward": rewards[-1][1],
141-
"collected_frames": collected_frames,
142-
}
143-
if q_loss is not None:
144-
train_log.update(
145-
{
146-
"actor_loss": np.mean(actor_losses),
147-
"q_loss": np.mean(q_losses),
148-
"alpha_loss": np.mean(alpha_losses),
149-
"alpha": loss_td["alpha"],
150-
"entropy": loss_td["entropy"],
151-
}
166+
episode_rewards = tensordict["next", "episode_reward"][episode_end]
167+
168+
# Logging
169+
metrics_to_log = {}
170+
if len(episode_rewards) > 0:
171+
episode_length = tensordict["next", "step_count"][episode_end]
172+
metrics_to_log["train/reward"] = episode_rewards.mean().item()
173+
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
174+
episode_length
152175
)
153-
if logger is not None:
154-
for key, value in train_log.items():
155-
logger.log_scalar(key, value, step=collected_frames)
156-
if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip:
176+
if collected_frames >= init_random_frames:
177+
metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item()
178+
metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item()
179+
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item()
180+
metrics_to_log["train/alpha"] = loss_td["alpha"].item()
181+
metrics_to_log["train/entropy"] = loss_td["entropy"].item()
182+
metrics_to_log["train/sampling_time"] = sampling_time
183+
metrics_to_log["train/training_time"] = training_time
184+
185+
# Evaluation
186+
if abs(collected_frames % eval_iter) < frames_per_batch:
157187
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
188+
eval_start = time.time()
158189
eval_rollout = eval_env.rollout(
159190
eval_rollout_steps,
160191
model[0],
161192
auto_cast_to_device=True,
162193
break_when_any_done=True,
163194
)
195+
eval_time = time.time() - eval_start
164196
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
165-
rewards_eval.append((i, eval_reward))
166-
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
167-
if logger is not None:
168-
logger.log_scalar(
169-
"evaluation_reward", rewards_eval[-1][1], step=collected_frames
170-
)
171-
if len(rewards_eval):
172-
pbar.set_description(
173-
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
174-
)
197+
metrics_to_log["eval/reward"] = eval_reward
198+
metrics_to_log["eval/time"] = eval_time
199+
if logger is not None:
200+
log_metrics(logger, metrics_to_log, collected_frames)
201+
sampling_start = time.time()
175202

176203
collector.shutdown()
204+
end_time = time.time()
205+
execution_time = end_time - start_time
206+
print(f"Training took {execution_time:.2f} seconds to finish")
177207

178208

179209
if __name__ == "__main__":

0 commit comments

Comments
 (0)