@@ -35,11 +35,11 @@ class LLMCollector(SyncDataCollector):
35
35
36
36
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
37
37
38
- steps_per_batch (int): A keyword-only argument representing the total
39
- number of elements in a batch; -1 is never ending (until shutdown) .
40
- total_steps (int): A keyword-only argument representing the total
41
- number of steps returned by the collector
42
- during its lifespan .
38
+ dialog_turns_per_batch (int, optional ): A keyword-only argument representing the total
39
+ number of elements in a batch. It is always required except when `yield_completed_trajectories=True` .
40
+ total_dialog_turns (int): A keyword-only argument representing the total
41
+ number of steps returned by the collector during its lifespan. -1 is never ending (until shutdown).
42
+ Defaults to -1 .
43
43
yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps
44
44
(`yield_completed_trajectories=False`, default) or single, completed trajectories
45
45
(`yield_completed_trajectories=True`).
@@ -149,7 +149,7 @@ def __init__(
149
149
policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
150
150
policy_factory : Callable [[], Callable [[TensorDictBase ], TensorDictBase ]]
151
151
| None = None ,
152
- dialog_turns_per_batch : int ,
152
+ dialog_turns_per_batch : int | None = None ,
153
153
yield_only_last_steps : bool | None = None ,
154
154
yield_completed_trajectories : bool | None = None ,
155
155
postproc : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
@@ -172,6 +172,8 @@ def __init__(
172
172
elif queue is not None :
173
173
# disguise the queue as a replay buffer
174
174
replay_buffer = _QueueAsRB (queue )
175
+ if dialog_turns_per_batch is None and yield_completed_trajectories :
176
+ dialog_turns_per_batch = 0
175
177
super ().__init__ (
176
178
create_env_fn = env ,
177
179
policy = policy ,
0 commit comments