21
21
from multiprocessing .managers import SyncManager
22
22
from queue import Empty
23
23
from textwrap import indent
24
- from typing import Any , Callable , Iterator , Sequence
24
+ from typing import Any , Callable , Iterator , Sequence , TypeVar
25
25
26
26
import numpy as np
27
27
import torch
@@ -86,6 +86,8 @@ def cudagraph_mark_step_begin():
86
86
87
87
_is_osx = sys .platform .startswith ("darwin" )
88
88
89
+ T = TypeVar ("T" )
90
+
89
91
90
92
class _Interruptor :
91
93
"""A class for managing the collection state of a process.
@@ -343,7 +345,15 @@ class SyncDataCollector(DataCollectorBase):
343
345
344
346
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
345
347
348
+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
349
+ pickled directly), the :arg:`policy_factory` should be used instead.
350
+
346
351
Keyword Args:
352
+ policy_factory (Callable[[], Callable], optional): a callable that returns
353
+ a policy instance. This is exclusive with the `policy` argument.
354
+
355
+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
356
+
347
357
frames_per_batch (int): A keyword-only argument representing the total
348
358
number of elements in a batch.
349
359
total_frames (int): A keyword-only argument representing the total
@@ -515,6 +525,7 @@ def __init__(
515
525
policy : None
516
526
| (TensorDictModule | Callable [[TensorDictBase ], TensorDictBase ]) = None ,
517
527
* ,
528
+ policy_factory : Callable [[], Callable ] | None = None ,
518
529
frames_per_batch : int ,
519
530
total_frames : int = - 1 ,
520
531
device : DEVICE_TYPING = None ,
@@ -558,8 +569,13 @@ def __init__(
558
569
env .update_kwargs (create_env_kwargs )
559
570
560
571
if policy is None :
572
+ if policy_factory is not None :
573
+ policy = policy_factory ()
574
+ else :
575
+ policy = RandomPolicy (env .full_action_spec )
576
+ elif policy_factory is not None :
577
+ raise TypeError ("policy_factory cannot be used with policy argument." )
561
578
562
- policy = RandomPolicy (env .full_action_spec )
563
579
if trust_policy is None :
564
580
trust_policy = isinstance (policy , (RandomPolicy , CudaGraphModule ))
565
581
self .trust_policy = trust_policy
@@ -1429,17 +1445,22 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
1429
1445
self ._iter = state_dict ["iter" ]
1430
1446
1431
1447
def __repr__ (self ) -> str :
1432
- env_str = indent (f"env={ self .env } " , 4 * " " )
1433
- policy_str = indent (f"policy={ self .policy } " , 4 * " " )
1434
- td_out_str = indent (f"td_out={ getattr (self , '_final_rollout' , None )} " , 4 * " " )
1435
- string = (
1436
- f"{ self .__class__ .__name__ } ("
1437
- f"\n { env_str } ,"
1438
- f"\n { policy_str } ,"
1439
- f"\n { td_out_str } ,"
1440
- f"\n exploration={ self .exploration_type } )"
1441
- )
1442
- return string
1448
+ try :
1449
+ env_str = indent (f"env={ self .env } " , 4 * " " )
1450
+ policy_str = indent (f"policy={ self .policy } " , 4 * " " )
1451
+ td_out_str = indent (
1452
+ f"td_out={ getattr (self , '_final_rollout' , None )} " , 4 * " "
1453
+ )
1454
+ string = (
1455
+ f"{ self .__class__ .__name__ } ("
1456
+ f"\n { env_str } ,"
1457
+ f"\n { policy_str } ,"
1458
+ f"\n { td_out_str } ,"
1459
+ f"\n exploration={ self .exploration_type } )"
1460
+ )
1461
+ return string
1462
+ except AttributeError :
1463
+ return f"{ type (self ).__name__ } (not_init)"
1443
1464
1444
1465
1445
1466
class _MultiDataCollector (DataCollectorBase ):
@@ -1469,7 +1490,18 @@ class _MultiDataCollector(DataCollectorBase):
1469
1490
- In all other cases an attempt to wrap it will be undergone as such:
1470
1491
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
1471
1492
1493
+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
1494
+ pickled directly), the :arg:`policy_factory` should be used instead.
1495
+
1472
1496
Keyword Args:
1497
+ policy_factory (Callable[[], Callable], optional): a callable that returns
1498
+ a policy instance. This is exclusive with the `policy` argument.
1499
+
1500
+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
1501
+
1502
+ .. warning:: `policy_factory` is currently not compatible with multiprocessed data
1503
+ collectors.
1504
+
1473
1505
frames_per_batch (int): A keyword-only argument representing the
1474
1506
total number of elements in a batch.
1475
1507
total_frames (int, optional): A keyword-only argument representing the
@@ -1612,6 +1644,7 @@ def __init__(
1612
1644
policy : None
1613
1645
| (TensorDictModule | Callable [[TensorDictBase ], TensorDictBase ]) = None ,
1614
1646
* ,
1647
+ policy_factory : Callable [[], Callable ] | None = None ,
1615
1648
frames_per_batch : int ,
1616
1649
total_frames : int | None = - 1 ,
1617
1650
device : DEVICE_TYPING | Sequence [DEVICE_TYPING ] | None = None ,
@@ -1695,27 +1728,36 @@ def __init__(
1695
1728
self ._get_weights_fn_dict = {}
1696
1729
1697
1730
if trust_policy is None :
1698
- trust_policy = isinstance (policy , CudaGraphModule )
1731
+ trust_policy = policy is not None and isinstance (policy , CudaGraphModule )
1699
1732
self .trust_policy = trust_policy
1700
1733
1701
- for policy_device , env_maker , env_maker_kwargs in zip (
1702
- self .policy_device , self .create_env_fn , self .create_env_kwargs
1703
- ):
1704
- (policy_copy , get_weights_fn ,) = self ._get_policy_and_device (
1705
- policy = policy ,
1706
- policy_device = policy_device ,
1707
- env_maker = env_maker ,
1708
- env_maker_kwargs = env_maker_kwargs ,
1709
- )
1710
- if type (policy_copy ) is not type (policy ):
1711
- policy = policy_copy
1712
- weights = (
1713
- TensorDict .from_module (policy_copy )
1714
- if isinstance (policy_copy , nn .Module )
1715
- else TensorDict ()
1734
+ if policy_factory is not None and policy is not None :
1735
+ raise TypeError ("policy_factory and policy are mutually exclusive" )
1736
+ elif policy_factory is None :
1737
+ for policy_device , env_maker , env_maker_kwargs in zip (
1738
+ self .policy_device , self .create_env_fn , self .create_env_kwargs
1739
+ ):
1740
+ (policy_copy , get_weights_fn ,) = self ._get_policy_and_device (
1741
+ policy = policy ,
1742
+ policy_device = policy_device ,
1743
+ env_maker = env_maker ,
1744
+ env_maker_kwargs = env_maker_kwargs ,
1745
+ )
1746
+ if type (policy_copy ) is not type (policy ):
1747
+ policy = policy_copy
1748
+ weights = (
1749
+ TensorDict .from_module (policy_copy )
1750
+ if isinstance (policy_copy , nn .Module )
1751
+ else TensorDict ()
1752
+ )
1753
+ self ._policy_weights_dict [policy_device ] = weights
1754
+ self ._get_weights_fn_dict [policy_device ] = get_weights_fn
1755
+ else :
1756
+ # TODO
1757
+ raise NotImplementedError (
1758
+ "weight syncing is not supported for multiprocessed data collectors at the "
1759
+ "moment."
1716
1760
)
1717
- self ._policy_weights_dict [policy_device ] = weights
1718
- self ._get_weights_fn_dict [policy_device ] = get_weights_fn
1719
1761
self .policy = policy
1720
1762
1721
1763
remainder = 0
@@ -2782,7 +2824,15 @@ class aSyncDataCollector(MultiaSyncDataCollector):
2782
2824
2783
2825
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
2784
2826
2827
+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
2828
+ pickled directly), the :arg:`policy_factory` should be used instead.
2829
+
2785
2830
Keyword Args:
2831
+ policy_factory (Callable[[], Callable], optional): a callable that returns
2832
+ a policy instance. This is exclusive with the `policy` argument.
2833
+
2834
+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
2835
+
2786
2836
frames_per_batch (int): A keyword-only argument representing the
2787
2837
total number of elements in a batch.
2788
2838
total_frames (int, optional): A keyword-only argument representing the
@@ -2888,8 +2938,10 @@ class aSyncDataCollector(MultiaSyncDataCollector):
2888
2938
def __init__ (
2889
2939
self ,
2890
2940
create_env_fn : Callable [[], EnvBase ],
2891
- policy : None | (TensorDictModule | Callable [[TensorDictBase ], TensorDictBase ]),
2941
+ policy : None
2942
+ | (TensorDictModule | Callable [[TensorDictBase ], TensorDictBase ]) = None ,
2892
2943
* ,
2944
+ policy_factory : Callable [[], Callable ] | None = None ,
2893
2945
frames_per_batch : int ,
2894
2946
total_frames : int | None = - 1 ,
2895
2947
device : DEVICE_TYPING | Sequence [DEVICE_TYPING ] | None = None ,
@@ -2914,6 +2966,7 @@ def __init__(
2914
2966
super ().__init__ (
2915
2967
create_env_fn = [create_env_fn ],
2916
2968
policy = policy ,
2969
+ policy_factory = policy_factory ,
2917
2970
total_frames = total_frames ,
2918
2971
create_env_kwargs = [create_env_kwargs ],
2919
2972
max_frames_per_traj = max_frames_per_traj ,
0 commit comments