Skip to content

Commit 902a393

Browse files
ronertRonert Obst
andauthored
[Refactor] Rename TensorDictSequence to TensorDictSequential (#440)
The TensorDictSequence class behaves like the nn.Sequential module, hence those names should match. Co-authored-by: Ronert Obst <ronert@fb.com>
1 parent d694814 commit 902a393

File tree

11 files changed

+3859
-3862
lines changed

11 files changed

+3859
-3862
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,19 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
112112
+ out = tensordict["out"]
113113
```
114114

115-
The `TensorDictSequence` class allows to branch sequences of `nn.Module` instances in a highly modular way.
115+
The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
116116
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
117117
```python
118118
encoder_module = TransformerEncoder(...)
119119
encoder = TensorDictModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
120120
decoder_module = TransformerDecoder(...)
121121
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
122-
transformer = TensorDictSequence(encoder, decoder)
122+
transformer = TensorDictSequential(encoder, decoder)
123123
assert transformer.in_keys == ["src", "src_mask", "tgt"]
124124
assert transformer.out_keys == ["memory", "output"]
125125
```
126126

127-
`TensorDictSequence` allows to isolate subgraphs by querying a set of desired input / output keys:
127+
`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
128128
```python
129129
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
130130
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder

test/test_exploration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torchrl.data.tensordict.tensordict import TensorDict
1515
from torchrl.envs.transforms.transforms import gSDENoise
1616
from torchrl.envs.utils import set_exploration_mode
17-
from torchrl.modules import TensorDictModule, TensorDictSequence
17+
from torchrl.modules import TensorDictModule, TensorDictSequential
1818
from torchrl.modules.distributions import TanhNormal
1919
from torchrl.modules.distributions.continuous import (
2020
IndependentNormal,
@@ -227,7 +227,7 @@ def test_gsde(
227227
if gSDE:
228228
model = torch.nn.LazyLinear(action_dim)
229229
in_keys = ["observation"]
230-
module = TensorDictSequence(
230+
module = TensorDictSequential(
231231
TensorDictModule(model, in_keys=in_keys, out_keys=["action"]),
232232
TensorDictModule(
233233
LazygSDEModule(),

test/test_functorch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from functorch import vmap
66
from torch import nn
77
from torchrl.data import TensorDict
8-
from torchrl.modules import TensorDictModule, TensorDictSequence
8+
from torchrl.modules import TensorDictModule, TensorDictSequential
99
from torchrl.modules.functional_modules import (
1010
FunctionalModuleWithBuffers,
1111
FunctionalModule,
@@ -158,7 +158,7 @@ def test_vmap_tdsequence(moduletype, batch_params):
158158
tdmodule1 = TensorDictModule(fmodule1, in_keys=["x"], out_keys=["y"])
159159
tdmodule2 = TensorDictModule(fmodule2, in_keys=["y"], out_keys=["z"])
160160
params = TensorDict({"0": params1, "1": params2}, [])
161-
tdmodule = TensorDictSequence(tdmodule1, tdmodule2)
161+
tdmodule = TensorDictSequential(tdmodule1, tdmodule2)
162162
assert {"0", "1"} == set(params.keys())
163163
x = torch.randn(10, 1, 3)
164164
td = TensorDict({"x": x}, [10])
@@ -174,7 +174,7 @@ def test_vmap_tdsequence(moduletype, batch_params):
174174
tdmodule2 = TensorDictModule(fmodule2, in_keys=["y"], out_keys=["z"])
175175
params = TensorDict({"0": params1, "1": params2}, [])
176176
buffers = TensorDict({"0": buffers1, "1": buffers2}, [])
177-
tdmodule = TensorDictSequence(tdmodule1, tdmodule2)
177+
tdmodule = TensorDictSequential(tdmodule1, tdmodule2)
178178
assert {"0", "1"} == set(params.keys())
179179
assert {"0", "1"} == set(buffers.keys())
180180
x = torch.randn(10, 2, 3)
@@ -209,7 +209,7 @@ def test_vmap_tdsequence_nativebuilt(moduletype, batch_params):
209209
if moduletype == "linear":
210210
tdmodule1 = TensorDictModule(module1, in_keys=["x"], out_keys=["y"])
211211
tdmodule2 = TensorDictModule(module2, in_keys=["y"], out_keys=["z"])
212-
tdmodule = TensorDictSequence(tdmodule1, tdmodule2)
212+
tdmodule = TensorDictSequential(tdmodule1, tdmodule2)
213213
tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True)
214214
assert {"0", "1"} == set(params.keys())
215215
x = torch.randn(10, 1, 3)
@@ -225,7 +225,7 @@ def test_vmap_tdsequence_nativebuilt(moduletype, batch_params):
225225
elif moduletype == "bn1":
226226
tdmodule1 = TensorDictModule(module1, in_keys=["x"], out_keys=["y"])
227227
tdmodule2 = TensorDictModule(module2, in_keys=["y"], out_keys=["z"])
228-
tdmodule = TensorDictSequence(tdmodule1, tdmodule2)
228+
tdmodule = TensorDictSequential(tdmodule1, tdmodule2)
229229
tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True)
230230
assert {"0", "1"} == set(params.keys())
231231
assert {"0", "1"} == set(buffers.keys())

test/test_tensordictmodules.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torchrl.modules.tensordict_module.probabilistic import (
2525
ProbabilisticTensorDictModule,
2626
)
27-
from torchrl.modules.tensordict_module.sequence import TensorDictSequence
27+
from torchrl.modules.tensordict_module.sequence import TensorDictSequential
2828

2929

3030
class TestTDModule:
@@ -785,7 +785,7 @@ def test_key_exclusion(self):
785785
module3 = TensorDictModule(
786786
nn.Linear(3, 4), in_keys=["foo1", "key3"], out_keys=["key2"]
787787
)
788-
seq = TensorDictSequence(module1, module2, module3)
788+
seq = TensorDictSequential(module1, module2, module3)
789789
assert set(seq.in_keys) == {"key1", "key2", "key3"}
790790
assert set(seq.out_keys) == {"foo1", "key1", "key2"}
791791

@@ -838,7 +838,7 @@ def test_stateful(self, safe, spec_type, lazy):
838838
safe=False,
839839
**kwargs
840840
)
841-
tdmodule = TensorDictSequence(tdmodule1, dummy_tdmodule, tdmodule2)
841+
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)
842842

843843
assert hasattr(tdmodule, "__setitem__")
844844
assert len(tdmodule) == 3
@@ -926,7 +926,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy):
926926
safe=False,
927927
**kwargs
928928
)
929-
tdmodule = TensorDictSequence(tdmodule1, dummy_tdmodule, tdmodule2)
929+
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)
930930

931931
assert hasattr(tdmodule, "__setitem__")
932932
assert len(tdmodule) == 3
@@ -998,7 +998,7 @@ def test_functional(self, safe, spec_type):
998998
out_keys=["out"],
999999
safe=safe,
10001000
)
1001-
tdmodule = TensorDictSequence(tdmodule1, dummy_tdmodule, tdmodule2)
1001+
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)
10021002

10031003
assert hasattr(tdmodule, "__setitem__")
10041004
assert len(tdmodule) == 3
@@ -1082,7 +1082,7 @@ def test_functional_probabilistic(self, safe, spec_type):
10821082
safe=safe,
10831083
**kwargs
10841084
)
1085-
tdmodule = TensorDictSequence(tdmodule1, dummy_tdmodule, tdmodule2)
1085+
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)
10861086

10871087
assert hasattr(tdmodule, "__setitem__")
10881088
assert len(tdmodule) == 3
@@ -1162,7 +1162,7 @@ def test_functional_with_buffer(
11621162
out_keys=["out"],
11631163
safe=safe,
11641164
)
1165-
tdmodule = TensorDictSequence(tdmodule1, dummy_tdmodule, tdmodule2)
1165+
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)
11661166

11671167
assert hasattr(tdmodule, "__setitem__")
11681168
assert len(tdmodule) == 3
@@ -1255,7 +1255,7 @@ def test_functional_with_buffer_probabilistic(
12551255
safe=safe,
12561256
**kwargs
12571257
)
1258-
tdmodule = TensorDictSequence(tdmodule1, dummy_tdmodule, tdmodule2)
1258+
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)
12591259

12601260
assert hasattr(tdmodule, "__setitem__")
12611261
assert len(tdmodule) == 3
@@ -1331,7 +1331,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct(
13311331
safe=safe,
13321332
**kwargs
13331333
)
1334-
tdmodule = TensorDictSequence(tdmodule1, tdmodule2)
1334+
tdmodule = TensorDictSequential(tdmodule1, tdmodule2)
13351335

13361336
tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers()
13371337

@@ -1396,7 +1396,7 @@ def test_vmap(self, safe, spec_type):
13961396
out_keys=["out"],
13971397
safe=safe,
13981398
)
1399-
tdmodule = TensorDictSequence(tdmodule1, dummy_tdmodule, tdmodule2)
1399+
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)
14001400

14011401
assert hasattr(tdmodule, "__setitem__")
14021402
assert len(tdmodule) == 3
@@ -1496,7 +1496,7 @@ def test_vmap_probabilistic(self, safe, spec_type):
14961496
safe=safe,
14971497
**kwargs
14981498
)
1499-
tdmodule = TensorDictSequence(tdmodule1, tdmodule2)
1499+
tdmodule = TensorDictSequential(tdmodule1, tdmodule2)
15001500

15011501
# vmap = True
15021502
params = [p.repeat(10, *[1 for _ in p.shape]) for p in params]
@@ -1546,7 +1546,7 @@ def test_submodule_sequence(self, functional):
15461546
in_keys=["hidden"],
15471547
out_keys=["out"],
15481548
)
1549-
td_module = TensorDictSequence(td_module_1, td_module_2)
1549+
td_module = TensorDictSequential(td_module_1, td_module_2)
15501550

15511551
if functional:
15521552
td_1 = TensorDict({"in": torch.randn(5, 3)}, [5])
@@ -1640,7 +1640,7 @@ def test_sequential_partial(self, stack, functional):
16401640
safe=True,
16411641
**kwargs
16421642
)
1643-
tdmodule = TensorDictSequence(
1643+
tdmodule = TensorDictSequential(
16441644
tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True
16451645
)
16461646

@@ -1686,7 +1686,7 @@ def test_subsequence_weight_update(self):
16861686
in_keys=["hidden"],
16871687
out_keys=["out"],
16881688
)
1689-
td_module = TensorDictSequence(td_module_1, td_module_2)
1689+
td_module = TensorDictSequential(td_module_1, td_module_2)
16901690

16911691
td_1 = TensorDict({"in": torch.randn(5, 3)}, [5])
16921692
sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"])

torchrl/modules/models/exploration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ class gSDEModule(nn.Module):
262262
to the sampled action.
263263
264264
Examples:
265-
>>> from torchrl.modules import TensorDictModule, TensorDictSequence, ProbabilisticActor, TanhNormal
265+
>>> from torchrl.modules import TensorDictModule, TensorDictSequential, ProbabilisticActor, TanhNormal
266266
>>> from torchrl.data import TensorDict
267267
>>> batch, state_dim, action_dim = 3, 7, 5
268268
>>> model = nn.Linear(state_dim, action_dim)
@@ -274,7 +274,7 @@ class gSDEModule(nn.Module):
274274
>>> stochatstic_part = ProbabilisticActor(stochatstic_part,
275275
... dist_param_keys=["loc", "scale"],
276276
... distribution_class=TanhNormal)
277-
>>> stochatstic_policy = TensorDictSequence(deterministic_policy, stochatstic_part)
277+
>>> stochatstic_policy = TensorDictSequential(deterministic_policy, stochatstic_part)
278278
>>> tensordict = TensorDict({'obs': torch.randn(state_dim), '_epx_gSDE': torch.zeros(1)}, [])
279279
>>> _ = stochatstic_policy(tensordict)
280280
>>> print(tensordict)

torchrl/modules/tensordict_module/actors.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torchrl.modules.tensordict_module.probabilistic import (
1717
ProbabilisticTensorDictModule,
1818
)
19-
from torchrl.modules.tensordict_module.sequence import TensorDictSequence
19+
from torchrl.modules.tensordict_module.sequence import TensorDictSequential
2020

2121
__all__ = [
2222
"Actor",
@@ -498,7 +498,7 @@ def __init__(
498498
)
499499

500500

501-
class ActorValueOperator(TensorDictSequence):
501+
class ActorValueOperator(TensorDictSequential):
502502
"""
503503
Actor-value operator.
504504
@@ -615,21 +615,21 @@ def __init__(
615615
value_operator,
616616
)
617617

618-
def get_policy_operator(self) -> TensorDictSequence:
618+
def get_policy_operator(self) -> TensorDictSequential:
619619
"""
620620
621621
Returns a stand-alone policy operator that maps an observation to an action.
622622
623623
"""
624-
return TensorDictSequence(self.module[0], self.module[1])
624+
return TensorDictSequential(self.module[0], self.module[1])
625625

626-
def get_value_operator(self) -> TensorDictSequence:
626+
def get_value_operator(self) -> TensorDictSequential:
627627
"""
628628
629629
Returns a stand-alone value network operator that maps an observation to a value estimate.
630630
631631
"""
632-
return TensorDictSequence(self.module[0], self.module[2])
632+
return TensorDictSequential(self.module[0], self.module[2])
633633

634634

635635
class ActorCriticOperator(ActorValueOperator):
@@ -772,7 +772,7 @@ def get_value_operator(self) -> TensorDictModuleWrapper:
772772
)
773773

774774

775-
class ActorCriticWrapper(TensorDictSequence):
775+
class ActorCriticWrapper(TensorDictSequential):
776776
"""
777777
Actor-value operator without common module.
778778
@@ -863,15 +863,15 @@ def __init__(
863863
value_operator,
864864
)
865865

866-
def get_policy_operator(self) -> TensorDictSequence:
866+
def get_policy_operator(self) -> TensorDictSequential:
867867
"""
868868
869869
Returns a stand-alone policy operator that maps an observation to an action.
870870
871871
"""
872872
return self.module[0]
873873

874-
def get_value_operator(self) -> TensorDictSequence:
874+
def get_value_operator(self) -> TensorDictSequential:
875875
"""
876876
877877
Returns a stand-alone value network operator that maps an observation to a value estimate.

0 commit comments

Comments
 (0)