Skip to content

Commit 0da9044

Browse files
author
Vincent Moens
committed
[Refactor] Refactor the weight update logic
ghstack-source-id: 72b710a Pull Request resolved: #2914
1 parent 21ef725 commit 0da9044

File tree

12 files changed

+206
-777
lines changed

12 files changed

+206
-777
lines changed

docs/source/reference/collectors.rst

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -118,75 +118,49 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
118118
Policy copy decision tree in Collectors.
119119

120120
Weight Synchronization in Distributed Environments
121-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
121+
--------------------------------------------------
122+
122123
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
123124
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
124125
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
125126

126-
Local and Remote Weight Updaters
127-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
127+
Sending and receiving model weights with WeightUpdaters
128+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128129

129-
The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.WeightUpdateReceiverBase`
130-
and :class:`~torchrl.collectors.WeightUpdateSenderBase`. These base classes provide a structured interface for
130+
The weight synchronization process is facilitated by one dedicated extension point:
131+
:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for
131132
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
132133

133-
- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for updating the policy weights on
134-
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
135-
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
136-
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
137-
situations where the server decides when to update the worker policies).
138-
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
139-
remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
140-
the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
141-
devices or processes.
134+
:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to
135+
the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary.
136+
Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the
137+
weight synchronization with the policy.
138+
Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy
139+
state-dict (assuming it is a :class:`~torch.nn.Module` instance).
142140

143-
Extending the Updater Classes
144-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141+
Extending the Updater Class
142+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
145143

146144
To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations.
145+
The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation
146+
untouched.
147147
This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware
148-
setups. By implementing the abstract methods in these base classes, users can define how weights are retrieved,
148+
setups.
149+
By implementing the abstract methods in these base classes, users can define how weights are retrieved,
149150
transformed, and applied, ensuring seamless integration with their existing infrastructure.
150151

151-
Default Implementations
152-
~~~~~~~~~~~~~~~~~~~~~~~
153-
154-
For common scenarios, the API provides default implementations of these updaters, such as
155-
:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
156-
:class:`~torchrl.collectors.RayWeightUpdateSender`, :class:`~torchrl.collectors.RPCWeightUpdateSender`, and
157-
:class:`~torchrl.collectors.DistributedWeightUpdateSender`.
158-
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
159-
distributed systems.
160-
161-
Practical Considerations
162-
~~~~~~~~~~~~~~~~~~~~~~~~
163-
164-
When designing a system that leverages this API, consider the following:
165-
166-
- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
167-
implementation accounts for potential delays and optimizes data transfer where possible.
168-
- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
169-
the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
170-
suboptimal policy performance.
171-
- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
172-
overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.
173-
174-
By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
175-
scenarios, ensuring that their policies remain up-to-date and performant.
176-
177152
.. currentmodule:: torchrl.collectors
178153

179154
.. autosummary::
180155
:toctree: generated/
181156
:template: rl_template.rst
182157

183-
WeightUpdateReceiverBase
184-
WeightUpdateSenderBase
185-
VanillaLocalWeightUpdater
186-
MultiProcessedRemoteWeightUpdate
187-
RayWeightUpdateSender
188-
DistributedWeightUpdateSender
189-
RPCWeightUpdateSender
158+
WeightUpdaterBase
159+
VanillaWeightUpdater
160+
MultiProcessedWeightUpdater
161+
RayWeightUpdater
162+
DistributedWeightUpdater
163+
RPCWeightUpdater
190164

191165
Collectors and replay buffers interoperability
192166
----------------------------------------------

examples/collectors/mp_collector_mps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ class is necessary because MPS tensors cannot be sent over a pipe due to seriali
4545
from tensordict import TensorDictBase
4646
from tensordict.nn import TensorDictModule
4747
from torch import nn
48-
from torchrl.collectors import MultiSyncDataCollector, WeightUpdateSenderBase
48+
from torchrl.collectors import MultiSyncDataCollector, WeightUpdaterBase
4949

5050
from torchrl.envs.libs.gym import GymEnv
5151

5252

53-
class MPSWeightUpdaterBase(WeightUpdateSenderBase):
53+
class MPSWeightUpdaterBase(WeightUpdaterBase):
5454
def __init__(self, policy_weights, num_workers):
5555
# Weights are on mps device, which cannot be shared
5656
self.policy_weights = policy_weights.data
@@ -101,7 +101,7 @@ def policy_factory(device=device):
101101
reset_at_each_iter=False,
102102
device=device,
103103
storing_device="cpu",
104-
weight_update_sender=MPSWeightUpdaterBase(policy_weights, 2),
104+
weight_updater=MPSWeightUpdaterBase(policy_weights, 2),
105105
# use_buffers=False,
106106
# cat_results="stack",
107107
)

test/test_collector.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,7 @@
3939
prod,
4040
seed_generator,
4141
)
42-
from torchrl.collectors import (
43-
aSyncDataCollector,
44-
SyncDataCollector,
45-
WeightUpdateSenderBase,
46-
)
42+
from torchrl.collectors import aSyncDataCollector, SyncDataCollector, WeightUpdaterBase
4743
from torchrl.collectors.collectors import (
4844
_Interruptor,
4945
MultiaSyncDataCollector,
@@ -3489,7 +3485,7 @@ def __deepcopy_error__(*args, **kwargs):
34893485

34903486

34913487
class TestPolicyFactory:
3492-
class MPSWeightUpdaterBase(WeightUpdateSenderBase):
3488+
class MPSWeightUpdaterBase(WeightUpdaterBase):
34933489
def __init__(self, policy_weights, num_workers):
34943490
# Weights are on mps device, which cannot be shared
34953491
self.policy_weights = policy_weights.data
@@ -3533,7 +3529,7 @@ def test_weight_update(self):
35333529
reset_at_each_iter=False,
35343530
device=device,
35353531
storing_device="cpu",
3536-
weight_update_sender=self.MPSWeightUpdaterBase(policy_weights, 2),
3532+
weight_updater=self.MPSWeightUpdaterBase(policy_weights, 2),
35373533
)
35383534

35393535
collector.update_policy_weights_()

torchrl/collectors/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@
1616
MultiProcessedWeightUpdate,
1717
RayWeightUpdater,
1818
VanillaWeightUpdater,
19-
WeightUpdateReceiverBase,
20-
WeightUpdateSenderBase,
19+
WeightUpdaterBase,
2120
)
2221

2322
__all__ = [
2423
"RandomPolicy",
25-
"WeightUpdateReceiverBase",
26-
"WeightUpdateSenderBase",
24+
"WeightUpdaterBase",
2725
"VanillaWeightUpdater",
2826
"RayWeightUpdater",
2927
"MultiProcessedWeightUpdate",

0 commit comments

Comments
 (0)