Skip to content

Commit 9e2d214

Browse files
author
Vincent Moens
committed
[Feature] Discrete SAC compatibility with compile
ghstack-source-id: ddc131a Pull Request resolved: #2569
1 parent fbfe104 commit 9e2d214

File tree

13 files changed

+127
-104
lines changed

13 files changed

+127
-104
lines changed

sota-implementations/a2c/utils_atari.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ def make_ppo_modules_pixels(proof_environment, device):
9393
input_shape = proof_environment.observation_spec["pixels"].shape
9494

9595
# Define distribution class and kwargs
96-
if isinstance(proof_environment.single_action_spec.space, CategoricalBox):
97-
num_outputs = proof_environment.single_action_spec.space.n
96+
if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox):
97+
num_outputs = proof_environment.action_spec_unbatched.space.n
9898
distribution_class = OneHotCategorical
9999
distribution_kwargs = {}
100100
else: # is ContinuousBox
101-
num_outputs = proof_environment.single_action_spec.shape
101+
num_outputs = proof_environment.action_spec_unbatched.shape
102102
distribution_class = TanhNormal
103103
distribution_kwargs = {
104104
"low": proof_environment.action_spec_unbatched.space.low.to(device),

sota-implementations/a2c/utils_mujoco.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
5454
input_shape = proof_environment.observation_spec["observation"].shape
5555

5656
# Define policy output distribution class
57-
num_outputs = proof_environment.single_action_spec.shape[-1]
57+
num_outputs = proof_environment.action_spec_unbatched.shape[-1]
5858
distribution_class = TanhNormal
5959
distribution_kwargs = {
6060
"low": proof_environment.action_spec_unbatched.space.low.to(device),
@@ -82,7 +82,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
8282
policy_mlp = torch.nn.Sequential(
8383
policy_mlp,
8484
AddStateIndependentNormalScale(
85-
proof_environment.single_action_spec.shape[-1], device=device
85+
proof_environment.action_spec_unbatched.shape[-1], device=device
8686
),
8787
)
8888

sota-implementations/cql/cql_online.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def update(sampled_tensordict):
172172
c_iter = iter(collector)
173173
for i in range(len(collector)):
174174
with timeit("collecting"):
175-
torch.compiler.cudagraph_mark_step_begin()
176175
tensordict = next(c_iter)
177176
pbar.update(tensordict.numel())
178177
# update weights of the inference policy

sota-implementations/cql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"):
298298

299299

300300
def make_cql_modules_state(model_cfg, proof_environment):
301-
action_spec = proof_environment.single_action_spec
301+
action_spec = proof_environment.action_spec_unbatched
302302

303303
actor_net_kwargs = {
304304
"num_cells": model_cfg.hidden_sizes,

sota-implementations/discrete_sac/config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ network:
4444
activation: relu
4545
device: null
4646

47+
compile:
48+
compile: False
49+
compile_mode:
50+
cudagraphs: False
51+
4752
# logging
4853
logger:
4954
backend: wandb

sota-implementations/discrete_sac/discrete_sac.py

Lines changed: 90 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@
1010
1111
The helper functions are coded in the utils.py associated with this script.
1212
"""
13+
1314
from __future__ import annotations
1415

15-
import time
16+
import warnings
1617

1718
import hydra
1819
import numpy as np
1920
import torch
2021
import torch.cuda
2122
import tqdm
22-
from torchrl._utils import logger as torchrl_logger
23-
23+
from tensordict.nn import CudaGraphModule
24+
from torchrl._utils import timeit
2425
from torchrl.envs.utils import ExplorationType, set_exploration_type
25-
26+
from torchrl.objectives import group_optimizers
2627
from torchrl.record.loggers import generate_exp_name, get_logger
2728
from utils import (
2829
dump_video,
@@ -75,9 +76,6 @@ def main(cfg: "DictConfig"): # noqa: F821
7576
# Create TD3 loss
7677
loss_module, target_net_updater = make_loss_module(cfg, model)
7778

78-
# Create off-policy collector
79-
collector = make_collector(cfg, train_env, model[0])
80-
8179
# Create replay buffer
8280
replay_buffer = make_replay_buffer(
8381
batch_size=cfg.optim.batch_size,
@@ -91,9 +89,57 @@ def main(cfg: "DictConfig"): # noqa: F821
9189
optimizer_actor, optimizer_critic, optimizer_alpha = make_optimizer(
9290
cfg, loss_module
9391
)
92+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
93+
del optimizer_actor, optimizer_critic, optimizer_alpha
94+
95+
def update(sampled_tensordict):
96+
optimizer.zero_grad(set_to_none=True)
97+
98+
# Compute loss
99+
loss_out = loss_module(sampled_tensordict)
100+
101+
actor_loss, q_loss, alpha_loss = (
102+
loss_out["loss_actor"],
103+
loss_out["loss_qvalue"],
104+
loss_out["loss_alpha"],
105+
)
106+
107+
# Update critic
108+
(q_loss + actor_loss + alpha_loss).backward()
109+
optimizer.step()
110+
111+
# Update target params
112+
target_net_updater.step()
113+
114+
return loss_out.detach()
115+
116+
compile_mode = None
117+
if cfg.compile.compile:
118+
compile_mode = cfg.compile.compile_mode
119+
if compile_mode in ("", None):
120+
if cfg.compile.cudagraphs:
121+
compile_mode = "default"
122+
else:
123+
compile_mode = "reduce-overhead"
124+
update = torch.compile(update, mode=compile_mode)
125+
if cfg.compile.cudagraphs:
126+
warnings.warn(
127+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
128+
category=UserWarning,
129+
)
130+
update = CudaGraphModule(update, warmup=50)
131+
132+
# Create off-policy collector
133+
collector = make_collector(
134+
cfg,
135+
train_env,
136+
model[0],
137+
compile=compile_mode is not None,
138+
compile_mode=compile_mode,
139+
cudagraphs=cfg.compile.cudagraphs,
140+
)
94141

95142
# Main loop
96-
start_time = time.time()
97143
collected_frames = 0
98144
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
99145

@@ -108,129 +154,93 @@ def main(cfg: "DictConfig"): # noqa: F821
108154
eval_iter = cfg.logger.eval_iter
109155
frames_per_batch = cfg.collector.frames_per_batch
110156

111-
sampling_start = time.time()
112-
for i, tensordict in enumerate(collector):
113-
sampling_time = time.time() - sampling_start
157+
c_iter = iter(collector)
158+
for i in range(len(collector)):
159+
with timeit("collecting"):
160+
collected_data = next(c_iter)
114161

115162
# Update weights of the inference policy
116163
collector.update_policy_weights_()
164+
current_frames = collected_data.numel()
117165

118-
pbar.update(tensordict.numel())
166+
pbar.update(current_frames)
119167

120-
tensordict = tensordict.reshape(-1)
121-
current_frames = tensordict.numel()
122-
# Add to replay buffer
123-
replay_buffer.extend(tensordict.cpu())
168+
collected_data = collected_data.reshape(-1)
169+
with timeit("rb - extend"):
170+
# Add to replay buffer
171+
replay_buffer.extend(collected_data)
124172
collected_frames += current_frames
125173

126174
# Optimization steps
127-
training_start = time.time()
128175
if collected_frames >= init_random_frames:
129-
(
130-
actor_losses,
131-
q_losses,
132-
alpha_losses,
133-
) = ([], [], [])
176+
tds = []
134177
for _ in range(num_updates):
135-
# Sample from replay buffer
136-
sampled_tensordict = replay_buffer.sample()
137-
if sampled_tensordict.device != device:
138-
sampled_tensordict = sampled_tensordict.to(
139-
device, non_blocking=True
140-
)
141-
else:
142-
sampled_tensordict = sampled_tensordict.clone()
143-
144-
# Compute loss
145-
loss_out = loss_module(sampled_tensordict)
146-
147-
actor_loss, q_loss, alpha_loss = (
148-
loss_out["loss_actor"],
149-
loss_out["loss_qvalue"],
150-
loss_out["loss_alpha"],
151-
)
152-
153-
# Update critic
154-
optimizer_critic.zero_grad()
155-
q_loss.backward()
156-
optimizer_critic.step()
157-
q_losses.append(q_loss.item())
178+
with timeit("rb - sample"):
179+
# Sample from replay buffer
180+
sampled_tensordict = replay_buffer.sample()
158181

159-
# Update actor
160-
optimizer_actor.zero_grad()
161-
actor_loss.backward()
162-
optimizer_actor.step()
182+
with timeit("update"):
183+
torch.compiler.cudagraph_mark_step_begin()
184+
sampled_tensordict = sampled_tensordict.to(device)
185+
loss_out = update(sampled_tensordict).clone()
163186

164-
actor_losses.append(actor_loss.item())
165-
166-
# Update alpha
167-
optimizer_alpha.zero_grad()
168-
alpha_loss.backward()
169-
optimizer_alpha.step()
170-
171-
alpha_losses.append(alpha_loss.item())
172-
173-
# Update target params
174-
target_net_updater.step()
187+
tds.append(loss_out)
175188

176189
# Update priority
177190
if prb:
178191
replay_buffer.update_priority(sampled_tensordict)
192+
tds = torch.stack(tds).mean()
179193

180-
training_time = time.time() - training_start
194+
# Logging
181195
episode_end = (
182-
tensordict["next", "done"]
183-
if tensordict["next", "done"].any()
184-
else tensordict["next", "truncated"]
196+
collected_data["next", "done"]
197+
if collected_data["next", "done"].any()
198+
else collected_data["next", "truncated"]
185199
)
186-
episode_rewards = tensordict["next", "episode_reward"][episode_end]
200+
episode_rewards = collected_data["next", "episode_reward"][episode_end]
187201

188-
# Logging
189202
metrics_to_log = {}
190203
if len(episode_rewards) > 0:
191-
episode_length = tensordict["next", "step_count"][episode_end]
204+
episode_length = collected_data["next", "step_count"][episode_end]
192205
metrics_to_log["train/reward"] = episode_rewards.mean().item()
193206
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
194207
episode_length
195208
)
196209

197210
if collected_frames >= init_random_frames:
198-
metrics_to_log["train/q_loss"] = np.mean(q_losses)
199-
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
200-
metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses)
201-
metrics_to_log["train/sampling_time"] = sampling_time
202-
metrics_to_log["train/training_time"] = training_time
211+
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
212+
metrics_to_log["train/a_loss"] = tds["loss_actor"]
213+
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]
203214

204215
# Evaluation
205216
prev_test_frame = ((i - 1) * frames_per_batch) // eval_iter
206217
cur_test_frame = (i * frames_per_batch) // eval_iter
207218
final = current_frames >= collector.total_frames
208219
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
209-
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
210-
eval_start = time.time()
220+
with set_exploration_type(
221+
ExplorationType.DETERMINISTIC
222+
), torch.no_grad(), timeit("eval"):
211223
eval_rollout = eval_env.rollout(
212224
eval_rollout_steps,
213225
model[0],
214226
auto_cast_to_device=True,
215227
break_when_any_done=True,
216228
)
217229
eval_env.apply(dump_video)
218-
eval_time = time.time() - eval_start
219230
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
220231
metrics_to_log["eval/reward"] = eval_reward
221-
metrics_to_log["eval/time"] = eval_time
232+
if i % 50 == 0:
233+
metrics_to_log.update(timeit.todict(prefix="time"))
234+
timeit.print()
235+
timeit.erase()
222236
if logger is not None:
223237
log_metrics(logger, metrics_to_log, collected_frames)
224-
sampling_start = time.time()
225238

226239
collector.shutdown()
227240
if not eval_env.is_closed:
228241
eval_env.close()
229242
if not train_env.is_closed:
230243
train_env.close()
231-
end_time = time.time()
232-
execution_time = end_time - start_time
233-
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
234244

235245

236246
if __name__ == "__main__":

sota-implementations/discrete_sac/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,14 @@ def make_environment(cfg, logger=None):
113113
# ---------------------------
114114

115115

116-
def make_collector(cfg, train_env, actor_model_explore):
116+
def make_collector(
117+
cfg,
118+
train_env,
119+
actor_model_explore,
120+
compile=False,
121+
compile_mode=None,
122+
cudagraphs=False,
123+
):
117124
"""Make collector."""
118125
device = cfg.collector.device
119126
if device in ("", None):
@@ -131,6 +138,8 @@ def make_collector(cfg, train_env, actor_model_explore):
131138
reset_at_each_iter=cfg.collector.reset_at_each_iter,
132139
device=device,
133140
storing_device="cpu",
141+
compile_policy=False if not compile else {"mode": compile_mode},
142+
cudagraph_policy=cudagraphs,
134143
)
135144
collector.set_seed(cfg.env.seed)
136145
return collector

sota-implementations/dreamer/dreamer_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -475,12 +475,12 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
475475
spec=Composite(
476476
**{
477477
"loc": Unbounded(
478-
proof_environment.single_action_spec.shape,
479-
device=proof_environment.single_action_spec.device,
478+
proof_environment.action_spec_unbatched.shape,
479+
device=proof_environment.action_spec_unbatched.device,
480480
),
481481
"scale": Unbounded(
482-
proof_environment.single_action_spec.shape,
483-
device=proof_environment.single_action_spec.device,
482+
proof_environment.action_spec_unbatched.shape,
483+
device=proof_environment.action_spec_unbatched.device,
484484
),
485485
}
486486
),
@@ -491,7 +491,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
491491
default_interaction_type=InteractionType.RANDOM,
492492
distribution_class=TanhNormal,
493493
distribution_kwargs={"tanh_loc": True},
494-
spec=Composite(**{action_key: proof_environment.single_action_spec}),
494+
spec=Composite(**{action_key: proof_environment.action_spec_unbatched}),
495495
),
496496
)
497497
return actor_simulator
@@ -532,10 +532,10 @@ def _dreamer_make_actor_real(
532532
spec=Composite(
533533
**{
534534
"loc": Unbounded(
535-
proof_environment.single_action_spec.shape,
535+
proof_environment.action_spec_unbatched.shape,
536536
),
537537
"scale": Unbounded(
538-
proof_environment.single_action_spec.shape,
538+
proof_environment.action_spec_unbatched.shape,
539539
),
540540
}
541541
),

0 commit comments

Comments
 (0)