Skip to content

Commit 5d13488

Browse files
[Feature] Warning for init_random_frames rounding in collectors (#1616)
Signed-off-by: Matteo Bettini <matbet@meta.com>
1 parent bf264e0 commit 5d13488

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

torchrl/collectors/collectors.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ class SyncDataCollector(DataCollectorBase):
398398
policy is ignored before it is called. This feature is mainly
399399
intended to be used in offline/model-based settings, where a
400400
batch of random trajectories can be used to initialize training.
401+
If provided, it will be rounded up to the closest multiple of frames_per_batch.
401402
Defaults to ``None`` (i.e. no random frames).
402403
reset_at_each_iter (bool, optional): Whether environments should be reset
403404
at the beginning of a batch collection.
@@ -599,13 +600,26 @@ def __init__(
599600
self.total_frames = total_frames
600601
self.reset_at_each_iter = reset_at_each_iter
601602
self.init_random_frames = init_random_frames
603+
if (
604+
init_random_frames is not None
605+
and init_random_frames % frames_per_batch != 0
606+
and RL_WARNINGS
607+
):
608+
warnings.warn(
609+
f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), "
610+
f" this results in more init_random_frames than requested"
611+
f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})."
612+
"To silence this message, set the environment variable RL_WARNINGS to False."
613+
)
614+
602615
self.postproc = postproc
603616
if self.postproc is not None and hasattr(self.postproc, "to"):
604617
self.postproc.to(self.storing_device)
605618
if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
606619
warnings.warn(
607-
f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, "
608-
f" this results in more frames_per_batch per iteration that requested."
620+
f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
621+
f" this results in more frames_per_batch per iteration that requested"
622+
f" ({-(-frames_per_batch // self.n_env) * self.n_env})."
609623
"To silence this message, set the environment variable RL_WARNINGS to False."
610624
)
611625
self.requested_frames_per_batch = frames_per_batch
@@ -1026,6 +1040,7 @@ class _MultiDataCollector(DataCollectorBase):
10261040
policy is ignored before it is called. This feature is mainly
10271041
intended to be used in offline/model-based settings, where a
10281042
batch of random trajectories can be used to initialize training.
1043+
If provided, it will be rounded up to the closest multiple of frames_per_batch.
10291044
Defaults to ``None`` (i.e. no random frames).
10301045
reset_at_each_iter (bool, optional): Whether environments should be reset
10311046
at the beginning of a batch collection.

0 commit comments

Comments
 (0)