Skip to content

Commit 23ca67c

Browse files
authored
[Feature] Rename _TensorDict into TensorDictBase (#316)
1 parent f07015d commit 23ca67c

39 files changed

+548
-494
lines changed

test/mocking_classes.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
UnboundedContinuousTensorSpec,
1717
OneHotDiscreteTensorSpec,
1818
)
19-
from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict
19+
from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict
2020
from torchrl.envs.common import _EnvClass
2121

2222
spec_dict = {
@@ -110,15 +110,15 @@ def _step(self, tensordict):
110110
done = torch.tensor([done], dtype=torch.bool, device=self.device)
111111
return TensorDict({"reward": n, "done": done, "next_observation": n}, [])
112112

113-
def _reset(self, tensordict: _TensorDict, **kwargs) -> _TensorDict:
113+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
114114
self.max_val = max(self.counter + 100, self.counter * 2)
115115

116116
n = torch.tensor([self.counter]).to(self.device).to(torch.get_default_dtype())
117117
done = self.counter >= self.max_val
118118
done = torch.tensor([done], dtype=torch.bool, device=self.device)
119119
return TensorDict({"done": done, "next_observation": n}, [])
120120

121-
def rand_step(self, tensordict: Optional[_TensorDict] = None) -> _TensorDict:
121+
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
122122
return self.step(tensordict)
123123

124124

@@ -144,7 +144,7 @@ def _get_in_obs(self, obs):
144144
def _get_out_obs(self, obs):
145145
return obs
146146

147-
def _reset(self, tensordict: _TensorDict) -> _TensorDict:
147+
def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
148148
self.counter += 1
149149
state = torch.zeros(self.size) + self.counter
150150
tensordict = tensordict.select().set(
@@ -156,8 +156,8 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
156156

157157
def _step(
158158
self,
159-
tensordict: _TensorDict,
160-
) -> _TensorDict:
159+
tensordict: TensorDictBase,
160+
) -> TensorDictBase:
161161
tensordict = tensordict.to(self.device)
162162
a = tensordict.get("action")
163163
assert (a.sum(-1) == 1).all()
@@ -199,7 +199,7 @@ def _get_in_obs(self, obs):
199199
def _get_out_obs(self, obs):
200200
return obs
201201

202-
def _reset(self, tensordict: _TensorDict) -> _TensorDict:
202+
def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
203203
self.counter += 1
204204
self.step_count = 0
205205
state = torch.zeros(self.size) + self.counter
@@ -211,8 +211,8 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
211211

212212
def _step(
213213
self,
214-
tensordict: _TensorDict,
215-
) -> _TensorDict:
214+
tensordict: TensorDictBase,
215+
) -> TensorDictBase:
216216
self.step_count += 1
217217
tensordict = tensordict.to(self.device)
218218
a = tensordict.get("action")

test/test_distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from _utils_internal import get_available_devices
1111
from torch import nn, autograd
12-
from torchrl.data.tensordict.tensordict import _TensorDict
12+
from torchrl.data.tensordict.tensordict import TensorDictBase
1313
from torchrl.modules import (
1414
TanhNormal,
1515
NormalParamWrapper,
@@ -59,7 +59,7 @@ def test_delta(device, div_up, div_down):
5959

6060
def _map_all(*tensors_or_other, device):
6161
for t in tensors_or_other:
62-
if isinstance(t, (torch.Tensor, _TensorDict)):
62+
if isinstance(t, (torch.Tensor, TensorDictBase)):
6363
yield t.to(device)
6464
else:
6565
yield t

test/test_rb.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
LazyMemmapStorage,
2222
LazyTensorStorage,
2323
)
24-
from torchrl.data.tensordict.tensordict import assert_allclose_td, _TensorDict
24+
from torchrl.data.tensordict.tensordict import assert_allclose_td, TensorDictBase
2525

2626

2727
collate_fn_dict = {
@@ -128,7 +128,7 @@ def test_add(self, rbtype, storage, size, prefetch):
128128
data = self._get_datum(rbtype)
129129
rb.add(data)
130130
s = rb._storage[0]
131-
if isinstance(s, _TensorDict):
131+
if isinstance(s, TensorDictBase):
132132
assert (s == data.select(*s.keys())).all()
133133
else:
134134
assert (s == data).all()
@@ -142,12 +142,12 @@ def test_extend(self, rbtype, storage, size, prefetch):
142142
for d in data[-length:]:
143143
found_similar = False
144144
for b in rb._storage:
145-
if isinstance(b, _TensorDict):
145+
if isinstance(b, TensorDictBase):
146146
b = b.exclude("index").select(*set(d.keys()).intersection(b.keys()))
147147
d = d.select(*set(d.keys()).intersection(b.keys()))
148148

149149
value = b == d
150-
if isinstance(value, (torch.Tensor, _TensorDict)):
150+
if isinstance(value, (torch.Tensor, TensorDictBase)):
151151
value = value.all()
152152
if value:
153153
found_similar = True
@@ -160,18 +160,18 @@ def test_sample(self, rbtype, storage, size, prefetch):
160160
data = self._get_data(rbtype, size=5)
161161
rb.extend(data)
162162
new_data = rb.sample(3)
163-
if not isinstance(new_data, (torch.Tensor, _TensorDict)):
163+
if not isinstance(new_data, (torch.Tensor, TensorDictBase)):
164164
new_data = new_data[0]
165165

166166
for d in new_data:
167167
found_similar = False
168168
for b in data:
169-
if isinstance(b, _TensorDict):
169+
if isinstance(b, TensorDictBase):
170170
b = b.exclude("index").select(*set(d.keys()).intersection(b.keys()))
171171
d = d.select(*set(d.keys()).intersection(b.keys()))
172172

173173
value = b == d
174-
if isinstance(value, (torch.Tensor, _TensorDict)):
174+
if isinstance(value, (torch.Tensor, TensorDictBase)):
175175
value = value.all()
176176
if value:
177177
found_similar = True

test/test_tensor_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
UnboundedContinuousTensorSpec,
1919
OneHotDiscreteTensorSpec,
2020
)
21-
from torchrl.data.tensordict.tensordict import TensorDict, _TensorDict
21+
from torchrl.data.tensordict.tensordict import TensorDict, TensorDictBase
2222

2323

2424
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
@@ -376,7 +376,7 @@ def test_nested_composite_spec(self, is_complete, device, dtype):
376376
ts = self._composite_spec(is_complete, device, dtype)
377377
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
378378
td = ts.rand()
379-
assert isinstance(td["nested_cp"], _TensorDict)
379+
assert isinstance(td["nested_cp"], TensorDictBase)
380380
keys = list(td.keys())
381381
for key in keys:
382382
if key != "nested_cp":

test/test_tensordict.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
LazyStackedTensorDict,
1919
stack as stack_td,
2020
pad,
21-
_TensorDict,
21+
TensorDictBase,
2222
)
2323
from torchrl.data.tensordict.utils import _getitem_batch_size, convert_ellipsis_to_idx
2424

@@ -833,7 +833,7 @@ def test_masking_set(self, td_name, device, from_list):
833833
def zeros_like(item, n, d):
834834
if isinstance(item, (MemmapTensor, torch.Tensor)):
835835
return torch.zeros(n, *item.shape[d:], dtype=item.dtype, device=device)
836-
elif isinstance(item, _TensorDict):
836+
elif isinstance(item, TensorDictBase):
837837
batch_size = item.batch_size
838838
batch_size = [n, *batch_size[d:]]
839839
out = TensorDict(
@@ -1344,7 +1344,7 @@ def test_flatten_keys(self, td_name, device, inplace, separator):
13441344

13451345
td_flatten = td.flatten_keys(inplace=inplace, separator=separator)
13461346
for key, value in td_flatten.items():
1347-
assert not isinstance(value, _TensorDict)
1347+
assert not isinstance(value, TensorDictBase)
13481348
assert (
13491349
separator.join(["nested_tensordict", "nested_nested_tensordict", "a"])
13501350
in td_flatten.keys()

0 commit comments

Comments
 (0)