Skip to content

Commit 7d7cd95

Browse files
author
Vincent Moens
committed
[Feature] DDPG compatibility with compile
ghstack-source-id: f18928a Pull Request resolved: #2555
1 parent 01a421e commit 7d7cd95

File tree

5 files changed

+115
-69
lines changed

5 files changed

+115
-69
lines changed

sota-implementations/ddpg/config.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ collector:
1313
frames_per_batch: 1000
1414
init_env_steps: 1000
1515
reset_at_each_iter: False
16-
device: cpu
16+
device:
1717
env_per_collector: 1
1818

1919

@@ -40,6 +40,11 @@ network:
4040
activation: relu
4141
noise_type: "ou" # ou or gaussian
4242

43+
compile:
44+
compile: False
45+
compile_mode:
46+
cudagraphs: False
47+
4348
# logging
4449
logger:
4550
backend: wandb

sota-implementations/ddpg/ddpg.py

Lines changed: 81 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,21 @@
1212
"""
1313
from __future__ import annotations
1414

15-
import time
15+
import warnings
1616

1717
import hydra
1818

1919
import numpy as np
2020
import torch
2121
import torch.cuda
2222
import tqdm
23-
from torchrl._utils import logger as torchrl_logger
23+
from tensordict import TensorDict
24+
from tensordict.nn import CudaGraphModule
25+
26+
from torchrl._utils import timeit
2427

2528
from torchrl.envs.utils import ExplorationType, set_exploration_type
29+
from torchrl.objectives import group_optimizers
2630
from torchrl.record.loggers import generate_exp_name, get_logger
2731
from utils import (
2832
dump_video,
@@ -46,6 +50,14 @@ def main(cfg: "DictConfig"): # noqa: F821
4650
device = "cpu"
4751
device = torch.device(device)
4852

53+
collector_device = cfg.collector.device
54+
if collector_device in ("", None):
55+
if torch.cuda.is_available():
56+
collector_device = "cuda:0"
57+
else:
58+
collector_device = "cpu"
59+
collector_device = torch.device(collector_device)
60+
4961
# Create logger
5062
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
5163
logger = None
@@ -75,8 +87,25 @@ def main(cfg: "DictConfig"): # noqa: F821
7587
# Create DDPG loss
7688
loss_module, target_net_updater = make_loss_module(cfg, model)
7789

90+
compile_mode = None
91+
if cfg.compile.compile:
92+
if cfg.compile.compile_mode not in (None, ""):
93+
compile_mode = cfg.compile.compile_mode
94+
elif cfg.compile.cudagraphs:
95+
compile_mode = "default"
96+
else:
97+
compile_mode = "reduce-overhead"
98+
7899
# Create off-policy collector
79-
collector = make_collector(cfg, train_env, exploration_policy)
100+
collector = make_collector(
101+
cfg,
102+
train_env,
103+
exploration_policy,
104+
compile=cfg.compile.compile,
105+
compile_mode=compile_mode,
106+
cudagraph=cfg.compile.cudagraphs,
107+
device=collector_device,
108+
)
80109

81110
# Create replay buffer
82111
replay_buffer = make_replay_buffer(
@@ -89,9 +118,29 @@ def main(cfg: "DictConfig"): # noqa: F821
89118

90119
# Create optimizers
91120
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
121+
optimizer = group_optimizers(optimizer_actor, optimizer_critic)
122+
123+
def update(sampled_tensordict):
124+
optimizer.zero_grad(set_to_none=True)
125+
126+
td_loss: TensorDict = loss_module(sampled_tensordict)
127+
td_loss.sum(reduce=True).backward()
128+
optimizer.step()
129+
130+
# Update qnet_target params
131+
target_net_updater.step()
132+
return td_loss.detach()
133+
134+
if cfg.compile.compile:
135+
update = torch.compile(update, mode=compile_mode)
136+
if cfg.compile.cudagraphs:
137+
warnings.warn(
138+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
139+
category=UserWarning,
140+
)
141+
update = CudaGraphModule(update, warmup=50)
92142

93143
# Main loop
94-
start_time = time.time()
95144
collected_frames = 0
96145
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
97146

@@ -106,63 +155,43 @@ def main(cfg: "DictConfig"): # noqa: F821
106155
eval_iter = cfg.logger.eval_iter
107156
eval_rollout_steps = cfg.env.max_episode_steps
108157

109-
sampling_start = time.time()
110-
for _, tensordict in enumerate(collector):
111-
sampling_time = time.time() - sampling_start
158+
c_iter = iter(collector)
159+
for i in range(len(collector)):
160+
with timeit("collecting"):
161+
tensordict = next(c_iter)
112162
# Update exploration policy
113163
exploration_policy[1].step(tensordict.numel())
114164

115165
# Update weights of the inference policy
116166
collector.update_policy_weights_()
117167

118-
pbar.update(tensordict.numel())
119-
120-
tensordict = tensordict.reshape(-1)
121168
current_frames = tensordict.numel()
169+
pbar.update(current_frames)
170+
122171
# Add to replay buffer
123-
replay_buffer.extend(tensordict.cpu())
172+
with timeit("rb - extend"):
173+
tensordict = tensordict.reshape(-1)
174+
replay_buffer.extend(tensordict)
175+
124176
collected_frames += current_frames
125177

126178
# Optimization steps
127-
training_start = time.time()
128179
if collected_frames >= init_random_frames:
129-
(
130-
actor_losses,
131-
q_losses,
132-
) = ([], [])
180+
tds = []
133181
for _ in range(num_updates):
134182
# Sample from replay buffer
135-
sampled_tensordict = replay_buffer.sample()
136-
if sampled_tensordict.device != device:
137-
sampled_tensordict = sampled_tensordict.to(
138-
device, non_blocking=True
139-
)
140-
else:
141-
sampled_tensordict = sampled_tensordict.clone()
142-
143-
# Update critic
144-
q_loss, *_ = loss_module.loss_value(sampled_tensordict)
145-
optimizer_critic.zero_grad()
146-
q_loss.backward()
147-
optimizer_critic.step()
148-
149-
# Update actor
150-
actor_loss, *_ = loss_module.loss_actor(sampled_tensordict)
151-
optimizer_actor.zero_grad()
152-
actor_loss.backward()
153-
optimizer_actor.step()
154-
155-
q_losses.append(q_loss.item())
156-
actor_losses.append(actor_loss.item())
157-
158-
# Update qnet_target params
159-
target_net_updater.step()
183+
with timeit("rb - sample"):
184+
sampled_tensordict = replay_buffer.sample().to(device)
185+
with timeit("update"):
186+
torch.compiler.cudagraph_mark_step_begin()
187+
td_loss = update(sampled_tensordict)
188+
tds.append(td_loss.clone())
160189

161190
# Update priority
162191
if prb:
163192
replay_buffer.update_priority(sampled_tensordict)
193+
tds = torch.stack(tds)
164194

165-
training_time = time.time() - training_start
166195
episode_end = (
167196
tensordict["next", "done"]
168197
if tensordict["next", "done"].any()
@@ -180,38 +209,36 @@ def main(cfg: "DictConfig"): # noqa: F821
180209
)
181210

182211
if collected_frames >= init_random_frames:
183-
metrics_to_log["train/q_loss"] = np.mean(q_losses)
184-
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
185-
metrics_to_log["train/sampling_time"] = sampling_time
186-
metrics_to_log["train/training_time"] = training_time
212+
tds = TensorDict(train=tds).flatten_keys("/").mean()
213+
metrics_to_log.update(tds.to_dict())
187214

188215
# Evaluation
189216
if abs(collected_frames % eval_iter) < frames_per_batch:
190-
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
191-
eval_start = time.time()
217+
with set_exploration_type(
218+
ExplorationType.DETERMINISTIC
219+
), torch.no_grad(), timeit("eval"):
192220
eval_rollout = eval_env.rollout(
193221
eval_rollout_steps,
194222
exploration_policy,
195223
auto_cast_to_device=True,
196224
break_when_any_done=True,
197225
)
198226
eval_env.apply(dump_video)
199-
eval_time = time.time() - eval_start
200227
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
201228
metrics_to_log["eval/reward"] = eval_reward
202-
metrics_to_log["eval/time"] = eval_time
229+
if i % 20 == 0:
230+
metrics_to_log.update(timeit.todict(prefix="time"))
231+
timeit.print()
232+
timeit.erase()
233+
203234
if logger is not None:
204235
log_metrics(logger, metrics_to_log, collected_frames)
205-
sampling_start = time.time()
206236

207237
collector.shutdown()
208-
end_time = time.time()
209-
execution_time = end_time - start_time
210238
if not eval_env.is_closed:
211239
eval_env.close()
212240
if not train_env.is_closed:
213241
train_env.close()
214-
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
215242

216243

217244
if __name__ == "__main__":

sota-implementations/ddpg/utils.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010

11-
from tensordict.nn import TensorDictSequential
11+
from tensordict.nn import TensorDictModule, TensorDictSequential
1212

1313
from torch import nn, optim
1414
from torchrl.collectors import SyncDataCollector
@@ -32,8 +32,6 @@
3232
AdditiveGaussianModule,
3333
MLP,
3434
OrnsteinUhlenbeckProcessModule,
35-
SafeModule,
36-
SafeSequential,
3735
TanhModule,
3836
ValueOperator,
3937
)
@@ -115,7 +113,15 @@ def make_environment(cfg, logger):
115113
# ---------------------------
116114

117115

118-
def make_collector(cfg, train_env, actor_model_explore):
116+
def make_collector(
117+
cfg,
118+
train_env,
119+
actor_model_explore,
120+
compile=False,
121+
compile_mode=None,
122+
cudagraph=False,
123+
device: torch.device | None = None,
124+
):
119125
"""Make collector."""
120126
collector = SyncDataCollector(
121127
train_env,
@@ -124,7 +130,9 @@ def make_collector(cfg, train_env, actor_model_explore):
124130
init_random_frames=cfg.collector.init_random_frames,
125131
reset_at_each_iter=cfg.collector.reset_at_each_iter,
126132
total_frames=cfg.collector.total_frames,
127-
device=cfg.collector.device,
133+
device=device,
134+
compile_policy={"mode": compile_mode, "fullgraph": True} if compile else False,
135+
cudagraph_policy=cudagraph,
128136
)
129137
collector.set_seed(cfg.env.seed)
130138
return collector
@@ -174,9 +182,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
174182
"""Make DDPG agent."""
175183
# Define Actor Network
176184
in_keys = ["observation"]
177-
action_spec = train_env.action_spec
178-
if train_env.batch_size:
179-
action_spec = action_spec[(0,) * len(train_env.batch_size)]
185+
action_spec = train_env.action_spec_unbatched
180186
actor_net_kwargs = {
181187
"num_cells": cfg.network.hidden_sizes,
182188
"out_features": action_spec.shape[-1],
@@ -186,19 +192,16 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
186192
actor_net = MLP(**actor_net_kwargs)
187193

188194
in_keys_actor = in_keys
189-
actor_module = SafeModule(
195+
actor_module = TensorDictModule(
190196
actor_net,
191197
in_keys=in_keys_actor,
192-
out_keys=[
193-
"param",
194-
],
198+
out_keys=["param"],
195199
)
196-
actor = SafeSequential(
200+
actor = TensorDictSequential(
197201
actor_module,
198202
TanhModule(
199203
in_keys=["param"],
200204
out_keys=["action"],
201-
spec=action_spec,
202205
),
203206
)
204207

@@ -237,6 +240,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
237240
spec=action_spec,
238241
annealing_num_steps=1_000_000,
239242
device=device,
243+
safe=False,
240244
),
241245
)
242246
elif cfg.network.noise_type == "gaussian":
@@ -249,6 +253,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
249253
mean=0.0,
250254
std=0.1,
251255
device=device,
256+
safe=False,
252257
),
253258
)
254259
else:

torchrl/collectors/collectors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
set_exploration_type,
6868
)
6969

70+
7071
try:
7172
from torch.compiler import cudagraph_mark_step_begin
7273
except ImportError:

0 commit comments

Comments
 (0)