Skip to content

Commit 49a8a42

Browse files
author
Vincent Moens
committed
[Feature] policy factory for collectors
ghstack-source-id: 96b928e Pull Request resolved: #2841
1 parent 50af984 commit 49a8a42

File tree

6 files changed

+194
-60
lines changed

6 files changed

+194
-60
lines changed

test/test_distributed.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -556,19 +556,21 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync):
556556
total = 0
557557
first_batch = None
558558
last_batch = None
559-
for i, data in enumerate(collector):
560-
total += data.numel()
561-
assert data.numel() == frames_per_batch
562-
if i == 0:
563-
first_batch = data
564-
policy.weight.data += 1
565-
collector.update_policy_weights_()
566-
elif total == total_frames - frames_per_batch:
567-
last_batch = data
568-
assert (first_batch["action"] == 1).all(), first_batch["action"]
569-
assert (last_batch["action"] == 2).all(), last_batch["action"]
570-
collector.shutdown()
571-
assert total == total_frames
559+
try:
560+
for i, data in enumerate(collector):
561+
total += data.numel()
562+
assert data.numel() == frames_per_batch
563+
if i == 0:
564+
first_batch = data
565+
policy.weight.data += 1
566+
collector.update_policy_weights_()
567+
elif total == total_frames - frames_per_batch:
568+
last_batch = data
569+
assert (first_batch["action"] == 1).all(), first_batch["action"]
570+
assert (last_batch["action"] == 2).all(), last_batch["action"]
571+
assert total == total_frames
572+
finally:
573+
collector.shutdown()
572574

573575
@pytest.mark.parametrize("storage", [None, partial(LazyTensorStorage, 1000)])
574576
@pytest.mark.parametrize(
@@ -593,6 +595,34 @@ def test_ray_replaybuffer(self, storage, sampler, writer):
593595
if sampler is SamplerWithoutReplacement:
594596
assert sample["a"].unique().numel() == sample.numel()
595597

598+
# class CustomCollectorCls(SyncDataCollector):
599+
# def __init__(self, create_env_fn, **kwargs):
600+
# policy = lambda td: td.set("action", torch.full(td.shape, 2))
601+
# super().__init__(create_env_fn, policy, **kwargs)
602+
603+
def test_ray_collector_policy_constructor(self):
604+
n_collectors = 2
605+
frames_per_batch = 50
606+
total_frames = 300
607+
env = CountingEnv
608+
609+
def policy_constructor():
610+
return lambda td: td.set("action", torch.full(td.shape, 2))
611+
612+
collector = self.distributed_class()(
613+
[env] * n_collectors,
614+
collector_class=SyncDataCollector,
615+
policy_factory=policy_constructor,
616+
total_frames=total_frames,
617+
frames_per_batch=frames_per_batch,
618+
**self.distributed_kwargs(),
619+
)
620+
try:
621+
for data in collector:
622+
assert (data["action"] == 2).all()
623+
finally:
624+
collector.shutdown()
625+
596626

597627
if __name__ == "__main__":
598628
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/collectors/collectors.py

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from multiprocessing.managers import SyncManager
2222
from queue import Empty
2323
from textwrap import indent
24-
from typing import Any, Callable, Iterator, Sequence
24+
from typing import Any, Callable, Iterator, Sequence, TypeVar
2525

2626
import numpy as np
2727
import torch
@@ -86,6 +86,8 @@ def cudagraph_mark_step_begin():
8686

8787
_is_osx = sys.platform.startswith("darwin")
8888

89+
T = TypeVar("T")
90+
8991

9092
class _Interruptor:
9193
"""A class for managing the collection state of a process.
@@ -343,7 +345,15 @@ class SyncDataCollector(DataCollectorBase):
343345
344346
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
345347
348+
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
349+
pickled directly), the :arg:`policy_factory` should be used instead.
350+
346351
Keyword Args:
352+
policy_factory (Callable[[], Callable], optional): a callable that returns
353+
a policy instance. This is exclusive with the `policy` argument.
354+
355+
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
356+
347357
frames_per_batch (int): A keyword-only argument representing the total
348358
number of elements in a batch.
349359
total_frames (int): A keyword-only argument representing the total
@@ -515,6 +525,7 @@ def __init__(
515525
policy: None
516526
| (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
517527
*,
528+
policy_factory: Callable[[], Callable] | None = None,
518529
frames_per_batch: int,
519530
total_frames: int = -1,
520531
device: DEVICE_TYPING = None,
@@ -558,8 +569,13 @@ def __init__(
558569
env.update_kwargs(create_env_kwargs)
559570

560571
if policy is None:
572+
if policy_factory is not None:
573+
policy = policy_factory()
574+
else:
575+
policy = RandomPolicy(env.full_action_spec)
576+
elif policy_factory is not None:
577+
raise TypeError("policy_factory cannot be used with policy argument.")
561578

562-
policy = RandomPolicy(env.full_action_spec)
563579
if trust_policy is None:
564580
trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
565581
self.trust_policy = trust_policy
@@ -1429,17 +1445,22 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
14291445
self._iter = state_dict["iter"]
14301446

14311447
def __repr__(self) -> str:
1432-
env_str = indent(f"env={self.env}", 4 * " ")
1433-
policy_str = indent(f"policy={self.policy}", 4 * " ")
1434-
td_out_str = indent(f"td_out={getattr(self, '_final_rollout', None)}", 4 * " ")
1435-
string = (
1436-
f"{self.__class__.__name__}("
1437-
f"\n{env_str},"
1438-
f"\n{policy_str},"
1439-
f"\n{td_out_str},"
1440-
f"\nexploration={self.exploration_type})"
1441-
)
1442-
return string
1448+
try:
1449+
env_str = indent(f"env={self.env}", 4 * " ")
1450+
policy_str = indent(f"policy={self.policy}", 4 * " ")
1451+
td_out_str = indent(
1452+
f"td_out={getattr(self, '_final_rollout', None)}", 4 * " "
1453+
)
1454+
string = (
1455+
f"{self.__class__.__name__}("
1456+
f"\n{env_str},"
1457+
f"\n{policy_str},"
1458+
f"\n{td_out_str},"
1459+
f"\nexploration={self.exploration_type})"
1460+
)
1461+
return string
1462+
except AttributeError:
1463+
return f"{type(self).__name__}(not_init)"
14431464

14441465

14451466
class _MultiDataCollector(DataCollectorBase):
@@ -1469,7 +1490,18 @@ class _MultiDataCollector(DataCollectorBase):
14691490
- In all other cases an attempt to wrap it will be undergone as such:
14701491
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
14711492
1493+
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
1494+
pickled directly), the :arg:`policy_factory` should be used instead.
1495+
14721496
Keyword Args:
1497+
policy_factory (Callable[[], Callable], optional): a callable that returns
1498+
a policy instance. This is exclusive with the `policy` argument.
1499+
1500+
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
1501+
1502+
.. warning:: `policy_factory` is currently not compatible with multiprocessed data
1503+
collectors.
1504+
14731505
frames_per_batch (int): A keyword-only argument representing the
14741506
total number of elements in a batch.
14751507
total_frames (int, optional): A keyword-only argument representing the
@@ -1612,6 +1644,7 @@ def __init__(
16121644
policy: None
16131645
| (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
16141646
*,
1647+
policy_factory: Callable[[], Callable] | None = None,
16151648
frames_per_batch: int,
16161649
total_frames: int | None = -1,
16171650
device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
@@ -1695,27 +1728,36 @@ def __init__(
16951728
self._get_weights_fn_dict = {}
16961729

16971730
if trust_policy is None:
1698-
trust_policy = isinstance(policy, CudaGraphModule)
1731+
trust_policy = policy is not None and isinstance(policy, CudaGraphModule)
16991732
self.trust_policy = trust_policy
17001733

1701-
for policy_device, env_maker, env_maker_kwargs in zip(
1702-
self.policy_device, self.create_env_fn, self.create_env_kwargs
1703-
):
1704-
(policy_copy, get_weights_fn,) = self._get_policy_and_device(
1705-
policy=policy,
1706-
policy_device=policy_device,
1707-
env_maker=env_maker,
1708-
env_maker_kwargs=env_maker_kwargs,
1709-
)
1710-
if type(policy_copy) is not type(policy):
1711-
policy = policy_copy
1712-
weights = (
1713-
TensorDict.from_module(policy_copy)
1714-
if isinstance(policy_copy, nn.Module)
1715-
else TensorDict()
1734+
if policy_factory is not None and policy is not None:
1735+
raise TypeError("policy_factory and policy are mutually exclusive")
1736+
elif policy_factory is None:
1737+
for policy_device, env_maker, env_maker_kwargs in zip(
1738+
self.policy_device, self.create_env_fn, self.create_env_kwargs
1739+
):
1740+
(policy_copy, get_weights_fn,) = self._get_policy_and_device(
1741+
policy=policy,
1742+
policy_device=policy_device,
1743+
env_maker=env_maker,
1744+
env_maker_kwargs=env_maker_kwargs,
1745+
)
1746+
if type(policy_copy) is not type(policy):
1747+
policy = policy_copy
1748+
weights = (
1749+
TensorDict.from_module(policy_copy)
1750+
if isinstance(policy_copy, nn.Module)
1751+
else TensorDict()
1752+
)
1753+
self._policy_weights_dict[policy_device] = weights
1754+
self._get_weights_fn_dict[policy_device] = get_weights_fn
1755+
else:
1756+
# TODO
1757+
raise NotImplementedError(
1758+
"weight syncing is not supported for multiprocessed data collectors at the "
1759+
"moment."
17161760
)
1717-
self._policy_weights_dict[policy_device] = weights
1718-
self._get_weights_fn_dict[policy_device] = get_weights_fn
17191761
self.policy = policy
17201762

17211763
remainder = 0
@@ -2782,7 +2824,15 @@ class aSyncDataCollector(MultiaSyncDataCollector):
27822824
27832825
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
27842826
2827+
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
2828+
pickled directly), the :arg:`policy_factory` should be used instead.
2829+
27852830
Keyword Args:
2831+
policy_factory (Callable[[], Callable], optional): a callable that returns
2832+
a policy instance. This is exclusive with the `policy` argument.
2833+
2834+
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
2835+
27862836
frames_per_batch (int): A keyword-only argument representing the
27872837
total number of elements in a batch.
27882838
total_frames (int, optional): A keyword-only argument representing the
@@ -2888,8 +2938,10 @@ class aSyncDataCollector(MultiaSyncDataCollector):
28882938
def __init__(
28892939
self,
28902940
create_env_fn: Callable[[], EnvBase],
2891-
policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]),
2941+
policy: None
2942+
| (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
28922943
*,
2944+
policy_factory: Callable[[], Callable] | None = None,
28932945
frames_per_batch: int,
28942946
total_frames: int | None = -1,
28952947
device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
@@ -2914,6 +2966,7 @@ def __init__(
29142966
super().__init__(
29152967
create_env_fn=[create_env_fn],
29162968
policy=policy,
2969+
policy_factory=policy_factory,
29172970
total_frames=total_frames,
29182971
create_env_kwargs=[create_env_kwargs],
29192972
max_frames_per_traj=max_frames_per_traj,

torchrl/collectors/distributed/generic.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Callable, OrderedDict
1515

1616
import torch.cuda
17-
from tensordict import TensorDict
17+
from tensordict import TensorDict, TensorDictBase
1818
from torch import nn
1919

2020
from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE
@@ -270,7 +270,15 @@ class DistributedDataCollector(DataCollectorBase):
270270
271271
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
272272
273+
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
274+
pickled directly), the :arg:`policy_factory` should be used instead.
275+
273276
Keyword Args:
277+
policy_factory (Callable[[], Callable], optional): a callable that returns
278+
a policy instance. This is exclusive with the `policy` argument.
279+
280+
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
281+
274282
frames_per_batch (int): A keyword-only argument representing the total
275283
number of elements in a batch.
276284
total_frames (int): A keyword-only argument representing the total
@@ -406,8 +414,9 @@ class DistributedDataCollector(DataCollectorBase):
406414
def __init__(
407415
self,
408416
create_env_fn,
409-
policy,
417+
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
410418
*,
419+
policy_factory: Callable[[], Callable] | None = None,
411420
frames_per_batch: int,
412421
total_frames: int = -1,
413422
device: torch.device | list[torch.device] = None,

0 commit comments

Comments
 (0)