10
10
11
11
The helper functions are coded in the utils.py associated with this script.
12
12
"""
13
+
13
14
from __future__ import annotations
14
15
15
- import time
16
+ import warnings
16
17
17
18
import hydra
18
19
import numpy as np
19
20
import torch
20
21
import torch .cuda
21
22
import tqdm
22
- from torchrl . _utils import logger as torchrl_logger
23
-
23
+ from tensordict . nn import CudaGraphModule
24
+ from torchrl . _utils import timeit
24
25
from torchrl .envs .utils import ExplorationType , set_exploration_type
25
-
26
+ from torchrl . objectives import group_optimizers
26
27
from torchrl .record .loggers import generate_exp_name , get_logger
27
28
from utils import (
28
29
dump_video ,
@@ -75,9 +76,6 @@ def main(cfg: "DictConfig"): # noqa: F821
75
76
# Create TD3 loss
76
77
loss_module , target_net_updater = make_loss_module (cfg , model )
77
78
78
- # Create off-policy collector
79
- collector = make_collector (cfg , train_env , model [0 ])
80
-
81
79
# Create replay buffer
82
80
replay_buffer = make_replay_buffer (
83
81
batch_size = cfg .optim .batch_size ,
@@ -91,9 +89,57 @@ def main(cfg: "DictConfig"): # noqa: F821
91
89
optimizer_actor , optimizer_critic , optimizer_alpha = make_optimizer (
92
90
cfg , loss_module
93
91
)
92
+ optimizer = group_optimizers (optimizer_actor , optimizer_critic , optimizer_alpha )
93
+ del optimizer_actor , optimizer_critic , optimizer_alpha
94
+
95
+ def update (sampled_tensordict ):
96
+ optimizer .zero_grad (set_to_none = True )
97
+
98
+ # Compute loss
99
+ loss_out = loss_module (sampled_tensordict )
100
+
101
+ actor_loss , q_loss , alpha_loss = (
102
+ loss_out ["loss_actor" ],
103
+ loss_out ["loss_qvalue" ],
104
+ loss_out ["loss_alpha" ],
105
+ )
106
+
107
+ # Update critic
108
+ (q_loss + actor_loss + alpha_loss ).backward ()
109
+ optimizer .step ()
110
+
111
+ # Update target params
112
+ target_net_updater .step ()
113
+
114
+ return loss_out .detach ()
115
+
116
+ compile_mode = None
117
+ if cfg .compile .compile :
118
+ compile_mode = cfg .compile .compile_mode
119
+ if compile_mode in ("" , None ):
120
+ if cfg .compile .cudagraphs :
121
+ compile_mode = "default"
122
+ else :
123
+ compile_mode = "reduce-overhead"
124
+ update = torch .compile (update , mode = compile_mode )
125
+ if cfg .compile .cudagraphs :
126
+ warnings .warn (
127
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution." ,
128
+ category = UserWarning ,
129
+ )
130
+ update = CudaGraphModule (update , warmup = 50 )
131
+
132
+ # Create off-policy collector
133
+ collector = make_collector (
134
+ cfg ,
135
+ train_env ,
136
+ model [0 ],
137
+ compile = compile_mode is not None ,
138
+ compile_mode = compile_mode ,
139
+ cudagraphs = cfg .compile .cudagraphs ,
140
+ )
94
141
95
142
# Main loop
96
- start_time = time .time ()
97
143
collected_frames = 0
98
144
pbar = tqdm .tqdm (total = cfg .collector .total_frames )
99
145
@@ -108,129 +154,93 @@ def main(cfg: "DictConfig"): # noqa: F821
108
154
eval_iter = cfg .logger .eval_iter
109
155
frames_per_batch = cfg .collector .frames_per_batch
110
156
111
- sampling_start = time .time ()
112
- for i , tensordict in enumerate (collector ):
113
- sampling_time = time .time () - sampling_start
157
+ c_iter = iter (collector )
158
+ for i in range (len (collector )):
159
+ with timeit ("collecting" ):
160
+ collected_data = next (c_iter )
114
161
115
162
# Update weights of the inference policy
116
163
collector .update_policy_weights_ ()
164
+ current_frames = collected_data .numel ()
117
165
118
- pbar .update (tensordict . numel () )
166
+ pbar .update (current_frames )
119
167
120
- tensordict = tensordict .reshape (- 1 )
121
- current_frames = tensordict . numel ()
122
- # Add to replay buffer
123
- replay_buffer .extend (tensordict . cpu () )
168
+ collected_data = collected_data .reshape (- 1 )
169
+ with timeit ( "rb - extend" ):
170
+ # Add to replay buffer
171
+ replay_buffer .extend (collected_data )
124
172
collected_frames += current_frames
125
173
126
174
# Optimization steps
127
- training_start = time .time ()
128
175
if collected_frames >= init_random_frames :
129
- (
130
- actor_losses ,
131
- q_losses ,
132
- alpha_losses ,
133
- ) = ([], [], [])
176
+ tds = []
134
177
for _ in range (num_updates ):
135
- # Sample from replay buffer
136
- sampled_tensordict = replay_buffer .sample ()
137
- if sampled_tensordict .device != device :
138
- sampled_tensordict = sampled_tensordict .to (
139
- device , non_blocking = True
140
- )
141
- else :
142
- sampled_tensordict = sampled_tensordict .clone ()
143
-
144
- # Compute loss
145
- loss_out = loss_module (sampled_tensordict )
146
-
147
- actor_loss , q_loss , alpha_loss = (
148
- loss_out ["loss_actor" ],
149
- loss_out ["loss_qvalue" ],
150
- loss_out ["loss_alpha" ],
151
- )
152
-
153
- # Update critic
154
- optimizer_critic .zero_grad ()
155
- q_loss .backward ()
156
- optimizer_critic .step ()
157
- q_losses .append (q_loss .item ())
178
+ with timeit ("rb - sample" ):
179
+ # Sample from replay buffer
180
+ sampled_tensordict = replay_buffer .sample ()
158
181
159
- # Update actor
160
- optimizer_actor . zero_grad ()
161
- actor_loss . backward ( )
162
- optimizer_actor . step ()
182
+ with timeit ( "update" ):
183
+ torch . compiler . cudagraph_mark_step_begin ()
184
+ sampled_tensordict = sampled_tensordict . to ( device )
185
+ loss_out = update ( sampled_tensordict ). clone ()
163
186
164
- actor_losses .append (actor_loss .item ())
165
-
166
- # Update alpha
167
- optimizer_alpha .zero_grad ()
168
- alpha_loss .backward ()
169
- optimizer_alpha .step ()
170
-
171
- alpha_losses .append (alpha_loss .item ())
172
-
173
- # Update target params
174
- target_net_updater .step ()
187
+ tds .append (loss_out )
175
188
176
189
# Update priority
177
190
if prb :
178
191
replay_buffer .update_priority (sampled_tensordict )
192
+ tds = torch .stack (tds ).mean ()
179
193
180
- training_time = time . time () - training_start
194
+ # Logging
181
195
episode_end = (
182
- tensordict ["next" , "done" ]
183
- if tensordict ["next" , "done" ].any ()
184
- else tensordict ["next" , "truncated" ]
196
+ collected_data ["next" , "done" ]
197
+ if collected_data ["next" , "done" ].any ()
198
+ else collected_data ["next" , "truncated" ]
185
199
)
186
- episode_rewards = tensordict ["next" , "episode_reward" ][episode_end ]
200
+ episode_rewards = collected_data ["next" , "episode_reward" ][episode_end ]
187
201
188
- # Logging
189
202
metrics_to_log = {}
190
203
if len (episode_rewards ) > 0 :
191
- episode_length = tensordict ["next" , "step_count" ][episode_end ]
204
+ episode_length = collected_data ["next" , "step_count" ][episode_end ]
192
205
metrics_to_log ["train/reward" ] = episode_rewards .mean ().item ()
193
206
metrics_to_log ["train/episode_length" ] = episode_length .sum ().item () / len (
194
207
episode_length
195
208
)
196
209
197
210
if collected_frames >= init_random_frames :
198
- metrics_to_log ["train/q_loss" ] = np .mean (q_losses )
199
- metrics_to_log ["train/a_loss" ] = np .mean (actor_losses )
200
- metrics_to_log ["train/alpha_loss" ] = np .mean (alpha_losses )
201
- metrics_to_log ["train/sampling_time" ] = sampling_time
202
- metrics_to_log ["train/training_time" ] = training_time
211
+ metrics_to_log ["train/q_loss" ] = tds ["loss_qvalue" ]
212
+ metrics_to_log ["train/a_loss" ] = tds ["loss_actor" ]
213
+ metrics_to_log ["train/alpha_loss" ] = tds ["loss_alpha" ]
203
214
204
215
# Evaluation
205
216
prev_test_frame = ((i - 1 ) * frames_per_batch ) // eval_iter
206
217
cur_test_frame = (i * frames_per_batch ) // eval_iter
207
218
final = current_frames >= collector .total_frames
208
219
if (i >= 1 and (prev_test_frame < cur_test_frame )) or final :
209
- with set_exploration_type (ExplorationType .DETERMINISTIC ), torch .no_grad ():
210
- eval_start = time .time ()
220
+ with set_exploration_type (
221
+ ExplorationType .DETERMINISTIC
222
+ ), torch .no_grad (), timeit ("eval" ):
211
223
eval_rollout = eval_env .rollout (
212
224
eval_rollout_steps ,
213
225
model [0 ],
214
226
auto_cast_to_device = True ,
215
227
break_when_any_done = True ,
216
228
)
217
229
eval_env .apply (dump_video )
218
- eval_time = time .time () - eval_start
219
230
eval_reward = eval_rollout ["next" , "reward" ].sum (- 2 ).mean ().item ()
220
231
metrics_to_log ["eval/reward" ] = eval_reward
221
- metrics_to_log ["eval/time" ] = eval_time
232
+ if i % 50 == 0 :
233
+ metrics_to_log .update (timeit .todict (prefix = "time" ))
234
+ timeit .print ()
235
+ timeit .erase ()
222
236
if logger is not None :
223
237
log_metrics (logger , metrics_to_log , collected_frames )
224
- sampling_start = time .time ()
225
238
226
239
collector .shutdown ()
227
240
if not eval_env .is_closed :
228
241
eval_env .close ()
229
242
if not train_env .is_closed :
230
243
train_env .close ()
231
- end_time = time .time ()
232
- execution_time = end_time - start_time
233
- torchrl_logger .info (f"Training took { execution_time :.2f} seconds to finish" )
234
244
235
245
236
246
if __name__ == "__main__" :
0 commit comments