11
11
The helper functions are coded in the utils.py associated with this script.
12
12
"""
13
13
14
+ import time
15
+
14
16
import hydra
15
17
16
18
import numpy as np
17
19
import torch
18
20
import torch .cuda
19
21
import tqdm
22
+
20
23
from torchrl .envs .utils import ExplorationType , set_exploration_type
21
24
from torchrl .record .loggers import generate_exp_name , get_logger
22
25
from utils import (
26
+ log_metrics ,
23
27
make_collector ,
24
28
make_ddpg_agent ,
25
29
make_environment ,
33
37
def main (cfg : "DictConfig" ): # noqa: F821
34
38
device = torch .device (cfg .network .device )
35
39
40
+ # Create logger
36
41
exp_name = generate_exp_name ("DDPG" , cfg .env .exp_name )
37
42
logger = None
38
43
if cfg .logger .backend :
@@ -43,137 +48,149 @@ def main(cfg: "DictConfig"): # noqa: F821
43
48
wandb_kwargs = {"mode" : cfg .logger .mode , "config" : cfg },
44
49
)
45
50
51
+ # Set seeds
46
52
torch .manual_seed (cfg .env .seed )
47
53
np .random .seed (cfg .env .seed )
48
54
49
- # Create Environments
55
+ # Create environments
50
56
train_env , eval_env = make_environment (cfg )
51
57
52
- # Create Agent
58
+ # Create agent
53
59
model , exploration_policy = make_ddpg_agent (cfg , train_env , eval_env , device )
54
60
55
- # Create Loss Module and Target Updater
61
+ # Create DDPG loss
56
62
loss_module , target_net_updater = make_loss_module (cfg , model )
57
63
58
- # Make Off-Policy Collector
64
+ # Create off-policy collector
59
65
collector = make_collector (cfg , train_env , exploration_policy )
60
66
61
- # Make Replay Buffer
67
+ # Create replay buffer
62
68
replay_buffer = make_replay_buffer (
63
- batch_size = cfg .optimization .batch_size ,
69
+ batch_size = cfg .optim .batch_size ,
64
70
prb = cfg .replay_buffer .prb ,
65
71
buffer_size = cfg .replay_buffer .size ,
72
+ buffer_scratch_dir = "/tmp/" + cfg .replay_buffer .scratch_dir ,
66
73
device = device ,
67
74
)
68
75
69
- # Make Optimizers
76
+ # Create optimizers
70
77
optimizer_actor , optimizer_critic = make_optimizer (cfg , loss_module )
71
78
72
- rewards = []
73
- rewards_eval = []
74
-
75
79
# Main loop
80
+ start_time = time .time ()
76
81
collected_frames = 0
77
82
pbar = tqdm .tqdm (total = cfg .collector .total_frames )
78
- r0 = None
79
- q_loss = None
80
83
81
84
init_random_frames = cfg .collector .init_random_frames
82
85
num_updates = int (
83
86
cfg .collector .env_per_collector
84
87
* cfg .collector .frames_per_batch
85
- * cfg .optimization .utd_ratio
88
+ * cfg .optim .utd_ratio
86
89
)
87
90
prb = cfg .replay_buffer .prb
88
- env_per_collector = cfg .collector .env_per_collector
89
- frames_per_batch , frame_skip = cfg .collector .frames_per_batch , cfg .env .frame_skip
91
+ frames_per_batch = cfg .collector .frames_per_batch
90
92
eval_iter = cfg .logger .eval_iter
91
- eval_rollout_steps = cfg .collector . max_frames_per_traj // frame_skip
93
+ eval_rollout_steps = cfg .env . max_episode_steps
92
94
93
- for i , tensordict in enumerate (collector ):
95
+ sampling_start = time .time ()
96
+ for _ , tensordict in enumerate (collector ):
97
+ sampling_time = time .time () - sampling_start
98
+ # Update exploration policy
94
99
exploration_policy .step (tensordict .numel ())
95
- # update weights of the inference policy
100
+
101
+ # Update weights of the inference policy
96
102
collector .update_policy_weights_ ()
97
103
98
- if r0 is None :
99
- r0 = tensordict ["next" , "reward" ].sum (- 1 ).mean ().item ()
100
104
pbar .update (tensordict .numel ())
101
105
102
106
tensordict = tensordict .reshape (- 1 )
103
107
current_frames = tensordict .numel ()
108
+ # Add to replay buffer
104
109
replay_buffer .extend (tensordict .cpu ())
105
110
collected_frames += current_frames
106
111
107
- # optimization steps
112
+ # Optimization steps
113
+ training_start = time .time ()
108
114
if collected_frames >= init_random_frames :
109
115
(
110
116
actor_losses ,
111
117
q_losses ,
112
118
) = ([], [])
113
119
for _ in range (num_updates ):
114
- # sample from replay buffer
120
+ # Sample from replay buffer
115
121
sampled_tensordict = replay_buffer .sample ().clone ()
116
122
123
+ # Compute loss
117
124
loss_td = loss_module (sampled_tensordict )
118
125
119
- optimizer_critic .zero_grad ()
120
- optimizer_actor .zero_grad ()
121
-
122
126
actor_loss = loss_td ["loss_actor" ]
123
127
q_loss = loss_td ["loss_value" ]
124
- (actor_loss + q_loss ).backward ()
125
128
129
+ # Update critic
130
+ optimizer_critic .zero_grad ()
131
+ q_loss .backward ()
126
132
optimizer_critic .step ()
127
- q_losses .append (q_loss .item ())
128
133
134
+ # Update actor
135
+ optimizer_actor .zero_grad ()
136
+ actor_loss .backward ()
129
137
optimizer_actor .step ()
138
+
139
+ q_losses .append (q_loss .item ())
130
140
actor_losses .append (actor_loss .item ())
131
141
132
- # update qnet_target params
142
+ # Update qnet_target params
133
143
target_net_updater .step ()
134
144
135
- # update priority
145
+ # Update priority
136
146
if prb :
137
147
replay_buffer .update_priority (sampled_tensordict )
138
148
139
- rewards .append (
140
- (i , tensordict ["next" , "reward" ].sum ().item () / env_per_collector )
149
+ training_time = time .time () - training_start
150
+ episode_end = (
151
+ tensordict ["next" , "done" ]
152
+ if tensordict ["next" , "done" ].any ()
153
+ else tensordict ["next" , "truncated" ]
141
154
)
142
- train_log = {
143
- "train_reward" : rewards [- 1 ][1 ],
144
- "collected_frames" : collected_frames ,
145
- }
146
- if q_loss is not None :
147
- train_log .update (
148
- {
149
- "actor_loss" : np .mean (actor_losses ),
150
- "q_loss" : np .mean (q_losses ),
151
- }
155
+ episode_rewards = tensordict ["next" , "episode_reward" ][episode_end ]
156
+
157
+ # Logging
158
+ metrics_to_log = {}
159
+ if len (episode_rewards ) > 0 :
160
+ episode_length = tensordict ["next" , "step_count" ][episode_end ]
161
+ metrics_to_log ["train/reward" ] = episode_rewards .mean ().item ()
162
+ metrics_to_log ["train/episode_length" ] = episode_length .sum ().item () / len (
163
+ episode_length
152
164
)
153
- if logger is not None :
154
- for key , value in train_log .items ():
155
- logger .log_scalar (key , value , step = collected_frames )
156
- if abs (collected_frames % eval_iter ) < frames_per_batch * frame_skip :
165
+
166
+ if collected_frames >= init_random_frames :
167
+ metrics_to_log ["train/q_loss" ] = np .mean (q_losses )
168
+ metrics_to_log ["train/a_loss" ] = np .mean (actor_losses )
169
+ metrics_to_log ["train/sampling_time" ] = sampling_time
170
+ metrics_to_log ["train/training_time" ] = training_time
171
+
172
+ # Evaluation
173
+ if abs (collected_frames % eval_iter ) < frames_per_batch :
157
174
with set_exploration_type (ExplorationType .MODE ), torch .no_grad ():
175
+ eval_start = time .time ()
158
176
eval_rollout = eval_env .rollout (
159
177
eval_rollout_steps ,
160
178
exploration_policy ,
161
179
auto_cast_to_device = True ,
162
180
break_when_any_done = True ,
163
181
)
182
+ eval_time = time .time () - eval_start
164
183
eval_reward = eval_rollout ["next" , "reward" ].sum (- 2 ).mean ().item ()
165
- rewards_eval .append ((i , eval_reward ))
166
- eval_str = f"eval cumulative reward: { rewards_eval [- 1 ][1 ]: 4.4f} (init: { rewards_eval [0 ][1 ]: 4.4f} )"
167
- if logger is not None :
168
- logger .log_scalar (
169
- "evaluation_reward" , rewards_eval [- 1 ][1 ], step = collected_frames
170
- )
171
- if len (rewards_eval ):
172
- pbar .set_description (
173
- f"reward: { rewards [- 1 ][1 ]: 4.4f} (r0 = { r0 : 4.4f} )," + eval_str
174
- )
184
+ metrics_to_log ["eval/reward" ] = eval_reward
185
+ metrics_to_log ["eval/time" ] = eval_time
186
+ if logger is not None :
187
+ log_metrics (logger , metrics_to_log , collected_frames )
188
+ sampling_start = time .time ()
175
189
176
190
collector .shutdown ()
191
+ end_time = time .time ()
192
+ execution_time = end_time - start_time
193
+ print (f"Training took { execution_time :.2f} seconds to finish" )
177
194
178
195
179
196
if __name__ == "__main__" :
0 commit comments