3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
import hydra
6
- from torchrl ._utils import logger as torchrl_logger
7
- from torchrl .record import VideoRecorder
6
+ import torch
7
+
8
+ torch .set_float32_matmul_precision ("high" )
8
9
9
10
10
11
@hydra .main (config_path = "" , config_name = "config_atari" , version_base = "1.1" )
11
12
def main (cfg : "DictConfig" ): # noqa: F821
12
13
13
- import time
14
+ from copy import deepcopy
14
15
15
16
import torch .optim
16
17
import tqdm
18
+ from tensordict import from_module
19
+ from tensordict .nn import CudaGraphModule
17
20
18
- from tensordict import TensorDict
21
+ from torchrl . _utils import timeit
19
22
from torchrl .collectors import SyncDataCollector
20
- from torchrl .data import LazyMemmapStorage , TensorDictReplayBuffer
23
+ from torchrl .data import LazyTensorStorage , TensorDictReplayBuffer
21
24
from torchrl .data .replay_buffers .samplers import SamplerWithoutReplacement
22
25
from torchrl .envs import ExplorationType , set_exploration_type
23
26
from torchrl .objectives import A2CLoss
24
27
from torchrl .objectives .value .advantages import GAE
28
+ from torchrl .record import VideoRecorder
25
29
from torchrl .record .loggers import generate_exp_name , get_logger
26
30
from utils_atari import eval_model , make_parallel_env , make_ppo_models
27
31
28
- device = "cpu" if not torch .cuda .device_count () else "cuda"
32
+ device = cfg .loss .device
33
+ if not device :
34
+ device = torch .device ("cpu" if not torch .cuda .is_available () else "cuda:0" )
35
+ else :
36
+ device = torch .device (device )
29
37
30
38
# Correct for frame_skip
31
39
frame_skip = 4
@@ -35,28 +43,16 @@ def main(cfg: "DictConfig"): # noqa: F821
35
43
test_interval = cfg .logger .test_interval // frame_skip
36
44
37
45
# Create models (check utils_atari.py)
38
- actor , critic , critic_head = make_ppo_models (cfg .env .env_name )
39
- actor , critic , critic_head = (
40
- actor .to (device ),
41
- critic .to (device ),
42
- critic_head .to (device ),
43
- )
44
-
45
- # Create collector
46
- collector = SyncDataCollector (
47
- create_env_fn = make_parallel_env (cfg .env .env_name , cfg .env .num_envs , device ),
48
- policy = actor ,
49
- frames_per_batch = frames_per_batch ,
50
- total_frames = total_frames ,
51
- device = device ,
52
- storing_device = device ,
53
- max_frames_per_traj = - 1 ,
54
- )
46
+ actor , critic , critic_head = make_ppo_models (cfg .env .env_name , device = device )
47
+ with from_module (actor ).data .to ("meta" ).to_module (actor ):
48
+ actor_eval = deepcopy (actor )
49
+ actor_eval .eval ()
50
+ from_module (actor ).data .to_module (actor_eval )
55
51
56
52
# Create data buffer
57
53
sampler = SamplerWithoutReplacement ()
58
54
data_buffer = TensorDictReplayBuffer (
59
- storage = LazyMemmapStorage (frames_per_batch ),
55
+ storage = LazyTensorStorage (frames_per_batch , device = device ),
60
56
sampler = sampler ,
61
57
batch_size = mini_batch_size ,
62
58
)
@@ -67,6 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821
67
63
lmbda = cfg .loss .gae_lambda ,
68
64
value_network = critic ,
69
65
average_gae = True ,
66
+ vectorized = not cfg .compile .compile ,
67
+ device = device ,
70
68
)
71
69
loss_module = A2CLoss (
72
70
actor_network = actor ,
@@ -83,9 +81,10 @@ def main(cfg: "DictConfig"): # noqa: F821
83
81
# Create optimizer
84
82
optim = torch .optim .Adam (
85
83
loss_module .parameters (),
86
- lr = cfg .optim .lr ,
84
+ lr = torch . tensor ( cfg .optim .lr , device = device ) ,
87
85
weight_decay = cfg .optim .weight_decay ,
88
86
eps = cfg .optim .eps ,
87
+ capturable = device .type == "cuda" ,
89
88
)
90
89
91
90
# Create logger
@@ -115,19 +114,71 @@ def main(cfg: "DictConfig"): # noqa: F821
115
114
)
116
115
test_env .eval ()
117
116
117
+ # update function
118
+ def update (batch , max_grad_norm = cfg .optim .max_grad_norm ):
119
+ # Forward pass A2C loss
120
+ loss = loss_module (batch )
121
+
122
+ loss_sum = loss ["loss_critic" ] + loss ["loss_objective" ] + loss ["loss_entropy" ]
123
+
124
+ # Backward pass
125
+ loss_sum .backward ()
126
+ gn = torch .nn .utils .clip_grad_norm_ (
127
+ loss_module .parameters (), max_norm = max_grad_norm
128
+ )
129
+
130
+ # Update the networks
131
+ optim .step ()
132
+ optim .zero_grad (set_to_none = True )
133
+
134
+ return (
135
+ loss .select ("loss_critic" , "loss_entropy" , "loss_objective" )
136
+ .detach ()
137
+ .set ("grad_norm" , gn )
138
+ )
139
+
140
+ compile_mode = None
141
+ if cfg .compile .compile :
142
+ compile_mode = cfg .compile .compile_mode
143
+ if compile_mode in ("" , None ):
144
+ if cfg .compile .cudagraphs :
145
+ compile_mode = "default"
146
+ else :
147
+ compile_mode = "reduce-overhead"
148
+ update = torch .compile (update , mode = compile_mode )
149
+ adv_module = torch .compile (adv_module , mode = compile_mode )
150
+
151
+ if cfg .compile .cudagraphs :
152
+ update = CudaGraphModule (update , in_keys = [], out_keys = [], warmup = 5 )
153
+ adv_module = CudaGraphModule (adv_module )
154
+
155
+ # Create collector
156
+ collector = SyncDataCollector (
157
+ create_env_fn = make_parallel_env (cfg .env .env_name , cfg .env .num_envs , device ),
158
+ policy = actor ,
159
+ frames_per_batch = frames_per_batch ,
160
+ total_frames = total_frames ,
161
+ device = device ,
162
+ storing_device = device ,
163
+ policy_device = device ,
164
+ compile_policy = {"mode" : compile_mode } if cfg .compile .compile else False ,
165
+ cudagraph_policy = cfg .compile .cudagraphs ,
166
+ )
167
+
118
168
# Main loop
119
169
collected_frames = 0
120
170
num_network_updates = 0
121
- start_time = time .time ()
122
171
pbar = tqdm .tqdm (total = total_frames )
123
172
num_mini_batches = frames_per_batch // mini_batch_size
124
173
total_network_updates = (total_frames // frames_per_batch ) * num_mini_batches
174
+ lr = cfg .optim .lr
125
175
126
- sampling_start = time .time ()
127
- for i , data in enumerate (collector ):
176
+ c_iter = iter (collector )
177
+ for i in range (len (collector )):
178
+ with timeit ("collecting" ):
179
+ data = next (c_iter )
128
180
129
181
log_info = {}
130
- sampling_time = time .time () - sampling_start
131
182
frames_in_batch = data .numel ()
132
183
collected_frames += frames_in_batch * frame_skip
133
184
pbar .update (data .numel ())
@@ -144,94 +195,76 @@ def main(cfg: "DictConfig"): # noqa: F821
144
195
}
145
196
)
146
197
147
- losses = TensorDict (batch_size = [num_mini_batches ])
148
- training_start = time .time ()
198
+ losses = []
149
199
150
200
# Compute GAE
151
- with torch .no_grad ():
201
+ with torch .no_grad (), timeit ("advantage" ):
202
+ torch .compiler .cudagraph_mark_step_begin ()
152
203
data = adv_module (data )
153
204
data_reshape = data .reshape (- 1 )
154
205
155
206
# Update the data buffer
156
- data_buffer .extend (data_reshape )
157
-
158
- for k , batch in enumerate (data_buffer ):
159
-
160
- # Get a data batch
161
- batch = batch .to (device )
162
-
163
- # Linearly decrease the learning rate and clip epsilon
164
- alpha = 1.0
165
- if cfg .optim .anneal_lr :
166
- alpha = 1 - (num_network_updates / total_network_updates )
167
- for group in optim .param_groups :
168
- group ["lr" ] = cfg .optim .lr * alpha
169
- num_network_updates += 1
170
-
171
- # Forward pass A2C loss
172
- loss = loss_module (batch )
173
- losses [k ] = loss .select (
174
- "loss_critic" , "loss_entropy" , "loss_objective"
175
- ).detach ()
176
- loss_sum = (
177
- loss ["loss_critic" ] + loss ["loss_objective" ] + loss ["loss_entropy" ]
178
- )
207
+ with timeit ("rb - emptying" ):
208
+ data_buffer .empty ()
209
+ with timeit ("rb - extending" ):
210
+ data_buffer .extend (data_reshape )
179
211
180
- # Backward pass
181
- loss_sum .backward ()
182
- torch .nn .utils .clip_grad_norm_ (
183
- list (loss_module .parameters ()), max_norm = cfg .optim .max_grad_norm
184
- )
212
+ with timeit ("optim" ):
213
+ for batch in data_buffer :
185
214
186
- # Update the networks
187
- optim .step ()
188
- optim .zero_grad ()
215
+ # Linearly decrease the learning rate and clip epsilon
216
+ with timeit ("optim - lr" ):
217
+ alpha = 1.0
218
+ if cfg .optim .anneal_lr :
219
+ alpha = 1 - (num_network_updates / total_network_updates )
220
+ for group in optim .param_groups :
221
+ group ["lr" ].copy_ (lr * alpha )
222
+
223
+ num_network_updates += 1
224
+
225
+ with timeit ("update" ):
226
+ torch .compiler .cudagraph_mark_step_begin ()
227
+ loss = update (batch ).clone ()
228
+ losses .append (loss )
189
229
190
230
# Get training losses
191
- training_time = time . time () - training_start
192
- losses = losses . apply ( lambda x : x . float (). mean (), batch_size = [])
231
+ losses = torch . stack ( losses ). float (). mean ()
232
+
193
233
for key , value in losses .items ():
194
234
log_info .update ({f"train/{ key } " : value .item ()})
195
235
log_info .update (
196
236
{
197
- "train/lr" : alpha * cfg .optim .lr ,
198
- "train/sampling_time" : sampling_time ,
199
- "train/training_time" : training_time ,
237
+ "train/lr" : lr * alpha ,
200
238
}
201
239
)
202
240
203
241
# Get test rewards
204
- with torch .no_grad (), set_exploration_type (ExplorationType .DETERMINISTIC ):
242
+ with torch .no_grad (), set_exploration_type (
243
+ ExplorationType .DETERMINISTIC
244
+ ), timeit ("eval" ):
205
245
if ((i - 1 ) * frames_in_batch * frame_skip ) // test_interval < (
206
246
i * frames_in_batch * frame_skip
207
247
) // test_interval :
208
- actor .eval ()
209
- eval_start = time .time ()
210
248
test_rewards = eval_model (
211
- actor , test_env , num_episodes = cfg .logger .num_test_episodes
249
+ actor_eval , test_env , num_episodes = cfg .logger .num_test_episodes
212
250
)
213
- eval_time = time .time () - eval_start
214
251
log_info .update (
215
252
{
216
253
"test/reward" : test_rewards .mean (),
217
- "test/eval_time" : eval_time ,
218
254
}
219
255
)
220
- actor .train ()
256
+ if i % 200 == 0 :
257
+ log_info .update (timeit .todict (prefix = "time" ))
258
+ timeit .print ()
259
+ timeit .erase ()
221
260
222
261
if logger :
223
262
for key , value in log_info .items ():
224
263
logger .log_scalar (key , value , collected_frames )
225
264
226
- collector .update_policy_weights_ ()
227
- sampling_start = time .time ()
228
-
229
265
collector .shutdown ()
230
266
if not test_env .is_closed :
231
267
test_env .close ()
232
- end_time = time .time ()
233
- execution_time = end_time - start_time
234
- torchrl_logger .info (f"Training took { execution_time :.2f} seconds to finish" )
235
268
236
269
237
270
if __name__ == "__main__" :
0 commit comments