@@ -118,75 +118,49 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
118
118
Policy copy decision tree in Collectors.
119
119
120
120
Weight Synchronization in Distributed Environments
121
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
121
+ --------------------------------------------------
122
+
122
123
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
123
124
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
124
125
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
125
126
126
- Local and Remote Weight Updaters
127
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
127
+ Sending and receiving model weights with WeightUpdaters
128
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128
129
129
- The weight synchronization process is facilitated by two main components: :class: ` ~torchrl.collectors.WeightUpdateReceiverBase `
130
- and :class: `~torchrl.collectors.WeightUpdateSenderBase `. These base classes provide a structured interface for
130
+ The weight synchronization process is facilitated by one dedicated extension point:
131
+ :class: `~torchrl.collectors.WeightUpdaterBase `. These base class provides a structured interface for
131
132
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
132
133
133
- - :class: `~torchrl.collectors.WeightUpdateReceiverBase `: This component is responsible for updating the policy weights on
134
- the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
135
- different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
136
- It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
137
- situations where the server decides when to update the worker policies).
138
- - :class: `~torchrl.collectors.WeightUpdateSenderBase `: This component handles the distribution of policy weights to
139
- remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
140
- the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
141
- devices or processes.
134
+ :class: `~torchrl.collectors.WeightUpdaterBase ` handles the distribution of policy weights to
135
+ the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary.
136
+ Every collector -- server or worker -- should have a `WeightUpdaterBase ` instance to handle the
137
+ weight synchronization with the policy.
138
+ Even the simplest collectors use a :class: `~torchrl.collectors.VanillaWeightUpdater ` instance to update the policy
139
+ state-dict (assuming it is a :class: `~torch.nn.Module ` instance).
142
140
143
- Extending the Updater Classes
144
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141
+ Extending the Updater Class
142
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
145
143
146
144
To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations.
145
+ The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation
146
+ untouched.
147
147
This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware
148
- setups. By implementing the abstract methods in these base classes, users can define how weights are retrieved,
148
+ setups.
149
+ By implementing the abstract methods in these base classes, users can define how weights are retrieved,
149
150
transformed, and applied, ensuring seamless integration with their existing infrastructure.
150
151
151
- Default Implementations
152
- ~~~~~~~~~~~~~~~~~~~~~~~
153
-
154
- For common scenarios, the API provides default implementations of these updaters, such as
155
- :class: `~torchrl.collectors.VanillaLocalWeightUpdater `, :class: `~torchrl.collectors.MultiProcessedRemoteWeightUpdate `,
156
- :class: `~torchrl.collectors.RayWeightUpdateSender `, :class: `~torchrl.collectors.RPCWeightUpdateSender `, and
157
- :class: `~torchrl.collectors.DistributedWeightUpdateSender `.
158
- These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
159
- distributed systems.
160
-
161
- Practical Considerations
162
- ~~~~~~~~~~~~~~~~~~~~~~~~
163
-
164
- When designing a system that leverages this API, consider the following:
165
-
166
- - Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
167
- implementation accounts for potential delays and optimizes data transfer where possible.
168
- - Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
169
- the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
170
- suboptimal policy performance.
171
- - Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
172
- overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.
173
-
174
- By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
175
- scenarios, ensuring that their policies remain up-to-date and performant.
176
-
177
152
.. currentmodule :: torchrl.collectors
178
153
179
154
.. autosummary ::
180
155
:toctree: generated/
181
156
:template: rl_template.rst
182
157
183
- WeightUpdateReceiverBase
184
- WeightUpdateSenderBase
185
- VanillaLocalWeightUpdater
186
- MultiProcessedRemoteWeightUpdate
187
- RayWeightUpdateSender
188
- DistributedWeightUpdateSender
189
- RPCWeightUpdateSender
158
+ WeightUpdaterBase
159
+ VanillaWeightUpdater
160
+ MultiProcessedWeightUpdater
161
+ RayWeightUpdater
162
+ DistributedWeightUpdater
163
+ RPCWeightUpdater
190
164
191
165
Collectors and replay buffers interoperability
192
166
----------------------------------------------
0 commit comments