12
12
"""
13
13
from __future__ import annotations
14
14
15
- import time
15
+ import warnings
16
16
17
17
import hydra
18
18
19
19
import numpy as np
20
20
import torch
21
21
import torch .cuda
22
22
import tqdm
23
- from torchrl ._utils import logger as torchrl_logger
23
+ from tensordict import TensorDict
24
+ from tensordict .nn import CudaGraphModule
25
+
26
+ from torchrl ._utils import timeit
24
27
25
28
from torchrl .envs .utils import ExplorationType , set_exploration_type
29
+ from torchrl .objectives import group_optimizers
26
30
from torchrl .record .loggers import generate_exp_name , get_logger
27
31
from utils import (
28
32
dump_video ,
@@ -46,6 +50,14 @@ def main(cfg: "DictConfig"): # noqa: F821
46
50
device = "cpu"
47
51
device = torch .device (device )
48
52
53
+ collector_device = cfg .collector .device
54
+ if collector_device in ("" , None ):
55
+ if torch .cuda .is_available ():
56
+ collector_device = "cuda:0"
57
+ else :
58
+ collector_device = "cpu"
59
+ collector_device = torch .device (collector_device )
60
+
49
61
# Create logger
50
62
exp_name = generate_exp_name ("DDPG" , cfg .logger .exp_name )
51
63
logger = None
@@ -75,8 +87,25 @@ def main(cfg: "DictConfig"): # noqa: F821
75
87
# Create DDPG loss
76
88
loss_module , target_net_updater = make_loss_module (cfg , model )
77
89
90
+ compile_mode = None
91
+ if cfg .compile .compile :
92
+ if cfg .compile .compile_mode not in (None , "" ):
93
+ compile_mode = cfg .compile .compile_mode
94
+ elif cfg .compile .cudagraphs :
95
+ compile_mode = "default"
96
+ else :
97
+ compile_mode = "reduce-overhead"
98
+
78
99
# Create off-policy collector
79
- collector = make_collector (cfg , train_env , exploration_policy )
100
+ collector = make_collector (
101
+ cfg ,
102
+ train_env ,
103
+ exploration_policy ,
104
+ compile = cfg .compile .compile ,
105
+ compile_mode = compile_mode ,
106
+ cudagraph = cfg .compile .cudagraphs ,
107
+ device = collector_device ,
108
+ )
80
109
81
110
# Create replay buffer
82
111
replay_buffer = make_replay_buffer (
@@ -89,9 +118,29 @@ def main(cfg: "DictConfig"): # noqa: F821
89
118
90
119
# Create optimizers
91
120
optimizer_actor , optimizer_critic = make_optimizer (cfg , loss_module )
121
+ optimizer = group_optimizers (optimizer_actor , optimizer_critic )
122
+
123
+ def update (sampled_tensordict ):
124
+ optimizer .zero_grad (set_to_none = True )
125
+
126
+ td_loss : TensorDict = loss_module (sampled_tensordict )
127
+ td_loss .sum (reduce = True ).backward ()
128
+ optimizer .step ()
129
+
130
+ # Update qnet_target params
131
+ target_net_updater .step ()
132
+ return td_loss .detach ()
133
+
134
+ if cfg .compile .compile :
135
+ update = torch .compile (update , mode = compile_mode )
136
+ if cfg .compile .cudagraphs :
137
+ warnings .warn (
138
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution." ,
139
+ category = UserWarning ,
140
+ )
141
+ update = CudaGraphModule (update , warmup = 50 )
92
142
93
143
# Main loop
94
- start_time = time .time ()
95
144
collected_frames = 0
96
145
pbar = tqdm .tqdm (total = cfg .collector .total_frames )
97
146
@@ -106,63 +155,43 @@ def main(cfg: "DictConfig"): # noqa: F821
106
155
eval_iter = cfg .logger .eval_iter
107
156
eval_rollout_steps = cfg .env .max_episode_steps
108
157
109
- sampling_start = time .time ()
110
- for _ , tensordict in enumerate (collector ):
111
- sampling_time = time .time () - sampling_start
158
+ c_iter = iter (collector )
159
+ for i in range (len (collector )):
160
+ with timeit ("collecting" ):
161
+ tensordict = next (c_iter )
112
162
# Update exploration policy
113
163
exploration_policy [1 ].step (tensordict .numel ())
114
164
115
165
# Update weights of the inference policy
116
166
collector .update_policy_weights_ ()
117
167
118
- pbar .update (tensordict .numel ())
119
-
120
- tensordict = tensordict .reshape (- 1 )
121
168
current_frames = tensordict .numel ()
169
+ pbar .update (current_frames )
170
+
122
171
# Add to replay buffer
123
- replay_buffer .extend (tensordict .cpu ())
172
+ with timeit ("rb - extend" ):
173
+ tensordict = tensordict .reshape (- 1 )
174
+ replay_buffer .extend (tensordict )
175
+
124
176
collected_frames += current_frames
125
177
126
178
# Optimization steps
127
- training_start = time .time ()
128
179
if collected_frames >= init_random_frames :
129
- (
130
- actor_losses ,
131
- q_losses ,
132
- ) = ([], [])
180
+ tds = []
133
181
for _ in range (num_updates ):
134
182
# Sample from replay buffer
135
- sampled_tensordict = replay_buffer .sample ()
136
- if sampled_tensordict .device != device :
137
- sampled_tensordict = sampled_tensordict .to (
138
- device , non_blocking = True
139
- )
140
- else :
141
- sampled_tensordict = sampled_tensordict .clone ()
142
-
143
- # Update critic
144
- q_loss , * _ = loss_module .loss_value (sampled_tensordict )
145
- optimizer_critic .zero_grad ()
146
- q_loss .backward ()
147
- optimizer_critic .step ()
148
-
149
- # Update actor
150
- actor_loss , * _ = loss_module .loss_actor (sampled_tensordict )
151
- optimizer_actor .zero_grad ()
152
- actor_loss .backward ()
153
- optimizer_actor .step ()
154
-
155
- q_losses .append (q_loss .item ())
156
- actor_losses .append (actor_loss .item ())
157
-
158
- # Update qnet_target params
159
- target_net_updater .step ()
183
+ with timeit ("rb - sample" ):
184
+ sampled_tensordict = replay_buffer .sample ().to (device )
185
+ with timeit ("update" ):
186
+ torch .compiler .cudagraph_mark_step_begin ()
187
+ td_loss = update (sampled_tensordict )
188
+ tds .append (td_loss .clone ())
160
189
161
190
# Update priority
162
191
if prb :
163
192
replay_buffer .update_priority (sampled_tensordict )
193
+ tds = torch .stack (tds )
164
194
165
- training_time = time .time () - training_start
166
195
episode_end = (
167
196
tensordict ["next" , "done" ]
168
197
if tensordict ["next" , "done" ].any ()
@@ -180,38 +209,36 @@ def main(cfg: "DictConfig"): # noqa: F821
180
209
)
181
210
182
211
if collected_frames >= init_random_frames :
183
- metrics_to_log ["train/q_loss" ] = np .mean (q_losses )
184
- metrics_to_log ["train/a_loss" ] = np .mean (actor_losses )
185
- metrics_to_log ["train/sampling_time" ] = sampling_time
186
- metrics_to_log ["train/training_time" ] = training_time
212
+ tds = TensorDict (train = tds ).flatten_keys ("/" ).mean ()
213
+ metrics_to_log .update (tds .to_dict ())
187
214
188
215
# Evaluation
189
216
if abs (collected_frames % eval_iter ) < frames_per_batch :
190
- with set_exploration_type (ExplorationType .DETERMINISTIC ), torch .no_grad ():
191
- eval_start = time .time ()
217
+ with set_exploration_type (
218
+ ExplorationType .DETERMINISTIC
219
+ ), torch .no_grad (), timeit ("eval" ):
192
220
eval_rollout = eval_env .rollout (
193
221
eval_rollout_steps ,
194
222
exploration_policy ,
195
223
auto_cast_to_device = True ,
196
224
break_when_any_done = True ,
197
225
)
198
226
eval_env .apply (dump_video )
199
- eval_time = time .time () - eval_start
200
227
eval_reward = eval_rollout ["next" , "reward" ].sum (- 2 ).mean ().item ()
201
228
metrics_to_log ["eval/reward" ] = eval_reward
202
- metrics_to_log ["eval/time" ] = eval_time
229
+ if i % 20 == 0 :
230
+ metrics_to_log .update (timeit .todict (prefix = "time" ))
231
+ timeit .print ()
232
+ timeit .erase ()
233
+
203
234
if logger is not None :
204
235
log_metrics (logger , metrics_to_log , collected_frames )
205
- sampling_start = time .time ()
206
236
207
237
collector .shutdown ()
208
- end_time = time .time ()
209
- execution_time = end_time - start_time
210
238
if not eval_env .is_closed :
211
239
eval_env .close ()
212
240
if not train_env .is_closed :
213
241
train_env .close ()
214
- torchrl_logger .info (f"Training took { execution_time :.2f} seconds to finish" )
215
242
216
243
217
244
if __name__ == "__main__" :
0 commit comments