10
10
from __future__ import annotations
11
11
12
12
import tempfile
13
- import time
13
+ import warnings
14
14
15
15
import hydra
16
16
import torch .nn
17
17
import torch .optim
18
18
import tqdm
19
- from tensordict .nn import TensorDictSequential
20
- from torchrl ._utils import logger as torchrl_logger
19
+ from tensordict .nn import CudaGraphModule , TensorDictSequential
20
+ from torchrl ._utils import timeit
21
21
22
22
from torchrl .collectors import SyncDataCollector
23
23
from torchrl .data import LazyMemmapStorage , TensorDictReplayBuffer
@@ -48,28 +48,17 @@ def main(cfg: "DictConfig"): # noqa: F821
48
48
test_interval = cfg .logger .test_interval // frame_skip
49
49
50
50
# Make the components
51
- model = make_dqn_model (cfg .env .env_name , frame_skip )
51
+ model = make_dqn_model (cfg .env .env_name , frame_skip , device = device )
52
52
greedy_module = EGreedyModule (
53
53
annealing_num_steps = cfg .collector .annealing_frames ,
54
54
eps_init = cfg .collector .eps_start ,
55
55
eps_end = cfg .collector .eps_end ,
56
56
spec = model .spec ,
57
+ device = device ,
57
58
)
58
59
model_explore = TensorDictSequential (
59
60
model ,
60
61
greedy_module ,
61
- ).to (device )
62
-
63
- # Create the collector
64
- collector = SyncDataCollector (
65
- create_env_fn = make_env (cfg .env .env_name , frame_skip , device ),
66
- policy = model_explore ,
67
- frames_per_batch = frames_per_batch ,
68
- total_frames = total_frames ,
69
- device = device ,
70
- storing_device = device ,
71
- max_frames_per_traj = - 1 ,
72
- init_random_frames = init_random_frames ,
73
62
)
74
63
75
64
# Create the replay buffer
@@ -129,25 +118,70 @@ def main(cfg: "DictConfig"): # noqa: F821
129
118
)
130
119
test_env .eval ()
131
120
121
+ def update (sampled_tensordict ):
122
+ loss_td = loss_module (sampled_tensordict )
123
+ q_loss = loss_td ["loss" ]
124
+ optimizer .zero_grad ()
125
+ q_loss .backward ()
126
+ torch .nn .utils .clip_grad_norm_ (
127
+ list (loss_module .parameters ()), max_norm = max_grad
128
+ )
129
+ optimizer .step ()
130
+ target_net_updater .step ()
131
+ return q_loss .detach ()
132
+
133
+ compile_mode = None
134
+ if cfg .compile .compile :
135
+ compile_mode = cfg .compile .compile_mode
136
+ if compile_mode in ("" , None ):
137
+ if cfg .compile .cudagraphs :
138
+ compile_mode = "default"
139
+ else :
140
+ compile_mode = "reduce-overhead"
141
+ update = torch .compile (update , mode = compile_mode )
142
+ if cfg .compile .cudagraphs :
143
+ warnings .warn (
144
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution." ,
145
+ category = UserWarning ,
146
+ )
147
+ update = CudaGraphModule (update , warmup = 50 )
148
+
149
+ # Create the collector
150
+ collector = SyncDataCollector (
151
+ create_env_fn = make_env (cfg .env .env_name , frame_skip , device ),
152
+ policy = model_explore ,
153
+ frames_per_batch = frames_per_batch ,
154
+ total_frames = total_frames ,
155
+ device = device ,
156
+ storing_device = device ,
157
+ max_frames_per_traj = - 1 ,
158
+ init_random_frames = init_random_frames ,
159
+ compile_policy = {"mode" : compile_mode , "fullgraph" : True }
160
+ if compile_mode is not None
161
+ else False ,
162
+ cudagraph_policy = cfg .compile .cudagraphs ,
163
+ )
164
+
132
165
# Main loop
133
166
collected_frames = 0
134
- start_time = time .time ()
135
- sampling_start = time .time ()
136
167
num_updates = cfg .loss .num_updates
137
168
max_grad = cfg .optim .max_grad_norm
138
169
num_test_episodes = cfg .logger .num_test_episodes
139
170
q_losses = torch .zeros (num_updates , device = device )
140
171
pbar = tqdm .tqdm (total = total_frames )
141
- for i , data in enumerate (collector ):
142
172
173
+ c_iter = iter (collector )
174
+ for i in range (len (collector )):
175
+ with timeit ("collecting" ):
176
+ data = next (c_iter )
143
177
log_info = {}
144
- sampling_time = time .time () - sampling_start
145
178
pbar .update (data .numel ())
146
179
data = data .reshape (- 1 )
147
180
current_frames = data .numel () * frame_skip
148
181
collected_frames += current_frames
149
182
greedy_module .step (current_frames )
150
- replay_buffer .extend (data )
183
+ with timeit ("rb - extend" ):
184
+ replay_buffer .extend (data )
151
185
152
186
# Get and log training rewards and episode lengths
153
187
episode_rewards = data ["next" , "episode_reward" ][data ["next" , "done" ]]
@@ -169,74 +203,59 @@ def main(cfg: "DictConfig"): # noqa: F821
169
203
continue
170
204
171
205
# optimization steps
172
- training_start = time .time ()
173
206
for j in range (num_updates ):
174
-
175
- sampled_tensordict = replay_buffer .sample ()
176
- sampled_tensordict = sampled_tensordict .to (device )
177
-
178
- loss_td = loss_module (sampled_tensordict )
179
- q_loss = loss_td ["loss" ]
180
- optimizer .zero_grad ()
181
- q_loss .backward ()
182
- torch .nn .utils .clip_grad_norm_ (
183
- list (loss_module .parameters ()), max_norm = max_grad
184
- )
185
- optimizer .step ()
186
- target_net_updater .step ()
187
- q_losses [j ].copy_ (q_loss .detach ())
188
-
189
- training_time = time .time () - training_start
207
+ with timeit ("rb - sample" ):
208
+ sampled_tensordict = replay_buffer .sample ()
209
+ sampled_tensordict = sampled_tensordict .to (device )
210
+ with timeit ("update" ):
211
+ q_loss = update (sampled_tensordict )
212
+ q_losses [j ].copy_ (q_loss )
190
213
191
214
# Get and log q-values, loss, epsilon, sampling time and training time
192
215
log_info .update (
193
216
{
194
- "train/q_values" : (data ["action_value" ] * data ["action" ]).sum ().item ()
195
- / frames_per_batch ,
196
- "train/q_loss" : q_losses .mean ().item (),
217
+ "train/q_values" : data ["chosen_action_value" ].sum () / frames_per_batch ,
218
+ "train/q_loss" : q_losses .mean (),
197
219
"train/epsilon" : greedy_module .eps ,
198
- "train/sampling_time" : sampling_time ,
199
- "train/training_time" : training_time ,
200
220
}
201
221
)
202
222
203
223
# Get and log evaluation rewards and eval time
204
- with torch .no_grad (), set_exploration_type (ExplorationType .DETERMINISTIC ):
224
+ with torch .no_grad (), set_exploration_type (
225
+ ExplorationType .DETERMINISTIC
226
+ ), timeit ("eval" ):
205
227
prev_test_frame = ((i - 1 ) * frames_per_batch ) // test_interval
206
228
cur_test_frame = (i * frames_per_batch ) // test_interval
207
229
final = current_frames >= collector .total_frames
208
230
if (i >= 1 and (prev_test_frame < cur_test_frame )) or final :
209
231
model .eval ()
210
- eval_start = time .time ()
211
232
test_rewards = eval_model (
212
233
model , test_env , num_episodes = num_test_episodes
213
234
)
214
- eval_time = time .time () - eval_start
215
235
log_info .update (
216
236
{
217
237
"eval/reward" : test_rewards ,
218
- "eval/eval_time" : eval_time ,
219
238
}
220
239
)
221
240
model .train ()
222
241
242
+ if i % 200 == 0 :
243
+ timeit .print ()
244
+ log_info .update (timeit .todict (prefix = "time" ))
245
+ timeit .erase ()
246
+
223
247
# Log all the information
224
248
if logger :
225
249
for key , value in log_info .items ():
226
250
logger .log_scalar (key , value , step = collected_frames )
227
251
228
252
# update weights of the inference policy
229
253
collector .update_policy_weights_ ()
230
- sampling_start = time .time ()
231
254
232
255
collector .shutdown ()
233
256
if not test_env .is_closed :
234
257
test_env .close ()
235
258
236
- end_time = time .time ()
237
- execution_time = end_time - start_time
238
- torchrl_logger .info (f"Training took { execution_time :.2f} seconds to finish" )
239
-
240
259
241
260
if __name__ == "__main__" :
242
261
main ()
0 commit comments