Skip to content

Commit b2aeed9

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
2 parents 04d8866 + efe9389 commit b2aeed9

File tree

107 files changed

+7079
-1620
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+7079
-1620
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
3+
pip install --upgrade setuptools
4+
5+
export TORCHRL_BUILD_VERSION=0.7.0
6+
7+
${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U

.github/scripts/pre-build-script.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
3+
pip install --upgrade setuptools
4+
5+
${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U

.github/scripts/td_script.sh

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
#!/bin/bash
22

33
export TORCHRL_BUILD_VERSION=0.7.0
4+
pip install --upgrade setuptools
45

5-
${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U
6+
# Check if ARCH is set to aarch64
7+
ARCH=${ARCH:-} # This sets ARCH to an empty string if it's not defined
8+
9+
if pip list | grep -q torch; then
10+
echo "Torch is installed."
11+
if [[ "$ARCH" == "aarch64" ]]; then
12+
${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U --no-deps
13+
else
14+
${CONDA_RUN} pip install tensordict-nightly -U
15+
fi
16+
elif [[ -n "${SMOKE_TEST_SCRIPT:-}" ]]; then
17+
${CONDA_RUN} ${PIP_INSTALL_TORCH}
18+
if [[ "$ARCH" == "aarch64" ]]; then
19+
${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U --no-deps
20+
else
21+
${CONDA_RUN} pip install tensordict-nightly -U
22+
fi
23+
else
24+
echo "Torch is not installed - tensordict will be installed later."
25+
fi

.github/scripts/version_script.bat

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ if "%CU_VERSION%" == "xpu" call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat
3636

3737
set DISTUTILS_USE_SDK=1
3838

39+
:: Upgrade setuptools before installing PyTorch
40+
pip install --upgrade setuptools==72.1.0 || exit /b 1
41+
3942
set args=%1
4043
shift
4144
:start

.github/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- pytest-instafail
1818
- pytest-rerunfailures
1919
- pytest-timeout
20+
- pytest-asyncio
2021
- expecttest
2122
- pyyaml
2223
- scipy

.github/unittest/linux_distributed/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies:
1616
- pytest-mock
1717
- pytest-instafail
1818
- pytest-rerunfailures
19+
- pytest-asyncio
1920
- expecttest
2021
- pyyaml
2122
- scipy

.github/unittest/linux_libs/scripts_ataridqn/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies:
1515
- pytest-instafail
1616
- pytest-rerunfailures
1717
- pytest-error-for-skips
18+
- pytest-asyncio
1819
- expecttest
1920
- pyyaml
2021
- scipy

.github/unittest/linux_libs/scripts_brax/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_chess/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_d4rl/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_envpool/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies:
1515
- pytest-instafail
1616
- pytest-rerunfailures
1717
- pytest-error-for-skips
18+
- pytest-asyncio
1819
- expecttest
1920
- pyyaml
2021
- scipy

.github/unittest/linux_libs/scripts_gen-dgrl/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_gym/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- pytest-instafail
2020
- pytest-rerunfailures
2121
- pytest-error-for-skips
22+
- pytest-asyncio
2223
- expecttest
2324
- pyyaml
2425
- scipy

.github/unittest/linux_libs/scripts_habitat/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-error-for-skips
1515
- pytest-rerunfailures
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy==1.9.1

.github/unittest/linux_libs/scripts_jumanji/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_llm/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_meltingpot/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ dependencies:
1212
- pytest-instafail
1313
- pytest-rerunfailures
1414
- pytest-error-for-skips
15+
- pytest-asyncio
1516
- expecttest

.github/unittest/linux_libs/scripts_minari/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_openx/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_robohive/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- pytest-instafail
2020
- pytest-rerunfailures
2121
- pytest-error-for-skips
22+
- pytest-asyncio
2223
- expecttest
2324
- pyyaml
2425
- scipy

.github/unittest/linux_libs/scripts_roboset/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_sklearn/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_smacv2/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies:
1414
- pytest-instafail
1515
- pytest-rerunfailures
1616
- pytest-error-for-skips
17+
- pytest-asyncio
1718
- expecttest
1819
- pyyaml
1920
- numpy==1.23.0

.github/unittest/linux_libs/scripts_vd4rl/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- pytest-instafail
1414
- pytest-rerunfailures
1515
- pytest-error-for-skips
16+
- pytest-asyncio
1617
- expecttest
1718
- pyyaml
1819
- scipy

.github/unittest/linux_libs/scripts_vmas/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- pytest-instafail
2020
- pytest-rerunfailures
2121
- pytest-error-for-skips
22+
- pytest-asyncio
2223
- expecttest
2324
- pyyaml
2425
- scipy

.github/workflows/build-wheels-aarch64-linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ jobs:
3535
- repository: pytorch/rl
3636
smoke-test-script: test/smoke_test.py
3737
package-name: torchrl
38+
pre-script: .github/scripts/pre-build-script.sh
3839
name: pytorch/rl
3940
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
4041
with:

.github/workflows/build-wheels-linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
- repository: pytorch/rl
3535
smoke-test-script: test/smoke_test.py
3636
package-name: torchrl
37+
pre-script: .github/scripts/pre-build-script.sh
3738
name: pytorch/rl
3839
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
3940
with:

.github/workflows/build-wheels-m1.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
- repository: pytorch/rl
3535
smoke-test-script: test/smoke_test.py
3636
package-name: torchrl
37+
pre-script: .github/scripts/pre-build-script.sh
3738
name: ${{ matrix.repository }}
3839
uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main
3940
with:

.github/workflows/build-wheels-windows.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
matrix:
3333
include:
3434
- repository: pytorch/rl
35-
pre-script: .github/scripts/td_script.sh
35+
pre-script: .github/scripts/pre-build-script-win.sh
3636
env-script: .github/scripts/version_script.bat
3737
post-script: "python packaging/wheel/relocate.py"
3838
smoke-test-script: test/smoke_test.py

docs/source/_static/js/torchrl_theme.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ var downloadNote = $(".sphx-glr-download-link-note.admonition.note");
944944
if (downloadNote.length >= 1) {
945945
var tutorialUrlArray = $("#tutorial-type").text().split('/');
946946

947-
var githubLink = "https://github.com/pytorch/rl/tree/tutorial_py_dup/sphinx-tutorials/" + tutorialUrlArray[tutorialUrlArray.length - 1] + ".py",
947+
var githubLink = "https://github.com/pytorch/rl/tree/tutorial_py_dup/sphinx-" + tutorialUrlArray[tutorialUrlArray.length - 1] + ".py",
948948
notebookLink = $(".sphx-glr-download-jupyter").find(".download.reference")[0].href,
949949
notebookDownloadPath = notebookLink.split('_downloads')[1],
950950
colabLink = "https://colab.research.google.com/github/pytorch/rl/blob/gh-pages/_downloads" + notebookDownloadPath;

docs/source/_templates/layout.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
var downloadNote = $(".sphx-glr-download-link-note.admonition.note");
6161
if (downloadNote.length >= 1) {
6262
var tutorialUrl = $("#tutorial-type").text();
63-
var githubLink = "https://github.com/pytorch/rl/blob/main/tutorials/sphinx-tutorials/" + tutorialUrl + ".py",
63+
var githubLink = "https://github.com/pytorch/rl/blob/main/tutorials/sphinx-" + tutorialUrl + ".py",
6464
notebookLink = $(".sphx-glr-download-jupyter").find(".download.reference")[0].href,
6565
notebookDownloadPath = notebookLink.split('_downloads')[1],
6666
colabLink = "https://colab.research.google.com/github/pytorch/rl/blob/gh-pages/main/_downloads" + notebookDownloadPath;

docs/source/reference/collectors.rst

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,16 @@ mechanism for updating policy weights across different devices and processes, ac
126126
Local and Remote Weight Updaters
127127
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128128

129-
The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.LocalWeightUpdaterBase`
130-
and :class:`~torchrl.collectors.RemoteWeightUpdaterBase`. These base classes provide a structured interface for
129+
The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.WeightUpdateReceiverBase`
130+
and :class:`~torchrl.collectors.WeightUpdateSenderBase`. These base classes provide a structured interface for
131131
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
132132

133-
- :class:`~torchrl.collectors.LocalWeightUpdaterBase`: This component is responsible for updating the policy weights on
133+
- :class:`~torchrl.collectors.WeightUpdateReceiverBase`: This component is responsible for updating the policy weights on
134134
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
135135
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
136136
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
137137
situations where the server decides when to update the worker policies).
138-
- :class:`~torchrl.collectors.RemoteWeightUpdaterBase`: This component handles the distribution of policy weights to
138+
- :class:`~torchrl.collectors.WeightUpdateSenderBase`: This component handles the distribution of policy weights to
139139
remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
140140
the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
141141
devices or processes.
@@ -153,8 +153,8 @@ Default Implementations
153153

154154
For common scenarios, the API provides default implementations of these updaters, such as
155155
:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
156-
:class:`~torchrl.collectors.RayRemoteWeightUpdater`, :class:`~torchrl.collectors.RPCRemoteWeightUpdater`, and
157-
:class:`~torchrl.collectors.DistributedRemoteWeightUpdater`.
156+
:class:`~torchrl.collectors.RayWeightUpdateSender`, :class:`~torchrl.collectors.RPCWeightUpdateSender`, and
157+
:class:`~torchrl.collectors.DistributedWeightUpdateSender`.
158158
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
159159
distributed systems.
160160

@@ -180,13 +180,13 @@ scenarios, ensuring that their policies remain up-to-date and performant.
180180
:toctree: generated/
181181
:template: rl_template.rst
182182

183-
LocalWeightUpdaterBase
184-
RemoteWeightUpdaterBase
183+
WeightUpdateReceiverBase
184+
WeightUpdateSenderBase
185185
VanillaLocalWeightUpdater
186186
MultiProcessedRemoteWeightUpdate
187-
RayRemoteWeightUpdater
188-
DistributedRemoteWeightUpdater
189-
RPCRemoteWeightUpdater
187+
RayWeightUpdateSender
188+
DistributedWeightUpdateSender
189+
RPCWeightUpdateSender
190190

191191
Collectors and replay buffers interoperability
192192
----------------------------------------------
@@ -319,6 +319,21 @@ node or across multiple nodes.
319319
submitit_delayed_launcher
320320
RayCollector
321321

322+
LLM Collectors
323+
---------------------------
324+
TorchRL also provides a data collectors for large language models. These collectors
325+
are meant to include a subset of the functionality of other data collectors, targeted
326+
at supporting researchers in fine-tuning large language models. These classes
327+
currently derive from the :class:`~torchrl.collectors.SyncDataCollector` class.
328+
These classes are experimental and subject to change.
329+
330+
.. currentmodule:: torchrl.collectors.llm_collectors
331+
332+
.. autosummary::
333+
:toctree: generated/
334+
:template: rl_template.rst
335+
336+
LLMCollector
322337

323338
Helper functions
324339
----------------

docs/source/reference/data.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,10 +1107,10 @@ and the tree can be expanded for each of these. The following figure shows how t
11071107
Tree
11081108

11091109

1110-
Reinforcement Learning From Human Feedback (RLHF)
1111-
-------------------------------------------------
1110+
Large language models and Reinforcement Learning From Human Feedback (RLHF)
1111+
---------------------------------------------------------------------------
11121112

1113-
Data is of utmost importance in Reinforcement Learning from Human Feedback (RLHF).
1113+
Data is of utmost importance in LLM post-training (e.g., GRPO or Reinforcement Learning from Human Feedback (RLHF)).
11141114
Given that these techniques are commonly employed in the realm of language,
11151115
which is scarcely addressed in other subdomains of RL within the library,
11161116
we offer specific utilities to facilitate interaction with external libraries
@@ -1124,6 +1124,7 @@ efficient sampling.
11241124
:toctree: generated/
11251125
:template: rl_template.rst
11261126

1127+
History
11271128
PairwiseDataset
11281129
PromptData
11291130
PromptTensorDictTokenizer

0 commit comments

Comments
 (0)