Skip to content

Commit b875979

Browse files
author
Vincent Moens
committed
[BugFix] Fix collector length with non-empty batch size
ghstack-source-id: 0c6a7a4 Pull Request resolved: #2575
1 parent 5a2d9e2 commit b875979

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

test/test_collector.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,6 +3172,28 @@ def make_and_test_policy(
31723172
)
31733173

31743174

3175+
@pytest.mark.parametrize(
3176+
"ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
3177+
)
3178+
def test_no_stopiteration(ctype):
3179+
# Tests that there is no StopIteration raised and that the length of the collector is properly set
3180+
if ctype is SyncDataCollector:
3181+
envs = SerialEnv(16, CountingEnv)
3182+
else:
3183+
envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)]
3184+
3185+
collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300)
3186+
try:
3187+
c_iter = iter(collector)
3188+
for i in range(len(collector)): # noqa: B007
3189+
c = next(c_iter)
3190+
assert c is not None
3191+
assert i == 1
3192+
finally:
3193+
collector.shutdown()
3194+
del collector
3195+
3196+
31753197
if __name__ == "__main__":
31763198
args, unknown = argparse.ArgumentParser().parse_known_args()
31773199
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/collectors/collectors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
138138
_iterator = None
139139
total_frames: int
140140
frames_per_batch: int
141+
requested_frames_per_batch: int
141142
trust_policy: bool
142143
compiled_policy: bool
143144
cudagraphed_policy: bool
@@ -296,7 +297,7 @@ def __class_getitem__(self, index):
296297

297298
def __len__(self) -> int:
298299
if self.total_frames > 0:
299-
return -(self.total_frames // -self.frames_per_batch)
300+
return -(self.total_frames // -self.requested_frames_per_batch)
300301
raise RuntimeError("Non-terminating collectors do not have a length")
301302

302303

@@ -691,7 +692,7 @@ def __init__(
691692
remainder = total_frames % frames_per_batch
692693
if remainder != 0 and RL_WARNINGS:
693694
warnings.warn(
694-
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})."
695+
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
695696
f"This means {frames_per_batch - remainder} additional frames will be collected."
696697
"To silence this message, set the environment variable RL_WARNINGS to False."
697698
)

0 commit comments

Comments
 (0)