Skip to content

Commit 91064bc

Browse files
author
Vincent Moens
committed
[Feature] TD3-bc compatibility with compile
ghstack-source-id: 8a33e39 Pull Request resolved: #2657
1 parent 1b7eda1 commit 91064bc

File tree

3 files changed

+92
-68
lines changed

3 files changed

+92
-68
lines changed

sota-implementations/td3_bc/config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,8 @@ logger:
4343
eval_steps: 1000
4444
eval_envs: 1
4545
video: False
46+
47+
compile:
48+
compile: False
49+
compile_mode:
50+
cudagraphs: False

sota-implementations/td3_bc/td3_bc.py

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
"""
1212
from __future__ import annotations
1313

14-
import time
14+
import warnings
1515

1616
import hydra
1717
import numpy as np
1818
import torch
1919
import tqdm
20-
from torchrl._utils import logger as torchrl_logger
20+
from tensordict import TensorDict
21+
from tensordict.nn import CudaGraphModule
22+
23+
from torchrl._utils import compile_with_warmup, timeit
2124

2225
from torchrl.envs import set_gym_backend
2326
from torchrl.envs.utils import ExplorationType, set_exploration_type
@@ -72,7 +75,16 @@ def main(cfg: "DictConfig"): # noqa: F821
7275
)
7376

7477
# Create replay buffer
75-
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
78+
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer, device=device)
79+
80+
compile_mode = None
81+
if cfg.compile.compile:
82+
compile_mode = cfg.compile.compile_mode
83+
if compile_mode in ("", None):
84+
if cfg.compile.cudagraphs:
85+
compile_mode = "default"
86+
else:
87+
compile_mode = "reduce-overhead"
7688

7789
# Create agent
7890
model, _ = make_td3_agent(cfg, eval_env, device)
@@ -83,67 +95,86 @@ def main(cfg: "DictConfig"): # noqa: F821
8395
# Create optimizer
8496
optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module)
8597

86-
gradient_steps = cfg.optim.gradient_steps
87-
evaluation_interval = cfg.logger.eval_iter
88-
eval_steps = cfg.logger.eval_steps
89-
delayed_updates = cfg.optim.policy_update_delay
90-
update_counter = 0
91-
pbar = tqdm.tqdm(range(gradient_steps))
92-
# Training loop
93-
start_time = time.time()
94-
for i in pbar:
95-
pbar.update(1)
96-
# Update actor every delayed_updates
97-
update_counter += 1
98-
update_actor = update_counter % delayed_updates == 0
99-
100-
# Sample from replay buffer
101-
sampled_tensordict = replay_buffer.sample()
102-
if sampled_tensordict.device != device:
103-
sampled_tensordict = sampled_tensordict.to(device)
104-
else:
105-
sampled_tensordict = sampled_tensordict.clone()
106-
98+
def update(sampled_tensordict, update_actor):
10799
# Compute loss
108100
q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)
109101

110102
# Update critic
111-
optimizer_critic.zero_grad()
112103
q_loss.backward()
113104
optimizer_critic.step()
114-
q_loss.item()
115-
116-
to_log = {"q_loss": q_loss.item()}
105+
optimizer_critic.zero_grad(set_to_none=True)
117106

118107
# Update actor
119108
if update_actor:
120109
actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict)
121-
optimizer_actor.zero_grad()
122110
actor_loss.backward()
123111
optimizer_actor.step()
112+
optimizer_actor.zero_grad(set_to_none=True)
124113

125114
# Update target params
126115
target_net_updater.step()
116+
else:
117+
actorloss_metadata = {}
118+
actor_loss = q_loss.new_zeros(())
119+
metadata = TensorDict(actorloss_metadata)
120+
metadata.set("q_loss", q_loss.detach())
121+
metadata.set("actor_loss", actor_loss.detach())
122+
return metadata
123+
124+
if cfg.compile.compile:
125+
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
126+
127+
if cfg.compile.cudagraphs:
128+
warnings.warn(
129+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
130+
category=UserWarning,
131+
)
132+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
133+
134+
gradient_steps = cfg.optim.gradient_steps
135+
evaluation_interval = cfg.logger.eval_iter
136+
eval_steps = cfg.logger.eval_steps
137+
delayed_updates = cfg.optim.policy_update_delay
138+
pbar = tqdm.tqdm(range(gradient_steps))
139+
# Training loop
140+
for update_counter in pbar:
141+
timeit.printevery(num_prints=1000, total_count=gradient_steps, erase=True)
127142

128-
to_log["actor_loss"] = actor_loss.item()
129-
to_log.update(actorloss_metadata)
143+
# Update actor every delayed_updates
144+
update_actor = update_counter % delayed_updates == 0
145+
146+
with timeit("rb - sample"):
147+
# Sample from replay buffer
148+
sampled_tensordict = replay_buffer.sample()
149+
150+
with timeit("update"):
151+
torch.compiler.cudagraph_mark_step_begin()
152+
metadata = update(sampled_tensordict, update_actor).clone()
153+
154+
to_log = {}
155+
if update_actor:
156+
to_log.update(metadata.to_dict())
157+
else:
158+
to_log.update(metadata.exclude("actor_loss").to_dict())
130159

131160
# evaluation
132-
if i % evaluation_interval == 0:
133-
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
161+
if update_counter % evaluation_interval == 0:
162+
with set_exploration_type(
163+
ExplorationType.DETERMINISTIC
164+
), torch.no_grad(), timeit("eval"):
134165
eval_td = eval_env.rollout(
135166
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
136167
)
137168
eval_env.apply(dump_video)
138169
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
139170
to_log["evaluation_reward"] = eval_reward
140171
if logger is not None:
141-
log_metrics(logger, to_log, i)
172+
to_log.update(timeit.todict(prefix="time"))
173+
log_metrics(logger, to_log, update_counter)
142174

143175
if not eval_env.is_closed:
144176
eval_env.close()
145177
pbar.close()
146-
torchrl_logger.info(f"Training time: {time.time() - start_time}")
147178

148179

149180
if __name__ == "__main__":

sota-implementations/td3_bc/utils.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import functools
88

99
import torch
10-
from tensordict.nn import TensorDictSequential
10+
from tensordict.nn import TensorDictModule, TensorDictSequential
1111

1212
from torch import nn, optim
1313
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
@@ -26,14 +26,7 @@
2626
)
2727
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
2828
from torchrl.envs.utils import ExplorationType, set_exploration_type
29-
from torchrl.modules import (
30-
AdditiveGaussianModule,
31-
MLP,
32-
SafeModule,
33-
SafeSequential,
34-
TanhModule,
35-
ValueOperator,
36-
)
29+
from torchrl.modules import AdditiveGaussianModule, MLP, TanhModule, ValueOperator
3730

3831
from torchrl.objectives import SoftUpdate
3932
from torchrl.objectives.td3_bc import TD3BCLoss
@@ -98,17 +91,19 @@ def make_environment(cfg, logger=None):
9891
# ---------------------------
9992

10093

101-
def make_offline_replay_buffer(rb_cfg):
94+
def make_offline_replay_buffer(rb_cfg, device):
10295
data = D4RLExperienceReplay(
10396
dataset_id=rb_cfg.dataset,
10497
split_trajs=False,
10598
batch_size=rb_cfg.batch_size,
106-
sampler=SamplerWithoutReplacement(drop_last=False),
99+
# drop_last for compile
100+
sampler=SamplerWithoutReplacement(drop_last=True),
107101
prefetch=4,
108102
direct_download=True,
109103
)
110104

111105
data.append_transform(DoubleToFloat())
106+
data.append_transform(lambda td: td.to(device))
112107

113108
return data
114109

@@ -122,26 +117,22 @@ def make_td3_agent(cfg, train_env, device):
122117
"""Make TD3 agent."""
123118
# Define Actor Network
124119
in_keys = ["observation"]
125-
action_spec = train_env.action_spec
126-
if train_env.batch_size:
127-
action_spec = action_spec[(0,) * len(train_env.batch_size)]
128-
actor_net_kwargs = {
129-
"num_cells": cfg.network.hidden_sizes,
130-
"out_features": action_spec.shape[-1],
131-
"activation_class": get_activation(cfg),
132-
}
120+
action_spec = train_env.action_spec_unbatched.to(device)
133121

134-
actor_net = MLP(**actor_net_kwargs)
122+
actor_net = MLP(
123+
num_cells=cfg.network.hidden_sizes,
124+
out_features=action_spec.shape[-1],
125+
activation_class=get_activation(cfg),
126+
device=device,
127+
)
135128

136129
in_keys_actor = in_keys
137-
actor_module = SafeModule(
130+
actor_module = TensorDictModule(
138131
actor_net,
139132
in_keys=in_keys_actor,
140-
out_keys=[
141-
"param",
142-
],
133+
out_keys=["param"],
143134
)
144-
actor = SafeSequential(
135+
actor = TensorDictSequential(
145136
actor_module,
146137
TanhModule(
147138
in_keys=["param"],
@@ -151,22 +142,19 @@ def make_td3_agent(cfg, train_env, device):
151142
)
152143

153144
# Define Critic Network
154-
qvalue_net_kwargs = {
155-
"num_cells": cfg.network.hidden_sizes,
156-
"out_features": 1,
157-
"activation_class": get_activation(cfg),
158-
}
159-
160145
qvalue_net = MLP(
161-
**qvalue_net_kwargs,
146+
num_cells=cfg.network.hidden_sizes,
147+
out_features=1,
148+
activation_class=get_activation(cfg),
149+
device=device,
162150
)
163151

164152
qvalue = ValueOperator(
165153
in_keys=["action"] + in_keys,
166154
module=qvalue_net,
167155
)
168156

169-
model = nn.ModuleList([actor, qvalue]).to(device)
157+
model = nn.ModuleList([actor, qvalue])
170158

171159
# init nets
172160
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):

0 commit comments

Comments
 (0)