Skip to content

Commit 19dfefc

Browse files
author
Vincent Moens
committed
[BugFix] Fix init_random_frames=0
ghstack-source-id: 38a544e Pull Request resolved: #2645
1 parent b840a77 commit 19dfefc

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

test/test_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1345,7 +1345,7 @@ def make_env():
13451345
functools.partial(MultiSyncDataCollector, cat_results="stack"),
13461346
],
13471347
)
1348-
@pytest.mark.parametrize("init_random_frames", [50]) # 1226: faster execution
1348+
@pytest.mark.parametrize("init_random_frames", [0, 50]) # 1226: faster execution
13491349
@pytest.mark.parametrize(
13501350
"explicit_spec,split_trajs", [[True, True], [False, False]]
13511351
) # 1226: faster execution

torchrl/collectors/collectors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,10 +712,10 @@ def __init__(
712712
)
713713
self.reset_at_each_iter = reset_at_each_iter
714714
self.init_random_frames = (
715-
int(init_random_frames) if init_random_frames is not None else 0
715+
int(init_random_frames) if init_random_frames not in (None, -1) else 0
716716
)
717717
if (
718-
init_random_frames is not None
718+
init_random_frames not in (-1, None, 0)
719719
and init_random_frames % frames_per_batch != 0
720720
and RL_WARNINGS
721721
):

0 commit comments

Comments
 (0)