12
12
"""
13
13
from __future__ import annotations
14
14
15
- import time
15
+ import warnings
16
16
17
17
import hydra
18
18
21
21
import torch .cuda
22
22
import tqdm
23
23
from tensordict import TensorDict
24
- from torchrl ._utils import logger as torchrl_logger
24
+ from tensordict .nn import CudaGraphModule
25
+
26
+ from torchrl ._utils import compile_with_warmup , timeit
25
27
from torchrl .envs .utils import ExplorationType , set_exploration_type
28
+ from torchrl .objectives import group_optimizers
26
29
27
30
from torchrl .record .loggers import generate_exp_name , get_logger
28
31
from utils import (
36
39
make_sac_optimizer ,
37
40
)
38
41
42
+ torch .set_float32_matmul_precision ("high" )
43
+
39
44
40
45
@hydra .main (version_base = "1.1" , config_path = "" , config_name = "config" )
41
46
def main (cfg : "DictConfig" ): # noqa: F821
@@ -75,16 +80,27 @@ def main(cfg: "DictConfig"): # noqa: F821
75
80
# Create SAC loss
76
81
loss_module , target_net_updater = make_loss_module (cfg , model )
77
82
83
+ compile_mode = None
84
+ if cfg .compile .compile :
85
+ compile_mode = cfg .compile .compile_mode
86
+ if compile_mode in ("" , None ):
87
+ if cfg .compile .cudagraphs :
88
+ compile_mode = "default"
89
+ else :
90
+ compile_mode = "reduce-overhead"
91
+
78
92
# Create off-policy collector
79
- collector = make_collector (cfg , train_env , exploration_policy )
93
+ collector = make_collector (
94
+ cfg , train_env , exploration_policy , compile_mode = compile_mode
95
+ )
80
96
81
97
# Create replay buffer
82
98
replay_buffer = make_replay_buffer (
83
99
batch_size = cfg .optim .batch_size ,
84
100
prb = cfg .replay_buffer .prb ,
85
101
buffer_size = cfg .replay_buffer .size ,
86
102
scratch_dir = cfg .replay_buffer .scratch_dir ,
87
- device = "cpu" ,
103
+ device = device ,
88
104
)
89
105
90
106
# Create optimizers
@@ -93,9 +109,36 @@ def main(cfg: "DictConfig"): # noqa: F821
93
109
optimizer_critic ,
94
110
optimizer_alpha ,
95
111
) = make_sac_optimizer (cfg , loss_module )
112
+ optimizer = group_optimizers (optimizer_actor , optimizer_critic , optimizer_alpha )
113
+ del optimizer_actor , optimizer_critic , optimizer_alpha
114
+
115
+ def update (sampled_tensordict ):
116
+ # Compute loss
117
+ loss_td = loss_module (sampled_tensordict )
118
+
119
+ actor_loss = loss_td ["loss_actor" ]
120
+ q_loss = loss_td ["loss_qvalue" ]
121
+ alpha_loss = loss_td ["loss_alpha" ]
122
+
123
+ (actor_loss + q_loss + alpha_loss ).sum ().backward ()
124
+ optimizer .step ()
125
+ optimizer .zero_grad (set_to_none = True )
126
+
127
+ # Update qnet_target params
128
+ target_net_updater .step ()
129
+ return loss_td .detach ()
130
+
131
+ if cfg .compile .compile :
132
+ update = compile_with_warmup (update , mode = compile_mode , warmup = 1 )
133
+
134
+ if cfg .compile .cudagraphs :
135
+ warnings .warn (
136
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution." ,
137
+ category = UserWarning ,
138
+ )
139
+ update = CudaGraphModule (update , in_keys = [], out_keys = [], warmup = 5 )
96
140
97
141
# Main loop
98
- start_time = time .time ()
99
142
collected_frames = 0
100
143
pbar = tqdm .tqdm (total = cfg .collector .total_frames )
101
144
@@ -110,69 +153,48 @@ def main(cfg: "DictConfig"): # noqa: F821
110
153
frames_per_batch = cfg .collector .frames_per_batch
111
154
eval_rollout_steps = cfg .env .max_episode_steps
112
155
113
- sampling_start = time .time ()
114
- for i , tensordict in enumerate (collector ):
115
- sampling_time = time .time () - sampling_start
156
+ collector_iter = iter (collector )
157
+ total_iter = len (collector )
158
+
159
+ for i in range (total_iter ):
160
+ timeit .printevery (num_prints = 1000 , total_count = total_iter , erase = True )
161
+
162
+ with timeit ("collect" ):
163
+ tensordict = next (collector_iter )
116
164
117
165
# Update weights of the inference policy
118
166
collector .update_policy_weights_ ()
119
167
120
- pbar .update (tensordict .numel ())
121
-
122
- tensordict = tensordict .reshape (- 1 )
123
168
current_frames = tensordict .numel ()
124
- # Add to replay buffer
125
- replay_buffer .extend (tensordict .cpu ())
169
+ pbar .update (current_frames )
170
+
171
+ with timeit ("rb - extend" ):
172
+ # Add to replay buffer
173
+ tensordict = tensordict .reshape (- 1 )
174
+ replay_buffer .extend (tensordict )
175
+
126
176
collected_frames += current_frames
127
177
128
178
# Optimization steps
129
- training_start = time .time ()
130
- if collected_frames >= init_random_frames :
131
- losses = TensorDict (batch_size = [num_updates ])
132
- for i in range (num_updates ):
133
- # Sample from replay buffer
134
- sampled_tensordict = replay_buffer .sample ()
135
- if sampled_tensordict .device != device :
136
- sampled_tensordict = sampled_tensordict .to (
137
- device , non_blocking = True
179
+ with timeit ("train" ):
180
+ if collected_frames >= init_random_frames :
181
+ losses = TensorDict (batch_size = [num_updates ])
182
+ for i in range (num_updates ):
183
+ with timeit ("rb - sample" ):
184
+ # Sample from replay buffer
185
+ sampled_tensordict = replay_buffer .sample ()
186
+
187
+ with timeit ("update" ):
188
+ torch .compiler .cudagraph_mark_step_begin ()
189
+ loss_td = update (sampled_tensordict ).clone ()
190
+ losses [i ] = loss_td .select (
191
+ "loss_actor" , "loss_qvalue" , "loss_alpha"
138
192
)
139
- else :
140
- sampled_tensordict = sampled_tensordict .clone ()
141
-
142
- # Compute loss
143
- loss_td = loss_module (sampled_tensordict )
144
-
145
- actor_loss = loss_td ["loss_actor" ]
146
- q_loss = loss_td ["loss_qvalue" ]
147
- alpha_loss = loss_td ["loss_alpha" ]
148
-
149
- # Update actor
150
- optimizer_actor .zero_grad ()
151
- actor_loss .backward ()
152
- optimizer_actor .step ()
153
-
154
- # Update critic
155
- optimizer_critic .zero_grad ()
156
- q_loss .backward ()
157
- optimizer_critic .step ()
158
-
159
- # Update alpha
160
- optimizer_alpha .zero_grad ()
161
- alpha_loss .backward ()
162
- optimizer_alpha .step ()
163
-
164
- losses [i ] = loss_td .select (
165
- "loss_actor" , "loss_qvalue" , "loss_alpha"
166
- ).detach ()
167
-
168
- # Update qnet_target params
169
- target_net_updater .step ()
170
193
171
- # Update priority
172
- if prb :
173
- replay_buffer .update_priority (sampled_tensordict )
194
+ # Update priority
195
+ if prb :
196
+ replay_buffer .update_priority (sampled_tensordict )
174
197
175
- training_time = time .time () - training_start
176
198
episode_end = (
177
199
tensordict ["next" , "done" ]
178
200
if tensordict ["next" , "done" ].any ()
@@ -184,46 +206,41 @@ def main(cfg: "DictConfig"): # noqa: F821
184
206
metrics_to_log = {}
185
207
if len (episode_rewards ) > 0 :
186
208
episode_length = tensordict ["next" , "step_count" ][episode_end ]
187
- metrics_to_log ["train/reward" ] = episode_rewards . mean (). item ()
188
- metrics_to_log ["train/episode_length" ] = episode_length .sum (). item () / len (
209
+ metrics_to_log ["train/reward" ] = episode_rewards
210
+ metrics_to_log ["train/episode_length" ] = episode_length .sum () / len (
189
211
episode_length
190
212
)
191
213
if collected_frames >= init_random_frames :
192
- metrics_to_log ["train/q_loss" ] = losses .get ("loss_qvalue" ).mean ().item ()
193
- metrics_to_log ["train/actor_loss" ] = losses .get ("loss_actor" ).mean ().item ()
194
- metrics_to_log ["train/alpha_loss" ] = losses .get ("loss_alpha" ).mean ().item ()
195
- metrics_to_log ["train/alpha" ] = loss_td ["alpha" ].item ()
196
- metrics_to_log ["train/entropy" ] = loss_td ["entropy" ].item ()
197
- metrics_to_log ["train/sampling_time" ] = sampling_time
198
- metrics_to_log ["train/training_time" ] = training_time
214
+ losses = losses .mean ()
215
+ metrics_to_log ["train/q_loss" ] = losses .get ("loss_qvalue" )
216
+ metrics_to_log ["train/actor_loss" ] = losses .get ("loss_actor" )
217
+ metrics_to_log ["train/alpha_loss" ] = losses .get ("loss_alpha" )
218
+ metrics_to_log ["train/alpha" ] = loss_td ["alpha" ]
219
+ metrics_to_log ["train/entropy" ] = loss_td ["entropy" ]
199
220
200
221
# Evaluation
201
222
if abs (collected_frames % eval_iter ) < frames_per_batch :
202
- with set_exploration_type (ExplorationType .DETERMINISTIC ), torch .no_grad ():
203
- eval_start = time .time ()
223
+ with set_exploration_type (
224
+ ExplorationType .DETERMINISTIC
225
+ ), torch .no_grad (), timeit ("eval" ):
204
226
eval_rollout = eval_env .rollout (
205
227
eval_rollout_steps ,
206
228
model [0 ],
207
229
auto_cast_to_device = True ,
208
230
break_when_any_done = True ,
209
231
)
210
232
eval_env .apply (dump_video )
211
- eval_time = time .time () - eval_start
212
233
eval_reward = eval_rollout ["next" , "reward" ].sum (- 2 ).mean ().item ()
213
234
metrics_to_log ["eval/reward" ] = eval_reward
214
- metrics_to_log ["eval/time" ] = eval_time
215
235
if logger is not None :
236
+ metrics_to_log .update (timeit .todict (prefix = "time" ))
216
237
log_metrics (logger , metrics_to_log , collected_frames )
217
- sampling_start = time .time ()
218
238
219
239
collector .shutdown ()
220
240
if not eval_env .is_closed :
221
241
eval_env .close ()
222
242
if not train_env .is_closed :
223
243
train_env .close ()
224
- end_time = time .time ()
225
- execution_time = end_time - start_time
226
- torchrl_logger .info (f"Training took { execution_time :.2f} seconds to finish" )
227
244
228
245
229
246
if __name__ == "__main__" :
0 commit comments