Skip to content

Commit 8c313c8

Browse files
[Doc] Add a knowledge base (#375)
* Add a file with description of common issues with rendering mujoco envs * adding more files to the knowledge base * lint * nit * pro-tips and resources * MTRL Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 2127e60 commit 8c313c8

File tree

6 files changed

+309
-61
lines changed

6 files changed

+309
-61
lines changed

README.md

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ The features are available before an official release so that users and collabor
1010

1111
---
1212

13-
**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
13+
**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
1414

15-
It provides pytorch and **python-first**, low and high level abstractions for RL that are intended to be **efficient**, **modular**, **documented** and properly **tested**.
15+
It provides pytorch and **python-first**, low and high level abstractions for RL that are intended to be **efficient**, **modular**, **documented** and properly **tested**.
1616
The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.
1717

18-
This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar ([torchrl/envs](torchrl/envs)), [transforms](torchrl/envs/transforms), [models](torchrl/modules), data utilities (e.g. collectors and containers), etc.
18+
This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar ([torchrl/envs](torchrl/envs)), [transforms](torchrl/envs/transforms), [models](torchrl/modules), data utilities (e.g. collectors and containers), etc.
1919
TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch). Common environment libraries (e.g. OpenAI gym) are only optional.
2020

2121
On the low-level end, torchrl comes with a set of highly re-usable functionals for [cost functions](torchrl/objectives/costs), [returns](torchrl/objectives/returns) and data processing.
@@ -25,19 +25,19 @@ TorchRL aims at (1) a high modularity and (2) good runtime performance.
2525
## Features
2626

2727
On the high-level end, TorchRL provides:
28-
- [`TensorDict`](torchrl/data/tensordict/tensordict.py),
29-
a convenient data structure<sup>(1)</sup> to pass data from
28+
- [`TensorDict`](torchrl/data/tensordict/tensordict.py),
29+
a convenient data structure<sup>(1)</sup> to pass data from
3030
one object to another without friction.
3131
`TensorDict` makes it easy to re-use pieces of code across environments, models and
3232
algorithms. For instance, here's how to code a rollout in TorchRL:
3333
<details>
3434
<summary>Code</summary>
35-
35+
3636
```diff
3737
- obs, done = env.reset()
3838
+ tensordict = env.reset()
3939
policy = TensorDictModule(
40-
model,
40+
model,
4141
in_keys=["observation_pixels", "observation_vector"],
4242
out_keys=["action"],
4343
)
@@ -57,7 +57,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
5757
TensorDict abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing its primitives
5858
to be easily recycled across settings.
5959
Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
60-
60+
6161
```diff
6262
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
6363
+ for i, tensordict in enumerate(collector):
@@ -73,7 +73,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
7373
optim.zero_grad()
7474
```
7575
Again, this training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
76-
76+
7777
TensorDict supports multiple tensor operations on its device and shape
7878
(the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
7979
```python
@@ -96,11 +96,11 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
9696
</details>
9797

9898
Check our [TensorDict tutorial](tutorials/tensordict.ipynb) for more information.
99-
100-
- An associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
99+
100+
- An associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
101101
<details>
102102
<summary>Code</summary>
103-
103+
104104
```diff
105105
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
106106
+ td_module = TensorDictModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
@@ -111,7 +111,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
111111
+ td_module(tensordict)
112112
+ out = tensordict["out"]
113113
```
114-
114+
115115
The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
116116
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
117117
```python
@@ -123,7 +123,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
123123
assert transformer.in_keys == ["src", "src_mask", "tgt"]
124124
assert transformer.out_keys == ["memory", "output"]
125125
```
126-
126+
127127
`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
128128
```python
129129
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
@@ -132,19 +132,19 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
132132
</details>
133133

134134
The corresponding [tutorial](tutorials/tensordictmodule.ipynb) provides more context about its features.
135-
136-
- a generic [trainer class](torchrl/trainers/trainers.py)<sup>(1)</sup> that
137-
executes the aforementioned training loop. Through a hooking mechanism,
135+
136+
- a generic [trainer class](torchrl/trainers/trainers.py)<sup>(1)</sup> that
137+
executes the aforementioned training loop. Through a hooking mechanism,
138138
it also supports any logging or data transformation operation at any given
139139
time.
140140

141141
- A common [interface for environments](torchrl/envs)
142-
which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution (e.g. Model-based environments).
142+
which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution (e.g. Model-based environments).
143143
The [batched environments](torchrl/envs/vec_env.py) containers allow parallel execution<sup>(2)</sup>.
144144
A common pytorch-first class of [tensor-specification class](torchrl/data/tensor_specs.py) is also provided.
145145
<details>
146146
<summary>Code</summary>
147-
147+
148148
```python
149149
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
150150
env_parallel = ParallelEnv(4, env_make) # creates 4 envs in parallel
@@ -154,17 +154,17 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
154154
```
155155
</details>
156156

157-
- multiprocess [data collectors](torchrl/collectors/collectors.py)<sup>(2)</sup> that work synchronously or asynchronously.
158-
Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised
157+
- multiprocess [data collectors](torchrl/collectors/collectors.py)<sup>(2)</sup> that work synchronously or asynchronously.
158+
Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised
159159
learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
160160
<details>
161161
<summary>Code</summary>
162-
162+
163163
```python
164164
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
165165
collector = MultiaSyncDataCollector(
166-
[env_make, env_make],
167-
policy=policy,
166+
[env_make, env_make],
167+
policy=policy,
168168
devices=["cuda:0", "cuda:0"],
169169
total_frames=10000,
170170
frames_per_batch=50,
@@ -182,7 +182,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
182182
- efficient<sup>(2)</sup> and generic<sup>(1)</sup> [replay buffers](torchrl/data/replay_buffers/replay_buffers.py) with modularized storage:
183183
<details>
184184
<summary>Code</summary>
185-
185+
186186
```python
187187
storage = LazyMemmapStorage( # memory-mapped (physical) storage
188188
cfg.buffer_size,
@@ -200,19 +200,19 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
200200
```
201201
</details>
202202

203-
- cross-library [environment transforms](torchrl/envs/transforms/transforms.py)<sup>(1)</sup>,
204-
executed on device and in a vectorized fashion<sup>(2)</sup>,
203+
- cross-library [environment transforms](torchrl/envs/transforms/transforms.py)<sup>(1)</sup>,
204+
executed on device and in a vectorized fashion<sup>(2)</sup>,
205205
which process and prepare the data coming out of the environments to be used by the agent:
206206
<details>
207207
<summary>Code</summary>
208-
208+
209209
```python
210210
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
211211
env_base = ParallelEnv(4, env_make, device="cuda:0") # creates 4 envs in parallel
212212
env = TransformedEnv(
213-
env_base,
213+
env_base,
214214
Compose(
215-
ToTensorImage(),
215+
ToTensorImage(),
216216
ObservationNorm(loc=0.5, scale=1.0)), # executes the transforms once and on device
217217
)
218218
tensordict = env.reset()
@@ -237,7 +237,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
237237
- various [architectures](torchrl/modules/models/) and models (e.g. [actor-critic](torchrl/modules/tensordict_module/actors.py))<sup>(1)</sup>:
238238
<details>
239239
<summary>Code</summary>
240-
240+
241241
```python
242242
# create an nn.Module
243243
common_module = ConvNet(
@@ -255,7 +255,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
255255
out_keys=["hidden"],
256256
)
257257
# Wrap the policy module in NormalParamsWrapper, such that the output
258-
# tensor is split in loc and scale, and scale is mapped onto a positive space
258+
# tensor is split in loc and scale, and scale is mapped onto a positive space
259259
policy_module = NormalParamsWrapper(
260260
MLP(
261261
num_cells=[64, 64],
@@ -287,11 +287,11 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
287287
```
288288
</details>
289289

290-
- exploration [wrappers](torchrl/modules/tensordict_module/exploration.py) and
290+
- exploration [wrappers](torchrl/modules/tensordict_module/exploration.py) and
291291
[modules](torchrl/modules/models/exploration.py) to easily swap between exploration and exploitation<sup>(1)</sup>:
292292
<details>
293293
<summary>Code</summary>
294-
294+
295295
```python
296296
policy_explore = EGreedyWrapper(policy)
297297
with set_exploration_mode("random"):
@@ -301,22 +301,22 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
301301
```
302302
</details>
303303

304-
- A series of efficient [loss modules](https://github.com/facebookresearch/rl/blob/main/torchrl/objectives/costs)
305-
and highly vectorized
306-
[functional return and advantage](https://github.com/facebookresearch/rl/blob/main/torchrl/objectives/returns/functional.py)
307-
computation.
304+
- A series of efficient [loss modules](https://github.com/facebookresearch/rl/blob/main/torchrl/objectives/costs)
305+
and highly vectorized
306+
[functional return and advantage](https://github.com/facebookresearch/rl/blob/main/torchrl/objectives/returns/functional.py)
307+
computation.
308308

309309
<details>
310310
<summary>Code</summary>
311-
311+
312312
### Loss modules
313313
```python
314314
from torchrl.objectives.costs import DQNLoss
315315
loss_module = DQNLoss(value_network=value_network, gamma=0.99)
316316
tensordict = replay_buffer.sample(batch_size)
317317
loss = loss_module(tensordict)
318318
```
319-
319+
320320
### Advantage computation
321321
```python
322322
from torchrl.objectives.returns.functional import vec_td_lambda_return_estimate
@@ -325,7 +325,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
325325

326326
</details>
327327

328-
- various [recipes](torchrl/trainers/helpers/models.py) to build models that
328+
- various [recipes](torchrl/trainers/helpers/models.py) to build models that
329329
correspond to the environment being deployed.
330330

331331
## Examples, tutorials and demos
@@ -339,11 +339,11 @@ A series of [examples](examples/) are provided with an illustrative purpose:
339339

340340
and many more to come!
341341

342-
We also provide [tutorials and demos](tutorials) that give a sense of what the
342+
We also provide [tutorials and demos](tutorials) that give a sense of what the
343343
library can do.
344344

345345
## Installation
346-
Create a conda environment where the packages will be installed.
346+
Create a conda environment where the packages will be installed.
347347

348348
```
349349
conda create --name torch_rl python=3.9
@@ -381,16 +381,16 @@ pip3 install ninja # Makes the build go faster
381381
pip3 install "git+https://github.com/pytorch/functorch.git"
382382
```
383383
384-
If this fails, you can get the latest version of functorch that was marked to be
384+
If this fails, you can get the latest version of functorch that was marked to be
385385
compatible with the current torch version:
386386
```bash
387387
pip3 install ninja # Makes the build go faster
388388
PYTORCH_VERSION=`python -c "import torch.version; print(torch.version.git_version)"`
389389
pip3 install "git+https://github.com/pytorch/pytorch.git@$PYTORCH_VERSION#subdirectory=functorch"
390390
```
391391

392-
If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message
393-
`(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e'))` appears,
392+
If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message
393+
`(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e'))` appears,
394394
try erasing the previously created build artifacts (`torchrl.egg-info/`, `build/`, `torchrl/_torchsl.so`)
395395
or re-clone the library from GitHub, then try
396396

@@ -405,7 +405,7 @@ You can install the latest release by using
405405
```
406406
pip3 install torchrl
407407
```
408-
This should work on linux and MacOs (not M1). For Windows and M1/M2 machines, one
408+
This should work on linux and MacOs (not M1). For Windows and M1/M2 machines, one
409409
should install the library locally (see below).
410410

411411
To install extra dependencies, call
@@ -414,9 +414,9 @@ pip3 install "torchrl[atari,dm_control,gym_continuous,rendering,tests,utils]"
414414
```
415415
or a subset of these.
416416

417-
Alternatively, as the library is at an early stage, it may be wise to install
418-
it in develop mode as this will make it possible to pull the latest changes and
419-
benefit from them immediately.
417+
Alternatively, as the library is at an early stage, it may be wise to install
418+
it in develop mode as this will make it possible to pull the latest changes and
419+
benefit from them immediately.
420420
Start by cloning the repo:
421421
```
422422
git clone https://github.com/facebookresearch/rl
@@ -428,14 +428,14 @@ cd /path/to/torchrl/
428428
python setup.py develop
429429
```
430430

431-
If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message
431+
If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message
432432
`(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e'))` appears, then try
433433

434434
```
435435
ARCHFLAGS="-arch arm64" python setup.py develop
436436
```
437437

438-
To run a quick sanity check, leave that directory (e.g. by executing `cd ~/`)
438+
To run a quick sanity check, leave that directory (e.g. by executing `cd ~/`)
439439
and try to import the library.
440440
```
441441
python -c "import torchrl"
@@ -444,7 +444,7 @@ This should not return any warning or error.
444444

445445
**Optional dependencies**
446446

447-
The following libraries can be installed depending on the usage one wants to
447+
The following libraries can be installed depending on the usage one wants to
448448
make of torchrl:
449449
```
450450
# diverse
@@ -454,7 +454,7 @@ pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher
454454
pip3 install moviepy
455455
456456
# deepmind control suite
457-
pip3 install dm_control
457+
pip3 install dm_control
458458
459459
# gym, atari games
460460
pip3 install gym[atari] "gym[accept-rom-license]" pygame
@@ -471,19 +471,19 @@ pip3 install wandb
471471

472472
**Troubleshooting**
473473

474-
If a `ModuleNotFoundError: No module named ‘torchrl._torchrl` errors occurs,
475-
it means that the C++ extensions were not installed or not found.
476-
One common reason might be that you are trying to import torchrl from within the
477-
git repo location. Indeed the following code snippet should return an error if
474+
If a `ModuleNotFoundError: No module named ‘torchrl._torchrl` errors occurs,
475+
it means that the C++ extensions were not installed or not found.
476+
One common reason might be that you are trying to import torchrl from within the
477+
git repo location. Indeed the following code snippet should return an error if
478478
torchrl has not been installed in `develop` mode:
479479
```
480480
cd ~/path/to/rl/repo
481481
python -c 'from torchrl.envs.libs.gym import GymEnv'
482482
```
483483
If this is the case, consider executing torchrl from another location.
484484

485-
On **MacOs**, we recommend installing XCode first.
486-
With Apple Silicon M1 chips, make sure you are using the arm64-built python
485+
On **MacOs**, we recommend installing XCode first.
486+
With Apple Silicon M1 chips, make sure you are using the arm64-built python
487487
(e.g. [here](https://betterprogramming.pub/how-to-install-pytorch-on-apple-m1-series-512b3ad9bc6)). Running the following lines of code
488488

489489
```
@@ -505,7 +505,7 @@ To train an algorithm it is therefore advised to use the predefined configuratio
505505
```
506506
python examples/ppo/ppo.py --config=examples/ppo/configs/humanoid.txt
507507
```
508-
Note that using the config files requires the [configargparse](https://pypi.org/project/ConfigArgParse/) library.
508+
Note that using the config files requires the [configargparse](https://pypi.org/project/ConfigArgParse/) library.
509509

510510
One can also overwrite the config parameters using flags, e.g.
511511
```

knowledge_base/GYM.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Working with gym
2+
3+
## Versioning
4+
TorchRL is tested against the latest version of gym and we only guarantee compatibility
5+
against the gym version that was available at the time of release.
6+
7+
However, for specific projects we may be willing to work on keeping a backward
8+
compatibility with older versions of gym.
9+
If you run into an issue when running TorchRL with a specific version of gym,
10+
feel free to open an issue and we will gladly look into this.

knowledge_base/HABITAT.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Working with [`habitat-lab`](https://github.com/facebookresearch/habitat-lab)
2+
3+
We are currently working on integrating habitat-lab environments into torchrl.
4+
Stay tuned for more info on this.

0 commit comments

Comments
 (0)