Skip to content

Commit 2cfc2ab

Browse files
author
Vincent Moens
committed
[Feature] IQL compatibility with compile
ghstack-source-id: 77bca16 Pull Request resolved: #2649
1 parent 6482766 commit 2cfc2ab

File tree

9 files changed

+327
-244
lines changed

9 files changed

+327
-244
lines changed

sota-implementations/iql/discrete_iql.py

Lines changed: 99 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
"""
1414
from __future__ import annotations
1515

16-
import time
16+
import warnings
1717

1818
import hydra
1919
import numpy as np
2020
import torch
2121
import tqdm
22-
from torchrl._utils import logger as torchrl_logger
22+
from tensordict import TensorDict
23+
from tensordict.nn import CudaGraphModule
24+
25+
from torchrl._utils import timeit
2326

2427
from torchrl.envs import set_gym_backend
2528
from torchrl.envs.utils import ExplorationType, set_exploration_type
29+
from torchrl.objectives import group_optimizers
2630
from torchrl.record.loggers import generate_exp_name, get_logger
2731

2832
from utils import (
@@ -37,6 +41,9 @@
3741
)
3842

3943

44+
torch.set_float32_matmul_precision("high")
45+
46+
4047
@hydra.main(config_path="", config_name="discrete_iql")
4148
def main(cfg: "DictConfig"): # noqa: F821
4249
set_gym_backend(cfg.env.backend).set()
@@ -87,16 +94,54 @@ def main(cfg: "DictConfig"): # noqa: F821
8794
# Create model
8895
model = make_discrete_iql_model(cfg, train_env, eval_env, device)
8996

97+
compile_mode = None
98+
if cfg.compile.compile:
99+
compile_mode = cfg.compile.compile_mode
100+
if compile_mode in ("", None):
101+
if cfg.compile.cudagraphs:
102+
compile_mode = "default"
103+
else:
104+
compile_mode = "reduce-overhead"
105+
90106
# Create collector
91-
collector = make_collector(cfg, train_env, actor_model_explore=model[0])
107+
collector = make_collector(
108+
cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode
109+
)
92110

93111
# Create loss
94-
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)
112+
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)
95113

96114
# Create optimizer
97115
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
98116
cfg.optim, loss_module
99117
)
118+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)
119+
del optimizer_actor, optimizer_critic, optimizer_value
120+
121+
def update(sampled_tensordict):
122+
optimizer.zero_grad(set_to_none=True)
123+
# compute losses
124+
actor_loss, _ = loss_module.actor_loss(sampled_tensordict)
125+
value_loss, _ = loss_module.value_loss(sampled_tensordict)
126+
q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict)
127+
(actor_loss + value_loss + q_loss).backward()
128+
optimizer.step()
129+
130+
# update qnet_target params
131+
target_net_updater.step()
132+
metadata.update(
133+
{"actor_loss": actor_loss, "value_loss": value_loss, "q_loss": q_loss}
134+
)
135+
return TensorDict(metadata).detach()
136+
137+
if cfg.compile.compile:
138+
update = torch.compile(update, mode=compile_mode)
139+
if cfg.compile.cudagraphs:
140+
warnings.warn(
141+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
142+
category=UserWarning,
143+
)
144+
update = CudaGraphModule(update, warmup=50)
100145

101146
# Main loop
102147
collected_frames = 0
@@ -112,103 +157,82 @@ def main(cfg: "DictConfig"): # noqa: F821
112157
eval_iter = cfg.logger.eval_iter
113158
frames_per_batch = cfg.collector.frames_per_batch
114159
eval_rollout_steps = cfg.collector.max_frames_per_traj
115-
sampling_start = start_time = time.time()
116-
for tensordict in collector:
117-
sampling_time = time.time() - sampling_start
118-
pbar.update(tensordict.numel())
160+
161+
collector_iter = iter(collector)
162+
for _ in range(len(collector)):
163+
with timeit("collection"):
164+
tensordict = next(collector_iter)
165+
current_frames = tensordict.numel()
166+
pbar.update(current_frames)
167+
119168
# update weights of the inference policy
120169
collector.update_policy_weights_()
121170

122-
tensordict = tensordict.reshape(-1)
123-
current_frames = tensordict.numel()
124-
# add to replay buffer
125-
replay_buffer.extend(tensordict.cpu())
171+
with timeit("buffer - extend"):
172+
tensordict = tensordict.reshape(-1)
173+
174+
# add to replay buffer
175+
replay_buffer.extend(tensordict)
126176
collected_frames += current_frames
127177

128178
# optimization steps
129-
training_start = time.time()
130-
if collected_frames >= init_random_frames:
131-
for _ in range(num_updates):
132-
# sample from replay buffer
133-
sampled_tensordict = replay_buffer.sample().clone()
134-
if sampled_tensordict.device != device:
135-
sampled_tensordict = sampled_tensordict.to(
136-
device, non_blocking=True
137-
)
138-
else:
139-
sampled_tensordict = sampled_tensordict
140-
# compute losses
141-
actor_loss, _ = loss_module.actor_loss(sampled_tensordict)
142-
optimizer_actor.zero_grad()
143-
actor_loss.backward()
144-
optimizer_actor.step()
145-
146-
value_loss, _ = loss_module.value_loss(sampled_tensordict)
147-
optimizer_value.zero_grad()
148-
value_loss.backward()
149-
optimizer_value.step()
150-
151-
q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict)
152-
optimizer_critic.zero_grad()
153-
q_loss.backward()
154-
optimizer_critic.step()
155-
156-
# update qnet_target params
157-
target_net_updater.step()
158-
159-
# update priority
160-
if prb:
161-
sampled_tensordict.set(
162-
loss_module.tensor_keys.priority,
163-
metadata.pop("td_error").detach().max(0).values,
164-
)
165-
replay_buffer.update_priority(sampled_tensordict)
166-
167-
training_time = time.time() - training_start
179+
with timeit("training"):
180+
if collected_frames >= init_random_frames:
181+
for _ in range(num_updates):
182+
# sample from replay buffer
183+
with timeit("buffer - sample"):
184+
sampled_tensordict = replay_buffer.sample().to(device)
185+
186+
with timeit("training - update"):
187+
torch.compiler.cudagraph_mark_step_begin()
188+
metadata = update(sampled_tensordict)
189+
# update priority
190+
if prb:
191+
sampled_tensordict.set(
192+
loss_module.tensor_keys.priority,
193+
metadata.pop("td_error").detach().max(0).values,
194+
)
195+
replay_buffer.update_priority(sampled_tensordict)
196+
168197
episode_rewards = tensordict["next", "episode_reward"][
169198
tensordict["next", "done"]
170199
]
171200

172-
# Logging
173201
metrics_to_log = {}
174-
if len(episode_rewards) > 0:
175-
episode_length = tensordict["next", "step_count"][
176-
tensordict["next", "done"]
177-
]
178-
metrics_to_log["train/reward"] = episode_rewards.mean().item()
179-
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
180-
episode_length
181-
)
182-
if collected_frames >= init_random_frames:
183-
metrics_to_log["train/q_loss"] = q_loss.detach()
184-
metrics_to_log["train/actor_loss"] = actor_loss.detach()
185-
metrics_to_log["train/value_loss"] = value_loss.detach()
186-
metrics_to_log["train/sampling_time"] = sampling_time
187-
metrics_to_log["train/training_time"] = training_time
188-
189202
# Evaluation
190203
if abs(collected_frames % eval_iter) < frames_per_batch:
191-
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
192-
eval_start = time.time()
204+
with set_exploration_type(
205+
ExplorationType.DETERMINISTIC
206+
), torch.no_grad(), timeit("eval"):
193207
eval_rollout = eval_env.rollout(
194208
eval_rollout_steps,
195209
model[0],
196210
auto_cast_to_device=True,
197211
break_when_any_done=True,
198212
)
199213
eval_env.apply(dump_video)
200-
eval_time = time.time() - eval_start
201214
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
202215
metrics_to_log["eval/reward"] = eval_reward
203-
metrics_to_log["eval/time"] = eval_time
216+
217+
# Logging
218+
if len(episode_rewards) > 0:
219+
episode_length = tensordict["next", "step_count"][
220+
tensordict["next", "done"]
221+
]
222+
metrics_to_log["train/reward"] = episode_rewards.mean().item()
223+
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
224+
episode_length
225+
)
226+
if collected_frames >= init_random_frames:
227+
metrics_to_log["train/q_loss"] = metadata["q_loss"]
228+
metrics_to_log["train/actor_loss"] = metadata["actor_loss"]
229+
metrics_to_log["train/value_loss"] = metadata["value_loss"]
230+
metrics_to_log.update(timeit.todict(prefix="time"))
204231
if logger is not None:
205232
log_metrics(logger, metrics_to_log, collected_frames)
206-
sampling_start = time.time()
233+
timeit.erase()
207234

208235
collector.shutdown()
209-
end_time = time.time()
210-
execution_time = end_time - start_time
211-
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
212236

213237

214238
if __name__ == "__main__":

sota-implementations/iql/discrete_iql.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,8 @@ loss:
5959
# IQL specific hyperparameter
6060
temperature: 100
6161
expectile: 0.8
62+
63+
compile:
64+
compile: False
65+
compile_mode: default
66+
cudagraphs: False

sota-implementations/iql/iql_offline.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@
1111
"""
1212
from __future__ import annotations
1313

14-
import time
14+
import warnings
1515

1616
import hydra
1717
import numpy as np
1818
import torch
1919
import tqdm
20-
from torchrl._utils import logger as torchrl_logger
20+
from tensordict.nn import CudaGraphModule
21+
22+
from torchrl._utils import timeit
2123

2224
from torchrl.envs import set_gym_backend
2325
from torchrl.envs.utils import ExplorationType, set_exploration_type
26+
from torchrl.objectives import group_optimizers
2427
from torchrl.record.loggers import generate_exp_name, get_logger
2528

2629
from utils import (
@@ -34,6 +37,9 @@
3437
)
3538

3639

40+
torch.set_float32_matmul_precision("high")
41+
42+
3743
@hydra.main(config_path="", config_name="offline_config")
3844
def main(cfg: "DictConfig"): # noqa: F821
3945
set_gym_backend(cfg.env.backend).set()
@@ -79,60 +85,69 @@ def main(cfg: "DictConfig"): # noqa: F821
7985
model = make_iql_model(cfg, train_env, eval_env, device)
8086

8187
# Create loss
82-
loss_module, target_net_updater = make_loss(cfg.loss, model)
88+
loss_module, target_net_updater = make_loss(cfg.loss, model, device=device)
8389

8490
# Create optimizer
8591
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
8692
cfg.optim, loss_module
8793
)
94+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)
8895

89-
pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
90-
91-
gradient_steps = cfg.optim.gradient_steps
92-
evaluation_interval = cfg.logger.eval_iter
93-
eval_steps = cfg.logger.eval_steps
94-
95-
# Training loop
96-
start_time = time.time()
97-
for i in range(gradient_steps):
98-
pbar.update(1)
99-
# sample data
100-
data = replay_buffer.sample()
101-
102-
if data.device != device:
103-
data = data.to(device, non_blocking=True)
104-
96+
def update(data):
97+
optimizer.zero_grad(set_to_none=True)
10598
# compute losses
10699
loss_info = loss_module(data)
107100
actor_loss = loss_info["loss_actor"]
108101
value_loss = loss_info["loss_value"]
109102
q_loss = loss_info["loss_qvalue"]
110103

111-
optimizer_actor.zero_grad()
112-
actor_loss.backward()
113-
optimizer_actor.step()
114-
115-
optimizer_value.zero_grad()
116-
value_loss.backward()
117-
optimizer_value.step()
118-
119-
optimizer_critic.zero_grad()
120-
q_loss.backward()
121-
optimizer_critic.step()
104+
(actor_loss + value_loss + q_loss).backward()
105+
optimizer.step()
122106

123107
# update qnet_target params
124108
target_net_updater.step()
109+
return loss_info.detach()
110+
111+
compile_mode = None
112+
if cfg.compile.compile:
113+
compile_mode = cfg.compile.compile_mode
114+
if compile_mode in ("", None):
115+
if cfg.compile.cudagraphs:
116+
compile_mode = "default"
117+
else:
118+
compile_mode = "reduce-overhead"
119+
120+
if cfg.compile.compile:
121+
update = torch.compile(update, mode=compile_mode)
122+
if cfg.compile.cudagraphs:
123+
warnings.warn(
124+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
125+
category=UserWarning,
126+
)
127+
update = CudaGraphModule(update, warmup=50)
128+
129+
pbar = tqdm.tqdm(range(cfg.optim.gradient_steps))
130+
131+
evaluation_interval = cfg.logger.eval_iter
132+
eval_steps = cfg.logger.eval_steps
133+
134+
# Training loop
135+
for i in pbar:
136+
# sample data
137+
with timeit("sample"):
138+
data = replay_buffer.sample()
139+
data = data.to(device)
125140

126-
# log metrics
127-
to_log = {
128-
"loss_actor": actor_loss.item(),
129-
"loss_qvalue": q_loss.item(),
130-
"loss_value": value_loss.item(),
131-
}
141+
with timeit("update"):
142+
torch.compiler.cudagraph_mark_step_begin()
143+
loss_info = update(data)
132144

133145
# evaluation
146+
to_log = loss_info.to_dict()
134147
if i % evaluation_interval == 0:
135-
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
148+
with set_exploration_type(
149+
ExplorationType.DETERMINISTIC
150+
), torch.no_grad(), timeit("eval"):
136151
eval_td = eval_env.rollout(
137152
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
138153
)
@@ -147,7 +162,6 @@ def main(cfg: "DictConfig"): # noqa: F821
147162
eval_env.close()
148163
if not train_env.is_closed:
149164
train_env.close()
150-
torchrl_logger.info(f"Training time: {time.time() - start_time}")
151165

152166

153167
if __name__ == "__main__":

0 commit comments

Comments
 (0)