Skip to content

Commit f149811

Browse files
author
Vincent Moens
committed
[Feature] DQN compatibility with compile
ghstack-source-id: 113dc8c Pull Request resolved: #2571
1 parent bb6f87a commit f149811

File tree

8 files changed

+193
-115
lines changed

8 files changed

+193
-115
lines changed

sota-implementations/dqn/config_atari.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,8 @@ loss:
3939
gamma: 0.99
4040
hard_update_freq: 10_000
4141
num_updates: 1
42+
43+
compile:
44+
compile: False
45+
compile_mode:
46+
cudagraphs: False

sota-implementations/dqn/config_cartpole.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,8 @@ loss:
3838
gamma: 0.99
3939
hard_update_freq: 50
4040
num_updates: 1
41+
42+
compile:
43+
compile: False
44+
compile_mode:
45+
cudagraphs: False

sota-implementations/dqn/dqn_atari.py

Lines changed: 71 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
from __future__ import annotations
1111

1212
import tempfile
13-
import time
13+
import warnings
1414

1515
import hydra
1616
import torch.nn
1717
import torch.optim
1818
import tqdm
19-
from tensordict.nn import TensorDictSequential
20-
from torchrl._utils import logger as torchrl_logger
19+
from tensordict.nn import CudaGraphModule, TensorDictSequential
20+
from torchrl._utils import timeit
2121

2222
from torchrl.collectors import SyncDataCollector
2323
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
@@ -48,28 +48,17 @@ def main(cfg: "DictConfig"): # noqa: F821
4848
test_interval = cfg.logger.test_interval // frame_skip
4949

5050
# Make the components
51-
model = make_dqn_model(cfg.env.env_name, frame_skip)
51+
model = make_dqn_model(cfg.env.env_name, frame_skip, device=device)
5252
greedy_module = EGreedyModule(
5353
annealing_num_steps=cfg.collector.annealing_frames,
5454
eps_init=cfg.collector.eps_start,
5555
eps_end=cfg.collector.eps_end,
5656
spec=model.spec,
57+
device=device,
5758
)
5859
model_explore = TensorDictSequential(
5960
model,
6061
greedy_module,
61-
).to(device)
62-
63-
# Create the collector
64-
collector = SyncDataCollector(
65-
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
66-
policy=model_explore,
67-
frames_per_batch=frames_per_batch,
68-
total_frames=total_frames,
69-
device=device,
70-
storing_device=device,
71-
max_frames_per_traj=-1,
72-
init_random_frames=init_random_frames,
7362
)
7463

7564
# Create the replay buffer
@@ -129,25 +118,70 @@ def main(cfg: "DictConfig"): # noqa: F821
129118
)
130119
test_env.eval()
131120

121+
def update(sampled_tensordict):
122+
loss_td = loss_module(sampled_tensordict)
123+
q_loss = loss_td["loss"]
124+
optimizer.zero_grad()
125+
q_loss.backward()
126+
torch.nn.utils.clip_grad_norm_(
127+
list(loss_module.parameters()), max_norm=max_grad
128+
)
129+
optimizer.step()
130+
target_net_updater.step()
131+
return q_loss.detach()
132+
133+
compile_mode = None
134+
if cfg.compile.compile:
135+
compile_mode = cfg.compile.compile_mode
136+
if compile_mode in ("", None):
137+
if cfg.compile.cudagraphs:
138+
compile_mode = "default"
139+
else:
140+
compile_mode = "reduce-overhead"
141+
update = torch.compile(update, mode=compile_mode)
142+
if cfg.compile.cudagraphs:
143+
warnings.warn(
144+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
145+
category=UserWarning,
146+
)
147+
update = CudaGraphModule(update, warmup=50)
148+
149+
# Create the collector
150+
collector = SyncDataCollector(
151+
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
152+
policy=model_explore,
153+
frames_per_batch=frames_per_batch,
154+
total_frames=total_frames,
155+
device=device,
156+
storing_device=device,
157+
max_frames_per_traj=-1,
158+
init_random_frames=init_random_frames,
159+
compile_policy={"mode": compile_mode, "fullgraph": True}
160+
if compile_mode is not None
161+
else False,
162+
cudagraph_policy=cfg.compile.cudagraphs,
163+
)
164+
132165
# Main loop
133166
collected_frames = 0
134-
start_time = time.time()
135-
sampling_start = time.time()
136167
num_updates = cfg.loss.num_updates
137168
max_grad = cfg.optim.max_grad_norm
138169
num_test_episodes = cfg.logger.num_test_episodes
139170
q_losses = torch.zeros(num_updates, device=device)
140171
pbar = tqdm.tqdm(total=total_frames)
141-
for i, data in enumerate(collector):
142172

173+
c_iter = iter(collector)
174+
for i in range(len(collector)):
175+
with timeit("collecting"):
176+
data = next(c_iter)
143177
log_info = {}
144-
sampling_time = time.time() - sampling_start
145178
pbar.update(data.numel())
146179
data = data.reshape(-1)
147180
current_frames = data.numel() * frame_skip
148181
collected_frames += current_frames
149182
greedy_module.step(current_frames)
150-
replay_buffer.extend(data)
183+
with timeit("rb - extend"):
184+
replay_buffer.extend(data)
151185

152186
# Get and log training rewards and episode lengths
153187
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
@@ -169,74 +203,59 @@ def main(cfg: "DictConfig"): # noqa: F821
169203
continue
170204

171205
# optimization steps
172-
training_start = time.time()
173206
for j in range(num_updates):
174-
175-
sampled_tensordict = replay_buffer.sample()
176-
sampled_tensordict = sampled_tensordict.to(device)
177-
178-
loss_td = loss_module(sampled_tensordict)
179-
q_loss = loss_td["loss"]
180-
optimizer.zero_grad()
181-
q_loss.backward()
182-
torch.nn.utils.clip_grad_norm_(
183-
list(loss_module.parameters()), max_norm=max_grad
184-
)
185-
optimizer.step()
186-
target_net_updater.step()
187-
q_losses[j].copy_(q_loss.detach())
188-
189-
training_time = time.time() - training_start
207+
with timeit("rb - sample"):
208+
sampled_tensordict = replay_buffer.sample()
209+
sampled_tensordict = sampled_tensordict.to(device)
210+
with timeit("update"):
211+
q_loss = update(sampled_tensordict)
212+
q_losses[j].copy_(q_loss)
190213

191214
# Get and log q-values, loss, epsilon, sampling time and training time
192215
log_info.update(
193216
{
194-
"train/q_values": (data["action_value"] * data["action"]).sum().item()
195-
/ frames_per_batch,
196-
"train/q_loss": q_losses.mean().item(),
217+
"train/q_values": data["chosen_action_value"].sum() / frames_per_batch,
218+
"train/q_loss": q_losses.mean(),
197219
"train/epsilon": greedy_module.eps,
198-
"train/sampling_time": sampling_time,
199-
"train/training_time": training_time,
200220
}
201221
)
202222

203223
# Get and log evaluation rewards and eval time
204-
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
224+
with torch.no_grad(), set_exploration_type(
225+
ExplorationType.DETERMINISTIC
226+
), timeit("eval"):
205227
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
206228
cur_test_frame = (i * frames_per_batch) // test_interval
207229
final = current_frames >= collector.total_frames
208230
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
209231
model.eval()
210-
eval_start = time.time()
211232
test_rewards = eval_model(
212233
model, test_env, num_episodes=num_test_episodes
213234
)
214-
eval_time = time.time() - eval_start
215235
log_info.update(
216236
{
217237
"eval/reward": test_rewards,
218-
"eval/eval_time": eval_time,
219238
}
220239
)
221240
model.train()
222241

242+
if i % 200 == 0:
243+
timeit.print()
244+
log_info.update(timeit.todict(prefix="time"))
245+
timeit.erase()
246+
223247
# Log all the information
224248
if logger:
225249
for key, value in log_info.items():
226250
logger.log_scalar(key, value, step=collected_frames)
227251

228252
# update weights of the inference policy
229253
collector.update_policy_weights_()
230-
sampling_start = time.time()
231254

232255
collector.shutdown()
233256
if not test_env.is_closed:
234257
test_env.close()
235258

236-
end_time = time.time()
237-
execution_time = end_time - start_time
238-
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
239-
240259

241260
if __name__ == "__main__":
242261
main()

0 commit comments

Comments
 (0)