Skip to content

Commit 01a421e

Browse files
author
Vincent Moens
committed
[Feature] CROSSQ compatibility with compile
ghstack-source-id: 98a2b30 Pull Request resolved: #2554
1 parent e2be42e commit 01a421e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+299
-102
lines changed

sota-implementations/a2c/a2c_atari.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import warnings
8+
59
import hydra
610
import torch
711

@@ -149,6 +153,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
149153
adv_module = torch.compile(adv_module, mode=compile_mode)
150154

151155
if cfg.compile.cudagraphs:
156+
warnings.warn(
157+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
158+
category=UserWarning,
159+
)
152160
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
153161
adv_module = CudaGraphModule(adv_module)
154162

sota-implementations/a2c/a2c_mujoco.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import warnings
8+
59
import hydra
610
import torch
711

@@ -145,6 +149,10 @@ def update(batch):
145149
adv_module = torch.compile(adv_module, mode=compile_mode)
146150

147151
if cfg.compile.cudagraphs:
152+
warnings.warn(
153+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
154+
category=UserWarning,
155+
)
148156
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=20)
149157
adv_module = CudaGraphModule(adv_module, warmup=20)
150158

sota-implementations/a2c/utils_atari.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
56

67
import numpy as np
78
import torch.nn

sota-implementations/a2c/utils_mujoco.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
56

67
import numpy as np
78
import torch.nn

sota-implementations/bandits/dqn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
56

67
import argparse
78

sota-implementations/cql/cql_offline.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
The helper functions are coded in the utils.py associated with this script.
1010
1111
"""
12+
from __future__ import annotations
13+
1214
import time
15+
import warnings
1316

1417
import hydra
1518
import numpy as np
19+
1620
import torch
1721
import tqdm
1822
from tensordict.nn import CudaGraphModule
@@ -32,6 +36,8 @@
3236
make_offline_replay_buffer,
3337
)
3438

39+
torch.set_float32_matmul_precision("high")
40+
3541

3642
@hydra.main(config_path="", config_name="offline_config", version_base="1.1")
3743
def main(cfg: "DictConfig"): # noqa: F821
@@ -77,7 +83,9 @@ def main(cfg: "DictConfig"): # noqa: F821
7783
eval_env.start()
7884

7985
# Create loss
80-
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
86+
loss_module, target_net_updater = make_continuous_loss(
87+
cfg.loss, model, device=device
88+
)
8189

8290
# Create Optimizer
8391
(
@@ -134,6 +142,10 @@ def update(data, policy_eval_start, iteration):
134142
compile_mode = "reduce-overhead"
135143
update = torch.compile(update, mode=compile_mode)
136144
if cfg.compile.cudagraphs:
145+
warnings.warn(
146+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
147+
category=UserWarning,
148+
)
137149
update = CudaGraphModule(update, warmup=50)
138150

139151
pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
@@ -154,6 +166,7 @@ def update(data, policy_eval_start, iteration):
154166

155167
with timeit("update"):
156168
# compute loss
169+
torch.compiler.cudagraph_mark_step_begin()
157170
i_device = torch.tensor(i, device=device)
158171
loss, loss_vals = update(
159172
data.to(device), policy_eval_start=policy_eval_start, iteration=i_device

sota-implementations/cql/cql_online.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
The helper functions are coded in the utils.py associated with this script.
1212
1313
"""
14+
from __future__ import annotations
15+
16+
import warnings
17+
1418
import hydra
1519
import numpy as np
1620
import torch
@@ -34,6 +38,8 @@
3438
make_replay_buffer,
3539
)
3640

41+
torch.set_float32_matmul_precision("high")
42+
3743

3844
@hydra.main(version_base="1.1", config_path="", config_name="online_config")
3945
def main(cfg: "DictConfig"): # noqa: F821
@@ -103,7 +109,9 @@ def main(cfg: "DictConfig"): # noqa: F821
103109
)
104110

105111
# Create loss
106-
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
112+
loss_module, target_net_updater = make_continuous_loss(
113+
cfg.loss, model, device=device
114+
)
107115

108116
# Create optimizer
109117
(
@@ -140,6 +148,10 @@ def update(sampled_tensordict):
140148
if compile_mode:
141149
update = torch.compile(update, mode=compile_mode)
142150
if cfg.compile.cudagraphs:
151+
warnings.warn(
152+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
153+
category=UserWarning,
154+
)
143155
update = CudaGraphModule(update, warmup=50)
144156

145157
# Main loop

sota-implementations/cql/discrete_cql_online.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
1111
The helper functions are coded in the utils.py associated with this script.
1212
"""
13+
from __future__ import annotations
14+
15+
import warnings
1316

1417
import hydra
1518
import numpy as np
19+
1620
import torch
1721
import torch.cuda
1822
import tqdm
@@ -33,6 +37,8 @@
3337
make_replay_buffer,
3438
)
3539

40+
torch.set_float32_matmul_precision("high")
41+
3642

3743
@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config")
3844
def main(cfg: "DictConfig"): # noqa: F821
@@ -70,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821
7076
model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)
7177

7278
# Create loss
73-
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)
79+
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)
7480

7581
compile_mode = None
7682
if cfg.compile.compile:
@@ -123,6 +129,10 @@ def update(sampled_tensordict):
123129
if compile_mode:
124130
update = torch.compile(update, mode=compile_mode)
125131
if cfg.compile.cudagraphs:
132+
warnings.warn(
133+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
134+
category=UserWarning,
135+
)
126136
update = CudaGraphModule(update, warmup=50)
127137

128138
# Main loop
@@ -170,6 +180,7 @@ def update(sampled_tensordict):
170180
sampled_tensordict = replay_buffer.sample()
171181
sampled_tensordict = sampled_tensordict.to(device)
172182
with timeit("update"):
183+
torch.compiler.cudagraph_mark_step_begin()
173184
loss_dict = update(sampled_tensordict)
174185
tds.append(loss_dict)
175186

sota-implementations/cql/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
import functools
68

79
import torch.nn
@@ -221,8 +223,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
221223
# distribution_kwargs=TensorDictParams(
222224
# TensorDict(
223225
# {
224-
# "low": action_spec.space.low,
225-
# "high": action_spec.space.high,
226+
# "low": torch.as_tensor(action_spec.space.low, device=device),
227+
# "high": torch.as_tensor(action_spec.space.high, device=device),
226228
# "tanh_loc": NonTensorData(False),
227229
# }
228230
# ),
@@ -326,7 +328,7 @@ def make_cql_modules_state(model_cfg, proof_environment):
326328
# ---------
327329

328330

329-
def make_continuous_loss(loss_cfg, model):
331+
def make_continuous_loss(loss_cfg, model, device: torch.device | None = None):
330332
loss_module = CQLLoss(
331333
model[0],
332334
model[1],
@@ -339,19 +341,19 @@ def make_continuous_loss(loss_cfg, model):
339341
with_lagrange=loss_cfg.with_lagrange,
340342
lagrange_thresh=loss_cfg.lagrange_thresh,
341343
)
342-
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
344+
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
343345
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
344346

345347
return loss_module, target_net_updater
346348

347349

348-
def make_discrete_loss(loss_cfg, model):
350+
def make_discrete_loss(loss_cfg, model, device: torch.device | None = None):
349351
loss_module = DiscreteCQLLoss(
350352
model,
351353
loss_function=loss_cfg.loss_function,
352354
delay_value=True,
353355
)
354-
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
356+
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
355357
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
356358

357359
return loss_module, target_net_updater

sota-implementations/crossq/config.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ collector:
1212
init_random_frames: 25000
1313
frames_per_batch: 1000
1414
init_env_steps: 1000
15-
device: cpu
15+
device:
1616
env_per_collector: 1
1717
reset_at_each_iter: False
1818

@@ -46,7 +46,12 @@ network:
4646
actor_activation: relu
4747
default_policy_scale: 1.0
4848
scale_lb: 0.1
49-
device: "cuda:0"
49+
device:
50+
51+
compile:
52+
compile: False
53+
compile_mode:
54+
cudagraphs: False
5055

5156
# logging
5257
logger:

0 commit comments

Comments
 (0)