13
13
"""
14
14
from __future__ import annotations
15
15
16
- import time
16
+ import warnings
17
17
18
18
import hydra
19
19
import numpy as np
20
20
import torch
21
21
import tqdm
22
- from torchrl ._utils import logger as torchrl_logger
22
+ from tensordict import TensorDict
23
+ from tensordict .nn import CudaGraphModule
24
+
25
+ from torchrl ._utils import timeit
23
26
24
27
from torchrl .envs import set_gym_backend
25
28
from torchrl .envs .utils import ExplorationType , set_exploration_type
29
+ from torchrl .objectives import group_optimizers
26
30
from torchrl .record .loggers import generate_exp_name , get_logger
27
31
28
32
from utils import (
37
41
)
38
42
39
43
44
+ torch .set_float32_matmul_precision ("high" )
45
+
46
+
40
47
@hydra .main (config_path = "" , config_name = "discrete_iql" )
41
48
def main (cfg : "DictConfig" ): # noqa: F821
42
49
set_gym_backend (cfg .env .backend ).set ()
@@ -87,16 +94,54 @@ def main(cfg: "DictConfig"): # noqa: F821
87
94
# Create model
88
95
model = make_discrete_iql_model (cfg , train_env , eval_env , device )
89
96
97
+ compile_mode = None
98
+ if cfg .compile .compile :
99
+ compile_mode = cfg .compile .compile_mode
100
+ if compile_mode in ("" , None ):
101
+ if cfg .compile .cudagraphs :
102
+ compile_mode = "default"
103
+ else :
104
+ compile_mode = "reduce-overhead"
105
+
90
106
# Create collector
91
- collector = make_collector (cfg , train_env , actor_model_explore = model [0 ])
107
+ collector = make_collector (
108
+ cfg , train_env , actor_model_explore = model [0 ], compile_mode = compile_mode
109
+ )
92
110
93
111
# Create loss
94
- loss_module , target_net_updater = make_discrete_loss (cfg .loss , model )
112
+ loss_module , target_net_updater = make_discrete_loss (cfg .loss , model , device = device )
95
113
96
114
# Create optimizer
97
115
optimizer_actor , optimizer_critic , optimizer_value = make_iql_optimizer (
98
116
cfg .optim , loss_module
99
117
)
118
+ optimizer = group_optimizers (optimizer_actor , optimizer_critic , optimizer_value )
119
+ del optimizer_actor , optimizer_critic , optimizer_value
120
+
121
+ def update (sampled_tensordict ):
122
+ optimizer .zero_grad (set_to_none = True )
123
+ # compute losses
124
+ actor_loss , _ = loss_module .actor_loss (sampled_tensordict )
125
+ value_loss , _ = loss_module .value_loss (sampled_tensordict )
126
+ q_loss , metadata = loss_module .qvalue_loss (sampled_tensordict )
127
+ (actor_loss + value_loss + q_loss ).backward ()
128
+ optimizer .step ()
129
+
130
+ # update qnet_target params
131
+ target_net_updater .step ()
132
+ metadata .update (
133
+ {"actor_loss" : actor_loss , "value_loss" : value_loss , "q_loss" : q_loss }
134
+ )
135
+ return TensorDict (metadata ).detach ()
136
+
137
+ if cfg .compile .compile :
138
+ update = torch .compile (update , mode = compile_mode )
139
+ if cfg .compile .cudagraphs :
140
+ warnings .warn (
141
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution." ,
142
+ category = UserWarning ,
143
+ )
144
+ update = CudaGraphModule (update , warmup = 50 )
100
145
101
146
# Main loop
102
147
collected_frames = 0
@@ -112,103 +157,82 @@ def main(cfg: "DictConfig"): # noqa: F821
112
157
eval_iter = cfg .logger .eval_iter
113
158
frames_per_batch = cfg .collector .frames_per_batch
114
159
eval_rollout_steps = cfg .collector .max_frames_per_traj
115
- sampling_start = start_time = time .time ()
116
- for tensordict in collector :
117
- sampling_time = time .time () - sampling_start
118
- pbar .update (tensordict .numel ())
160
+
161
+ collector_iter = iter (collector )
162
+ for _ in range (len (collector )):
163
+ with timeit ("collection" ):
164
+ tensordict = next (collector_iter )
165
+ current_frames = tensordict .numel ()
166
+ pbar .update (current_frames )
167
+
119
168
# update weights of the inference policy
120
169
collector .update_policy_weights_ ()
121
170
122
- tensordict = tensordict .reshape (- 1 )
123
- current_frames = tensordict .numel ()
124
- # add to replay buffer
125
- replay_buffer .extend (tensordict .cpu ())
171
+ with timeit ("buffer - extend" ):
172
+ tensordict = tensordict .reshape (- 1 )
173
+
174
+ # add to replay buffer
175
+ replay_buffer .extend (tensordict )
126
176
collected_frames += current_frames
127
177
128
178
# optimization steps
129
- training_start = time .time ()
130
- if collected_frames >= init_random_frames :
131
- for _ in range (num_updates ):
132
- # sample from replay buffer
133
- sampled_tensordict = replay_buffer .sample ().clone ()
134
- if sampled_tensordict .device != device :
135
- sampled_tensordict = sampled_tensordict .to (
136
- device , non_blocking = True
137
- )
138
- else :
139
- sampled_tensordict = sampled_tensordict
140
- # compute losses
141
- actor_loss , _ = loss_module .actor_loss (sampled_tensordict )
142
- optimizer_actor .zero_grad ()
143
- actor_loss .backward ()
144
- optimizer_actor .step ()
145
-
146
- value_loss , _ = loss_module .value_loss (sampled_tensordict )
147
- optimizer_value .zero_grad ()
148
- value_loss .backward ()
149
- optimizer_value .step ()
150
-
151
- q_loss , metadata = loss_module .qvalue_loss (sampled_tensordict )
152
- optimizer_critic .zero_grad ()
153
- q_loss .backward ()
154
- optimizer_critic .step ()
155
-
156
- # update qnet_target params
157
- target_net_updater .step ()
158
-
159
- # update priority
160
- if prb :
161
- sampled_tensordict .set (
162
- loss_module .tensor_keys .priority ,
163
- metadata .pop ("td_error" ).detach ().max (0 ).values ,
164
- )
165
- replay_buffer .update_priority (sampled_tensordict )
166
-
167
- training_time = time .time () - training_start
179
+ with timeit ("training" ):
180
+ if collected_frames >= init_random_frames :
181
+ for _ in range (num_updates ):
182
+ # sample from replay buffer
183
+ with timeit ("buffer - sample" ):
184
+ sampled_tensordict = replay_buffer .sample ().to (device )
185
+
186
+ with timeit ("training - update" ):
187
+ torch .compiler .cudagraph_mark_step_begin ()
188
+ metadata = update (sampled_tensordict )
189
+ # update priority
190
+ if prb :
191
+ sampled_tensordict .set (
192
+ loss_module .tensor_keys .priority ,
193
+ metadata .pop ("td_error" ).detach ().max (0 ).values ,
194
+ )
195
+ replay_buffer .update_priority (sampled_tensordict )
196
+
168
197
episode_rewards = tensordict ["next" , "episode_reward" ][
169
198
tensordict ["next" , "done" ]
170
199
]
171
200
172
- # Logging
173
201
metrics_to_log = {}
174
- if len (episode_rewards ) > 0 :
175
- episode_length = tensordict ["next" , "step_count" ][
176
- tensordict ["next" , "done" ]
177
- ]
178
- metrics_to_log ["train/reward" ] = episode_rewards .mean ().item ()
179
- metrics_to_log ["train/episode_length" ] = episode_length .sum ().item () / len (
180
- episode_length
181
- )
182
- if collected_frames >= init_random_frames :
183
- metrics_to_log ["train/q_loss" ] = q_loss .detach ()
184
- metrics_to_log ["train/actor_loss" ] = actor_loss .detach ()
185
- metrics_to_log ["train/value_loss" ] = value_loss .detach ()
186
- metrics_to_log ["train/sampling_time" ] = sampling_time
187
- metrics_to_log ["train/training_time" ] = training_time
188
-
189
202
# Evaluation
190
203
if abs (collected_frames % eval_iter ) < frames_per_batch :
191
- with set_exploration_type (ExplorationType .DETERMINISTIC ), torch .no_grad ():
192
- eval_start = time .time ()
204
+ with set_exploration_type (
205
+ ExplorationType .DETERMINISTIC
206
+ ), torch .no_grad (), timeit ("eval" ):
193
207
eval_rollout = eval_env .rollout (
194
208
eval_rollout_steps ,
195
209
model [0 ],
196
210
auto_cast_to_device = True ,
197
211
break_when_any_done = True ,
198
212
)
199
213
eval_env .apply (dump_video )
200
- eval_time = time .time () - eval_start
201
214
eval_reward = eval_rollout ["next" , "reward" ].sum (- 2 ).mean ().item ()
202
215
metrics_to_log ["eval/reward" ] = eval_reward
203
- metrics_to_log ["eval/time" ] = eval_time
216
+
217
+ # Logging
218
+ if len (episode_rewards ) > 0 :
219
+ episode_length = tensordict ["next" , "step_count" ][
220
+ tensordict ["next" , "done" ]
221
+ ]
222
+ metrics_to_log ["train/reward" ] = episode_rewards .mean ().item ()
223
+ metrics_to_log ["train/episode_length" ] = episode_length .sum ().item () / len (
224
+ episode_length
225
+ )
226
+ if collected_frames >= init_random_frames :
227
+ metrics_to_log ["train/q_loss" ] = metadata ["q_loss" ]
228
+ metrics_to_log ["train/actor_loss" ] = metadata ["actor_loss" ]
229
+ metrics_to_log ["train/value_loss" ] = metadata ["value_loss" ]
230
+ metrics_to_log .update (timeit .todict (prefix = "time" ))
204
231
if logger is not None :
205
232
log_metrics (logger , metrics_to_log , collected_frames )
206
- sampling_start = time . time ()
233
+ timeit . erase ()
207
234
208
235
collector .shutdown ()
209
- end_time = time .time ()
210
- execution_time = end_time - start_time
211
- torchrl_logger .info (f"Training took { execution_time :.2f} seconds to finish" )
212
236
213
237
214
238
if __name__ == "__main__" :
0 commit comments