11
11
import warnings
12
12
from copy import copy , deepcopy
13
13
from datetime import timedelta
14
- from typing import Callable , OrderedDict
14
+ from typing import Any , Callable , OrderedDict , Sequence
15
15
16
16
import torch .cuda
17
17
from tensordict import TensorDict , TensorDictBase
@@ -131,6 +131,7 @@ def _distributed_init_collection_node(
131
131
num_workers ,
132
132
env_make ,
133
133
policy ,
134
+ policy_factory ,
134
135
frames_per_batch ,
135
136
collector_kwargs ,
136
137
verbose = True ,
@@ -143,6 +144,7 @@ def _distributed_init_collection_node(
143
144
num_workers ,
144
145
env_make ,
145
146
policy ,
147
+ policy_factory ,
146
148
frames_per_batch ,
147
149
collector_kwargs ,
148
150
verbose = verbose ,
@@ -156,6 +158,7 @@ def _run_collector(
156
158
num_workers ,
157
159
env_make ,
158
160
policy ,
161
+ policy_factory ,
159
162
frames_per_batch ,
160
163
collector_kwargs ,
161
164
verbose = True ,
@@ -178,12 +181,17 @@ def _run_collector(
178
181
policy_weights = TensorDict .from_module (policy )
179
182
policy_weights = policy_weights .data .lock_ ()
180
183
else :
181
- warnings .warn (_NON_NN_POLICY_WEIGHTS )
184
+ if collector_kwargs .get ("remote_weight_updater" ) is None and (
185
+ policy_factory is None
186
+ or (isinstance (policy_factory , Sequence ) and not any (policy_factory ))
187
+ ):
188
+ warnings .warn (_NON_NN_POLICY_WEIGHTS )
182
189
policy_weights = TensorDict (lock = True )
183
190
184
191
collector = collector_class (
185
192
env_make ,
186
193
policy ,
194
+ policy_factory = policy_factory ,
187
195
frames_per_batch = frames_per_batch ,
188
196
total_frames = - 1 ,
189
197
split_trajs = False ,
@@ -278,8 +286,8 @@ class DistributedDataCollector(DataCollectorBase):
278
286
pickled directly), the :arg:`policy_factory` should be used instead.
279
287
280
288
Keyword Args:
281
- policy_factory (Callable[[], Callable], optional): a callable that returns
282
- a policy instance. This is exclusive with the `policy` argument.
289
+ policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
290
+ (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument.
283
291
284
292
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
285
293
@@ -411,14 +419,16 @@ class DistributedDataCollector(DataCollectorBase):
411
419
to learn more.
412
420
Defaults to ``"submitit"``.
413
421
tcp_port (int, optional): the TCP port to be used. Defaults to 10003.
414
- local_weight_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
422
+ local_weight_updater (LocalWeightUpdaterBase or constructor , optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
415
423
or its subclass, responsible for updating the policy weights on the local inference worker.
416
424
This is typically not used in :class:`~torchrl.collectors.distributed.DistributedDataCollector` as it
417
425
focuses on distributed environments.
418
- remote_weight_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
426
+ Consider using a constructor if the updater needs to be serialized.
427
+ remote_weight_updater (RemoteWeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
419
428
or its subclass, responsible for updating the policy weights on distributed inference workers.
420
429
If not provided, a :class:`~torchrl.collectors.distributed.DistributedRemoteWeightUpdater` will be used by
421
430
default, which handles weight synchronization across distributed workers.
431
+ Consider using a constructor if the updater needs to be serialized.
422
432
423
433
"""
424
434
@@ -429,31 +439,37 @@ def __init__(
429
439
create_env_fn ,
430
440
policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
431
441
* ,
432
- policy_factory : Callable [[], Callable ] | None = None ,
442
+ policy_factory : Callable [[], Callable ]
443
+ | list [Callable [[] | Callable ]]
444
+ | None = None ,
433
445
frames_per_batch : int ,
434
446
total_frames : int = - 1 ,
435
- device : torch .device | list [torch .device ] = None ,
436
- storing_device : torch .device | list [torch .device ] = None ,
437
- env_device : torch .device | list [torch .device ] = None ,
438
- policy_device : torch .device | list [torch .device ] = None ,
447
+ device : torch .device | list [torch .device ] | None = None ,
448
+ storing_device : torch .device | list [torch .device ] | None = None ,
449
+ env_device : torch .device | list [torch .device ] | None = None ,
450
+ policy_device : torch .device | list [torch .device ] | None = None ,
439
451
max_frames_per_traj : int = - 1 ,
440
452
init_random_frames : int = - 1 ,
441
453
reset_at_each_iter : bool = False ,
442
454
postproc : Callable | None = None ,
443
455
split_trajs : bool = False ,
444
456
exploration_type : ExporationType = DEFAULT_EXPLORATION_TYPE , # noqa
445
457
collector_class : type = SyncDataCollector ,
446
- collector_kwargs : dict = None ,
458
+ collector_kwargs : dict [ str , Any ] | None = None ,
447
459
num_workers_per_collector : int = 1 ,
448
460
sync : bool = False ,
449
- slurm_kwargs : dict | None = None ,
461
+ slurm_kwargs : dict [ str , Any ] | None = None ,
450
462
backend : str = "gloo" ,
451
463
update_after_each_batch : bool = False ,
452
464
max_weight_update_interval : int = - 1 ,
453
465
launcher : str = "submitit" ,
454
- tcp_port : int = None ,
455
- remote_weight_updater : RemoteWeightUpdaterBase | None = None ,
456
- local_weight_updater : LocalWeightUpdaterBase | None = None ,
466
+ tcp_port : int | None = None ,
467
+ remote_weight_updater : RemoteWeightUpdaterBase
468
+ | Callable [[], RemoteWeightUpdaterBase ]
469
+ | None = None ,
470
+ local_weight_updater : LocalWeightUpdaterBase
471
+ | Callable [[], LocalWeightUpdaterBase ]
472
+ | None = None ,
457
473
):
458
474
459
475
if collector_class == "async" :
@@ -465,18 +481,22 @@ def __init__(
465
481
self .collector_class = collector_class
466
482
self .env_constructors = create_env_fn
467
483
self .policy = policy
484
+ if not isinstance (policy_factory , Sequence ):
485
+ policy_factory = [policy_factory for _ in range (len (self .env_constructors ))]
486
+ self .policy_factory = policy_factory
468
487
if isinstance (policy , nn .Module ):
469
488
policy_weights = TensorDict .from_module (policy )
470
489
policy_weights = policy_weights .data .lock_ ()
471
- elif policy_factory is not None :
490
+ elif any ( policy_factory ) :
472
491
policy_weights = None
473
492
if remote_weight_updater is None :
474
493
raise RuntimeError (
475
494
"remote_weight_updater must be passed along with "
476
495
"a policy_factory."
477
496
)
478
497
else :
479
- warnings .warn (_NON_NN_POLICY_WEIGHTS )
498
+ if not any (policy_factory ):
499
+ warnings .warn (_NON_NN_POLICY_WEIGHTS )
480
500
policy_weights = TensorDict (lock = True )
481
501
self .policy_weights = policy_weights
482
502
self .num_workers = len (create_env_fn )
@@ -664,12 +684,15 @@ def _make_container(self):
664
684
if self ._VERBOSE :
665
685
torchrl_logger .info ("making container" )
666
686
env_constructor = self .env_constructors [0 ]
687
+ kwargs = self .collector_kwargs [0 ]
667
688
pseudo_collector = SyncDataCollector (
668
689
env_constructor ,
669
- self .policy ,
690
+ policy = self .policy ,
691
+ policy_factory = self .policy_factory [0 ],
670
692
frames_per_batch = self ._frames_per_batch_corrected ,
671
693
total_frames = - 1 ,
672
694
split_trajs = False ,
695
+ ** kwargs ,
673
696
)
674
697
for _data in pseudo_collector :
675
698
break
@@ -713,6 +736,7 @@ def _init_worker_dist_submitit(self, executor, i):
713
736
self .num_workers_per_collector ,
714
737
env_make ,
715
738
self .policy ,
739
+ self .policy_factory [i ],
716
740
self ._frames_per_batch_corrected ,
717
741
self .collector_kwargs [i ],
718
742
self ._VERBOSE ,
@@ -734,6 +758,7 @@ def get_env_make(i):
734
758
"num_workers" : self .num_workers_per_collector ,
735
759
"env_make" : get_env_make (i ),
736
760
"policy" : self .policy ,
761
+ "policy_factory" : self .policy_factory [i ],
737
762
"frames_per_batch" : self ._frames_per_batch_corrected ,
738
763
"collector_kwargs" : self .collector_kwargs [i ],
739
764
}
@@ -760,6 +785,7 @@ def _init_worker_dist_mp(self, i):
760
785
self .num_workers_per_collector ,
761
786
env_make ,
762
787
self .policy ,
788
+ self .policy_factory [i ],
763
789
self ._frames_per_batch_corrected ,
764
790
self .collector_kwargs [i ],
765
791
self ._VERBOSE ,
0 commit comments