Skip to content

Commit 82f8ec2

Browse files
author
Vincent Moens
committed
[Feature] Pass lists of policy_factory
ghstack-source-id: e42b100 Pull Request resolved: #2888
1 parent 93ba865 commit 82f8ec2

File tree

10 files changed

+290
-137
lines changed

10 files changed

+290
-137
lines changed

test/mocking_classes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,14 +1138,17 @@ def _step(
11381138
dtype=torch.int,
11391139
device=device if self.device is None else self.device,
11401140
)
1141+
if self.reward_keys:
1142+
reward_spec = self.full_reward_spec[self.reward_keys[0]]
1143+
reward_spec_dtype = reward_spec.dtype
1144+
else:
1145+
reward_spec_dtype = torch.get_default_dtype()
11411146
tensordict = TensorDict(
11421147
source={
11431148
"observation": self.count.clone(),
11441149
"done": self.count > self.max_steps,
11451150
"terminated": self.count > self.max_steps,
1146-
"reward": torch.zeros_like(
1147-
self.count, dtype=self.full_reward_spec[self.reward_keys[0]].dtype
1148-
),
1151+
"reward": torch.zeros_like(self.count, dtype=reward_spec_dtype),
11491152
},
11501153
batch_size=self.batch_size,
11511154
device=self.device,

test/test_env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
dense_stack_tds,
3131
LazyStackedTensorDict,
3232
set_capture_non_tensor_stack,
33+
set_list_to_stack,
3334
TensorDict,
3435
TensorDictBase,
3536
)
@@ -3094,6 +3095,7 @@ def test_mocking_envs(envclass):
30943095
check_env_specs(env, seed=100, return_contiguous=False)
30953096

30963097

3098+
@set_list_to_stack(True)
30973099
class TestTerminatedOrTruncated:
30983100
@pytest.mark.parametrize("done_key", ["done", "terminated", "truncated"])
30993101
def test_root_prevail(self, done_key):
@@ -3409,6 +3411,7 @@ def test_single_task_share_individual_td():
34093411
)
34103412

34113413

3414+
@set_list_to_stack(True)
34123415
def test_stackable():
34133416
# Tests the _stackable util
34143417
stack = [TensorDict({"a": 0}, []), TensorDict({"b": 1}, [])]

torchrl/collectors/collectors.py

Lines changed: 97 additions & 33 deletions
Large diffs are not rendered by default.

torchrl/collectors/distributed/generic.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import warnings
1212
from copy import copy, deepcopy
1313
from datetime import timedelta
14-
from typing import Callable, OrderedDict
14+
from typing import Any, Callable, OrderedDict, Sequence
1515

1616
import torch.cuda
1717
from tensordict import TensorDict, TensorDictBase
@@ -131,6 +131,7 @@ def _distributed_init_collection_node(
131131
num_workers,
132132
env_make,
133133
policy,
134+
policy_factory,
134135
frames_per_batch,
135136
collector_kwargs,
136137
verbose=True,
@@ -143,6 +144,7 @@ def _distributed_init_collection_node(
143144
num_workers,
144145
env_make,
145146
policy,
147+
policy_factory,
146148
frames_per_batch,
147149
collector_kwargs,
148150
verbose=verbose,
@@ -156,6 +158,7 @@ def _run_collector(
156158
num_workers,
157159
env_make,
158160
policy,
161+
policy_factory,
159162
frames_per_batch,
160163
collector_kwargs,
161164
verbose=True,
@@ -178,12 +181,17 @@ def _run_collector(
178181
policy_weights = TensorDict.from_module(policy)
179182
policy_weights = policy_weights.data.lock_()
180183
else:
181-
warnings.warn(_NON_NN_POLICY_WEIGHTS)
184+
if collector_kwargs.get("remote_weight_updater") is None and (
185+
policy_factory is None
186+
or (isinstance(policy_factory, Sequence) and not any(policy_factory))
187+
):
188+
warnings.warn(_NON_NN_POLICY_WEIGHTS)
182189
policy_weights = TensorDict(lock=True)
183190

184191
collector = collector_class(
185192
env_make,
186193
policy,
194+
policy_factory=policy_factory,
187195
frames_per_batch=frames_per_batch,
188196
total_frames=-1,
189197
split_trajs=False,
@@ -278,8 +286,8 @@ class DistributedDataCollector(DataCollectorBase):
278286
pickled directly), the :arg:`policy_factory` should be used instead.
279287
280288
Keyword Args:
281-
policy_factory (Callable[[], Callable], optional): a callable that returns
282-
a policy instance. This is exclusive with the `policy` argument.
289+
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
290+
(or list of callables) that returns a policy instance. This is exclusive with the `policy` argument.
283291
284292
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
285293
@@ -411,14 +419,16 @@ class DistributedDataCollector(DataCollectorBase):
411419
to learn more.
412420
Defaults to ``"submitit"``.
413421
tcp_port (int, optional): the TCP port to be used. Defaults to 10003.
414-
local_weight_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
422+
local_weight_updater (LocalWeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
415423
or its subclass, responsible for updating the policy weights on the local inference worker.
416424
This is typically not used in :class:`~torchrl.collectors.distributed.DistributedDataCollector` as it
417425
focuses on distributed environments.
418-
remote_weight_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
426+
Consider using a constructor if the updater needs to be serialized.
427+
remote_weight_updater (RemoteWeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
419428
or its subclass, responsible for updating the policy weights on distributed inference workers.
420429
If not provided, a :class:`~torchrl.collectors.distributed.DistributedRemoteWeightUpdater` will be used by
421430
default, which handles weight synchronization across distributed workers.
431+
Consider using a constructor if the updater needs to be serialized.
422432
423433
"""
424434

@@ -429,31 +439,37 @@ def __init__(
429439
create_env_fn,
430440
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
431441
*,
432-
policy_factory: Callable[[], Callable] | None = None,
442+
policy_factory: Callable[[], Callable]
443+
| list[Callable[[] | Callable]]
444+
| None = None,
433445
frames_per_batch: int,
434446
total_frames: int = -1,
435-
device: torch.device | list[torch.device] = None,
436-
storing_device: torch.device | list[torch.device] = None,
437-
env_device: torch.device | list[torch.device] = None,
438-
policy_device: torch.device | list[torch.device] = None,
447+
device: torch.device | list[torch.device] | None = None,
448+
storing_device: torch.device | list[torch.device] | None = None,
449+
env_device: torch.device | list[torch.device] | None = None,
450+
policy_device: torch.device | list[torch.device] | None = None,
439451
max_frames_per_traj: int = -1,
440452
init_random_frames: int = -1,
441453
reset_at_each_iter: bool = False,
442454
postproc: Callable | None = None,
443455
split_trajs: bool = False,
444456
exploration_type: ExporationType = DEFAULT_EXPLORATION_TYPE, # noqa
445457
collector_class: type = SyncDataCollector,
446-
collector_kwargs: dict = None,
458+
collector_kwargs: dict[str, Any] | None = None,
447459
num_workers_per_collector: int = 1,
448460
sync: bool = False,
449-
slurm_kwargs: dict | None = None,
461+
slurm_kwargs: dict[str, Any] | None = None,
450462
backend: str = "gloo",
451463
update_after_each_batch: bool = False,
452464
max_weight_update_interval: int = -1,
453465
launcher: str = "submitit",
454-
tcp_port: int = None,
455-
remote_weight_updater: RemoteWeightUpdaterBase | None = None,
456-
local_weight_updater: LocalWeightUpdaterBase | None = None,
466+
tcp_port: int | None = None,
467+
remote_weight_updater: RemoteWeightUpdaterBase
468+
| Callable[[], RemoteWeightUpdaterBase]
469+
| None = None,
470+
local_weight_updater: LocalWeightUpdaterBase
471+
| Callable[[], LocalWeightUpdaterBase]
472+
| None = None,
457473
):
458474

459475
if collector_class == "async":
@@ -465,18 +481,22 @@ def __init__(
465481
self.collector_class = collector_class
466482
self.env_constructors = create_env_fn
467483
self.policy = policy
484+
if not isinstance(policy_factory, Sequence):
485+
policy_factory = [policy_factory for _ in range(len(self.env_constructors))]
486+
self.policy_factory = policy_factory
468487
if isinstance(policy, nn.Module):
469488
policy_weights = TensorDict.from_module(policy)
470489
policy_weights = policy_weights.data.lock_()
471-
elif policy_factory is not None:
490+
elif any(policy_factory):
472491
policy_weights = None
473492
if remote_weight_updater is None:
474493
raise RuntimeError(
475494
"remote_weight_updater must be passed along with "
476495
"a policy_factory."
477496
)
478497
else:
479-
warnings.warn(_NON_NN_POLICY_WEIGHTS)
498+
if not any(policy_factory):
499+
warnings.warn(_NON_NN_POLICY_WEIGHTS)
480500
policy_weights = TensorDict(lock=True)
481501
self.policy_weights = policy_weights
482502
self.num_workers = len(create_env_fn)
@@ -664,12 +684,15 @@ def _make_container(self):
664684
if self._VERBOSE:
665685
torchrl_logger.info("making container")
666686
env_constructor = self.env_constructors[0]
687+
kwargs = self.collector_kwargs[0]
667688
pseudo_collector = SyncDataCollector(
668689
env_constructor,
669-
self.policy,
690+
policy=self.policy,
691+
policy_factory=self.policy_factory[0],
670692
frames_per_batch=self._frames_per_batch_corrected,
671693
total_frames=-1,
672694
split_trajs=False,
695+
**kwargs,
673696
)
674697
for _data in pseudo_collector:
675698
break
@@ -713,6 +736,7 @@ def _init_worker_dist_submitit(self, executor, i):
713736
self.num_workers_per_collector,
714737
env_make,
715738
self.policy,
739+
self.policy_factory[i],
716740
self._frames_per_batch_corrected,
717741
self.collector_kwargs[i],
718742
self._VERBOSE,
@@ -734,6 +758,7 @@ def get_env_make(i):
734758
"num_workers": self.num_workers_per_collector,
735759
"env_make": get_env_make(i),
736760
"policy": self.policy,
761+
"policy_factory": self.policy_factory[i],
737762
"frames_per_batch": self._frames_per_batch_corrected,
738763
"collector_kwargs": self.collector_kwargs[i],
739764
}
@@ -760,6 +785,7 @@ def _init_worker_dist_mp(self, i):
760785
self.num_workers_per_collector,
761786
env_make,
762787
self.policy,
788+
self.policy_factory[i],
763789
self._frames_per_batch_corrected,
764790
self.collector_kwargs[i],
765791
self._VERBOSE,

torchrl/collectors/distributed/ray.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import asyncio
99
import warnings
10-
from typing import Callable, Iterator, OrderedDict
10+
from typing import Any, Callable, Iterator, OrderedDict, Sequence
1111

1212
import torch
1313
import torch.nn as nn
@@ -153,8 +153,8 @@ class RayCollector(DataCollectorBase):
153153
pickled directly), the :arg:`policy_factory` should be used instead.
154154
155155
Keyword Args:
156-
policy_factory (Callable[[], Callable], optional): a callable that returns
157-
a policy instance. This is exclusive with the `policy` argument.
156+
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
157+
(or list of callables) that returns a policy instance. This is exclusive with the `policy` argument.
158158
159159
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
160160
@@ -230,7 +230,7 @@ class RayCollector(DataCollectorBase):
230230
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
231231
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
232232
or ``torchrl.envs.utils.ExplorationType.MEAN``.
233-
collector_class (Python class): a collector class to be remotely instantiated. Can be
233+
collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be
234234
:class:`~torchrl.collectors.SyncDataCollector`,
235235
:class:`~torchrl.collectors.MultiSyncDataCollector`,
236236
:class:`~torchrl.collectors.MultiaSyncDataCollector`
@@ -277,13 +277,16 @@ class RayCollector(DataCollectorBase):
277277
278278
.. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a
279279
:class:`~torchrl.data.RayReplayBuffer` instance should be used here.
280-
local_weight_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
280+
local_weight_updater (LocalWeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
281281
or its subclass, responsible for updating the policy weights on the local inference worker.
282-
This is typically not used in :class:`~torchrl.collectors.RayCollector` as it focuses on distributed environments.
283-
remote_weight_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
282+
This is typically not used in :class:`~torchrl.collectors.RayCollector` as it focuses on distributed
283+
environments.
284+
Consider using a constructor if the updater needs to be serialized.
285+
remote_weight_updater (RemoteWeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
284286
or its subclass, responsible for updating the policy weights on remote inference workers managed by Ray.
285287
If not provided, a :class:`~torchrl.collectors.RayRemoteWeightUpdater` will be used by default, leveraging
286288
Ray's distributed capabilities.
289+
Consider using a constructor if the updater needs to be serialized.
287290
288291
Examples:
289292
>>> from torch import nn
@@ -319,31 +322,37 @@ def __init__(
319322
create_env_fn: Callable | EnvBase | list[Callable] | list[EnvBase],
320323
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
321324
*,
322-
policy_factory: Callable[[], Callable] | None = None,
325+
policy_factory: Callable[[], Callable]
326+
| list[Callable[[], Callable]]
327+
| None = None,
323328
frames_per_batch: int,
324329
total_frames: int = -1,
325-
device: torch.device | list[torch.device] = None,
326-
storing_device: torch.device | list[torch.device] = None,
327-
env_device: torch.device | list[torch.device] = None,
328-
policy_device: torch.device | list[torch.device] = None,
330+
device: torch.device | list[torch.device] | None = None,
331+
storing_device: torch.device | list[torch.device] | None = None,
332+
env_device: torch.device | list[torch.device] | None = None,
333+
policy_device: torch.device | list[torch.device] | None = None,
329334
max_frames_per_traj=-1,
330335
init_random_frames=-1,
331336
reset_at_each_iter=False,
332337
postproc=None,
333338
split_trajs=False,
334339
exploration_type=DEFAULT_EXPLORATION_TYPE,
335340
collector_class: Callable[[TensorDict], TensorDict] = SyncDataCollector,
336-
collector_kwargs: dict | list[dict] = None,
341+
collector_kwargs: dict[str, Any] | list[dict] | None = None,
337342
num_workers_per_collector: int = 1,
338343
sync: bool = False,
339-
ray_init_config: dict = None,
340-
remote_configs: dict | list[dict] = None,
341-
num_collectors: int = None,
342-
update_after_each_batch=False,
343-
max_weight_update_interval=-1,
344-
replay_buffer: ReplayBuffer = None,
345-
remote_weight_updater: RemoteWeightUpdaterBase | None = None,
346-
local_weight_updater: LocalWeightUpdaterBase | None = None,
344+
ray_init_config: dict[str, Any] | None = None,
345+
remote_configs: dict[str, Any] | list[dict[str, Any]] | None = None,
346+
num_collectors: int | None = None,
347+
update_after_each_batch: bool = False,
348+
max_weight_update_interval: int = -1,
349+
replay_buffer: ReplayBuffer | None = None,
350+
remote_weight_updater: RemoteWeightUpdaterBase
351+
| Callable[[], RemoteWeightUpdaterBase]
352+
| None = None,
353+
local_weight_updater: LocalWeightUpdaterBase
354+
| Callable[[], LocalWeightUpdaterBase]
355+
| None = None,
347356
):
348357
self.frames_per_batch = frames_per_batch
349358
if remote_configs is None:
@@ -451,6 +460,9 @@ def check_list_length_consistency(*lists):
451460
collector_class.print_remote_collector_info = print_remote_collector_info
452461

453462
self.replay_buffer = replay_buffer
463+
if not isinstance(policy_factory, Sequence):
464+
policy_factory = [policy_factory] * len(create_env_fn)
465+
self.policy_factory = policy_factory
454466
self._local_policy = policy
455467
if isinstance(self._local_policy, nn.Module):
456468
policy_weights = TensorDict.from_module(self._local_policy)
@@ -491,7 +503,7 @@ def check_list_length_consistency(*lists):
491503

492504
# update collector kwargs
493505
for i, collector_kwarg in enumerate(self.collector_kwargs):
494-
collector_kwarg["policy_factory"] = policy_factory
506+
collector_kwarg["policy_factory"] = policy_factory[i]
495507
collector_kwarg["max_frames_per_traj"] = max_frames_per_traj
496508
collector_kwarg["init_random_frames"] = (
497509
init_random_frames // self.num_collectors
@@ -678,6 +690,7 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]:
678690
"""Collects one data batch per remote collector in each iteration."""
679691
while self.collected_frames < self.total_frames:
680692
if self.update_after_each_batch or self.max_weight_update_interval > -1:
693+
torchrl_logger.info("Updating weights on all workers")
681694
self.update_policy_weights_()
682695

683696
# Ask for batches to all remote workers.
@@ -759,6 +772,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
759772
yield out_td
760773

761774
if self.update_after_each_batch or self.max_weight_update_interval > -1:
775+
torchrl_logger.info(f"Updating weights on worker {collector_index}")
762776
self.update_policy_weights_(worker_ids=collector_index + 1)
763777

764778
# Schedule a new collection task

0 commit comments

Comments
 (0)