11
11
The helper functions are coded in the utils.py associated with this script.
12
12
"""
13
13
14
- import hydra
14
+ import time
15
15
16
+ import hydra
16
17
import numpy as np
17
18
import torch
18
19
import torch .cuda
22
23
23
24
from torchrl .record .loggers import generate_exp_name , get_logger
24
25
from utils import (
26
+ log_metrics ,
25
27
make_collector ,
26
28
make_environment ,
27
29
make_loss_module ,
35
37
def main (cfg : "DictConfig" ): # noqa: F821
36
38
device = torch .device (cfg .network .device )
37
39
40
+ # Create logger
38
41
exp_name = generate_exp_name ("TD3" , cfg .env .exp_name )
39
42
logger = None
40
43
if cfg .logger .backend :
@@ -45,140 +48,155 @@ def main(cfg: "DictConfig"): # noqa: F821
45
48
wandb_kwargs = {"mode" : cfg .logger .mode , "config" : cfg },
46
49
)
47
50
51
+ # Set seeds
48
52
torch .manual_seed (cfg .env .seed )
49
53
np .random .seed (cfg .env .seed )
50
54
51
- # Create Environments
55
+ # Create environments
52
56
train_env , eval_env = make_environment (cfg )
53
57
54
- # Create Agent
58
+ # Create agent
55
59
model , exploration_policy = make_td3_agent (cfg , train_env , eval_env , device )
56
60
57
61
# Create TD3 loss
58
62
loss_module , target_net_updater = make_loss_module (cfg , model )
59
63
60
- # Make Off-Policy Collector
64
+ # Create off-policy collector
61
65
collector = make_collector (cfg , train_env , exploration_policy )
62
66
63
- # Make Replay Buffer
67
+ # Create replay buffer
64
68
replay_buffer = make_replay_buffer (
65
- batch_size = cfg .optimization .batch_size ,
69
+ batch_size = cfg .optim .batch_size ,
66
70
prb = cfg .replay_buffer .prb ,
67
71
buffer_size = cfg .replay_buffer .size ,
72
+ buffer_scratch_dir = "/tmp/" + cfg .replay_buffer .scratch_dir ,
68
73
device = device ,
69
74
)
70
75
71
- # Make Optimizers
76
+ # Create optimizers
72
77
optimizer_actor , optimizer_critic = make_optimizer (cfg , loss_module )
73
78
74
- rewards = []
75
- rewards_eval = []
76
-
77
79
# Main loop
80
+ start_time = time .time ()
78
81
collected_frames = 0
79
82
pbar = tqdm .tqdm (total = cfg .collector .total_frames )
80
- r0 = None
81
- q_loss = None
82
83
83
84
init_random_frames = cfg .collector .init_random_frames
84
85
num_updates = int (
85
86
cfg .collector .env_per_collector
86
87
* cfg .collector .frames_per_batch
87
- * cfg .optimization .utd_ratio
88
+ * cfg .optim .utd_ratio
88
89
)
89
- delayed_updates = cfg .optimization .policy_update_delay
90
+ delayed_updates = cfg .optim .policy_update_delay
90
91
prb = cfg .replay_buffer .prb
91
- env_per_collector = cfg .collector .env_per_collector
92
- eval_rollout_steps = cfg .collector .max_frames_per_traj // cfg .env .frame_skip
92
+ eval_rollout_steps = cfg .env .max_episode_steps
93
93
eval_iter = cfg .logger .eval_iter
94
- frames_per_batch , frame_skip = cfg .collector .frames_per_batch , cfg .env .frame_skip
94
+ frames_per_batch = cfg .collector .frames_per_batch
95
+ update_counter = 0
95
96
96
- for i , tensordict in enumerate (collector ):
97
+ sampling_start = time .time ()
98
+ for tensordict in collector :
99
+ sampling_time = time .time () - sampling_start
97
100
exploration_policy .step (tensordict .numel ())
98
- # update weights of the inference policy
101
+
102
+ # Update weights of the inference policy
99
103
collector .update_policy_weights_ ()
100
104
101
- if r0 is None :
102
- r0 = tensordict ["next" , "reward" ].sum (- 1 ).mean ().item ()
103
105
pbar .update (tensordict .numel ())
104
106
105
107
tensordict = tensordict .reshape (- 1 )
106
108
current_frames = tensordict .numel ()
109
+ # Add to replay buffer
107
110
replay_buffer .extend (tensordict .cpu ())
108
111
collected_frames += current_frames
109
112
110
- # optimization steps
113
+ # Optimization steps
114
+ training_start = time .time ()
111
115
if collected_frames >= init_random_frames :
112
116
(
113
117
actor_losses ,
114
118
q_losses ,
115
119
) = ([], [])
116
- for j in range (num_updates ):
117
- # sample from replay buffer
118
- sampled_tensordict = replay_buffer .sample ().clone ()
120
+ for _ in range (num_updates ):
121
+
122
+ # Update actor every delayed_updates
123
+ update_counter += 1
124
+ update_actor = update_counter % delayed_updates == 0
119
125
120
- loss_td = loss_module (sampled_tensordict )
126
+ # Sample from replay buffer
127
+ sampled_tensordict = replay_buffer .sample ().clone ()
121
128
122
- actor_loss = loss_td [ "loss_actor" ]
123
- q_loss = loss_td [ "loss_qvalue" ]
129
+ # Compute loss
130
+ q_loss , * _ = loss_module . value_loss ( sampled_tensordict )
124
131
132
+ # Update critic
125
133
optimizer_critic .zero_grad ()
126
- update_actor = j % delayed_updates == 0
127
- q_loss .backward (retain_graph = update_actor )
134
+ q_loss .backward ()
128
135
optimizer_critic .step ()
129
136
q_losses .append (q_loss .item ())
130
137
138
+ # Update actor
131
139
if update_actor :
140
+ actor_loss , * _ = loss_module .actor_loss (sampled_tensordict )
132
141
optimizer_actor .zero_grad ()
133
142
actor_loss .backward ()
134
143
optimizer_actor .step ()
144
+
135
145
actor_losses .append (actor_loss .item ())
136
146
137
- # update qnet_target params
147
+ # Update target params
138
148
target_net_updater .step ()
139
149
140
- # update priority
150
+ # Update priority
141
151
if prb :
142
152
replay_buffer .update_priority (sampled_tensordict )
143
153
144
- rewards .append (
145
- (i , tensordict ["next" , "reward" ].sum ().item () / env_per_collector )
154
+ training_time = time .time () - training_start
155
+ episode_end = (
156
+ tensordict ["next" , "done" ]
157
+ if tensordict ["next" , "done" ].any ()
158
+ else tensordict ["next" , "truncated" ]
146
159
)
147
- train_log = {
148
- "train_reward" : rewards [- 1 ][1 ],
149
- "collected_frames" : collected_frames ,
150
- }
151
- if q_loss is not None :
152
- train_log .update (
153
- {
154
- "actor_loss" : np .mean (actor_losses ),
155
- "q_loss" : np .mean (q_losses ),
156
- }
160
+ episode_rewards = tensordict ["next" , "episode_reward" ][episode_end ]
161
+
162
+ # Logging
163
+ metrics_to_log = {}
164
+ if len (episode_rewards ) > 0 :
165
+ episode_length = tensordict ["next" , "step_count" ][episode_end ]
166
+ metrics_to_log ["train/reward" ] = episode_rewards .mean ().item ()
167
+ metrics_to_log ["train/episode_length" ] = episode_length .sum ().item () / len (
168
+ episode_length
157
169
)
158
- if logger is not None :
159
- for key , value in train_log .items ():
160
- logger .log_scalar (key , value , step = collected_frames )
161
- if abs (collected_frames % eval_iter ) < frames_per_batch * frame_skip :
170
+
171
+ if collected_frames >= init_random_frames :
172
+ metrics_to_log ["train/q_loss" ] = np .mean (q_losses )
173
+ if update_actor :
174
+ metrics_to_log ["train/a_loss" ] = np .mean (actor_losses )
175
+ metrics_to_log ["train/sampling_time" ] = sampling_time
176
+ metrics_to_log ["train/training_time" ] = training_time
177
+
178
+ # Evaluation
179
+ if abs (collected_frames % eval_iter ) < frames_per_batch :
162
180
with set_exploration_type (ExplorationType .MODE ), torch .no_grad ():
181
+ eval_start = time .time ()
163
182
eval_rollout = eval_env .rollout (
164
183
eval_rollout_steps ,
165
184
exploration_policy ,
166
185
auto_cast_to_device = True ,
167
186
break_when_any_done = True ,
168
187
)
188
+ eval_time = time .time () - eval_start
169
189
eval_reward = eval_rollout ["next" , "reward" ].sum (- 2 ).mean ().item ()
170
- rewards_eval .append ((i , eval_reward ))
171
- eval_str = f"eval cumulative reward: { rewards_eval [- 1 ][1 ]: 4.4f} (init: { rewards_eval [0 ][1 ]: 4.4f} )"
172
- if logger is not None :
173
- logger .log_scalar (
174
- "evaluation_reward" , rewards_eval [- 1 ][1 ], step = collected_frames
175
- )
176
- if len (rewards_eval ):
177
- pbar .set_description (
178
- f"reward: { rewards [- 1 ][1 ]: 4.4f} (r0 = { r0 : 4.4f} )," + eval_str
179
- )
190
+ metrics_to_log ["eval/reward" ] = eval_reward
191
+ metrics_to_log ["eval/time" ] = eval_time
192
+ if logger is not None :
193
+ log_metrics (logger , metrics_to_log , collected_frames )
194
+ sampling_start = time .time ()
180
195
181
196
collector .shutdown ()
197
+ end_time = time .time ()
198
+ execution_time = end_time - start_time
199
+ print (f"Training took { execution_time :.2f} seconds to finish" )
182
200
183
201
184
202
if __name__ == "__main__" :
0 commit comments