11
11
"""
12
12
from __future__ import annotations
13
13
14
+ import warnings
15
+
14
16
import hydra
15
17
import numpy as np
16
18
import torch
17
19
import tqdm
18
20
19
21
from gail_utils import log_metrics , make_gail_discriminator , make_offline_replay_buffer
20
22
from ppo_utils import eval_model , make_env , make_ppo_models
23
+ from tensordict .nn import CudaGraphModule
24
+
25
+ from torchrl ._utils import compile_with_warmup
21
26
from torchrl .collectors import SyncDataCollector
22
- from torchrl .data import LazyMemmapStorage , TensorDictReplayBuffer
27
+ from torchrl .data import LazyTensorStorage , TensorDictReplayBuffer
23
28
from torchrl .data .replay_buffers .samplers import SamplerWithoutReplacement
24
29
25
30
from torchrl .envs import set_gym_backend
26
31
from torchrl .envs .utils import ExplorationType , set_exploration_type
27
- from torchrl .objectives import ClipPPOLoss , GAILLoss
32
+ from torchrl .objectives import ClipPPOLoss , GAILLoss , group_optimizers
28
33
from torchrl .objectives .value .advantages import GAE
29
34
from torchrl .record import VideoRecorder
30
35
from torchrl .record .loggers import generate_exp_name , get_logger
31
36
32
37
38
+ torch .set_float32_matmul_precision ("high" )
39
+
40
+
33
41
@hydra .main (config_path = "" , config_name = "config" )
34
42
def main (cfg : "DictConfig" ): # noqa: F821
35
43
set_gym_backend (cfg .env .backend ).set ()
@@ -71,25 +79,20 @@ def main(cfg: "DictConfig"): # noqa: F821
71
79
np .random .seed (cfg .env .seed )
72
80
73
81
# Create models (check utils_mujoco.py)
74
- actor , critic = make_ppo_models (cfg .env .env_name )
75
- actor , critic = actor .to (device ), critic .to (device )
76
-
77
- # Create collector
78
- collector = SyncDataCollector (
79
- create_env_fn = make_env (cfg .env .env_name , device ),
80
- policy = actor ,
81
- frames_per_batch = cfg .ppo .collector .frames_per_batch ,
82
- total_frames = cfg .ppo .collector .total_frames ,
83
- device = device ,
84
- storing_device = device ,
85
- max_frames_per_traj = - 1 ,
82
+ actor , critic = make_ppo_models (
83
+ cfg .env .env_name , compile = cfg .compile .compile , device = device
86
84
)
87
85
88
86
# Create data buffer
89
87
data_buffer = TensorDictReplayBuffer (
90
- storage = LazyMemmapStorage (cfg .ppo .collector .frames_per_batch ),
88
+ storage = LazyTensorStorage (
89
+ cfg .ppo .collector .frames_per_batch ,
90
+ device = device ,
91
+ compilable = cfg .compile .compile ,
92
+ ),
91
93
sampler = SamplerWithoutReplacement (),
92
94
batch_size = cfg .ppo .loss .mini_batch_size ,
95
+ compilable = cfg .compile .compile ,
93
96
)
94
97
95
98
# Create loss and adv modules
@@ -98,6 +101,7 @@ def main(cfg: "DictConfig"): # noqa: F821
98
101
lmbda = cfg .ppo .loss .gae_lambda ,
99
102
value_network = critic ,
100
103
average_gae = False ,
104
+ device = device ,
101
105
)
102
106
103
107
loss_module = ClipPPOLoss (
@@ -111,8 +115,35 @@ def main(cfg: "DictConfig"): # noqa: F821
111
115
)
112
116
113
117
# Create optimizers
114
- actor_optim = torch .optim .Adam (actor .parameters (), lr = cfg .ppo .optim .lr , eps = 1e-5 )
115
- critic_optim = torch .optim .Adam (critic .parameters (), lr = cfg .ppo .optim .lr , eps = 1e-5 )
118
+ actor_optim = torch .optim .Adam (
119
+ actor .parameters (), lr = torch .tensor (cfg .ppo .optim .lr , device = device ), eps = 1e-5
120
+ )
121
+ critic_optim = torch .optim .Adam (
122
+ critic .parameters (), lr = torch .tensor (cfg .ppo .optim .lr , device = device ), eps = 1e-5
123
+ )
124
+ optim = group_optimizers (actor_optim , critic_optim )
125
+ del actor_optim , critic_optim
126
+
127
+ compile_mode = None
128
+ if cfg .compile .compile :
129
+ compile_mode = cfg .compile .compile_mode
130
+ if compile_mode in ("" , None ):
131
+ if cfg .compile .cudagraphs :
132
+ compile_mode = "default"
133
+ else :
134
+ compile_mode = "reduce-overhead"
135
+
136
+ # Create collector
137
+ collector = SyncDataCollector (
138
+ create_env_fn = make_env (cfg .env .env_name , device ),
139
+ policy = actor ,
140
+ frames_per_batch = cfg .ppo .collector .frames_per_batch ,
141
+ total_frames = cfg .ppo .collector .total_frames ,
142
+ device = device ,
143
+ max_frames_per_traj = - 1 ,
144
+ compile_policy = {"mode" : compile_mode } if compile_mode is not None else False ,
145
+ cudagraph_policy = cfg .compile .cudagraphs ,
146
+ )
116
147
117
148
# Create replay buffer
118
149
replay_buffer = make_offline_replay_buffer (cfg .replay_buffer )
@@ -140,32 +171,9 @@ def main(cfg: "DictConfig"): # noqa: F821
140
171
VideoRecorder (logger , tag = "rendering/test" , in_keys = ["pixels" ])
141
172
)
142
173
test_env .eval ()
174
+ num_network_updates = torch .zeros ((), dtype = torch .int64 , device = device )
143
175
144
- # Training loop
145
- collected_frames = 0
146
- num_network_updates = 0
147
- pbar = tqdm .tqdm (total = cfg .ppo .collector .total_frames )
148
-
149
- # extract cfg variables
150
- cfg_loss_ppo_epochs = cfg .ppo .loss .ppo_epochs
151
- cfg_optim_anneal_lr = cfg .ppo .optim .anneal_lr
152
- cfg_optim_lr = cfg .ppo .optim .lr
153
- cfg_loss_anneal_clip_eps = cfg .ppo .loss .anneal_clip_epsilon
154
- cfg_loss_clip_epsilon = cfg .ppo .loss .clip_epsilon
155
- cfg_logger_test_interval = cfg .logger .test_interval
156
- cfg_logger_num_test_episodes = cfg .logger .num_test_episodes
157
-
158
- for i , data in enumerate (collector ):
159
-
160
- log_info = {}
161
- frames_in_batch = data .numel ()
162
- collected_frames += frames_in_batch
163
- pbar .update (data .numel ())
164
-
165
- # Update discriminator
166
- # Get expert data
167
- expert_data = replay_buffer .sample ()
168
- expert_data = expert_data .to (device )
176
+ def update (data , expert_data , num_network_updates = num_network_updates ):
169
177
# Add collector data to expert data
170
178
expert_data .set (
171
179
discriminator_loss .tensor_keys .collector_action ,
@@ -178,9 +186,9 @@ def main(cfg: "DictConfig"): # noqa: F821
178
186
d_loss = discriminator_loss (expert_data )
179
187
180
188
# Backward pass
181
- discriminator_optim .zero_grad ()
182
189
d_loss .get ("loss" ).backward ()
183
190
discriminator_optim .step ()
191
+ discriminator_optim .zero_grad (set_to_none = True )
184
192
185
193
# Compute discriminator reward
186
194
with torch .no_grad ():
@@ -190,40 +198,25 @@ def main(cfg: "DictConfig"): # noqa: F821
190
198
# Set discriminator rewards to tensordict
191
199
data .set (("next" , "reward" ), d_rewards )
192
200
193
- # Get training rewards and episode lengths
194
- episode_rewards = data ["next" , "episode_reward" ][data ["next" , "done" ]]
195
- if len (episode_rewards ) > 0 :
196
- episode_length = data ["next" , "step_count" ][data ["next" , "done" ]]
197
- log_info .update (
198
- {
199
- "train/reward" : episode_rewards .mean ().item (),
200
- "train/episode_length" : episode_length .sum ().item ()
201
- / len (episode_length ),
202
- }
203
- )
204
201
# Update PPO
205
202
for _ in range (cfg_loss_ppo_epochs ):
206
-
207
203
# Compute GAE
208
204
with torch .no_grad ():
209
205
data = adv_module (data )
210
206
data_reshape = data .reshape (- 1 )
211
207
212
208
# Update the data buffer
209
+ data_buffer .empty ()
213
210
data_buffer .extend (data_reshape )
214
211
215
- for _ , batch in enumerate (data_buffer ):
216
-
217
- # Get a data batch
218
- batch = batch .to (device )
212
+ for batch in data_buffer :
213
+ optim .zero_grad (set_to_none = True )
219
214
220
215
# Linearly decrease the learning rate and clip epsilon
221
- alpha = 1.0
216
+ alpha = torch . ones ((), device = device )
222
217
if cfg_optim_anneal_lr :
223
218
alpha = 1 - (num_network_updates / total_network_updates )
224
- for group in actor_optim .param_groups :
225
- group ["lr" ] = cfg_optim_lr * alpha
226
- for group in critic_optim .param_groups :
219
+ for group in optim .param_groups :
227
220
group ["lr" ] = cfg_optim_lr * alpha
228
221
if cfg_loss_anneal_clip_eps :
229
222
loss_module .clip_epsilon .copy_ (cfg_loss_clip_epsilon * alpha )
@@ -235,20 +228,68 @@ def main(cfg: "DictConfig"): # noqa: F821
235
228
actor_loss = loss ["loss_objective" ] + loss ["loss_entropy" ]
236
229
237
230
# Backward pass
238
- actor_loss .backward ()
239
- critic_loss .backward ()
231
+ (actor_loss + critic_loss ).backward ()
240
232
241
233
# Update the networks
242
- actor_optim .step ()
243
- critic_optim .step ()
244
- actor_optim .zero_grad ()
245
- critic_optim .zero_grad ()
234
+ optim .step ()
235
+ return {"dloss" : d_loss , "alpha" : alpha }
236
+
237
+ if cfg .compile .compile :
238
+ update = compile_with_warmup (update , warmup = 2 , mode = compile_mode )
239
+ if cfg .compile .cudagraphs :
240
+ warnings .warn (
241
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution." ,
242
+ category = UserWarning ,
243
+ )
244
+ update = CudaGraphModule (update , warmup = 50 )
245
+
246
+ # Training loop
247
+ collected_frames = 0
248
+ pbar = tqdm .tqdm (total = cfg .ppo .collector .total_frames )
249
+
250
+ # extract cfg variables
251
+ cfg_loss_ppo_epochs = cfg .ppo .loss .ppo_epochs
252
+ cfg_optim_anneal_lr = cfg .ppo .optim .anneal_lr
253
+ cfg_optim_lr = cfg .ppo .optim .lr
254
+ cfg_loss_anneal_clip_eps = cfg .ppo .loss .anneal_clip_epsilon
255
+ cfg_loss_clip_epsilon = cfg .ppo .loss .clip_epsilon
256
+ cfg_logger_test_interval = cfg .logger .test_interval
257
+ cfg_logger_num_test_episodes = cfg .logger .num_test_episodes
258
+
259
+ for i , data in enumerate (collector ):
260
+
261
+ log_info = {}
262
+ frames_in_batch = data .numel ()
263
+ collected_frames += frames_in_batch
264
+ pbar .update (data .numel ())
265
+
266
+ # Update discriminator
267
+ # Get expert data
268
+ expert_data = replay_buffer .sample ()
269
+ expert_data = expert_data .to (device )
270
+
271
+ metadata = update (data , expert_data )
272
+ d_loss = metadata ["dloss" ]
273
+ alpha = metadata ["alpha" ]
274
+
275
+ # Get training rewards and episode lengths
276
+ episode_rewards = data ["next" , "episode_reward" ][data ["next" , "done" ]]
277
+ if len (episode_rewards ) > 0 :
278
+ episode_length = data ["next" , "step_count" ][data ["next" , "done" ]]
279
+
280
+ log_info .update (
281
+ {
282
+ "train/reward" : episode_rewards .mean ().item (),
283
+ "train/episode_length" : episode_length .sum ().item ()
284
+ / len (episode_length ),
285
+ }
286
+ )
246
287
247
288
log_info .update (
248
289
{
249
- "train/actor_loss" : actor_loss .item (),
250
- "train/critic_loss" : critic_loss .item (),
251
- "train/discriminator_loss" : d_loss ["loss" ]. item () ,
290
+ # "train/actor_loss": actor_loss.item(),
291
+ # "train/critic_loss": critic_loss.item(),
292
+ "train/discriminator_loss" : d_loss ["loss" ],
252
293
"train/lr" : alpha * cfg_optim_lr ,
253
294
"train/clip_epsilon" : (
254
295
alpha * cfg_loss_clip_epsilon
0 commit comments