Skip to content

Commit 87a59fb

Browse files
author
Vincent Moens
committed
[Feature] SAC compatibility with compile
ghstack-source-id: b57caea Pull Request resolved: #2655
1 parent 526b38d commit 87a59fb

File tree

3 files changed

+128
-100
lines changed

3 files changed

+128
-100
lines changed

sota-implementations/sac/config.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ collector:
2020
replay_buffer:
2121
size: 1000000
2222
prb: 0 # use prioritized experience replay
23-
scratch_dir: null
23+
scratch_dir:
2424

2525
# optim
2626
optim:
@@ -51,3 +51,8 @@ logger:
5151
mode: online
5252
eval_iter: 25000
5353
video: False
54+
55+
compile:
56+
compile: False
57+
compile_mode:
58+
cudagraphs: False

sota-implementations/sac/sac.py

Lines changed: 91 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313
from __future__ import annotations
1414

15-
import time
15+
import warnings
1616

1717
import hydra
1818

@@ -21,8 +21,11 @@
2121
import torch.cuda
2222
import tqdm
2323
from tensordict import TensorDict
24-
from torchrl._utils import logger as torchrl_logger
24+
from tensordict.nn import CudaGraphModule
25+
26+
from torchrl._utils import compile_with_warmup, timeit
2527
from torchrl.envs.utils import ExplorationType, set_exploration_type
28+
from torchrl.objectives import group_optimizers
2629

2730
from torchrl.record.loggers import generate_exp_name, get_logger
2831
from utils import (
@@ -36,6 +39,8 @@
3639
make_sac_optimizer,
3740
)
3841

42+
torch.set_float32_matmul_precision("high")
43+
3944

4045
@hydra.main(version_base="1.1", config_path="", config_name="config")
4146
def main(cfg: "DictConfig"): # noqa: F821
@@ -75,16 +80,27 @@ def main(cfg: "DictConfig"): # noqa: F821
7580
# Create SAC loss
7681
loss_module, target_net_updater = make_loss_module(cfg, model)
7782

83+
compile_mode = None
84+
if cfg.compile.compile:
85+
compile_mode = cfg.compile.compile_mode
86+
if compile_mode in ("", None):
87+
if cfg.compile.cudagraphs:
88+
compile_mode = "default"
89+
else:
90+
compile_mode = "reduce-overhead"
91+
7892
# Create off-policy collector
79-
collector = make_collector(cfg, train_env, exploration_policy)
93+
collector = make_collector(
94+
cfg, train_env, exploration_policy, compile_mode=compile_mode
95+
)
8096

8197
# Create replay buffer
8298
replay_buffer = make_replay_buffer(
8399
batch_size=cfg.optim.batch_size,
84100
prb=cfg.replay_buffer.prb,
85101
buffer_size=cfg.replay_buffer.size,
86102
scratch_dir=cfg.replay_buffer.scratch_dir,
87-
device="cpu",
103+
device=device,
88104
)
89105

90106
# Create optimizers
@@ -93,9 +109,36 @@ def main(cfg: "DictConfig"): # noqa: F821
93109
optimizer_critic,
94110
optimizer_alpha,
95111
) = make_sac_optimizer(cfg, loss_module)
112+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
113+
del optimizer_actor, optimizer_critic, optimizer_alpha
114+
115+
def update(sampled_tensordict):
116+
# Compute loss
117+
loss_td = loss_module(sampled_tensordict)
118+
119+
actor_loss = loss_td["loss_actor"]
120+
q_loss = loss_td["loss_qvalue"]
121+
alpha_loss = loss_td["loss_alpha"]
122+
123+
(actor_loss + q_loss + alpha_loss).sum().backward()
124+
optimizer.step()
125+
optimizer.zero_grad(set_to_none=True)
126+
127+
# Update qnet_target params
128+
target_net_updater.step()
129+
return loss_td.detach()
130+
131+
if cfg.compile.compile:
132+
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
133+
134+
if cfg.compile.cudagraphs:
135+
warnings.warn(
136+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
137+
category=UserWarning,
138+
)
139+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
96140

97141
# Main loop
98-
start_time = time.time()
99142
collected_frames = 0
100143
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
101144

@@ -110,69 +153,48 @@ def main(cfg: "DictConfig"): # noqa: F821
110153
frames_per_batch = cfg.collector.frames_per_batch
111154
eval_rollout_steps = cfg.env.max_episode_steps
112155

113-
sampling_start = time.time()
114-
for i, tensordict in enumerate(collector):
115-
sampling_time = time.time() - sampling_start
156+
collector_iter = iter(collector)
157+
total_iter = len(collector)
158+
159+
for i in range(total_iter):
160+
timeit.printevery(num_prints=1000, total_count=total_iter, erase=True)
161+
162+
with timeit("collect"):
163+
tensordict = next(collector_iter)
116164

117165
# Update weights of the inference policy
118166
collector.update_policy_weights_()
119167

120-
pbar.update(tensordict.numel())
121-
122-
tensordict = tensordict.reshape(-1)
123168
current_frames = tensordict.numel()
124-
# Add to replay buffer
125-
replay_buffer.extend(tensordict.cpu())
169+
pbar.update(current_frames)
170+
171+
with timeit("rb - extend"):
172+
# Add to replay buffer
173+
tensordict = tensordict.reshape(-1)
174+
replay_buffer.extend(tensordict)
175+
126176
collected_frames += current_frames
127177

128178
# Optimization steps
129-
training_start = time.time()
130-
if collected_frames >= init_random_frames:
131-
losses = TensorDict(batch_size=[num_updates])
132-
for i in range(num_updates):
133-
# Sample from replay buffer
134-
sampled_tensordict = replay_buffer.sample()
135-
if sampled_tensordict.device != device:
136-
sampled_tensordict = sampled_tensordict.to(
137-
device, non_blocking=True
179+
with timeit("train"):
180+
if collected_frames >= init_random_frames:
181+
losses = TensorDict(batch_size=[num_updates])
182+
for i in range(num_updates):
183+
with timeit("rb - sample"):
184+
# Sample from replay buffer
185+
sampled_tensordict = replay_buffer.sample()
186+
187+
with timeit("update"):
188+
torch.compiler.cudagraph_mark_step_begin()
189+
loss_td = update(sampled_tensordict).clone()
190+
losses[i] = loss_td.select(
191+
"loss_actor", "loss_qvalue", "loss_alpha"
138192
)
139-
else:
140-
sampled_tensordict = sampled_tensordict.clone()
141-
142-
# Compute loss
143-
loss_td = loss_module(sampled_tensordict)
144-
145-
actor_loss = loss_td["loss_actor"]
146-
q_loss = loss_td["loss_qvalue"]
147-
alpha_loss = loss_td["loss_alpha"]
148-
149-
# Update actor
150-
optimizer_actor.zero_grad()
151-
actor_loss.backward()
152-
optimizer_actor.step()
153-
154-
# Update critic
155-
optimizer_critic.zero_grad()
156-
q_loss.backward()
157-
optimizer_critic.step()
158-
159-
# Update alpha
160-
optimizer_alpha.zero_grad()
161-
alpha_loss.backward()
162-
optimizer_alpha.step()
163-
164-
losses[i] = loss_td.select(
165-
"loss_actor", "loss_qvalue", "loss_alpha"
166-
).detach()
167-
168-
# Update qnet_target params
169-
target_net_updater.step()
170193

171-
# Update priority
172-
if prb:
173-
replay_buffer.update_priority(sampled_tensordict)
194+
# Update priority
195+
if prb:
196+
replay_buffer.update_priority(sampled_tensordict)
174197

175-
training_time = time.time() - training_start
176198
episode_end = (
177199
tensordict["next", "done"]
178200
if tensordict["next", "done"].any()
@@ -184,46 +206,41 @@ def main(cfg: "DictConfig"): # noqa: F821
184206
metrics_to_log = {}
185207
if len(episode_rewards) > 0:
186208
episode_length = tensordict["next", "step_count"][episode_end]
187-
metrics_to_log["train/reward"] = episode_rewards.mean().item()
188-
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
209+
metrics_to_log["train/reward"] = episode_rewards
210+
metrics_to_log["train/episode_length"] = episode_length.sum() / len(
189211
episode_length
190212
)
191213
if collected_frames >= init_random_frames:
192-
metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item()
193-
metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item()
194-
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item()
195-
metrics_to_log["train/alpha"] = loss_td["alpha"].item()
196-
metrics_to_log["train/entropy"] = loss_td["entropy"].item()
197-
metrics_to_log["train/sampling_time"] = sampling_time
198-
metrics_to_log["train/training_time"] = training_time
214+
losses = losses.mean()
215+
metrics_to_log["train/q_loss"] = losses.get("loss_qvalue")
216+
metrics_to_log["train/actor_loss"] = losses.get("loss_actor")
217+
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha")
218+
metrics_to_log["train/alpha"] = loss_td["alpha"]
219+
metrics_to_log["train/entropy"] = loss_td["entropy"]
199220

200221
# Evaluation
201222
if abs(collected_frames % eval_iter) < frames_per_batch:
202-
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
203-
eval_start = time.time()
223+
with set_exploration_type(
224+
ExplorationType.DETERMINISTIC
225+
), torch.no_grad(), timeit("eval"):
204226
eval_rollout = eval_env.rollout(
205227
eval_rollout_steps,
206228
model[0],
207229
auto_cast_to_device=True,
208230
break_when_any_done=True,
209231
)
210232
eval_env.apply(dump_video)
211-
eval_time = time.time() - eval_start
212233
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
213234
metrics_to_log["eval/reward"] = eval_reward
214-
metrics_to_log["eval/time"] = eval_time
215235
if logger is not None:
236+
metrics_to_log.update(timeit.todict(prefix="time"))
216237
log_metrics(logger, metrics_to_log, collected_frames)
217-
sampling_start = time.time()
218238

219239
collector.shutdown()
220240
if not eval_env.is_closed:
221241
eval_env.close()
222242
if not train_env.is_closed:
223243
train_env.close()
224-
end_time = time.time()
225-
execution_time = end_time - start_time
226-
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
227244

228245

229246
if __name__ == "__main__":

0 commit comments

Comments
 (0)