Skip to content

Commit d28a8c3

Browse files
vmoenstcbegley
andauthored
[Refactor] Relying on the standalone tensordict -- phase 1 (#650)
* init * amend * amend * lint and other * quickfix * lint * [Refactor] Relying on the standalone tensordict -- phase 1 updates (#665) * Install tensordict in GitHub Actions * Clean up remaining references to torchrl.data.tensordict * Use in td.keys() for membership checks * Rerun CI * Rerun CI * amend * amend * amend * lint Co-authored-by: Tom Begley <tomcbegley@gmail.com>
1 parent 278e9be commit d28a8c3

File tree

87 files changed

+335
-9492
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+335
-9492
lines changed

.circleci/unittest/linux/scripts/install.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,8 @@ python -c "import functorch"
3939
# install snapshot
4040
pip install git+https://github.com/pytorch/torchsnapshot
4141

42+
# install tensordict
43+
pip install git+https://github.com/pytorch-labs/tensordict
44+
4245
printf "* Installing torchrl\n"
4346
python setup.py develop

.circleci/unittest/linux_libs/scripts_habitat/install.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ else
3737
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall
3838
fi
3939

40+
# install tensordict
41+
pip install git+https://github.com/pytorch-labs/tensordict
42+
4043
# smoke test
4144
python -c "import functorch"
4245

.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,8 @@ else
4141
conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge -y
4242
fi
4343

44+
# install tensordict
45+
pip install git+https://github.com/pytorch-labs/tensordict
46+
4447
printf "* Installing torchrl\n"
4548
python setup.py develop

.circleci/unittest/linux_optdeps/scripts/install.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ else
3535
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu113
3636
fi
3737

38+
# install tensordict
39+
pip install git+https://github.com/pytorch-labs/tensordict
40+
3841
# smoke test
3942
python -c "import functorch"
4043

.circleci/unittest/linux_stable/scripts/install.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ else
3333
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
3434
fi
3535

36+
# install tensordict
37+
pip install git+https://github.com/pytorch-labs/tensordict
38+
3639
# smoke test
3740
python -c "import torch;import functorch"
3841

.github/workflows/docs.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ jobs:
5454
shell: bash
5555
run: |
5656
conda run -n build_binary python -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
57+
- name: Install tensordict
58+
run: |
59+
python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git
5760
- name: Install TorchRL
5861
run: |
5962
conda run -n build_binary python -m pip install -e .

.github/workflows/nightly_build.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ jobs:
217217
run: |
218218
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
219219
python3 -mpip install --upgrade pip
220+
- name: Install tensordict
221+
run: |
222+
python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git
220223
- name: Install test dependencies
221224
run: |
222225
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"

.github/workflows/wheels.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ jobs:
9999
- name: Upgrade pip
100100
run: |
101101
python3 -mpip install --upgrade pip
102+
- name: Install tensordict
103+
run: |
104+
python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git
102105
- name: Install test dependencies
103106
run: |
104107
python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml

README.md

Lines changed: 116 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -31,116 +31,129 @@ On the low-level end, torchrl comes with a set of highly re-usable functionals f
3131

3232
TorchRL aims at (1) a high modularity and (2) good runtime performance.
3333

34-
## Features
34+
## TensorDict as a common data carrier for RL
3535

36-
On the high-level end, TorchRL provides:
37-
- [`TensorDict`](torchrl/data/tensordict/tensordict.py),
36+
TorchRL relies on [`TensorDict`](https://github.com/pytorch-labs/tensordict/),
3837
a convenient data structure<sup>(1)</sup> to pass data from
3938
one object to another without friction.
4039
`TensorDict` makes it easy to re-use pieces of code across environments, models and
4140
algorithms. For instance, here's how to code a rollout in TorchRL:
42-
<details>
43-
<summary>Code</summary>
44-
45-
```diff
46-
- obs, done = env.reset()
47-
+ tensordict = env.reset()
48-
policy = TensorDictModule(
49-
model,
50-
in_keys=["observation_pixels", "observation_vector"],
51-
out_keys=["action"],
52-
)
53-
out = []
54-
for i in range(n_steps):
55-
- action, log_prob = policy(obs)
56-
- next_obs, reward, done, info = env.step(action)
57-
- out.append((obs, next_obs, action, log_prob, reward, done))
58-
- obs = next_obs
59-
+ tensordict = policy(tensordict)
60-
+ tensordict = env.step(tensordict)
61-
+ out.append(tensordict)
62-
+ tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*
63-
- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
64-
+ out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
65-
```
66-
TensorDict abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing its primitives
67-
to be easily recycled across settings.
68-
Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
69-
70-
```diff
71-
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
72-
+ for i, tensordict in enumerate(collector):
73-
- replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
74-
+ replay_buffer.add(tensordict)
75-
for j in range(num_optim_steps):
76-
- obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
77-
- loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
78-
+ tensordict = replay_buffer.sample(batch_size)
79-
+ loss = loss_fn(tensordict)
80-
loss.backward()
81-
optim.step()
82-
optim.zero_grad()
83-
```
84-
Again, this training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
41+
<details>
42+
<summary>Code</summary>
43+
44+
```diff
45+
- obs, done = env.reset()
46+
+ tensordict = env.reset()
47+
policy = TensorDictModule(
48+
model,
49+
in_keys=["observation_pixels", "observation_vector"],
50+
out_keys=["action"],
51+
)
52+
out = []
53+
for i in range(n_steps):
54+
- action, log_prob = policy(obs)
55+
- next_obs, reward, done, info = env.step(action)
56+
- out.append((obs, next_obs, action, log_prob, reward, done))
57+
- obs = next_obs
58+
+ tensordict = policy(tensordict)
59+
+ tensordict = env.step(tensordict)
60+
+ out.append(tensordict)
61+
+ tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*
62+
- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
63+
+ out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
64+
```
65+
</details>
66+
TensorDict abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing its primitives
67+
to be easily recycled across settings.
68+
Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
69+
70+
<details>
71+
<summary>Code</summary>
72+
73+
```diff
74+
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
75+
+ for i, tensordict in enumerate(collector):
76+
- replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
77+
+ replay_buffer.add(tensordict)
78+
for j in range(num_optim_steps):
79+
- obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
80+
- loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
81+
+ tensordict = replay_buffer.sample(batch_size)
82+
+ loss = loss_fn(tensordict)
83+
loss.backward()
84+
optim.step()
85+
optim.zero_grad()
86+
```
87+
Again, this training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
88+
</details>
89+
90+
TensorDict supports multiple tensor operations on its device and shape
91+
(the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
92+
93+
<details>
94+
<summary>Code</summary>
95+
96+
```python
97+
# stack and cat
98+
tensordict = torch.stack(list_of_tensordicts, 0)
99+
tensordict = torch.cat(list_of_tensordicts, 0)
100+
# reshape
101+
tensordict = tensordict.view(-1)
102+
tensordict = tensordict.permute(0, 2, 1)
103+
tensordict = tensordict.unsqueeze(-1)
104+
tensordict = tensordict.squeeze(-1)
105+
# indexing
106+
tensordict = tensordict[:2]
107+
tensordict[:, 2] = sub_tensordict
108+
# device and memory location
109+
tensordict.cuda()
110+
tensordict.to("cuda:1")
111+
tensordict.share_memory_()
112+
```
113+
</details>
114+
115+
Check our TorchRL-specific [TensorDict tutorial](tutorials/tensordict.ipynb) for more information.
116+
117+
The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
118+
119+
<details>
120+
<summary>Code</summary>
121+
122+
```diff
123+
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
124+
+ td_module = TensorDictModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
125+
src = torch.rand((10, 32, 512))
126+
tgt = torch.rand((20, 32, 512))
127+
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
128+
- out = transformer_model(src, tgt)
129+
+ td_module(tensordict)
130+
+ out = tensordict["out"]
131+
```
132+
133+
The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
134+
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
135+
```python
136+
encoder_module = TransformerEncoder(...)
137+
encoder = TensorDictModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
138+
decoder_module = TransformerDecoder(...)
139+
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
140+
transformer = TensorDictSequential(encoder, decoder)
141+
assert transformer.in_keys == ["src", "src_mask", "tgt"]
142+
assert transformer.out_keys == ["memory", "output"]
143+
```
144+
145+
`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
146+
```python
147+
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
148+
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
149+
```
150+
</details>
151+
152+
The corresponding [tutorial](tutorials/tensordictmodule.ipynb) provides more context about its features.
85153

86-
TensorDict supports multiple tensor operations on its device and shape
87-
(the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
88-
```python
89-
# stack and cat
90-
tensordict = torch.stack(list_of_tensordicts, 0)
91-
tensordict = torch.cat(list_of_tensordicts, 0)
92-
# reshape
93-
tensordict = tensordict.view(-1)
94-
tensordict = tensordict.permute(0, 2, 1)
95-
tensordict = tensordict.unsqueeze(-1)
96-
tensordict = tensordict.squeeze(-1)
97-
# indexing
98-
tensordict = tensordict[:2]
99-
tensordict[:, 2] = sub_tensordict
100-
# device and memory location
101-
tensordict.cuda()
102-
tensordict.to("cuda:1")
103-
tensordict.share_memory_()
104-
```
105-
</details>
106154

107-
Check our [TensorDict tutorial](tutorials/tensordict.ipynb) for more information.
108155

109-
- An associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
110-
<details>
111-
<summary>Code</summary>
112-
113-
```diff
114-
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
115-
+ td_module = TensorDictModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
116-
src = torch.rand((10, 32, 512))
117-
tgt = torch.rand((20, 32, 512))
118-
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
119-
- out = transformer_model(src, tgt)
120-
+ td_module(tensordict)
121-
+ out = tensordict["out"]
122-
```
123-
124-
The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
125-
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
126-
```python
127-
encoder_module = TransformerEncoder(...)
128-
encoder = TensorDictModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
129-
decoder_module = TransformerDecoder(...)
130-
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
131-
transformer = TensorDictSequential(encoder, decoder)
132-
assert transformer.in_keys == ["src", "src_mask", "tgt"]
133-
assert transformer.out_keys == ["memory", "output"]
134-
```
135-
136-
`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
137-
```python
138-
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
139-
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
140-
```
141-
</details>
142-
143-
The corresponding [tutorial](tutorials/tensordictmodule.ipynb) provides more context about its features.
156+
## Features
144157

145158
- a generic [trainer class](torchrl/trainers/trainers.py)<sup>(1)</sup> that
146159
executes the aforementioned training loop. Through a hooking mechanism,
@@ -242,7 +255,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
242255
```
243256
</details>
244257

245-
- various tools for distributed learning (e.g. [memory mapped tensors](torchrl/data/tensordict/memmap.py))<sup>(2)</sup>;
258+
- various tools for distributed learning (e.g. [memory mapped tensors](https://github.com/pytorch-labs/tensordict/blob/main/tensordict/memmap.py))<sup>(2)</sup>;
246259
- various [architectures](torchrl/modules/models/) and models (e.g. [actor-critic](torchrl/modules/tensordict_module/actors.py))<sup>(1)</sup>:
247260
<details>
248261
<summary>Code</summary>

benchmarks/storage/benchmark_sample_latency_over_rpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
import torch.distributed.rpc as rpc
21+
from tensordict import TensorDict
2122
from torchrl.data.replay_buffers.rb_prototype import RemoteTensorDictReplayBuffer
2223
from torchrl.data.replay_buffers.samplers import RandomSampler
2324
from torchrl.data.replay_buffers.storages import (
@@ -26,7 +27,6 @@
2627
ListStorage,
2728
)
2829
from torchrl.data.replay_buffers.writers import RoundRobinWriter
29-
from torchrl.data.tensordict import TensorDict
3030

3131
RETRY_LIMIT = 2
3232
RETRY_DELAY_SECS = 3

0 commit comments

Comments
 (0)