15
15
import numpy as np
16
16
import torch
17
17
import tqdm
18
- from torchrl ._utils import logger as torchrl_logger
18
+ from tensordict .nn import CudaGraphModule
19
+
20
+ from torchrl ._utils import logger as torchrl_logger , timeit
19
21
from torchrl .envs .utils import ExplorationType , set_exploration_type
22
+ from torchrl .objectives import group_optimizers
20
23
from torchrl .record .loggers import generate_exp_name , get_logger
21
24
22
25
from utils import (
@@ -69,6 +72,9 @@ def main(cfg: "DictConfig"): # noqa: F821
69
72
# Create agent
70
73
model = make_cql_model (cfg , train_env , eval_env , device )
71
74
del train_env
75
+ if hasattr (eval_env , "start" ):
76
+ # To set the number of threads to the definitive value
77
+ eval_env .start ()
72
78
73
79
# Create loss
74
80
loss_module , target_net_updater = make_continuous_loss (cfg .loss , model )
@@ -81,81 +87,104 @@ def main(cfg: "DictConfig"): # noqa: F821
81
87
alpha_prime_optim ,
82
88
) = make_continuous_cql_optimizer (cfg , loss_module )
83
89
84
- pbar = tqdm .tqdm (total = cfg .optim .gradient_steps )
90
+ # Group optimizers
91
+ optimizer = group_optimizers (
92
+ policy_optim , critic_optim , alpha_optim , alpha_prime_optim
93
+ )
85
94
86
- gradient_steps = cfg .optim .gradient_steps
87
- policy_eval_start = cfg .optim .policy_eval_start
88
- evaluation_interval = cfg .logger .eval_iter
89
- eval_steps = cfg .logger .eval_steps
90
-
91
- # Training loop
92
- start_time = time .time ()
93
- for i in range (gradient_steps ):
94
- pbar .update (1 )
95
- # sample data
96
- data = replay_buffer .sample ()
97
- # compute loss
98
- loss_vals = loss_module (data .clone ().to (device ))
95
+ def update (data , policy_eval_start , iteration ):
96
+ loss_vals = loss_module (data .to (device ))
99
97
100
98
# official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
101
- if i >= policy_eval_start :
102
- actor_loss = loss_vals ["loss_actor" ]
103
- else :
104
- actor_loss = loss_vals ["loss_actor_bc" ]
99
+ actor_loss = torch .where (
100
+ iteration >= policy_eval_start ,
101
+ loss_vals ["loss_actor" ],
102
+ loss_vals ["loss_actor_bc" ],
103
+ )
105
104
q_loss = loss_vals ["loss_qvalue" ]
106
105
cql_loss = loss_vals ["loss_cql" ]
107
106
108
107
q_loss = q_loss + cql_loss
108
+ loss_vals ["q_loss" ] = q_loss
109
109
110
110
# update model
111
111
alpha_loss = loss_vals ["loss_alpha" ]
112
112
alpha_prime_loss = loss_vals ["loss_alpha_prime" ]
113
+ if alpha_prime_loss is None :
114
+ alpha_prime_loss = 0
113
115
114
- alpha_optim .zero_grad ()
115
- alpha_loss .backward ()
116
- alpha_optim .step ()
116
+ loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
117
117
118
- policy_optim . zero_grad ()
119
- actor_loss . backward ()
120
- policy_optim . step ( )
118
+ loss . backward ()
119
+ optimizer . step ()
120
+ optimizer . zero_grad ( set_to_none = True )
121
121
122
- if alpha_prime_optim is not None :
123
- alpha_prime_optim .zero_grad ()
124
- alpha_prime_loss .backward (retain_graph = True )
125
- alpha_prime_optim .step ()
122
+ # update qnet_target params
123
+ target_net_updater .step ()
126
124
127
- critic_optim .zero_grad ()
128
- # TODO: we have the option to compute losses independently retain is not needed?
129
- q_loss .backward (retain_graph = False )
130
- critic_optim .step ()
125
+ return loss .detach (), loss_vals .detach ()
131
126
132
- loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
127
+ compile_mode = None
128
+ if cfg .compile .compile :
129
+ if cfg .compile .compile_mode not in (None , "" ):
130
+ compile_mode = cfg .compile .compile_mode
131
+ elif cfg .compile .cudagraphs :
132
+ compile_mode = "default"
133
+ else :
134
+ compile_mode = "reduce-overhead"
135
+ update = torch .compile (update , mode = compile_mode )
136
+ if cfg .compile .cudagraphs :
137
+ update = CudaGraphModule (update , warmup = 50 )
138
+
139
+ pbar = tqdm .tqdm (total = cfg .optim .gradient_steps )
140
+
141
+ gradient_steps = cfg .optim .gradient_steps
142
+ policy_eval_start = cfg .optim .policy_eval_start
143
+ evaluation_interval = cfg .logger .eval_iter
144
+ eval_steps = cfg .logger .eval_steps
145
+
146
+ # Training loop
147
+ start_time = time .time ()
148
+ policy_eval_start = torch .tensor (policy_eval_start , device = device )
149
+ for i in range (gradient_steps ):
150
+ pbar .update (1 )
151
+ # sample data
152
+ with timeit ("sample" ):
153
+ data = replay_buffer .sample ()
154
+
155
+ with timeit ("update" ):
156
+ # compute loss
157
+ i_device = torch .tensor (i , device = device )
158
+ loss , loss_vals = update (
159
+ data .to (device ), policy_eval_start = policy_eval_start , iteration = i_device
160
+ )
133
161
134
162
# log metrics
135
163
to_log = {
136
- "loss" : loss .item (),
137
- "loss_actor_bc" : loss_vals ["loss_actor_bc" ].item (),
138
- "loss_actor" : loss_vals ["loss_actor" ].item (),
139
- "loss_qvalue" : q_loss .item (),
140
- "loss_cql" : cql_loss .item (),
141
- "loss_alpha" : alpha_loss .item (),
142
- "loss_alpha_prime" : alpha_prime_loss .item (),
164
+ "loss" : loss .cpu (),
165
+ ** loss_vals .cpu (),
143
166
}
144
167
145
- # update qnet_target params
146
- target_net_updater .step ()
147
-
148
168
# evaluation
149
- if i % evaluation_interval == 0 :
150
- with set_exploration_type (ExplorationType .DETERMINISTIC ), torch .no_grad ():
151
- eval_td = eval_env .rollout (
152
- max_steps = eval_steps , policy = model [0 ], auto_cast_to_device = True
153
- )
154
- eval_env .apply (dump_video )
155
- eval_reward = eval_td ["next" , "reward" ].sum (1 ).mean ().item ()
156
- to_log ["evaluation_reward" ] = eval_reward
157
-
158
- log_metrics (logger , to_log , i )
169
+ with timeit ("log/eval" ):
170
+ if i % evaluation_interval == 0 :
171
+ with set_exploration_type (
172
+ ExplorationType .DETERMINISTIC
173
+ ), torch .no_grad ():
174
+ eval_td = eval_env .rollout (
175
+ max_steps = eval_steps , policy = model [0 ], auto_cast_to_device = True
176
+ )
177
+ eval_env .apply (dump_video )
178
+ eval_reward = eval_td ["next" , "reward" ].sum (1 ).mean ().item ()
179
+ to_log ["evaluation_reward" ] = eval_reward
180
+
181
+ with timeit ("log" ):
182
+ if i % 200 == 0 :
183
+ to_log .update (timeit .todict (prefix = "time" ))
184
+ log_metrics (logger , to_log , i )
185
+ if i % 200 == 0 :
186
+ timeit .print ()
187
+ timeit .erase ()
159
188
160
189
pbar .close ()
161
190
torchrl_logger .info (f"Training time: { time .time () - start_time } " )
0 commit comments