Skip to content

Commit 7691318

Browse files
authored
[Feature] Make dialog_turns_per_batch optional when yield_completed_trajectories=True (#3039)
1 parent 82c9707 commit 7691318

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

torchrl/collectors/llm/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ class LLMCollector(SyncDataCollector):
3535
3636
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
3737
38-
steps_per_batch (int): A keyword-only argument representing the total
39-
number of elements in a batch; -1 is never ending (until shutdown).
40-
total_steps (int): A keyword-only argument representing the total
41-
number of steps returned by the collector
42-
during its lifespan.
38+
dialog_turns_per_batch (int, optional): A keyword-only argument representing the total
39+
number of elements in a batch. It is always required except when `yield_completed_trajectories=True`.
40+
total_dialog_turns (int): A keyword-only argument representing the total
41+
number of steps returned by the collector during its lifespan. -1 is never ending (until shutdown).
42+
Defaults to -1.
4343
yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps
4444
(`yield_completed_trajectories=False`, default) or single, completed trajectories
4545
(`yield_completed_trajectories=True`).
@@ -149,7 +149,7 @@ def __init__(
149149
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
150150
policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]]
151151
| None = None,
152-
dialog_turns_per_batch: int,
152+
dialog_turns_per_batch: int | None = None,
153153
yield_only_last_steps: bool | None = None,
154154
yield_completed_trajectories: bool | None = None,
155155
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
@@ -172,6 +172,8 @@ def __init__(
172172
elif queue is not None:
173173
# disguise the queue as a replay buffer
174174
replay_buffer = _QueueAsRB(queue)
175+
if dialog_turns_per_batch is None and yield_completed_trajectories:
176+
dialog_turns_per_batch = 0
175177
super().__init__(
176178
create_env_fn=env,
177179
policy=policy,

0 commit comments

Comments
 (0)