Skip to content

Commit fbfe104

Browse files
author
Vincent Moens
committed
[Feature] DT compatibility with compile
ghstack-source-id: 362b6e8 Pull Request resolved: #2556
1 parent 7d7cd95 commit fbfe104

File tree

18 files changed

+237
-148
lines changed

18 files changed

+237
-148
lines changed

sota-implementations/a2c/utils_atari.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.nn
99
import torch.optim
1010
from tensordict.nn import TensorDictModule
11-
from torchrl.data import Composite
1211
from torchrl.data.tensor_specs import CategoricalBox
1312
from torchrl.envs import (
1413
CatFrames,
@@ -94,12 +93,12 @@ def make_ppo_modules_pixels(proof_environment, device):
9493
input_shape = proof_environment.observation_spec["pixels"].shape
9594

9695
# Define distribution class and kwargs
97-
if isinstance(proof_environment.action_spec.space, CategoricalBox):
98-
num_outputs = proof_environment.action_spec.space.n
96+
if isinstance(proof_environment.single_action_spec.space, CategoricalBox):
97+
num_outputs = proof_environment.single_action_spec.space.n
9998
distribution_class = OneHotCategorical
10099
distribution_kwargs = {}
101100
else: # is ContinuousBox
102-
num_outputs = proof_environment.action_spec.shape
101+
num_outputs = proof_environment.single_action_spec.shape
103102
distribution_class = TanhNormal
104103
distribution_kwargs = {
105104
"low": proof_environment.action_spec_unbatched.space.low.to(device),
@@ -153,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device):
153152
policy_module = ProbabilisticActor(
154153
policy_module,
155154
in_keys=["logits"],
156-
spec=Composite(action=proof_environment.action_spec.to(device)),
155+
spec=proof_environment.single_full_action_spec.to(device),
157156
distribution_class=distribution_class,
158157
distribution_kwargs=distribution_kwargs,
159158
return_log_prob=True,

sota-implementations/a2c/utils_mujoco.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch.optim
1010

1111
from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
12-
from torchrl.data import Composite
1312
from torchrl.envs import (
1413
ClipTransform,
1514
DoubleToFloat,
@@ -55,7 +54,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
5554
input_shape = proof_environment.observation_spec["observation"].shape
5655

5756
# Define policy output distribution class
58-
num_outputs = proof_environment.action_spec.shape[-1]
57+
num_outputs = proof_environment.single_action_spec.shape[-1]
5958
distribution_class = TanhNormal
6059
distribution_kwargs = {
6160
"low": proof_environment.action_spec_unbatched.space.low.to(device),
@@ -83,7 +82,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
8382
policy_mlp = torch.nn.Sequential(
8483
policy_mlp,
8584
AddStateIndependentNormalScale(
86-
proof_environment.action_spec.shape[-1], device=device
85+
proof_environment.single_action_spec.shape[-1], device=device
8786
),
8887
)
8988

@@ -95,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
9594
out_keys=["loc", "scale"],
9695
),
9796
in_keys=["loc", "scale"],
98-
spec=Composite(action=proof_environment.action_spec.to(device)),
97+
spec=proof_environment.single_full_action_spec.to(device),
9998
distribution_class=distribution_class,
10099
distribution_kwargs=distribution_kwargs,
101100
return_log_prob=True,

sota-implementations/cql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"):
298298

299299

300300
def make_cql_modules_state(model_cfg, proof_environment):
301-
action_spec = proof_environment.action_spec
301+
action_spec = proof_environment.single_action_spec
302302

303303
actor_net_kwargs = {
304304
"num_cells": model_cfg.hidden_sizes,

sota-implementations/decision_transformer/dt.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66
This is a self-contained example of an offline Decision Transformer training script.
77
The helper functions are coded in the utils.py associated with this script.
88
"""
9+
910
from __future__ import annotations
1011

11-
import time
12+
import warnings
1213

1314
import hydra
1415
import numpy as np
1516
import torch
1617
import tqdm
17-
from torchrl._utils import logger as torchrl_logger
18+
from tensordict import TensorDict
19+
from tensordict.nn import CudaGraphModule
20+
from torchrl._utils import logger as torchrl_logger, timeit
1821
from torchrl.envs.libs.gym import set_gym_backend
1922

2023
from torchrl.envs.utils import ExplorationType, set_exploration_type
@@ -67,58 +70,77 @@ def main(cfg: "DictConfig"): # noqa: F821
6770
)
6871

6972
# Create policy model
70-
actor = make_dt_model(cfg)
71-
policy = actor.to(model_device)
73+
actor = make_dt_model(cfg, device=model_device)
7274

7375
# Create loss
74-
loss_module = make_dt_loss(cfg.loss, actor)
76+
loss_module = make_dt_loss(cfg.loss, actor, device=model_device)
7577

7678
# Create optimizer
7779
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)
7880

7981
# Create inference policy
8082
inference_policy = DecisionTransformerInferenceWrapper(
81-
policy=policy,
83+
policy=actor,
8284
inference_context=cfg.env.inference_context,
83-
).to(model_device)
85+
device=model_device,
86+
)
8487
inference_policy.set_tensor_keys(
8588
observation="observation_cat",
8689
action="action_cat",
8790
return_to_go="return_to_go_cat",
8891
)
8992

90-
pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)
91-
9293
pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
9394
clip_grad = cfg.optim.clip_grad
94-
eval_steps = cfg.logger.eval_steps
95-
pretrain_log_interval = cfg.logger.pretrain_log_interval
96-
reward_scaling = cfg.env.reward_scaling
9795

98-
torchrl_logger.info(" ***Pretraining*** ")
99-
# Pretraining
100-
start_time = time.time()
101-
for i in range(pretrain_gradient_steps):
102-
pbar.update(1)
103-
104-
# Sample data
105-
data = offline_buffer.sample()
96+
def update(data: TensorDict) -> TensorDict:
97+
transformer_optim.zero_grad(set_to_none=True)
10698
# Compute loss
107-
loss_vals = loss_module(data.to(model_device))
99+
loss_vals = loss_module(data)
108100
transformer_loss = loss_vals["loss"]
109101

110-
transformer_optim.zero_grad()
111-
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
112102
transformer_loss.backward()
103+
torch.nn.utils.clip_grad_norm_(actor.parameters(), clip_grad)
113104
transformer_optim.step()
114105

115-
scheduler.step()
106+
return loss_vals
107+
108+
if cfg.compile.compile:
109+
compile_mode = cfg.compile.compile_mode
110+
if compile_mode in ("", None):
111+
if cfg.compile.cudagraphs:
112+
compile_mode = "default"
113+
else:
114+
compile_mode = "reduce-overhead"
115+
update = torch.compile(update, mode=compile_mode, dynamic=True)
116+
if cfg.compile.cudagraphs:
117+
warnings.warn(
118+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
119+
category=UserWarning,
120+
)
121+
update = CudaGraphModule(update, warmup=50)
122+
123+
eval_steps = cfg.logger.eval_steps
124+
pretrain_log_interval = cfg.logger.pretrain_log_interval
125+
reward_scaling = cfg.env.reward_scaling
116126

127+
torchrl_logger.info(" ***Pretraining*** ")
128+
# Pretraining
129+
pbar = tqdm.tqdm(range(pretrain_gradient_steps))
130+
for i in pbar:
131+
# Sample data
132+
with timeit("rb - sample"):
133+
data = offline_buffer.sample().to(model_device)
134+
with timeit("update"):
135+
loss_vals = update(data)
136+
scheduler.step()
117137
# Log metrics
118138
to_log = {"train/loss": loss_vals["loss"]}
119139

120140
# Evaluation
121-
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
141+
with set_exploration_type(
142+
ExplorationType.DETERMINISTIC
143+
), torch.no_grad(), timeit("eval"):
122144
if i % pretrain_log_interval == 0:
123145
eval_td = test_env.rollout(
124146
max_steps=eval_steps,
@@ -129,13 +151,17 @@ def main(cfg: "DictConfig"): # noqa: F821
129151
to_log["eval/reward"] = (
130152
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
131153
)
154+
if i % 200 == 0:
155+
to_log.update(timeit.todict(prefix="time"))
156+
timeit.print()
157+
timeit.erase()
158+
132159
if logger is not None:
133160
log_metrics(logger, to_log, i)
134161

135162
pbar.close()
136163
if not test_env.is_closed:
137164
test_env.close()
138-
torchrl_logger.info(f"Training time: {time.time() - start_time}")
139165

140166

141167
if __name__ == "__main__":

sota-implementations/decision_transformer/dt_config.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ optim:
5555
# loss
5656
loss:
5757
loss_function: "l2"
58-
58+
59+
compile:
60+
compile: False
61+
compile_mode:
62+
cudagraphs: False
63+
5964
# transformer model
6065
transformer:
6166
n_embd: 128

sota-implementations/decision_transformer/odt_config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ replay_buffer:
4242

4343
# optimizer
4444
optim:
45+
optimizer: lamb
4546
device: null
4647
lr: 1.0e-4
4748
weight_decay: 5.0e-4
@@ -56,6 +57,11 @@ loss:
5657
alpha_init: 0.1
5758
target_entropy: auto
5859

60+
compile:
61+
compile: False
62+
compile_mode:
63+
cudagraphs: False
64+
5965
# transformer model
6066
transformer:
6167
n_embd: 512

sota-implementations/decision_transformer/online_dt.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
from __future__ import annotations
1010

1111
import time
12+
import warnings
1213

1314
import hydra
1415
import numpy as np
1516
import torch
1617
import tqdm
17-
from torchrl._utils import logger as torchrl_logger
18+
from tensordict.nn import CudaGraphModule
19+
from torchrl._utils import logger as torchrl_logger, timeit
1820
from torchrl.envs.libs.gym import set_gym_backend
19-
2021
from torchrl.envs.utils import ExplorationType, set_exploration_type
2122
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
2223
from torchrl.record import VideoRecorder
@@ -65,8 +66,7 @@ def main(cfg: "DictConfig"): # noqa: F821
6566
)
6667

6768
# Create policy model
68-
actor = make_odt_model(cfg)
69-
policy = actor.to(model_device)
69+
policy = make_odt_model(cfg, device=model_device)
7070

7171
# Create loss
7272
loss_module = make_odt_loss(cfg.loss, policy)
@@ -80,13 +80,46 @@ def main(cfg: "DictConfig"): # noqa: F821
8080
inference_policy = DecisionTransformerInferenceWrapper(
8181
policy=policy,
8282
inference_context=cfg.env.inference_context,
83-
).to(model_device)
83+
device=model_device,
84+
)
8485
inference_policy.set_tensor_keys(
8586
observation="observation_cat",
8687
action="action_cat",
8788
return_to_go="return_to_go_cat",
8889
)
8990

91+
def update(data):
92+
transformer_optim.zero_grad(set_to_none=True)
93+
temperature_optim.zero_grad(set_to_none=True)
94+
# Compute loss
95+
loss_vals = loss_module(data.to(model_device))
96+
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
97+
temperature_loss = loss_vals["loss_alpha"]
98+
99+
(temperature_loss + transformer_loss).backward()
100+
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
101+
102+
transformer_optim.step()
103+
temperature_optim.step()
104+
105+
return loss_vals.detach()
106+
107+
if cfg.compile.compile:
108+
compile_mode = cfg.compile.compile_mode
109+
if compile_mode in ("", None):
110+
compile_mode = "default"
111+
update = torch.compile(update, mode=compile_mode, dynamic=False)
112+
if cfg.compile.cudagraphs:
113+
warnings.warn(
114+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
115+
category=UserWarning,
116+
)
117+
if cfg.optim.optimizer == "lamb":
118+
raise ValueError(
119+
"cudagraphs isn't compatible with the Lamb optimizer. Use optim.optimizer=Adam instead."
120+
)
121+
update = CudaGraphModule(update, warmup=50)
122+
90123
pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)
91124

92125
pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
@@ -100,35 +133,29 @@ def main(cfg: "DictConfig"): # noqa: F821
100133
start_time = time.time()
101134
for i in range(pretrain_gradient_steps):
102135
pbar.update(1)
103-
# Sample data
104-
data = offline_buffer.sample()
105-
# Compute loss
106-
loss_vals = loss_module(data.to(model_device))
107-
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
108-
temperature_loss = loss_vals["loss_alpha"]
109-
110-
transformer_optim.zero_grad()
111-
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
112-
transformer_loss.backward()
113-
transformer_optim.step()
136+
with timeit("sample"):
137+
# Sample data
138+
data = offline_buffer.sample()
114139

115-
temperature_optim.zero_grad()
116-
temperature_loss.backward()
117-
temperature_optim.step()
140+
with timeit("update"):
141+
torch.compiler.cudagraph_mark_step_begin()
142+
loss_vals = update(data.to(model_device))
118143

119144
scheduler.step()
120145

121146
# Log metrics
122147
to_log = {
123-
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(),
124-
"train/loss_entropy": loss_vals["loss_entropy"].item(),
125-
"train/loss_alpha": loss_vals["loss_alpha"].item(),
126-
"train/alpha": loss_vals["alpha"].item(),
127-
"train/entropy": loss_vals["entropy"].item(),
148+
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"],
149+
"train/loss_entropy": loss_vals["loss_entropy"],
150+
"train/loss_alpha": loss_vals["loss_alpha"],
151+
"train/alpha": loss_vals["alpha"],
152+
"train/entropy": loss_vals["entropy"],
128153
}
129154

130155
# Evaluation
131-
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
156+
with torch.no_grad(), set_exploration_type(
157+
ExplorationType.DETERMINISTIC
158+
), timeit("eval"):
132159
inference_policy.eval()
133160
if i % pretrain_log_interval == 0:
134161
eval_td = test_env.rollout(
@@ -143,6 +170,11 @@ def main(cfg: "DictConfig"): # noqa: F821
143170
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
144171
)
145172

173+
if i % 200 == 0:
174+
to_log.update(timeit.todict(prefix="time"))
175+
timeit.print()
176+
timeit.erase()
177+
146178
if logger is not None:
147179
log_metrics(logger, to_log, i)
148180

0 commit comments

Comments
 (0)