11
11
The helper functions are coded in the utils.py associated with this script.
12
12
"""
13
13
14
+ import time
15
+
14
16
import hydra
15
17
16
18
import numpy as np
17
19
import torch
18
20
import torch .cuda
19
21
import tqdm
20
-
22
+ from tensordict import TensorDict
21
23
from torchrl .envs .utils import ExplorationType , set_exploration_type
22
24
23
25
from torchrl .record .loggers import generate_exp_name , get_logger
24
26
from utils import (
27
+ log_metrics ,
25
28
make_collector ,
26
29
make_environment ,
27
30
make_loss_module ,
35
38
def main (cfg : "DictConfig" ): # noqa: F821
36
39
device = torch .device (cfg .network .device )
37
40
41
+ # Create logger
38
42
exp_name = generate_exp_name ("SAC" , cfg .env .exp_name )
39
43
logger = None
40
44
if cfg .logger .backend :
@@ -48,132 +52,158 @@ def main(cfg: "DictConfig"): # noqa: F821
48
52
torch .manual_seed (cfg .env .seed )
49
53
np .random .seed (cfg .env .seed )
50
54
51
- # Create Environments
55
+ # Create environments
52
56
train_env , eval_env = make_environment (cfg )
53
- # Create Agent
57
+
58
+ # Create agent
54
59
model , exploration_policy = make_sac_agent (cfg , train_env , eval_env , device )
55
60
56
- # Create TD3 loss
61
+ # Create SAC loss
57
62
loss_module , target_net_updater = make_loss_module (cfg , model )
58
63
59
- # Make Off-Policy Collector
64
+ # Create off-policy collector
60
65
collector = make_collector (cfg , train_env , exploration_policy )
61
66
62
- # Make Replay Buffer
67
+ # Create replay buffer
63
68
replay_buffer = make_replay_buffer (
64
- batch_size = cfg .optimization .batch_size ,
69
+ batch_size = cfg .optim .batch_size ,
65
70
prb = cfg .replay_buffer .prb ,
66
71
buffer_size = cfg .replay_buffer .size ,
72
+ buffer_scratch_dir = "/tmp/" + cfg .replay_buffer .scratch_dir ,
67
73
device = device ,
68
74
)
69
75
70
- # Make Optimizers
71
- optimizer = make_sac_optimizer (cfg , loss_module )
72
-
73
- rewards = []
74
- rewards_eval = []
76
+ # Create optimizers
77
+ (
78
+ optimizer_actor ,
79
+ optimizer_critic ,
80
+ optimizer_alpha ,
81
+ ) = make_sac_optimizer (cfg , loss_module )
75
82
76
83
# Main loop
84
+ start_time = time .time ()
77
85
collected_frames = 0
78
86
pbar = tqdm .tqdm (total = cfg .collector .total_frames )
79
- r0 = None
80
- q_loss = None
81
87
82
88
init_random_frames = cfg .collector .init_random_frames
83
89
num_updates = int (
84
90
cfg .collector .env_per_collector
85
91
* cfg .collector .frames_per_batch
86
- * cfg .optimization .utd_ratio
92
+ * cfg .optim .utd_ratio
87
93
)
88
94
prb = cfg .replay_buffer .prb
89
- env_per_collector = cfg .collector .env_per_collector
90
95
eval_iter = cfg .logger .eval_iter
91
- frames_per_batch , frame_skip = cfg .collector .frames_per_batch , cfg . env . frame_skip
92
- eval_rollout_steps = cfg .collector . max_frames_per_traj // frame_skip
96
+ frames_per_batch = cfg .collector .frames_per_batch
97
+ eval_rollout_steps = cfg .env . max_episode_steps
93
98
99
+ sampling_start = time .time ()
94
100
for i , tensordict in enumerate (collector ):
95
- # update weights of the inference policy
101
+ sampling_time = time .time () - sampling_start
102
+
103
+ # Update weights of the inference policy
96
104
collector .update_policy_weights_ ()
97
105
98
- if r0 is None :
99
- r0 = tensordict ["next" , "reward" ].sum (- 1 ).mean ().item ()
100
106
pbar .update (tensordict .numel ())
101
107
102
- tensordict = tensordict .view (- 1 )
108
+ tensordict = tensordict .reshape (- 1 )
103
109
current_frames = tensordict .numel ()
110
+ # Add to replay buffer
104
111
replay_buffer .extend (tensordict .cpu ())
105
112
collected_frames += current_frames
106
113
107
- # optimization steps
114
+ # Optimization steps
115
+ training_start = time .time ()
108
116
if collected_frames >= init_random_frames :
109
- (actor_losses , q_losses , alpha_losses ) = ([], [], [])
110
- for _ in range (num_updates ):
111
- # sample from replay buffer
117
+ losses = TensorDict (
118
+ {},
119
+ batch_size = [
120
+ num_updates ,
121
+ ],
122
+ )
123
+ for i in range (num_updates ):
124
+ # Sample from replay buffer
112
125
sampled_tensordict = replay_buffer .sample ().clone ()
113
126
127
+ # Compute loss
114
128
loss_td = loss_module (sampled_tensordict )
115
129
116
130
actor_loss = loss_td ["loss_actor" ]
117
131
q_loss = loss_td ["loss_qvalue" ]
118
132
alpha_loss = loss_td ["loss_alpha" ]
119
- loss = actor_loss + q_loss + alpha_loss
120
133
121
- optimizer .zero_grad ()
122
- loss .backward ()
123
- optimizer .step ()
134
+ # Update actor
135
+ optimizer_actor .zero_grad ()
136
+ actor_loss .backward ()
137
+ optimizer_actor .step ()
124
138
125
- q_losses .append (q_loss .item ())
126
- actor_losses .append (actor_loss .item ())
127
- alpha_losses .append (alpha_loss .item ())
139
+ # Update critic
140
+ optimizer_critic .zero_grad ()
141
+ q_loss .backward ()
142
+ optimizer_critic .step ()
128
143
129
- # update qnet_target params
144
+ # Update alpha
145
+ optimizer_alpha .zero_grad ()
146
+ alpha_loss .backward ()
147
+ optimizer_alpha .step ()
148
+
149
+ losses [i ] = loss_td .select (
150
+ "loss_actor" , "loss_qvalue" , "loss_alpha"
151
+ ).detach ()
152
+
153
+ # Update qnet_target params
130
154
target_net_updater .step ()
131
155
132
- # update priority
156
+ # Update priority
133
157
if prb :
134
158
replay_buffer .update_priority (sampled_tensordict )
135
159
136
- rewards .append (
137
- (i , tensordict ["next" , "reward" ].sum ().item () / env_per_collector )
160
+ training_time = time .time () - training_start
161
+ episode_end = (
162
+ tensordict ["next" , "done" ]
163
+ if tensordict ["next" , "done" ].any ()
164
+ else tensordict ["next" , "truncated" ]
138
165
)
139
- train_log = {
140
- "train_reward" : rewards [- 1 ][1 ],
141
- "collected_frames" : collected_frames ,
142
- }
143
- if q_loss is not None :
144
- train_log .update (
145
- {
146
- "actor_loss" : np .mean (actor_losses ),
147
- "q_loss" : np .mean (q_losses ),
148
- "alpha_loss" : np .mean (alpha_losses ),
149
- "alpha" : loss_td ["alpha" ],
150
- "entropy" : loss_td ["entropy" ],
151
- }
166
+ episode_rewards = tensordict ["next" , "episode_reward" ][episode_end ]
167
+
168
+ # Logging
169
+ metrics_to_log = {}
170
+ if len (episode_rewards ) > 0 :
171
+ episode_length = tensordict ["next" , "step_count" ][episode_end ]
172
+ metrics_to_log ["train/reward" ] = episode_rewards .mean ().item ()
173
+ metrics_to_log ["train/episode_length" ] = episode_length .sum ().item () / len (
174
+ episode_length
152
175
)
153
- if logger is not None :
154
- for key , value in train_log .items ():
155
- logger .log_scalar (key , value , step = collected_frames )
156
- if abs (collected_frames % eval_iter ) < frames_per_batch * frame_skip :
176
+ if collected_frames >= init_random_frames :
177
+ metrics_to_log ["train/q_loss" ] = losses .get ("loss_qvalue" ).mean ().item ()
178
+ metrics_to_log ["train/actor_loss" ] = losses .get ("loss_actor" ).mean ().item ()
179
+ metrics_to_log ["train/alpha_loss" ] = losses .get ("loss_alpha" ).mean ().item ()
180
+ metrics_to_log ["train/alpha" ] = loss_td ["alpha" ].item ()
181
+ metrics_to_log ["train/entropy" ] = loss_td ["entropy" ].item ()
182
+ metrics_to_log ["train/sampling_time" ] = sampling_time
183
+ metrics_to_log ["train/training_time" ] = training_time
184
+
185
+ # Evaluation
186
+ if abs (collected_frames % eval_iter ) < frames_per_batch :
157
187
with set_exploration_type (ExplorationType .MODE ), torch .no_grad ():
188
+ eval_start = time .time ()
158
189
eval_rollout = eval_env .rollout (
159
190
eval_rollout_steps ,
160
191
model [0 ],
161
192
auto_cast_to_device = True ,
162
193
break_when_any_done = True ,
163
194
)
195
+ eval_time = time .time () - eval_start
164
196
eval_reward = eval_rollout ["next" , "reward" ].sum (- 2 ).mean ().item ()
165
- rewards_eval .append ((i , eval_reward ))
166
- eval_str = f"eval cumulative reward: { rewards_eval [- 1 ][1 ]: 4.4f} (init: { rewards_eval [0 ][1 ]: 4.4f} )"
167
- if logger is not None :
168
- logger .log_scalar (
169
- "evaluation_reward" , rewards_eval [- 1 ][1 ], step = collected_frames
170
- )
171
- if len (rewards_eval ):
172
- pbar .set_description (
173
- f"reward: { rewards [- 1 ][1 ]: 4.4f} (r0 = { r0 : 4.4f} )," + eval_str
174
- )
197
+ metrics_to_log ["eval/reward" ] = eval_reward
198
+ metrics_to_log ["eval/time" ] = eval_time
199
+ if logger is not None :
200
+ log_metrics (logger , metrics_to_log , collected_frames )
201
+ sampling_start = time .time ()
175
202
176
203
collector .shutdown ()
204
+ end_time = time .time ()
205
+ execution_time = end_time - start_time
206
+ print (f"Training took { execution_time :.2f} seconds to finish" )
177
207
178
208
179
209
if __name__ == "__main__" :
0 commit comments