Skip to content

Commit 507766a

Browse files
author
Vincent Moens
committed
[Feature] A2C compatibility with compile
ghstack-source-id: 66a7f0d Pull Request resolved: #2464
1 parent 1474f85 commit 507766a

File tree

21 files changed

+681
-338
lines changed

21 files changed

+681
-338
lines changed

benchmarks/test_objectives_benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
) # Anything from 2.5, incl. nightlies, allows for fullgraph
5151

5252

53-
@pytest.fixture(scope="module")
53+
@pytest.fixture(scope="module", autouse=True)
5454
def set_default_device():
5555
cur_device = torch.get_default_device()
5656
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

sota-implementations/a2c/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,21 @@ Please note that each example is independent of each other for the sake of simpl
1919
You can execute the A2C algorithm on Atari environments by running the following command:
2020

2121
```bash
22-
python a2c_atari.py
22+
python a2c_atari.py compile.compile=1 compile.cudagraphs=1
2323
```
2424

25+
2526
You can execute the A2C algorithm on MuJoCo environments by running the following command:
2627

2728
```bash
28-
python a2c_mujoco.py
29+
python a2c_mujoco.py compile.compile=1 compile.cudagraphs=1
2930
```
31+
32+
## Runtimes
33+
34+
Runtimes when executed on H100:
35+
36+
| Environment | Eager | Compile | Compile+cudagraphs |
37+
|-------------|-----------|-----------|--------------------|
38+
| MUJOCO | < 25 mins | < 23 mins | < 20 mins |
39+
| ATARI | < 85 mins | < 60 mins | < 45 mins |

sota-implementations/a2c/a2c_atari.py

Lines changed: 114 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,37 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import hydra
6-
from torchrl._utils import logger as torchrl_logger
7-
from torchrl.record import VideoRecorder
6+
import torch
7+
8+
torch.set_float32_matmul_precision("high")
89

910

1011
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
1112
def main(cfg: "DictConfig"): # noqa: F821
1213

13-
import time
14+
from copy import deepcopy
1415

1516
import torch.optim
1617
import tqdm
18+
from tensordict import from_module
19+
from tensordict.nn import CudaGraphModule
1720

18-
from tensordict import TensorDict
21+
from torchrl._utils import timeit
1922
from torchrl.collectors import SyncDataCollector
20-
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
23+
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
2124
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
2225
from torchrl.envs import ExplorationType, set_exploration_type
2326
from torchrl.objectives import A2CLoss
2427
from torchrl.objectives.value.advantages import GAE
28+
from torchrl.record import VideoRecorder
2529
from torchrl.record.loggers import generate_exp_name, get_logger
2630
from utils_atari import eval_model, make_parallel_env, make_ppo_models
2731

28-
device = "cpu" if not torch.cuda.device_count() else "cuda"
32+
device = cfg.loss.device
33+
if not device:
34+
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
35+
else:
36+
device = torch.device(device)
2937

3038
# Correct for frame_skip
3139
frame_skip = 4
@@ -35,28 +43,16 @@ def main(cfg: "DictConfig"): # noqa: F821
3543
test_interval = cfg.logger.test_interval // frame_skip
3644

3745
# Create models (check utils_atari.py)
38-
actor, critic, critic_head = make_ppo_models(cfg.env.env_name)
39-
actor, critic, critic_head = (
40-
actor.to(device),
41-
critic.to(device),
42-
critic_head.to(device),
43-
)
44-
45-
# Create collector
46-
collector = SyncDataCollector(
47-
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
48-
policy=actor,
49-
frames_per_batch=frames_per_batch,
50-
total_frames=total_frames,
51-
device=device,
52-
storing_device=device,
53-
max_frames_per_traj=-1,
54-
)
46+
actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device)
47+
with from_module(actor).data.to("meta").to_module(actor):
48+
actor_eval = deepcopy(actor)
49+
actor_eval.eval()
50+
from_module(actor).data.to_module(actor_eval)
5551

5652
# Create data buffer
5753
sampler = SamplerWithoutReplacement()
5854
data_buffer = TensorDictReplayBuffer(
59-
storage=LazyMemmapStorage(frames_per_batch),
55+
storage=LazyTensorStorage(frames_per_batch, device=device),
6056
sampler=sampler,
6157
batch_size=mini_batch_size,
6258
)
@@ -67,6 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821
6763
lmbda=cfg.loss.gae_lambda,
6864
value_network=critic,
6965
average_gae=True,
66+
vectorized=not cfg.compile.compile,
67+
device=device,
7068
)
7169
loss_module = A2CLoss(
7270
actor_network=actor,
@@ -83,9 +81,10 @@ def main(cfg: "DictConfig"): # noqa: F821
8381
# Create optimizer
8482
optim = torch.optim.Adam(
8583
loss_module.parameters(),
86-
lr=cfg.optim.lr,
84+
lr=torch.tensor(cfg.optim.lr, device=device),
8785
weight_decay=cfg.optim.weight_decay,
8886
eps=cfg.optim.eps,
87+
capturable=device.type == "cuda",
8988
)
9089

9190
# Create logger
@@ -115,19 +114,71 @@ def main(cfg: "DictConfig"): # noqa: F821
115114
)
116115
test_env.eval()
117116

117+
# update function
118+
def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
119+
# Forward pass A2C loss
120+
loss = loss_module(batch)
121+
122+
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
123+
124+
# Backward pass
125+
loss_sum.backward()
126+
gn = torch.nn.utils.clip_grad_norm_(
127+
loss_module.parameters(), max_norm=max_grad_norm
128+
)
129+
130+
# Update the networks
131+
optim.step()
132+
optim.zero_grad(set_to_none=True)
133+
134+
return (
135+
loss.select("loss_critic", "loss_entropy", "loss_objective")
136+
.detach()
137+
.set("grad_norm", gn)
138+
)
139+
140+
compile_mode = None
141+
if cfg.compile.compile:
142+
compile_mode = cfg.compile.compile_mode
143+
if compile_mode in ("", None):
144+
if cfg.compile.cudagraphs:
145+
compile_mode = "default"
146+
else:
147+
compile_mode = "reduce-overhead"
148+
update = torch.compile(update, mode=compile_mode)
149+
adv_module = torch.compile(adv_module, mode=compile_mode)
150+
151+
if cfg.compile.cudagraphs:
152+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
153+
adv_module = CudaGraphModule(adv_module)
154+
155+
# Create collector
156+
collector = SyncDataCollector(
157+
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
158+
policy=actor,
159+
frames_per_batch=frames_per_batch,
160+
total_frames=total_frames,
161+
device=device,
162+
storing_device=device,
163+
policy_device=device,
164+
compile_policy={"mode": compile_mode} if cfg.compile.compile else False,
165+
cudagraph_policy=cfg.compile.cudagraphs,
166+
)
167+
118168
# Main loop
119169
collected_frames = 0
120170
num_network_updates = 0
121-
start_time = time.time()
122171
pbar = tqdm.tqdm(total=total_frames)
123172
num_mini_batches = frames_per_batch // mini_batch_size
124173
total_network_updates = (total_frames // frames_per_batch) * num_mini_batches
174+
lr = cfg.optim.lr
125175

126-
sampling_start = time.time()
127-
for i, data in enumerate(collector):
176+
c_iter = iter(collector)
177+
for i in range(len(collector)):
178+
with timeit("collecting"):
179+
data = next(c_iter)
128180

129181
log_info = {}
130-
sampling_time = time.time() - sampling_start
131182
frames_in_batch = data.numel()
132183
collected_frames += frames_in_batch * frame_skip
133184
pbar.update(data.numel())
@@ -144,94 +195,76 @@ def main(cfg: "DictConfig"): # noqa: F821
144195
}
145196
)
146197

147-
losses = TensorDict(batch_size=[num_mini_batches])
148-
training_start = time.time()
198+
losses = []
149199

150200
# Compute GAE
151-
with torch.no_grad():
201+
with torch.no_grad(), timeit("advantage"):
202+
torch.compiler.cudagraph_mark_step_begin()
152203
data = adv_module(data)
153204
data_reshape = data.reshape(-1)
154205

155206
# Update the data buffer
156-
data_buffer.extend(data_reshape)
157-
158-
for k, batch in enumerate(data_buffer):
159-
160-
# Get a data batch
161-
batch = batch.to(device)
162-
163-
# Linearly decrease the learning rate and clip epsilon
164-
alpha = 1.0
165-
if cfg.optim.anneal_lr:
166-
alpha = 1 - (num_network_updates / total_network_updates)
167-
for group in optim.param_groups:
168-
group["lr"] = cfg.optim.lr * alpha
169-
num_network_updates += 1
170-
171-
# Forward pass A2C loss
172-
loss = loss_module(batch)
173-
losses[k] = loss.select(
174-
"loss_critic", "loss_entropy", "loss_objective"
175-
).detach()
176-
loss_sum = (
177-
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
178-
)
207+
with timeit("rb - emptying"):
208+
data_buffer.empty()
209+
with timeit("rb - extending"):
210+
data_buffer.extend(data_reshape)
179211

180-
# Backward pass
181-
loss_sum.backward()
182-
torch.nn.utils.clip_grad_norm_(
183-
list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm
184-
)
212+
with timeit("optim"):
213+
for batch in data_buffer:
185214

186-
# Update the networks
187-
optim.step()
188-
optim.zero_grad()
215+
# Linearly decrease the learning rate and clip epsilon
216+
with timeit("optim - lr"):
217+
alpha = 1.0
218+
if cfg.optim.anneal_lr:
219+
alpha = 1 - (num_network_updates / total_network_updates)
220+
for group in optim.param_groups:
221+
group["lr"].copy_(lr * alpha)
222+
223+
num_network_updates += 1
224+
225+
with timeit("update"):
226+
torch.compiler.cudagraph_mark_step_begin()
227+
loss = update(batch).clone()
228+
losses.append(loss)
189229

190230
# Get training losses
191-
training_time = time.time() - training_start
192-
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
231+
losses = torch.stack(losses).float().mean()
232+
193233
for key, value in losses.items():
194234
log_info.update({f"train/{key}": value.item()})
195235
log_info.update(
196236
{
197-
"train/lr": alpha * cfg.optim.lr,
198-
"train/sampling_time": sampling_time,
199-
"train/training_time": training_time,
237+
"train/lr": lr * alpha,
200238
}
201239
)
202240

203241
# Get test rewards
204-
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
242+
with torch.no_grad(), set_exploration_type(
243+
ExplorationType.DETERMINISTIC
244+
), timeit("eval"):
205245
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
206246
i * frames_in_batch * frame_skip
207247
) // test_interval:
208-
actor.eval()
209-
eval_start = time.time()
210248
test_rewards = eval_model(
211-
actor, test_env, num_episodes=cfg.logger.num_test_episodes
249+
actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes
212250
)
213-
eval_time = time.time() - eval_start
214251
log_info.update(
215252
{
216253
"test/reward": test_rewards.mean(),
217-
"test/eval_time": eval_time,
218254
}
219255
)
220-
actor.train()
256+
if i % 200 == 0:
257+
log_info.update(timeit.todict(prefix="time"))
258+
timeit.print()
259+
timeit.erase()
221260

222261
if logger:
223262
for key, value in log_info.items():
224263
logger.log_scalar(key, value, collected_frames)
225264

226-
collector.update_policy_weights_()
227-
sampling_start = time.time()
228-
229265
collector.shutdown()
230266
if not test_env.is_closed:
231267
test_env.close()
232-
end_time = time.time()
233-
execution_time = end_time - start_time
234-
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
235268

236269

237270
if __name__ == "__main__":

0 commit comments

Comments
 (0)