Skip to content

Commit 187de7c

Browse files
author
Vincent Moens
committed
[Feature] timeit.printevery
ghstack-source-id: 19165bb Pull Request resolved: #2653
1 parent f5a187d commit 187de7c

File tree

21 files changed

+104
-87
lines changed

21 files changed

+104
-87
lines changed

sota-implementations/a2c/a2c_atari.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
182182
lr = cfg.optim.lr
183183

184184
c_iter = iter(collector)
185-
for i in range(len(collector)):
185+
total_iter = len(collector)
186+
for i in range(total_iter):
187+
timeit.printevery(1000, total_iter, erase=True)
188+
186189
with timeit("collecting"):
187190
data = next(c_iter)
188191

@@ -261,10 +264,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
261264
"test/reward": test_rewards.mean(),
262265
}
263266
)
264-
if i % 200 == 0:
265-
log_info.update(timeit.todict(prefix="time"))
266-
timeit.print()
267-
timeit.erase()
267+
log_info.update(timeit.todict(prefix="time"))
268268

269269
if logger:
270270
for key, value in log_info.items():

sota-implementations/a2c/a2c_mujoco.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ def update(batch):
179179
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
180180

181181
c_iter = iter(collector)
182-
for i in range(len(collector)):
182+
total_iter = len(collector)
183+
for i in range(total_iter):
184+
timeit.printevery(1000, total_iter, erase=True)
185+
183186
with timeit("collecting"):
184187
data = next(c_iter)
185188

@@ -257,10 +260,7 @@ def update(batch):
257260
)
258261
actor.train()
259262

260-
if i % 200 == 0:
261-
log_info.update(timeit.todict(prefix="time"))
262-
timeit.print()
263-
timeit.erase()
263+
log_info.update(timeit.todict(prefix="time"))
264264

265265
if logger:
266266
for key, value in log_info.items():

sota-implementations/cql/cql_offline.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
"""
1212
from __future__ import annotations
1313

14-
import time
1514
import warnings
1615

1716
import hydra
@@ -21,7 +20,7 @@
2120
import tqdm
2221
from tensordict.nn import CudaGraphModule
2322

24-
from torchrl._utils import logger as torchrl_logger, timeit
23+
from torchrl._utils import timeit
2524
from torchrl.envs.utils import ExplorationType, set_exploration_type
2625
from torchrl.objectives import group_optimizers
2726
from torchrl.record.loggers import generate_exp_name, get_logger
@@ -156,9 +155,9 @@ def update(data, policy_eval_start, iteration):
156155
eval_steps = cfg.logger.eval_steps
157156

158157
# Training loop
159-
start_time = time.time()
160158
policy_eval_start = torch.tensor(policy_eval_start, device=device)
161159
for i in range(gradient_steps):
160+
timeit.printevery(1000, gradient_steps, erase=True)
162161
pbar.update(1)
163162
# sample data
164163
with timeit("sample"):
@@ -192,15 +191,10 @@ def update(data, policy_eval_start, iteration):
192191
to_log["evaluation_reward"] = eval_reward
193192

194193
with timeit("log"):
195-
if i % 200 == 0:
196-
to_log.update(timeit.todict(prefix="time"))
194+
to_log.update(timeit.todict(prefix="time"))
197195
log_metrics(logger, to_log, i)
198-
if i % 200 == 0:
199-
timeit.print()
200-
timeit.erase()
201196

202197
pbar.close()
203-
torchrl_logger.info(f"Training time: {time.time() - start_time}")
204198
if not eval_env.is_closed:
205199
eval_env.close()
206200

sota-implementations/cql/cql_online.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def update(sampled_tensordict):
170170
eval_rollout_steps = cfg.logger.eval_steps
171171

172172
c_iter = iter(collector)
173-
for i in range(len(collector)):
173+
total_iter = len(collector)
174+
for i in range(total_iter):
175+
timeit.printevery(1000, total_iter, erase=True)
174176
with timeit("collecting"):
175177
tensordict = next(c_iter)
176178
pbar.update(tensordict.numel())
@@ -222,8 +224,7 @@ def update(sampled_tensordict):
222224
"loss_alpha_prime"
223225
).mean()
224226
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
225-
if i % 10 == 0:
226-
metrics_to_log.update(timeit.todict(prefix="time"))
227+
metrics_to_log.update(timeit.todict(prefix="time"))
227228

228229
# Evaluation
229230
with timeit("eval"):
@@ -245,9 +246,6 @@ def update(sampled_tensordict):
245246
metrics_to_log["eval/reward"] = eval_reward
246247

247248
log_metrics(logger, metrics_to_log, collected_frames)
248-
if i % 10 == 0:
249-
timeit.print()
250-
timeit.erase()
251249

252250
collector.shutdown()
253251
if not eval_env.is_closed:

sota-implementations/cql/discrete_cql_online.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def update(sampled_tensordict):
151151
frames_per_batch = cfg.collector.frames_per_batch
152152

153153
c_iter = iter(collector)
154-
for i in range(len(collector)):
154+
total_iter = len(collector)
155+
for _ in range(total_iter):
156+
timeit.printevery(1000, total_iter, erase=True)
155157
with timeit("collecting"):
156158
torch.compiler.cudagraph_mark_step_begin()
157159
tensordict = next(c_iter)
@@ -224,12 +226,7 @@ def update(sampled_tensordict):
224226
tds = torch.stack(tds, dim=0).mean()
225227
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
226228
metrics_to_log["train/cql_loss"] = tds["loss_cql"]
227-
if i % 100 == 0:
228-
metrics_to_log.update(timeit.todict(prefix="time"))
229-
230-
if i % 100 == 0:
231-
timeit.print()
232-
timeit.erase()
229+
metrics_to_log.update(timeit.todict(prefix="time"))
233230

234231
if logger is not None:
235232
log_metrics(logger, metrics_to_log, collected_frames)

sota-implementations/crossq/crossq.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
192192
update_counter = 0
193193
delayed_updates = cfg.optim.policy_update_delay
194194
c_iter = iter(collector)
195-
for i in range(len(collector)):
195+
total_iter = len(collector)
196+
for _ in range(total_iter):
197+
timeit.printevery(1000, total_iter, erase=True)
196198
with timeit("collecting"):
197199
torch.compiler.cudagraph_mark_step_begin()
198200
tensordict = next(c_iter)
@@ -258,18 +260,14 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
258260
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
259261
episode_length
260262
)
261-
if i % 20 == 0:
262-
metrics_to_log.update(timeit.todict(prefix="time"))
263+
metrics_to_log.update(timeit.todict(prefix="time"))
263264
if collected_frames >= init_random_frames:
264265
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
265266
metrics_to_log["train/actor_loss"] = tds["loss_actor"]
266267
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]
267268

268269
if logger is not None:
269270
log_metrics(logger, metrics_to_log, collected_frames)
270-
if i % 20 == 0:
271-
timeit.print()
272-
timeit.erase()
273271

274272
collector.shutdown()
275273
if not eval_env.is_closed:

sota-implementations/ddpg/ddpg.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ def update(sampled_tensordict):
156156
eval_rollout_steps = cfg.env.max_episode_steps
157157

158158
c_iter = iter(collector)
159-
for i in range(len(collector)):
159+
total_iter = len(collector)
160+
for _ in range(total_iter):
161+
timeit.printevery(1000, total_iter, erase=True)
160162
with timeit("collecting"):
161163
tensordict = next(c_iter)
162164
# Update exploration policy
@@ -226,10 +228,7 @@ def update(sampled_tensordict):
226228
eval_env.apply(dump_video)
227229
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
228230
metrics_to_log["eval/reward"] = eval_reward
229-
if i % 20 == 0:
230-
metrics_to_log.update(timeit.todict(prefix="time"))
231-
timeit.print()
232-
timeit.erase()
231+
metrics_to_log.update(timeit.todict(prefix="time"))
233232

234233
if logger is not None:
235234
log_metrics(logger, metrics_to_log, collected_frames)

sota-implementations/decision_transformer/dt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def update(data: TensorDict) -> TensorDict:
128128
# Pretraining
129129
pbar = tqdm.tqdm(range(pretrain_gradient_steps))
130130
for i in pbar:
131+
timeit.printevery(1000, pretrain_gradient_steps, erase=True)
131132
# Sample data
132133
with timeit("rb - sample"):
133134
data = offline_buffer.sample().to(model_device)
@@ -151,10 +152,7 @@ def update(data: TensorDict) -> TensorDict:
151152
to_log["eval/reward"] = (
152153
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
153154
)
154-
if i % 200 == 0:
155-
to_log.update(timeit.todict(prefix="time"))
156-
timeit.print()
157-
timeit.erase()
155+
to_log.update(timeit.todict(prefix="time"))
158156

159157
if logger is not None:
160158
log_metrics(logger, to_log, i)

sota-implementations/decision_transformer/online_dt.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99
from __future__ import annotations
1010

11-
import time
1211
import warnings
1312

1413
import hydra
@@ -130,8 +129,8 @@ def update(data):
130129

131130
torchrl_logger.info(" ***Pretraining*** ")
132131
# Pretraining
133-
start_time = time.time()
134132
for i in range(pretrain_gradient_steps):
133+
timeit.printevery(1000, pretrain_gradient_steps, erase=True)
135134
pbar.update(1)
136135
with timeit("sample"):
137136
# Sample data
@@ -170,18 +169,14 @@ def update(data):
170169
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
171170
)
172171

173-
if i % 200 == 0:
174-
to_log.update(timeit.todict(prefix="time"))
175-
timeit.print()
176-
timeit.erase()
172+
to_log.update(timeit.todict(prefix="time"))
177173

178174
if logger is not None:
179175
log_metrics(logger, to_log, i)
180176

181177
pbar.close()
182178
if not test_env.is_closed:
183179
test_env.close()
184-
torchrl_logger.info(f"Training time: {time.time() - start_time}")
185180

186181

187182
if __name__ == "__main__":

sota-implementations/discrete_sac/discrete_sac.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ def update(sampled_tensordict):
155155
frames_per_batch = cfg.collector.frames_per_batch
156156

157157
c_iter = iter(collector)
158-
for i in range(len(collector)):
158+
total_iter = len(collector)
159+
for i in range(total_iter):
160+
timeit.printevery(1000, total_iter, erase=True)
159161
with timeit("collecting"):
160162
collected_data = next(c_iter)
161163

@@ -229,10 +231,7 @@ def update(sampled_tensordict):
229231
eval_env.apply(dump_video)
230232
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
231233
metrics_to_log["eval/reward"] = eval_reward
232-
if i % 50 == 0:
233-
metrics_to_log.update(timeit.todict(prefix="time"))
234-
timeit.print()
235-
timeit.erase()
234+
metrics_to_log.update(timeit.todict(prefix="time"))
236235
if logger is not None:
237236
log_metrics(logger, metrics_to_log, collected_frames)
238237

0 commit comments

Comments
 (0)