Skip to content

Commit e2be42e

Browse files
author
Vincent Moens
committed
[Feature] CQL compatibility with compile
ghstack-source-id: d362d6c Pull Request resolved: #2553
1 parent e3c3047 commit e2be42e

File tree

9 files changed

+374
-279
lines changed

9 files changed

+374
-279
lines changed

sota-implementations/cql/cql_offline.py

Lines changed: 83 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
import numpy as np
1616
import torch
1717
import tqdm
18-
from torchrl._utils import logger as torchrl_logger
18+
from tensordict.nn import CudaGraphModule
19+
20+
from torchrl._utils import logger as torchrl_logger, timeit
1921
from torchrl.envs.utils import ExplorationType, set_exploration_type
22+
from torchrl.objectives import group_optimizers
2023
from torchrl.record.loggers import generate_exp_name, get_logger
2124

2225
from utils import (
@@ -69,6 +72,9 @@ def main(cfg: "DictConfig"): # noqa: F821
6972
# Create agent
7073
model = make_cql_model(cfg, train_env, eval_env, device)
7174
del train_env
75+
if hasattr(eval_env, "start"):
76+
# To set the number of threads to the definitive value
77+
eval_env.start()
7278

7379
# Create loss
7480
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
@@ -81,81 +87,104 @@ def main(cfg: "DictConfig"): # noqa: F821
8187
alpha_prime_optim,
8288
) = make_continuous_cql_optimizer(cfg, loss_module)
8389

84-
pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
90+
# Group optimizers
91+
optimizer = group_optimizers(
92+
policy_optim, critic_optim, alpha_optim, alpha_prime_optim
93+
)
8594

86-
gradient_steps = cfg.optim.gradient_steps
87-
policy_eval_start = cfg.optim.policy_eval_start
88-
evaluation_interval = cfg.logger.eval_iter
89-
eval_steps = cfg.logger.eval_steps
90-
91-
# Training loop
92-
start_time = time.time()
93-
for i in range(gradient_steps):
94-
pbar.update(1)
95-
# sample data
96-
data = replay_buffer.sample()
97-
# compute loss
98-
loss_vals = loss_module(data.clone().to(device))
95+
def update(data, policy_eval_start, iteration):
96+
loss_vals = loss_module(data.to(device))
9997

10098
# official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
101-
if i >= policy_eval_start:
102-
actor_loss = loss_vals["loss_actor"]
103-
else:
104-
actor_loss = loss_vals["loss_actor_bc"]
99+
actor_loss = torch.where(
100+
iteration >= policy_eval_start,
101+
loss_vals["loss_actor"],
102+
loss_vals["loss_actor_bc"],
103+
)
105104
q_loss = loss_vals["loss_qvalue"]
106105
cql_loss = loss_vals["loss_cql"]
107106

108107
q_loss = q_loss + cql_loss
108+
loss_vals["q_loss"] = q_loss
109109

110110
# update model
111111
alpha_loss = loss_vals["loss_alpha"]
112112
alpha_prime_loss = loss_vals["loss_alpha_prime"]
113+
if alpha_prime_loss is None:
114+
alpha_prime_loss = 0
113115

114-
alpha_optim.zero_grad()
115-
alpha_loss.backward()
116-
alpha_optim.step()
116+
loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
117117

118-
policy_optim.zero_grad()
119-
actor_loss.backward()
120-
policy_optim.step()
118+
loss.backward()
119+
optimizer.step()
120+
optimizer.zero_grad(set_to_none=True)
121121

122-
if alpha_prime_optim is not None:
123-
alpha_prime_optim.zero_grad()
124-
alpha_prime_loss.backward(retain_graph=True)
125-
alpha_prime_optim.step()
122+
# update qnet_target params
123+
target_net_updater.step()
126124

127-
critic_optim.zero_grad()
128-
# TODO: we have the option to compute losses independently retain is not needed?
129-
q_loss.backward(retain_graph=False)
130-
critic_optim.step()
125+
return loss.detach(), loss_vals.detach()
131126

132-
loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
127+
compile_mode = None
128+
if cfg.compile.compile:
129+
if cfg.compile.compile_mode not in (None, ""):
130+
compile_mode = cfg.compile.compile_mode
131+
elif cfg.compile.cudagraphs:
132+
compile_mode = "default"
133+
else:
134+
compile_mode = "reduce-overhead"
135+
update = torch.compile(update, mode=compile_mode)
136+
if cfg.compile.cudagraphs:
137+
update = CudaGraphModule(update, warmup=50)
138+
139+
pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
140+
141+
gradient_steps = cfg.optim.gradient_steps
142+
policy_eval_start = cfg.optim.policy_eval_start
143+
evaluation_interval = cfg.logger.eval_iter
144+
eval_steps = cfg.logger.eval_steps
145+
146+
# Training loop
147+
start_time = time.time()
148+
policy_eval_start = torch.tensor(policy_eval_start, device=device)
149+
for i in range(gradient_steps):
150+
pbar.update(1)
151+
# sample data
152+
with timeit("sample"):
153+
data = replay_buffer.sample()
154+
155+
with timeit("update"):
156+
# compute loss
157+
i_device = torch.tensor(i, device=device)
158+
loss, loss_vals = update(
159+
data.to(device), policy_eval_start=policy_eval_start, iteration=i_device
160+
)
133161

134162
# log metrics
135163
to_log = {
136-
"loss": loss.item(),
137-
"loss_actor_bc": loss_vals["loss_actor_bc"].item(),
138-
"loss_actor": loss_vals["loss_actor"].item(),
139-
"loss_qvalue": q_loss.item(),
140-
"loss_cql": cql_loss.item(),
141-
"loss_alpha": alpha_loss.item(),
142-
"loss_alpha_prime": alpha_prime_loss.item(),
164+
"loss": loss.cpu(),
165+
**loss_vals.cpu(),
143166
}
144167

145-
# update qnet_target params
146-
target_net_updater.step()
147-
148168
# evaluation
149-
if i % evaluation_interval == 0:
150-
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
151-
eval_td = eval_env.rollout(
152-
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
153-
)
154-
eval_env.apply(dump_video)
155-
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
156-
to_log["evaluation_reward"] = eval_reward
157-
158-
log_metrics(logger, to_log, i)
169+
with timeit("log/eval"):
170+
if i % evaluation_interval == 0:
171+
with set_exploration_type(
172+
ExplorationType.DETERMINISTIC
173+
), torch.no_grad():
174+
eval_td = eval_env.rollout(
175+
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
176+
)
177+
eval_env.apply(dump_video)
178+
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
179+
to_log["evaluation_reward"] = eval_reward
180+
181+
with timeit("log"):
182+
if i % 200 == 0:
183+
to_log.update(timeit.todict(prefix="time"))
184+
log_metrics(logger, to_log, i)
185+
if i % 200 == 0:
186+
timeit.print()
187+
timeit.erase()
159188

160189
pbar.close()
161190
torchrl_logger.info(f"Training time: {time.time() - start_time}")

0 commit comments

Comments
 (0)