You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Refactor] Relying on the standalone tensordict -- phase 1 (#650)
* init
* amend
* amend
* lint and other
* quickfix
* lint
* [Refactor] Relying on the standalone tensordict -- phase 1 updates (#665)
* Install tensordict in GitHub Actions
* Clean up remaining references to torchrl.data.tensordict
* Use in td.keys() for membership checks
* Rerun CI
* Rerun CI
* amend
* amend
* amend
* lint
Co-authored-by: Tom Begley <tomcbegley@gmail.com>
TensorDict abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing its primitives
67
-
to be easily recycled across settings.
68
-
Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
69
-
70
-
```diff
71
-
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
TensorDict abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing its primitives
67
+
to be easily recycled across settings.
68
+
Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
69
+
70
+
<details>
71
+
<summary>Code</summary>
72
+
73
+
```diff
74
+
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
- loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
81
+
+ tensordict = replay_buffer.sample(batch_size)
82
+
+ loss = loss_fn(tensordict)
83
+
loss.backward()
84
+
optim.step()
85
+
optim.zero_grad()
86
+
```
87
+
Again, this training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
88
+
</details>
89
+
90
+
TensorDict supports multiple tensor operations on its device and shape
91
+
(the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
92
+
93
+
<details>
94
+
<summary>Code</summary>
95
+
96
+
```python
97
+
# stack and cat
98
+
tensordict = torch.stack(list_of_tensordicts, 0)
99
+
tensordict = torch.cat(list_of_tensordicts, 0)
100
+
# reshape
101
+
tensordict = tensordict.view(-1)
102
+
tensordict = tensordict.permute(0, 2, 1)
103
+
tensordict = tensordict.unsqueeze(-1)
104
+
tensordict = tensordict.squeeze(-1)
105
+
# indexing
106
+
tensordict = tensordict[:2]
107
+
tensordict[:, 2] = sub_tensordict
108
+
# device and memory location
109
+
tensordict.cuda()
110
+
tensordict.to("cuda:1")
111
+
tensordict.share_memory_()
112
+
```
113
+
</details>
114
+
115
+
Check our TorchRL-specific [TensorDict tutorial](tutorials/tensordict.ipynb) for more information.
116
+
117
+
The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
146
+
```python
147
+
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
148
+
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
149
+
```
150
+
</details>
151
+
152
+
The corresponding [tutorial](tutorials/tensordictmodule.ipynb) provides more context about its features.
85
153
86
-
TensorDict supports multiple tensor operations on its device and shape
87
-
(the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
88
-
```python
89
-
# stack and cat
90
-
tensordict = torch.stack(list_of_tensordicts, 0)
91
-
tensordict = torch.cat(list_of_tensordicts, 0)
92
-
# reshape
93
-
tensordict = tensordict.view(-1)
94
-
tensordict = tensordict.permute(0, 2, 1)
95
-
tensordict = tensordict.unsqueeze(-1)
96
-
tensordict = tensordict.squeeze(-1)
97
-
# indexing
98
-
tensordict = tensordict[:2]
99
-
tensordict[:, 2] = sub_tensordict
100
-
# device and memory location
101
-
tensordict.cuda()
102
-
tensordict.to("cuda:1")
103
-
tensordict.share_memory_()
104
-
```
105
-
</details>
106
154
107
-
Check our [TensorDict tutorial](tutorials/tensordict.ipynb) for more information.
108
155
109
-
- An associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
137
-
```python
138
-
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
139
-
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
140
-
```
141
-
</details>
142
-
143
-
The corresponding [tutorial](tutorials/tensordictmodule.ipynb) provides more context about its features.
156
+
## Features
144
157
145
158
- a generic [trainer class](torchrl/trainers/trainers.py)<sup>(1)</sup> that
146
159
executes the aforementioned training loop. Through a hooking mechanism,
@@ -242,7 +255,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
242
255
```
243
256
</details>
244
257
245
-
- various tools for distributed learning (e.g. [memory mapped tensors](torchrl/data/tensordict/memmap.py))<sup>(2)</sup>;
258
+
- various tools for distributed learning (e.g. [memory mapped tensors](https://github.com/pytorch-labs/tensordict/blob/main/tensordict/memmap.py))<sup>(2)</sup>;
246
259
- various [architectures](torchrl/modules/models/) and models (e.g. [actor-critic](torchrl/modules/tensordict_module/actors.py))<sup>(1)</sup>:
0 commit comments