Skip to content

Commit cb06ea3

Browse files
author
Vincent Moens
committed
[Algorithm] Async SAC
ghstack-source-id: 84d845d Pull-Request-resolved: #2946
1 parent ccc31b5 commit cb06ea3

File tree

8 files changed

+558
-67
lines changed

8 files changed

+558
-67
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# environment and task
2+
env:
3+
name: HalfCheetah-v4
4+
task: ""
5+
library: gymnasium
6+
max_episode_steps: 1000
7+
seed: 42
8+
9+
# collector
10+
collector:
11+
total_frames: -1
12+
init_random_frames: 25000
13+
frames_per_batch: 8000
14+
init_env_steps: 1000
15+
device: cuda:1
16+
env_per_collector: 16
17+
reset_at_each_iter: False
18+
update_freq: 10_000
19+
20+
# replay buffer
21+
replay_buffer:
22+
size: 100_000 # Small buffer size to keep only recent elements
23+
prb: 0 # use prioritized experience replay
24+
scratch_dir:
25+
26+
# optim
27+
optim:
28+
utd_ratio: 1.0
29+
gamma: 0.99
30+
loss_function: l2
31+
lr: 3.0e-4
32+
weight_decay: 0.0
33+
batch_size: 256
34+
target_update_polyak: 0.995
35+
alpha_init: 1.0
36+
adam_eps: 1.0e-8
37+
38+
# network
39+
network:
40+
hidden_sizes: [256, 256]
41+
activation: relu
42+
default_policy_scale: 1.0
43+
scale_lb: 0.1
44+
device:
45+
46+
# logging
47+
logger:
48+
backend: wandb
49+
project_name: torchrl_example_sac
50+
group_name: null
51+
exp_name: ${env.name}_SAC
52+
mode: online
53+
log_freq: 25000 # logging freq in updates
54+
video: False
55+
56+
compile:
57+
compile: False
58+
compile_mode:
59+
cudagraphs: False

sota-implementations/sac/sac-async.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Async SAC Example.
6+
7+
WARNING: This isn't a SOTA implementation but a rudimentary implementation of SAC where inference
8+
and training are entirely decoupled. It can achieve a 20x speedup if compile and cudagraph are used.
9+
Two GPUs are required for this script to run.
10+
The API is currently being perfected, and contributions are welcome (as usual!) - see the TODOs in this script.
11+
12+
This is a simple self-contained example of a SAC training script.
13+
14+
It supports state environments like MuJoCo.
15+
16+
The helper functions are coded in the utils.py associated with this script.
17+
"""
18+
from __future__ import annotations
19+
20+
import time
21+
22+
import warnings
23+
from functools import partial
24+
25+
import hydra
26+
import numpy as np
27+
import tensordict
28+
import torch
29+
import torch.cuda
30+
import tqdm
31+
from tensordict import TensorDict
32+
from tensordict.nn import CudaGraphModule
33+
from torchrl._utils import compile_with_warmup, logger as torchrl_logger, timeit
34+
from torchrl.envs.utils import ExplorationType, set_exploration_type
35+
from torchrl.objectives import group_optimizers
36+
from torchrl.record.loggers import generate_exp_name, get_logger
37+
from utils import (
38+
dump_video,
39+
log_metrics,
40+
make_collector_async,
41+
make_environment,
42+
make_loss_module,
43+
make_replay_buffer,
44+
make_sac_agent,
45+
make_sac_optimizer,
46+
make_train_environment,
47+
)
48+
49+
torch.set_float32_matmul_precision("high")
50+
tensordict.nn.functional_modules._exclude_td_from_pytree().set()
51+
52+
53+
@hydra.main(version_base="1.1", config_path="", config_name="config-async")
54+
def main(cfg: DictConfig): # noqa: F821
55+
device = cfg.network.device
56+
if device in ("", None):
57+
if torch.cuda.is_available():
58+
device = torch.device("cuda:0")
59+
else:
60+
device = torch.device("cpu")
61+
device = torch.device(device)
62+
63+
# Create logger
64+
exp_name = generate_exp_name("SAC", cfg.logger.exp_name)
65+
logger = None
66+
if cfg.logger.backend:
67+
logger = get_logger(
68+
logger_type=cfg.logger.backend,
69+
logger_name="async_sac_logging",
70+
experiment_name=exp_name,
71+
wandb_kwargs={
72+
"mode": cfg.logger.mode,
73+
"config": dict(cfg),
74+
"project": cfg.logger.project_name,
75+
"group": cfg.logger.group_name,
76+
},
77+
)
78+
79+
torch.manual_seed(cfg.env.seed)
80+
np.random.seed(cfg.env.seed)
81+
82+
# Create environments
83+
_, eval_env = make_environment(cfg, logger=logger)
84+
85+
# TODO: This should be simplified. We need to create the policy on cuda:1 directly because of the bounds
86+
# of the TanhDistribution which cannot be sent to cuda:1 within the distribution construction (ie, the
87+
# distribution kwargs need to have access to the low / high values on the right device for compile and
88+
# cudagraph to work).
89+
# Create agent
90+
dummy_train_env = make_train_environment(cfg)
91+
model, _ = make_sac_agent(cfg, dummy_train_env, eval_env, device)
92+
_, exploration_policy = make_sac_agent(cfg, dummy_train_env, eval_env, "cuda:1")
93+
dummy_train_env.close(raise_if_closed=False)
94+
del dummy_train_env
95+
exploration_policy.load_state_dict(model[0].state_dict())
96+
97+
# Create SAC loss
98+
loss_module, target_net_updater = make_loss_module(cfg, model)
99+
100+
compile_mode = None
101+
if cfg.compile.compile:
102+
compile_mode = cfg.compile.compile_mode
103+
if compile_mode in ("", None):
104+
if cfg.compile.cudagraphs:
105+
compile_mode = "default"
106+
else:
107+
compile_mode = "reduce-overhead"
108+
compile_mode_collector = compile_mode # "reduce-overhead"
109+
110+
# TODO: enabling prefetch for mp RBs would speed up sampling which is currently responsible for
111+
# half of the compute time on the trainer side.
112+
# Create replay buffer
113+
replay_buffer = make_replay_buffer(
114+
batch_size=cfg.optim.batch_size,
115+
prb=cfg.replay_buffer.prb,
116+
buffer_size=cfg.replay_buffer.size,
117+
scratch_dir=cfg.replay_buffer.scratch_dir,
118+
device=device,
119+
shared=True,
120+
prefetch=0,
121+
)
122+
123+
# TODO: Simplify this - ideally we'd like to share the uninitialized lazy tensor storage and fetch it once
124+
# it's initialized
125+
replay_buffer.extend(make_train_environment(cfg).rollout(1).view(-1))
126+
replay_buffer.empty()
127+
128+
# Create off-policy collector and start it
129+
collector = make_collector_async(
130+
cfg,
131+
partial(make_train_environment, cfg),
132+
exploration_policy,
133+
compile_mode=compile_mode_collector,
134+
replay_buffer=replay_buffer,
135+
)
136+
137+
# Create optimizers
138+
(
139+
optimizer_actor,
140+
optimizer_critic,
141+
optimizer_alpha,
142+
) = make_sac_optimizer(cfg, loss_module)
143+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
144+
del optimizer_actor, optimizer_critic, optimizer_alpha
145+
146+
def update(sampled_tensordict):
147+
# Compute loss
148+
loss_td = loss_module(sampled_tensordict)
149+
150+
actor_loss = loss_td["loss_actor"]
151+
q_loss = loss_td["loss_qvalue"]
152+
alpha_loss = loss_td["loss_alpha"]
153+
154+
(actor_loss + q_loss + alpha_loss).sum().backward()
155+
optimizer.step()
156+
157+
# Update qnet_target params
158+
target_net_updater.step()
159+
160+
optimizer.zero_grad(set_to_none=True)
161+
return loss_td.detach()
162+
163+
if cfg.compile.compile:
164+
update = compile_with_warmup(update, mode=compile_mode, warmup=2)
165+
166+
cfg.compile.cudagraphs
167+
if cfg.compile.cudagraphs:
168+
warnings.warn(
169+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
170+
category=UserWarning,
171+
)
172+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10)
173+
174+
# Main loop
175+
init_random_frames = cfg.collector.init_random_frames
176+
177+
prb = cfg.replay_buffer.prb
178+
update_freq = cfg.collector.update_freq
179+
180+
eval_rollout_steps = cfg.env.max_episode_steps
181+
log_freq = cfg.logger.log_freq
182+
183+
# TODO: customize this
184+
num_updates = 1000
185+
total_iter = 1000
186+
pbar = tqdm.tqdm(total=total_iter * num_updates)
187+
params = TensorDict.from_module(model[0]).data
188+
189+
# Wait till we have enough data to start training
190+
while replay_buffer.write_count <= init_random_frames:
191+
time.sleep(0.01)
192+
193+
losses = []
194+
for i in range(total_iter * num_updates):
195+
timeit.printevery(
196+
num_prints=total_iter * num_updates // log_freq,
197+
total_count=total_iter * num_updates,
198+
erase=True,
199+
)
200+
201+
if (i % update_freq) == 0:
202+
# Update weights of the inference policy
203+
torchrl_logger.info("Updating weights")
204+
collector.update_policy_weights_(params)
205+
206+
pbar.update(1)
207+
208+
# Optimization steps
209+
with timeit("train"):
210+
with timeit("train - rb - sample"):
211+
# Sample from replay buffer
212+
sampled_tensordict = replay_buffer.sample()
213+
214+
with timeit("train - update"):
215+
torch.compiler.cudagraph_mark_step_begin()
216+
loss_td = update(sampled_tensordict).clone()
217+
losses.append(loss_td.select("loss_actor", "loss_qvalue", "loss_alpha"))
218+
219+
# Update priority
220+
if prb:
221+
replay_buffer.update_priority(sampled_tensordict)
222+
223+
# Logging
224+
if (i % log_freq) == (log_freq - 1):
225+
torchrl_logger.info("Logging")
226+
collected_frames = replay_buffer.write_count
227+
metrics_to_log = {}
228+
if collected_frames >= init_random_frames:
229+
losses_m = torch.stack(losses).mean()
230+
losses = []
231+
metrics_to_log["train/q_loss"] = losses_m.get("loss_qvalue")
232+
metrics_to_log["train/actor_loss"] = losses_m.get("loss_actor")
233+
metrics_to_log["train/alpha_loss"] = losses_m.get("loss_alpha")
234+
metrics_to_log["train/alpha"] = loss_td["alpha"]
235+
metrics_to_log["train/entropy"] = loss_td["entropy"]
236+
metrics_to_log["train/collected_frames"] = int(collected_frames)
237+
238+
# Evaluation
239+
with set_exploration_type(
240+
ExplorationType.DETERMINISTIC
241+
), torch.no_grad(), timeit("eval"):
242+
eval_rollout = eval_env.rollout(
243+
eval_rollout_steps,
244+
model[0],
245+
auto_cast_to_device=True,
246+
break_when_any_done=True,
247+
)
248+
eval_env.apply(dump_video)
249+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
250+
metrics_to_log["eval/reward"] = eval_reward
251+
torchrl_logger.info(f"Logs: {metrics_to_log}")
252+
if logger is not None:
253+
metrics_to_log.update(timeit.todict(prefix="time"))
254+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
255+
log_metrics(logger, metrics_to_log, collected_frames)
256+
257+
collector.shutdown()
258+
if not eval_env.is_closed:
259+
eval_env.close()
260+
261+
262+
if __name__ == "__main__":
263+
main()

0 commit comments

Comments
 (0)