Skip to content

Commit bea2474

Browse files
authored
[Doc] README rewrite (#971)
1 parent c8c9157 commit bea2474

File tree

1 file changed

+149
-71
lines changed

1 file changed

+149
-71
lines changed

README.md

Lines changed: 149 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
# TorchRL
1313

14-
[**Documentation**](#documentation) | [**TensorDict**](#tensordict-as-a-common-data-carrier-for-rl) |
15-
[**Features**](#features) | [**Examples, tutorials and demos**](#examples-tutorials-and-demos) |
16-
[**Running examples**](#running-examples) | [**Upcoming features**](#upcoming-features) | [**Contributing**](#contributing)
14+
[**Documentation**](#documentation-and-knowledge-base) | [**TensorDict**](#writing-simplified-and-portable-rl-codebase-with-tensordict) |
15+
[**Features**](#features) | [**Examples, tutorials and demos**](#examples-tutorials-and-demos) | [**Installation**](#installation) |
16+
[**Asking a question**](#asking-a-question) | [**Citation**](#citation) | [**Contributing**](#contributing)
1717

1818
**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
1919

@@ -27,27 +27,117 @@ On the low-level end, torchrl comes with a set of highly re-usable functionals f
2727

2828
TorchRL aims at (1) a high modularity and (2) good runtime performance.
2929

30-
## Documentation
30+
## Documentation and knowledge base
3131

3232
The TorchRL documentation can be found [here](https://pytorch.org/rl).
3333
It contains tutorials and the API reference.
3434

35-
## TensorDict as a common data carrier for RL
35+
TorchRL also provides a RL knowledge base to help you debug your code, or simply
36+
learn the basics of RL. Check it out [here](https://pytorch.org/rl/reference/knowledge_base.html).
3637

37-
TorchRL relies on [`TensorDict`](https://github.com/pytorch-labs/tensordict/),
38-
a convenient data structure<sup>(1)</sup> to pass data from
39-
one object to another without friction.
38+
We have some introductory videos for you to get to know the library better, check them out:
4039

40+
- [TorchRL intro at PyTorch day 2022](https://youtu.be/cIKMhZoykEE)
41+
- [PyTorch 2.0 Q&A: TorchRL](https://www.youtube.com/live/myEfUoYrbts?feature=share)
42+
43+
## Writing simplified and portable RL codebase with `TensorDict`
44+
45+
RL algorithms are very heterogeneous, and it can be hard to recycle a codebase
46+
across settings (e.g. from online to offline, from state-based to pixel-based
47+
learning).
48+
TorchRL solves this problem through [`TensorDict`](https://github.com/pytorch-labs/tensordict/),
49+
a convenient data structure<sup>(1)</sup> that can be used to streamline one's
50+
RL codebase.
51+
With this tool, one can write a *complete PPO training script in less than 100
52+
lines of code*!
53+
54+
<details>
55+
<summary>Code</summary>
56+
57+
```python
58+
import torch
59+
from tensordict.nn import TensorDictModule
60+
from tensordict.nn.distributions import NormalParamExtractor
61+
from torch import nn
62+
63+
from torchrl.collectors import SyncDataCollector
64+
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
65+
LazyTensorStorage, SamplerWithoutReplacement
66+
from torchrl.envs.libs.gym import GymEnv
67+
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
68+
from torchrl.objectives import ClipPPOLoss
69+
from torchrl.objectives.value import GAE
70+
71+
env = GymEnv("Pendulum-v1")
72+
model = TensorDictModule(
73+
nn.Sequential(
74+
nn.Linear(3, 128), nn.Tanh(),
75+
nn.Linear(128, 128), nn.Tanh(),
76+
nn.Linear(128, 128), nn.Tanh(),
77+
nn.Linear(128, 2),
78+
NormalParamExtractor()
79+
),
80+
in_keys=["observation"],
81+
out_keys=["loc", "scale"]
82+
)
83+
critic = ValueOperator(
84+
nn.Sequential(
85+
nn.Linear(3, 128), nn.Tanh(),
86+
nn.Linear(128, 128), nn.Tanh(),
87+
nn.Linear(128, 128), nn.Tanh(),
88+
nn.Linear(128, 1),
89+
),
90+
in_keys=["observation"],
91+
)
92+
actor = ProbabilisticActor(
93+
model,
94+
in_keys=["loc", "scale"],
95+
distribution_class=TanhNormal,
96+
distribution_kwargs={"min": -1.0, "max": 1.0},
97+
return_log_prob=True
98+
)
99+
buffer = TensorDictReplayBuffer(
100+
LazyTensorStorage(1000),
101+
SamplerWithoutReplacement()
102+
)
103+
collector = SyncDataCollector(
104+
env,
105+
actor,
106+
frames_per_batch=1000,
107+
total_frames=1_000_000
108+
)
109+
loss_fn = ClipPPOLoss(actor, critic, gamma=0.99)
110+
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
111+
adv_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True)
112+
for data in collector: # collect data
113+
for epoch in range(10):
114+
adv_fn(data) # compute advantage
115+
buffer.extend(data.view(-1))
116+
for i in range(20): # consume data
117+
sample = buffer.sample(50) # mini-batch
118+
loss_vals = loss_fn(sample)
119+
loss_val = sum(
120+
value for key, value in loss_vals.items() if
121+
key.startswith("loss")
122+
)
123+
loss_val.backward()
124+
optim.step()
125+
optim.zero_grad()
126+
print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
127+
```
128+
</details>
41129

42130
Here is an example of how the [environment API](https://pytorch.org/rl/reference/envs.html)
43131
relies on tensordict to carry data from one function to another during a rollout
44132
execution:
45133
![Alt Text](docs/source/_static/img/rollout.gif)
46134

47135
`TensorDict` makes it easy to re-use pieces of code across environments, models and
48-
algorithms. For instance, here's how to code a rollout in TorchRL:
136+
algorithms.
49137
<details>
50138
<summary>Code</summary>
139+
140+
For instance, here's how to code a rollout in TorchRL:
51141

52142
```diff
53143
- obs, done = env.reset()
@@ -71,13 +161,17 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
71161
+ out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
72162
```
73163
</details>
74-
TensorDict abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing its primitives
164+
165+
Using this, TorchRL abstracts away the input / output signatures of the modules, env,
166+
collectors, replay buffers and losses of the library, allowing all primitives
75167
to be easily recycled across settings.
76-
Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
77168

78169
<details>
79170
<summary>Code</summary>
80171

172+
Here's another example of an off-policy training loop in TorchRL (assuming
173+
that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
174+
81175
```diff
82176
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
83177
+ for i, tensordict in enumerate(collector):
@@ -92,7 +186,7 @@ Here's another example of an off-policy training loop in TorchRL (assuming that
92186
optim.step()
93187
optim.zero_grad()
94188
```
95-
Again, this training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
189+
This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
96190
</details>
97191

98192
TensorDict supports multiple tensor operations on its device and shape
@@ -120,9 +214,9 @@ Here's another example of an off-policy training loop in TorchRL (assuming that
120214
```
121215
</details>
122216

123-
Check our TorchRL-specific [TensorDict tutorial](https://pytorch.org/rl/tutorials/tensordict_tutorial.html) for more information.
124-
125-
The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
217+
TensorDict comes with a dedicated [`tensordict.nn`](https://pytorch-labs.github.io/tensordict/reference/nn.html)
218+
module that contains everything you might need to write your model with it.
219+
And it is `functorch` and `torch.compile` compatible!
126220

127221
<details>
128222
<summary>Code</summary>
@@ -138,27 +232,27 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py)
138232
+ out = tensordict["out"]
139233
```
140234

141-
The `SafeSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
235+
The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
142236
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
143237
```python
144238
encoder_module = TransformerEncoder(...)
145-
encoder = SafeModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
239+
encoder = TensorDictSequential(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
146240
decoder_module = TransformerDecoder(...)
147-
decoder = SafeModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
148-
transformer = SafeSequential(encoder, decoder)
241+
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
242+
transformer = TensorDictSequential(encoder, decoder)
149243
assert transformer.in_keys == ["src", "src_mask", "tgt"]
150244
assert transformer.out_keys == ["memory", "output"]
151245
```
152246

153-
`SafeSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
247+
`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
154248
```python
155249
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
156250
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
157251
```
158252
</details>
159253

160-
The corresponding [tutorial](https://pytorch.org/rl/tutorials/tensordict_module.html) provides more context about its features.
161-
254+
Check [TensorDict tutorials](https://pytorch-labs.github.io/tensordict/) to
255+
learn more!
162256

163257

164258
## Features
@@ -167,7 +261,7 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py)
167261
which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution
168262
(e.g. Model-based environments).
169263
The [batched environments](torchrl/envs/vec_env.py) containers allow parallel execution<sup>(2)</sup>.
170-
A common pytorch-first class of [tensor-specification class](torchrl/data/tensor_specs.py) is also provided.
264+
A common PyTorch-first class of [tensor-specification class](torchrl/data/tensor_specs.py) is also provided.
171265
TorchRL's environments API is simple but stringent and specific. Check the
172266
[documentation](https://pytorch.org/rl/reference/envs.html)
173267
and [tutorial](https://pytorch.org/rl/tutorials/pendulum.html) to learn more!
@@ -328,9 +422,9 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py)
328422
```
329423
</details>
330424

331-
- A series of efficient [loss modules](https://github.com/pytorch/rl/blob/main/torchrl/objectives/costs)
425+
- A series of efficient [loss modules](https://github.com/pytorch/rl/tree/main/torchrl/objectives)
332426
and highly vectorized
333-
[functional return and advantage](https://github.com/pytorch/rl/blob/main/torchrl/objectives/returns/functional.py)
427+
[functional return and advantage](https://github.com/pytorch/rl/blob/main/torchrl/objectives/value/functional.py)
334428
computation.
335429

336430
<details>
@@ -367,8 +461,10 @@ If you would like to contribute to new features, check our [call for contributio
367461
## Examples, tutorials and demos
368462

369463
A series of [examples](examples/) are provided with an illustrative purpose:
370-
- [DQN (and add-ons up to Rainbow)](examples/dqn/dqn.py)
464+
- [DQN and Rainbow](examples/dqn/dqn.py)
371465
- [DDPG](examples/ddpg/ddpg.py)
466+
- [IQL](examples/iql/iql.py)
467+
- [TD3](examples/td3/td3.py)
372468
- [A2C](examples/a2c/a2c.py)
373469
- [PPO](examples/ppo/ppo.py)
374470
- [SAC](examples/sac/sac.py)
@@ -377,7 +473,10 @@ A series of [examples](examples/) are provided with an illustrative purpose:
377473

378474
and many more to come!
379475

380-
We also provide [tutorials and demos](tutorials/README.md) that give a sense of
476+
Check the [examples markdown](examples/EXAMPLES.md) directory for more details
477+
about handling the various configuration settings.
478+
479+
We also provide [tutorials and demos](https://pytorch.org/rl/#tutorials) that give a sense of
381480
what the library can do.
382481

383482
## Installation
@@ -388,36 +487,12 @@ conda create --name torch_rl python=3.9
388487
conda activate torch_rl
389488
```
390489

391-
Depending on the use of functorch that you want to make, you may want to install the latest (nightly) pytorch release or the latest stable version of pytorch.
392-
See [here](https://pytorch.org/get-started/locally/) for a more detailed list of commands, including `pip3` or windows/OSX compatible installation commands:
393-
394-
**Stable**
395-
396-
```
397-
# For CUDA 11.3
398-
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
399-
# For CUDA 11.6
400-
conda install pytorch torchvision torchaudio cudatoolkit=11.6 -c pytorch -c conda-forge
401-
# For CPU-only build
402-
conda install pytorch torchvision cpuonly -c pytorch
403-
404-
# For torch 1.12 (and not above), one should install functorch separately:
405-
pip3 install functorch
406-
```
490+
**PyTorch**
407491

408-
**Nightly**
409-
```
410-
# For CUDA 11.6
411-
conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch-nightly -c nvidia
412-
# For CUDA 11.7
413-
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch-nightly -c nvidia
414-
# For CPU-only build
415-
conda install pytorch torchvision torchaudio cpuonly -c pytorch-nightly
416-
```
417-
418-
`functorch` is included in the nightly PyTorch package, so no need to install it separately.
419-
420-
For M1 Mac users, if the above commands do not work, you can build torch from source by following [this guide](https://github.com/pytorch/pytorch#from-source).
492+
Depending on the use of functorch that you want to make, you may want to
493+
install the latest (nightly) PyTorch release or the latest stable version of PyTorch.
494+
See [here](https://pytorch.org/get-started/locally/) for a detailed list of commands,
495+
including `pip3` or windows/OSX compatible installation commands.
421496

422497
**Torchrl**
423498

@@ -528,22 +603,25 @@ OS: macOS **** (x86_64)
528603
Versioning issues can cause error message of the type ```undefined symbol``` and such. For these, refer to the [versioning issues document](knowledge_base/VERSIONING_ISSUES.md) for a complete explanation and proposed workarounds.
529604

530605

531-
## Running examples
532-
Examples are coded in a very similar way but the configuration may change from one algorithm to another (e.g. async/sync data collection, hyperparameters, ratio of model updates / frame etc.)
533-
534-
Check the [examples markdown](examples/EXAMPLES.md) directory for more details about handling the various configuration settings.
535-
606+
## Asking a question
536607

537-
## Upcoming features
608+
If you spot a bug in the library, please raise an issue in this repo.
538609

539-
In the near future, we plan to:
540-
- provide tutorials on how to design new actors or environment wrappers;
541-
- implement IMPALA (as a distributed RL example) and Meta-RL algorithms;
542-
- improve the tests, documentation and nomenclature.
610+
If you have a more generic question regarding RL in PyTorch, post it on
611+
the [PyTorch forum](https://discuss.pytorch.org/c/reinforcement-learning/6).
543612

544-
We welcome any contribution, should you want to contribute to these new features
545-
or any other, lister or not, in the issues section of this repository.
613+
## Citation
546614

615+
If you're using TorchRL, please refer to this BibTeX entry to cite this work:
616+
```
617+
@software{TorchRL,
618+
author = {Moens, Vincent},
619+
title = {{TorchRL: an open-source Reinforcement Learning (RL) library for PyTorch}},
620+
url = {https://github.com/pytorch/rl},
621+
version = {0.1.0},
622+
year = {2023}
623+
}
624+
```
547625

548626
## Contributing
549627

@@ -556,9 +634,9 @@ Contributors are recommended to install [pre-commit hooks](https://pre-commit.co
556634

557635
## Disclaimer
558636

559-
This library is not officially released yet and is subject to change.
560-
561-
The features are available before an official release so that users and collaborators can get early access and provide feedback. No guarantee of stability, robustness or backward compatibility is provided.
637+
This library is released as a PyTorch beta feature.
638+
BC-breaking changes are likely to happen but they will be introduced with a deprecation
639+
warranty after a few release cycles.
562640

563641
# License
564642
TorchRL is licensed under the MIT License. See [LICENSE](LICENSE) for details.

0 commit comments

Comments
 (0)