Skip to content

Commit efe9389

Browse files
author
Vincent Moens
committed
[Refactor] Rename weight updaters
ghstack-source-id: 8889046 Pull Request resolved: #2892
1 parent 31df775 commit efe9389

File tree

35 files changed

+231
-198
lines changed

35 files changed

+231
-198
lines changed

.github/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- pytest-instafail
1818
- pytest-rerunfailures
1919
- pytest-timeout
20+
- pytest-asyncio
2021
- expecttest
2122
- pyyaml
2223
- scipy

.github/unittest/linux_distributed/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies:
1616
- pytest-mock
1717
- pytest-instafail
1818
- pytest-rerunfailures
19+
- pytest-asyncio
1920
- expecttest
2021
- pyyaml
2122
- scipy

.github/unittest/linux_libs/scripts_ataridqn/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies:
1515
- pytest-instafail
1616
- pytest-rerunfailures
1717
- pytest-error-for-skips
18+
- pytest-asyncio
1819
- expecttest
1920
- pyyaml
2021
- scipy

.github/unittest/linux_libs/scripts_brax/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_chess/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_d4rl/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_envpool/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies:
1515
- pytest-instafail
1616
- pytest-rerunfailures
1717
- pytest-error-for-skips
18+
- pytest-asyncio
1819
- expecttest
1920
- pyyaml
2021
- scipy

.github/unittest/linux_libs/scripts_gen-dgrl/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_gym/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- pytest-instafail
2020
- pytest-rerunfailures
2121
- pytest-error-for-skips
22+
- pytest-asyncio
2223
- expecttest
2324
- pyyaml
2425
- scipy

.github/unittest/linux_libs/scripts_habitat/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-error-for-skips
1515
- pytest-rerunfailures
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy==1.9.1

.github/unittest/linux_libs/scripts_jumanji/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_llm/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_meltingpot/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ dependencies:
1212
- pytest-instafail
1313
- pytest-rerunfailures
1414
- pytest-error-for-skips
15+
- pytest-asyncio
1516
- expecttest

.github/unittest/linux_libs/scripts_minari/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_openx/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_robohive/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- pytest-instafail
2020
- pytest-rerunfailures
2121
- pytest-error-for-skips
22+
- pytest-asyncio
2223
- expecttest
2324
- pyyaml
2425
- scipy

.github/unittest/linux_libs/scripts_roboset/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_sklearn/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_smacv2/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies:
1414
- pytest-instafail
1515
- pytest-rerunfailures
1616
- pytest-error-for-skips
17+
- pytest-asyncio
1718
- expecttest
1819
- pyyaml
1920
- numpy==1.23.0

.github/unittest/linux_libs/scripts_vd4rl/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_vmas/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- pytest-instafail
2020
- pytest-rerunfailures
2121
- pytest-error-for-skips
22+
- pytest-asyncio
2223
- expecttest
2324
- pyyaml
2425
- scipy

docs/source/reference/collectors.rst

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,16 @@ mechanism for updating policy weights across different devices and processes, ac
126126
Local and Remote Weight Updaters
127127
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128128

129-
The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.LocalWeightUpdaterBase`
130-
and :class:`~torchrl.collectors.RemoteWeightUpdaterBase`. These base classes provide a structured interface for
129+
The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.WeightUpdateReceiverBase`
130+
and :class:`~torchrl.collectors.WeightUpdateSenderBase`. These base classes provide a structured interface for
131131
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
132132

133-
- :class:`~torchrl.collectors.LocalWeightUpdaterBase`: This component is responsible for updating the policy weights on
133+
- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for updating the policy weights on
134134
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
135135
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
136136
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
137137
situations where the server decides when to update the worker policies).
138-
- :class:`~torchrl.collectors.RemoteWeightUpdaterBase`: This component handles the distribution of policy weights to
138+
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
139139
remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
140140
the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
141141
devices or processes.
@@ -153,8 +153,8 @@ Default Implementations
153153

154154
For common scenarios, the API provides default implementations of these updaters, such as
155155
:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
156-
:class:`~torchrl.collectors.RayRemoteWeightUpdater`, :class:`~torchrl.collectors.RPCRemoteWeightUpdater`, and
157-
:class:`~torchrl.collectors.DistributedRemoteWeightUpdater`.
156+
:class:`~torchrl.collectors.RayWeightUpdateSender`, :class:`~torchrl.collectors.RPCWeightUpdateSender`, and
157+
:class:`~torchrl.collectors.DistributedWeightUpdateSender`.
158158
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
159159
distributed systems.
160160

@@ -180,13 +180,13 @@ scenarios, ensuring that their policies remain up-to-date and performant.
180180
:toctree: generated/
181181
:template: rl_template.rst
182182

183-
LocalWeightUpdaterBase
184-
RemoteWeightUpdaterBase
183+
WeightUpdateReceiverBase
184+
WeightUpdateSenderBase
185185
VanillaLocalWeightUpdater
186186
MultiProcessedRemoteWeightUpdate
187-
RayRemoteWeightUpdater
188-
DistributedRemoteWeightUpdater
189-
RPCRemoteWeightUpdater
187+
RayWeightUpdateSender
188+
DistributedWeightUpdateSender
189+
RPCWeightUpdateSender
190190

191191
Collectors and replay buffers interoperability
192192
----------------------------------------------

examples/collectors/mp_collector_mps.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,29 @@
1111
----------------------
1212
1313
This script demonstrates a weight update in TorchRL.
14-
The script uses a custom `MPSRemoteWeightUpdater` class to update the weights of a policy network across multiple workers.
14+
The script uses a custom `MPSWeightUpdateSender` class to update the weights of a policy network across multiple workers.
1515
1616
Key Features
1717
------------
1818
1919
- Multi-Worker Setup: The script creates two worker processes that collect data from a Gym environment
2020
("Pendulum-v1") using a policy network.
2121
- MPS (Metal Performance Shaders) Device: The policy network is placed on an MPS device.
22-
- Custom Weight Updater: The `MPSRemoteWeightUpdater` class is used to update the policy weights across workers. This
22+
- Custom Weight Updater: The `MPSWeightUpdateSender` class is used to update the policy weights across workers. This
2323
class is necessary because MPS tensors cannot be sent over a pipe due to serialization/pickling issues in PyTorch.
2424
2525
Workaround for MPS Tensor Serialization Issue
2626
---------------------------------------------
2727
2828
In PyTorch, MPS tensors cannot be serialized or pickled, which means they cannot be sent over a pipe or shared between
29-
processes. To work around this issue, the MPSRemoteWeightUpdater class sends the policy weights on the CPU device
29+
processes. To work around this issue, the MPSWeightUpdateSender class sends the policy weights on the CPU device
3030
instead of the MPS device. The local workers then copy the weights from the CPU device to the MPS device.
3131
3232
Script Flow
3333
-----------
3434
3535
1. Initialize the environment, policy network, and collector.
36-
2. Update the policy weights using the MPSRemoteWeightUpdater.
36+
2. Update the policy weights using the MPSWeightUpdateSender.
3737
3. Collect data from the environment using the policy network.
3838
4. Zero out the policy weights after a few iterations.
3939
5. Verify that the updated policy weights are being used by checking the actions generated by the policy network.
@@ -45,12 +45,12 @@ class is necessary because MPS tensors cannot be sent over a pipe due to seriali
4545
from tensordict import TensorDictBase
4646
from tensordict.nn import TensorDictModule
4747
from torch import nn
48-
from torchrl.collectors import MultiSyncDataCollector, RemoteWeightUpdaterBase
48+
from torchrl.collectors import MultiSyncDataCollector, WeightUpdateSenderBase
4949

5050
from torchrl.envs.libs.gym import GymEnv
5151

5252

53-
class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
53+
class MPSWeightUpdaterBase(WeightUpdateSenderBase):
5454
def __init__(self, policy_weights, num_workers):
5555
# Weights are on mps device, which cannot be shared
5656
self.policy_weights = policy_weights.data
@@ -101,7 +101,7 @@ def policy_factory(device=device):
101101
reset_at_each_iter=False,
102102
device=device,
103103
storing_device="cpu",
104-
remote_weight_updater=MPSRemoteWeightUpdater(policy_weights, 2),
104+
weight_update_sender=MPSWeightUpdaterBase(policy_weights, 2),
105105
# use_buffers=False,
106106
# cat_results="stack",
107107
)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def _main(argv):
216216
"scipy",
217217
"pytest-mock",
218218
"pytest-cov",
219+
"pytest-asyncio",
219220
"pytest-benchmark",
220221
"pytest-rerunfailures",
221222
"pytest-error-for-skips",

test/test_collector.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
)
4444
from torchrl.collectors import (
4545
aSyncDataCollector,
46-
RemoteWeightUpdaterBase,
4746
SyncDataCollector,
47+
WeightUpdateSenderBase,
4848
)
4949
from torchrl.collectors.collectors import (
5050
_Interruptor,
@@ -3498,7 +3498,7 @@ def __deepcopy_error__(*args, **kwargs):
34983498

34993499

35003500
class TestPolicyFactory:
3501-
class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
3501+
class MPSWeightUpdaterBase(WeightUpdateSenderBase):
35023502
def __init__(self, policy_weights, num_workers):
35033503
# Weights are on mps device, which cannot be shared
35043504
self.policy_weights = policy_weights.data
@@ -3542,7 +3542,7 @@ def test_weight_update(self):
35423542
reset_at_each_iter=False,
35433543
device=device,
35443544
storing_device="cpu",
3545-
remote_weight_updater=self.MPSRemoteWeightUpdater(policy_weights, 2),
3545+
weight_update_sender=self.MPSWeightUpdaterBase(policy_weights, 2),
35463546
)
35473547

35483548
collector.update_policy_weights_()
@@ -3683,10 +3683,9 @@ def _run_collector_test(self, total_steps, rb, policy, tokenizer):
36833683
assert len(stack["text"][i]) < len(stack["next", "text"][i])
36843684
assert collector._frames >= total_steps
36853685

3686-
def test_llm_collector_start(self, vllm_instance):
3687-
asyncio.run(self._async_run_collector_test(vllm_instance))
3688-
3689-
async def _async_run_collector_test(self, vllm_instance):
3686+
@pytest.mark.slow
3687+
@pytest.mark.asyncio
3688+
async def test_llm_collector_start(self, vllm_instance):
36903689
total_steps = 20
36913690
policy = vLLMWrapper(vllm_instance)
36923691
vllm_instance.get_tokenizer()
@@ -3708,28 +3707,29 @@ async def _async_run_collector_test(self, vllm_instance):
37083707
replay_buffer=rb,
37093708
total_steps=total_steps,
37103709
)
3710+
torchrl_logger.info("starting")
37113711
collector.start()
37123712

3713-
i = 0
3714-
wait = 0
3713+
j = 0
37153714
while True:
3716-
while not len(rb):
3715+
if not len(rb):
37173716
await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep
3718-
wait += 1
3719-
if wait > 20:
3720-
raise RuntimeError
37213717
sample = rb.sample(10)
3722-
for i in range(sample.numel()):
3718+
assert sample.ndim == 1
3719+
for i in range(10):
37233720
# Check that there are more chars in the next step
37243721
assert len(sample["text"][i]) < len(sample["next", "text"][i])
37253722
assert not sample._has_exclusive_keys, sample
3726-
await asyncio.sleep(0.1) # Use asyncio.sleep instead of time.sleep
3727-
i += 1
3728-
if i == 5:
3723+
j += 1
3724+
if j == 5:
37293725
break
37303726
assert collector._frames >= total_steps
37313727

3732-
await collector.async_shutdown()
3728+
try:
3729+
# Assuming collector._task is the task created in start()
3730+
await asyncio.wait_for(collector.async_shutdown(), timeout=30)
3731+
except asyncio.TimeoutError:
3732+
torchrl_logger.info("Collector shutdown timed out")
37333733

37343734
@pytest.mark.slow
37353735
@pytest.mark.parametrize("rb", [False, True])

torchrl/collectors/__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@
1313
SyncDataCollector,
1414
)
1515
from .weight_update import (
16-
LocalWeightUpdaterBase,
17-
MultiProcessedRemoteWeightUpdate,
18-
RayRemoteWeightUpdater,
19-
RemoteWeightUpdaterBase,
20-
VanillaLocalWeightUpdater,
16+
MultiProcessedWeightUpdate,
17+
RayWeightUpdater,
18+
VanillaWeightUpdater,
19+
WeightUpdateReceiverBase,
20+
WeightUpdateSenderBase,
2121
)
2222

2323
__all__ = [
2424
"RandomPolicy",
25-
"LocalWeightUpdaterBase",
26-
"RemoteWeightUpdaterBase",
27-
"VanillaLocalWeightUpdater",
28-
"RayRemoteWeightUpdater",
29-
"MultiProcessedRemoteWeightUpdate",
25+
"WeightUpdateReceiverBase",
26+
"WeightUpdateSenderBase",
27+
"VanillaWeightUpdater",
28+
"RayWeightUpdater",
29+
"MultiProcessedWeightUpdate",
3030
"aSyncDataCollector",
3131
"DataCollectorBase",
3232
"MultiaSyncDataCollector",

0 commit comments

Comments
 (0)