11
11
"""
12
12
from __future__ import annotations
13
13
14
- import time
14
+ import warnings
15
15
16
16
import hydra
17
17
import numpy as np
18
18
import torch
19
19
import tqdm
20
- from torchrl ._utils import logger as torchrl_logger
20
+ from tensordict import TensorDict
21
+ from tensordict .nn import CudaGraphModule
22
+
23
+ from torchrl ._utils import compile_with_warmup , timeit
21
24
22
25
from torchrl .envs import set_gym_backend
23
26
from torchrl .envs .utils import ExplorationType , set_exploration_type
@@ -72,7 +75,16 @@ def main(cfg: "DictConfig"): # noqa: F821
72
75
)
73
76
74
77
# Create replay buffer
75
- replay_buffer = make_offline_replay_buffer (cfg .replay_buffer )
78
+ replay_buffer = make_offline_replay_buffer (cfg .replay_buffer , device = device )
79
+
80
+ compile_mode = None
81
+ if cfg .compile .compile :
82
+ compile_mode = cfg .compile .compile_mode
83
+ if compile_mode in ("" , None ):
84
+ if cfg .compile .cudagraphs :
85
+ compile_mode = "default"
86
+ else :
87
+ compile_mode = "reduce-overhead"
76
88
77
89
# Create agent
78
90
model , _ = make_td3_agent (cfg , eval_env , device )
@@ -83,67 +95,86 @@ def main(cfg: "DictConfig"): # noqa: F821
83
95
# Create optimizer
84
96
optimizer_actor , optimizer_critic = make_optimizer (cfg .optim , loss_module )
85
97
86
- gradient_steps = cfg .optim .gradient_steps
87
- evaluation_interval = cfg .logger .eval_iter
88
- eval_steps = cfg .logger .eval_steps
89
- delayed_updates = cfg .optim .policy_update_delay
90
- update_counter = 0
91
- pbar = tqdm .tqdm (range (gradient_steps ))
92
- # Training loop
93
- start_time = time .time ()
94
- for i in pbar :
95
- pbar .update (1 )
96
- # Update actor every delayed_updates
97
- update_counter += 1
98
- update_actor = update_counter % delayed_updates == 0
99
-
100
- # Sample from replay buffer
101
- sampled_tensordict = replay_buffer .sample ()
102
- if sampled_tensordict .device != device :
103
- sampled_tensordict = sampled_tensordict .to (device )
104
- else :
105
- sampled_tensordict = sampled_tensordict .clone ()
106
-
98
+ def update (sampled_tensordict , update_actor ):
107
99
# Compute loss
108
100
q_loss , * _ = loss_module .qvalue_loss (sampled_tensordict )
109
101
110
102
# Update critic
111
- optimizer_critic .zero_grad ()
112
103
q_loss .backward ()
113
104
optimizer_critic .step ()
114
- q_loss .item ()
115
-
116
- to_log = {"q_loss" : q_loss .item ()}
105
+ optimizer_critic .zero_grad (set_to_none = True )
117
106
118
107
# Update actor
119
108
if update_actor :
120
109
actor_loss , actorloss_metadata = loss_module .actor_loss (sampled_tensordict )
121
- optimizer_actor .zero_grad ()
122
110
actor_loss .backward ()
123
111
optimizer_actor .step ()
112
+ optimizer_actor .zero_grad (set_to_none = True )
124
113
125
114
# Update target params
126
115
target_net_updater .step ()
116
+ else :
117
+ actorloss_metadata = {}
118
+ actor_loss = q_loss .new_zeros (())
119
+ metadata = TensorDict (actorloss_metadata )
120
+ metadata .set ("q_loss" , q_loss .detach ())
121
+ metadata .set ("actor_loss" , actor_loss .detach ())
122
+ return metadata
123
+
124
+ if cfg .compile .compile :
125
+ update = compile_with_warmup (update , mode = compile_mode , warmup = 1 )
126
+
127
+ if cfg .compile .cudagraphs :
128
+ warnings .warn (
129
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution." ,
130
+ category = UserWarning ,
131
+ )
132
+ update = CudaGraphModule (update , in_keys = [], out_keys = [], warmup = 5 )
133
+
134
+ gradient_steps = cfg .optim .gradient_steps
135
+ evaluation_interval = cfg .logger .eval_iter
136
+ eval_steps = cfg .logger .eval_steps
137
+ delayed_updates = cfg .optim .policy_update_delay
138
+ pbar = tqdm .tqdm (range (gradient_steps ))
139
+ # Training loop
140
+ for update_counter in pbar :
141
+ timeit .printevery (num_prints = 1000 , total_count = gradient_steps , erase = True )
127
142
128
- to_log ["actor_loss" ] = actor_loss .item ()
129
- to_log .update (actorloss_metadata )
143
+ # Update actor every delayed_updates
144
+ update_actor = update_counter % delayed_updates == 0
145
+
146
+ with timeit ("rb - sample" ):
147
+ # Sample from replay buffer
148
+ sampled_tensordict = replay_buffer .sample ()
149
+
150
+ with timeit ("update" ):
151
+ torch .compiler .cudagraph_mark_step_begin ()
152
+ metadata = update (sampled_tensordict , update_actor ).clone ()
153
+
154
+ to_log = {}
155
+ if update_actor :
156
+ to_log .update (metadata .to_dict ())
157
+ else :
158
+ to_log .update (metadata .exclude ("actor_loss" ).to_dict ())
130
159
131
160
# evaluation
132
- if i % evaluation_interval == 0 :
133
- with set_exploration_type (ExplorationType .DETERMINISTIC ), torch .no_grad ():
161
+ if update_counter % evaluation_interval == 0 :
162
+ with set_exploration_type (
163
+ ExplorationType .DETERMINISTIC
164
+ ), torch .no_grad (), timeit ("eval" ):
134
165
eval_td = eval_env .rollout (
135
166
max_steps = eval_steps , policy = model [0 ], auto_cast_to_device = True
136
167
)
137
168
eval_env .apply (dump_video )
138
169
eval_reward = eval_td ["next" , "reward" ].sum (1 ).mean ().item ()
139
170
to_log ["evaluation_reward" ] = eval_reward
140
171
if logger is not None :
141
- log_metrics (logger , to_log , i )
172
+ to_log .update (timeit .todict (prefix = "time" ))
173
+ log_metrics (logger , to_log , update_counter )
142
174
143
175
if not eval_env .is_closed :
144
176
eval_env .close ()
145
177
pbar .close ()
146
- torchrl_logger .info (f"Training time: { time .time () - start_time } " )
147
178
148
179
149
180
if __name__ == "__main__" :
0 commit comments