Skip to content

Commit 6482766

Browse files
author
Vincent Moens
committed
[Feature] GAIL compatibility with compile
ghstack-source-id: 98c7602 Pull Request resolved: #2573
1 parent f149811 commit 6482766

24 files changed

+249
-119
lines changed

sota-implementations/a2c/utils_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device):
152152
policy_module = ProbabilisticActor(
153153
policy_module,
154154
in_keys=["logits"],
155-
spec=proof_environment.single_full_action_spec.to(device),
155+
spec=proof_environment.full_action_spec_unbatched.to(device),
156156
distribution_class=distribution_class,
157157
distribution_kwargs=distribution_kwargs,
158158
return_log_prob=True,

sota-implementations/a2c/utils_mujoco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
9494
out_keys=["loc", "scale"],
9595
),
9696
in_keys=["loc", "scale"],
97-
spec=proof_environment.single_full_action_spec.to(device),
97+
spec=proof_environment.full_action_spec_unbatched.to(device),
9898
distribution_class=distribution_class,
9999
distribution_kwargs=distribution_kwargs,
100100
return_log_prob=True,

sota-implementations/cql/discrete_cql_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ collector:
1414
multi_step: 0
1515
init_random_frames: 1000
1616
env_per_collector: 1
17-
device: cpu
17+
device:
1818
max_frames_per_traj: 200
1919
annealing_frames: 10000
2020
eps_start: 1.0

sota-implementations/cql/online_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ collector:
1515
multi_step: 0
1616
init_random_frames: 5_000
1717
env_per_collector: 1
18-
device: cpu
18+
device:
1919
max_frames_per_traj: 1000
2020

2121

sota-implementations/cql/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,20 @@ def make_collector(
124124
cudagraph=False,
125125
):
126126
"""Make collector."""
127+
device = cfg.collector.device
128+
if device in ("", None):
129+
if torch.cuda.is_available():
130+
device = torch.device("cuda:0")
131+
else:
132+
device = torch.device("cpu")
127133
collector = SyncDataCollector(
128134
train_env,
129135
actor_model_explore,
130136
init_random_frames=cfg.collector.init_random_frames,
131137
frames_per_batch=cfg.collector.frames_per_batch,
132138
max_frames_per_traj=cfg.collector.max_frames_per_traj,
133139
total_frames=cfg.collector.total_frames,
134-
device=cfg.collector.device,
140+
device=device,
135141
compile_policy={"mode": compile_mode} if compile else False,
136142
cudagraph_policy=cudagraph,
137143
)

sota-implementations/dreamer/dreamer_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def _dreamer_make_actor_real(
546546
default_interaction_type=InteractionType.DETERMINISTIC,
547547
distribution_class=TanhNormal,
548548
distribution_kwargs={"tanh_loc": True},
549-
spec=proof_environment.single_full_action_spec.to("cpu"),
549+
spec=proof_environment.full_action_spec_unbatched.to("cpu"),
550550
),
551551
),
552552
SafeModule(

sota-implementations/gail/config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ gail:
4141
gp_lambda: 10.0
4242
device: null
4343

44+
compile:
45+
compile: False
46+
compile_mode: default
47+
cudagraphs: False
48+
4449
replay_buffer:
4550
dataset: halfcheetah-expert-v2
4651
batch_size: 256

sota-implementations/gail/gail.py

Lines changed: 113 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,33 @@
1111
"""
1212
from __future__ import annotations
1313

14+
import warnings
15+
1416
import hydra
1517
import numpy as np
1618
import torch
1719
import tqdm
1820

1921
from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
2022
from ppo_utils import eval_model, make_env, make_ppo_models
23+
from tensordict.nn import CudaGraphModule
24+
25+
from torchrl._utils import compile_with_warmup
2126
from torchrl.collectors import SyncDataCollector
22-
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
27+
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
2328
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
2429

2530
from torchrl.envs import set_gym_backend
2631
from torchrl.envs.utils import ExplorationType, set_exploration_type
27-
from torchrl.objectives import ClipPPOLoss, GAILLoss
32+
from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers
2833
from torchrl.objectives.value.advantages import GAE
2934
from torchrl.record import VideoRecorder
3035
from torchrl.record.loggers import generate_exp_name, get_logger
3136

3237

38+
torch.set_float32_matmul_precision("high")
39+
40+
3341
@hydra.main(config_path="", config_name="config")
3442
def main(cfg: "DictConfig"): # noqa: F821
3543
set_gym_backend(cfg.env.backend).set()
@@ -71,25 +79,20 @@ def main(cfg: "DictConfig"): # noqa: F821
7179
np.random.seed(cfg.env.seed)
7280

7381
# Create models (check utils_mujoco.py)
74-
actor, critic = make_ppo_models(cfg.env.env_name)
75-
actor, critic = actor.to(device), critic.to(device)
76-
77-
# Create collector
78-
collector = SyncDataCollector(
79-
create_env_fn=make_env(cfg.env.env_name, device),
80-
policy=actor,
81-
frames_per_batch=cfg.ppo.collector.frames_per_batch,
82-
total_frames=cfg.ppo.collector.total_frames,
83-
device=device,
84-
storing_device=device,
85-
max_frames_per_traj=-1,
82+
actor, critic = make_ppo_models(
83+
cfg.env.env_name, compile=cfg.compile.compile, device=device
8684
)
8785

8886
# Create data buffer
8987
data_buffer = TensorDictReplayBuffer(
90-
storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch),
88+
storage=LazyTensorStorage(
89+
cfg.ppo.collector.frames_per_batch,
90+
device=device,
91+
compilable=cfg.compile.compile,
92+
),
9193
sampler=SamplerWithoutReplacement(),
9294
batch_size=cfg.ppo.loss.mini_batch_size,
95+
compilable=cfg.compile.compile,
9396
)
9497

9598
# Create loss and adv modules
@@ -98,6 +101,7 @@ def main(cfg: "DictConfig"): # noqa: F821
98101
lmbda=cfg.ppo.loss.gae_lambda,
99102
value_network=critic,
100103
average_gae=False,
104+
device=device,
101105
)
102106

103107
loss_module = ClipPPOLoss(
@@ -111,8 +115,35 @@ def main(cfg: "DictConfig"): # noqa: F821
111115
)
112116

113117
# Create optimizers
114-
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
115-
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
118+
actor_optim = torch.optim.Adam(
119+
actor.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
120+
)
121+
critic_optim = torch.optim.Adam(
122+
critic.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
123+
)
124+
optim = group_optimizers(actor_optim, critic_optim)
125+
del actor_optim, critic_optim
126+
127+
compile_mode = None
128+
if cfg.compile.compile:
129+
compile_mode = cfg.compile.compile_mode
130+
if compile_mode in ("", None):
131+
if cfg.compile.cudagraphs:
132+
compile_mode = "default"
133+
else:
134+
compile_mode = "reduce-overhead"
135+
136+
# Create collector
137+
collector = SyncDataCollector(
138+
create_env_fn=make_env(cfg.env.env_name, device),
139+
policy=actor,
140+
frames_per_batch=cfg.ppo.collector.frames_per_batch,
141+
total_frames=cfg.ppo.collector.total_frames,
142+
device=device,
143+
max_frames_per_traj=-1,
144+
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
145+
cudagraph_policy=cfg.compile.cudagraphs,
146+
)
116147

117148
# Create replay buffer
118149
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
@@ -140,32 +171,9 @@ def main(cfg: "DictConfig"): # noqa: F821
140171
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
141172
)
142173
test_env.eval()
174+
num_network_updates = torch.zeros((), dtype=torch.int64, device=device)
143175

144-
# Training loop
145-
collected_frames = 0
146-
num_network_updates = 0
147-
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)
148-
149-
# extract cfg variables
150-
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
151-
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
152-
cfg_optim_lr = cfg.ppo.optim.lr
153-
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
154-
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
155-
cfg_logger_test_interval = cfg.logger.test_interval
156-
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
157-
158-
for i, data in enumerate(collector):
159-
160-
log_info = {}
161-
frames_in_batch = data.numel()
162-
collected_frames += frames_in_batch
163-
pbar.update(data.numel())
164-
165-
# Update discriminator
166-
# Get expert data
167-
expert_data = replay_buffer.sample()
168-
expert_data = expert_data.to(device)
176+
def update(data, expert_data, num_network_updates=num_network_updates):
169177
# Add collector data to expert data
170178
expert_data.set(
171179
discriminator_loss.tensor_keys.collector_action,
@@ -178,9 +186,9 @@ def main(cfg: "DictConfig"): # noqa: F821
178186
d_loss = discriminator_loss(expert_data)
179187

180188
# Backward pass
181-
discriminator_optim.zero_grad()
182189
d_loss.get("loss").backward()
183190
discriminator_optim.step()
191+
discriminator_optim.zero_grad(set_to_none=True)
184192

185193
# Compute discriminator reward
186194
with torch.no_grad():
@@ -190,40 +198,25 @@ def main(cfg: "DictConfig"): # noqa: F821
190198
# Set discriminator rewards to tensordict
191199
data.set(("next", "reward"), d_rewards)
192200

193-
# Get training rewards and episode lengths
194-
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
195-
if len(episode_rewards) > 0:
196-
episode_length = data["next", "step_count"][data["next", "done"]]
197-
log_info.update(
198-
{
199-
"train/reward": episode_rewards.mean().item(),
200-
"train/episode_length": episode_length.sum().item()
201-
/ len(episode_length),
202-
}
203-
)
204201
# Update PPO
205202
for _ in range(cfg_loss_ppo_epochs):
206-
207203
# Compute GAE
208204
with torch.no_grad():
209205
data = adv_module(data)
210206
data_reshape = data.reshape(-1)
211207

212208
# Update the data buffer
209+
data_buffer.empty()
213210
data_buffer.extend(data_reshape)
214211

215-
for _, batch in enumerate(data_buffer):
216-
217-
# Get a data batch
218-
batch = batch.to(device)
212+
for batch in data_buffer:
213+
optim.zero_grad(set_to_none=True)
219214

220215
# Linearly decrease the learning rate and clip epsilon
221-
alpha = 1.0
216+
alpha = torch.ones((), device=device)
222217
if cfg_optim_anneal_lr:
223218
alpha = 1 - (num_network_updates / total_network_updates)
224-
for group in actor_optim.param_groups:
225-
group["lr"] = cfg_optim_lr * alpha
226-
for group in critic_optim.param_groups:
219+
for group in optim.param_groups:
227220
group["lr"] = cfg_optim_lr * alpha
228221
if cfg_loss_anneal_clip_eps:
229222
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
@@ -235,20 +228,68 @@ def main(cfg: "DictConfig"): # noqa: F821
235228
actor_loss = loss["loss_objective"] + loss["loss_entropy"]
236229

237230
# Backward pass
238-
actor_loss.backward()
239-
critic_loss.backward()
231+
(actor_loss + critic_loss).backward()
240232

241233
# Update the networks
242-
actor_optim.step()
243-
critic_optim.step()
244-
actor_optim.zero_grad()
245-
critic_optim.zero_grad()
234+
optim.step()
235+
return {"dloss": d_loss, "alpha": alpha}
236+
237+
if cfg.compile.compile:
238+
update = compile_with_warmup(update, warmup=2, mode=compile_mode)
239+
if cfg.compile.cudagraphs:
240+
warnings.warn(
241+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
242+
category=UserWarning,
243+
)
244+
update = CudaGraphModule(update, warmup=50)
245+
246+
# Training loop
247+
collected_frames = 0
248+
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)
249+
250+
# extract cfg variables
251+
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
252+
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
253+
cfg_optim_lr = cfg.ppo.optim.lr
254+
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
255+
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
256+
cfg_logger_test_interval = cfg.logger.test_interval
257+
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
258+
259+
for i, data in enumerate(collector):
260+
261+
log_info = {}
262+
frames_in_batch = data.numel()
263+
collected_frames += frames_in_batch
264+
pbar.update(data.numel())
265+
266+
# Update discriminator
267+
# Get expert data
268+
expert_data = replay_buffer.sample()
269+
expert_data = expert_data.to(device)
270+
271+
metadata = update(data, expert_data)
272+
d_loss = metadata["dloss"]
273+
alpha = metadata["alpha"]
274+
275+
# Get training rewards and episode lengths
276+
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
277+
if len(episode_rewards) > 0:
278+
episode_length = data["next", "step_count"][data["next", "done"]]
279+
280+
log_info.update(
281+
{
282+
"train/reward": episode_rewards.mean().item(),
283+
"train/episode_length": episode_length.sum().item()
284+
/ len(episode_length),
285+
}
286+
)
246287

247288
log_info.update(
248289
{
249-
"train/actor_loss": actor_loss.item(),
250-
"train/critic_loss": critic_loss.item(),
251-
"train/discriminator_loss": d_loss["loss"].item(),
290+
# "train/actor_loss": actor_loss.item(),
291+
# "train/critic_loss": critic_loss.item(),
292+
"train/discriminator_loss": d_loss["loss"],
252293
"train/lr": alpha * cfg_optim_lr,
253294
"train/clip_epsilon": (
254295
alpha * cfg_loss_clip_epsilon

0 commit comments

Comments
 (0)