Skip to content

Commit 53bee4b

Browse files
authored
[Feature] Adding additional checks to TensorDict.view to remove unnecessary ViewedTensorDict object creation (#319)
1 parent f5d6820 commit 53bee4b

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

test/test_tensordict.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,16 @@ def test_savedtensordict(device):
488488
assert ss.get("a").device == device
489489

490490

491+
def test_inferred_view_size():
492+
td = TensorDict({"a": torch.randn(3, 4)}, [3, 4])
493+
assert td.view(-1).view(-1, 4) is td
494+
495+
assert td.view(-1, 4) is td
496+
assert td.view(3, -1) is td
497+
assert td.view(3, 4) is td
498+
assert td.view(-1, 12).shape == torch.Size([1, 12])
499+
500+
491501
@pytest.mark.parametrize(
492502
"ellipsis_index, expected_index",
493503
[
@@ -1026,6 +1036,22 @@ def test_view(self, td_name, device):
10261036
assert (td_view.get("a") == 1).all()
10271037
assert (td.get("a") == 1).all()
10281038

1039+
def test_inferred_view_size(self, td_name, device):
1040+
if td_name in ("permute_td", "sub_td2"):
1041+
pytest.skip("view incompatible with stride / permutation")
1042+
torch.manual_seed(1)
1043+
td = getattr(self, td_name)(device)
1044+
for i in range(len(td.shape)):
1045+
# replacing every index one at a time
1046+
# with -1, to test that td.view(..., -1, ...)
1047+
# always returns the original tensordict
1048+
new_shape = [
1049+
dim_size if dim_idx != i else -1
1050+
for dim_idx, dim_size in enumerate(td.shape)
1051+
]
1052+
assert td.view(-1).view(*new_shape) is td
1053+
assert td.view(*new_shape) is td
1054+
10291055
def test_clone_td(self, td_name, device):
10301056
torch.manual_seed(1)
10311057
td = getattr(self, td_name)(device)

torchrl/data/tensordict/memmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class MemmapTensor(object):
5252
and remote workers that have access to
5353
a common storage, and as such it supports serialization and
5454
deserialization. It is possible to choose if the ownership is
55-
transferred upon serialization / deserialization: If owenership is not
55+
transferred upon serialization / deserialization: If ownership is not
5656
transferred (transfer_ownership=False, default), then the process where
5757
the MemmapTensor was created will be responsible of clearing it once it
5858
gets out of scope (in that process). Otherwise, the process that

torchrl/data/tensordict/tensordict.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import numpy as np
3535
import torch
36+
from torch.jit._shape_functions import infer_size_impl
3637

3738
from torchrl import KeyDependentDefaultDict, prod
3839
from torchrl.data.tensordict.memmap import MemmapTensor
@@ -1232,7 +1233,10 @@ def view(
12321233
elif len(shape) == 1 and isinstance(shape[0], (list, tuple, torch.Size)):
12331234
return self.view(*shape[0])
12341235
elif not isinstance(shape, torch.Size):
1236+
shape = infer_size_impl(shape, self.numel())
12351237
shape = torch.Size(shape)
1238+
if shape == self.shape:
1239+
return self
12361240
return ViewedTensorDict(
12371241
source=self,
12381242
custom_op="view",
@@ -4298,6 +4302,7 @@ def view(
42984302
elif len(shape) == 1 and isinstance(shape[0], (list, tuple, torch.Size)):
42994303
return self.view(*shape[0])
43004304
elif not isinstance(shape, torch.Size):
4305+
shape = infer_size_impl(shape, self.numel())
43014306
shape = torch.Size(shape)
43024307
if shape == self._source.batch_size:
43034308
return self._source

0 commit comments

Comments
 (0)