Skip to content

Commit 29f2bcc

Browse files
authored
[BugFix] Fix dreamer training loop (#915)
1 parent 4a74149 commit 29f2bcc

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

examples/dreamer/dreamer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,12 @@ def main(cfg: "DictConfig"): # noqa: F821
226226
current_frames = tensordict.numel()
227227
collected_frames += current_frames
228228

229-
# Compared to the original paper, the replay buffer is not temporally sampled. We fill it with trajectories of length batch_length.
230-
# To be closer to the paper, we would need to fill it with trajectories of lentgh 1000 and then sample subsequences of length batch_length.
229+
# Compared to the original paper, the replay buffer is not temporally
230+
# sampled. We fill it with trajectories of length batch_length.
231+
# To be closer to the paper, we would need to fill it with trajectories
232+
# of length 1000 and then sample subsequences of length batch_length.
231233

232-
# tensordict = tensordict.reshape(-1, cfg.batch_length)
233-
print(tensordict.shape)
234+
tensordict = tensordict.reshape(-1, cfg.batch_length)
234235
replay_buffer.extend(tensordict.cpu())
235236
logger.log_scalar(
236237
"r_training",

0 commit comments

Comments
 (0)