Skip to content

Commit a0af473

Browse files
authored
[Feature] Truly invertible tensordict permutation of dimensions (#295)
1 parent 5c74ab7 commit a0af473

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

test/test_tensordict.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,24 @@ def test_permute(device):
306306
assert torch.sum(t["a"]) == torch.Tensor([0])
307307

308308

309+
@pytest.mark.parametrize("device", get_available_devices())
310+
def test_permute_applied_twice(device):
311+
torch.manual_seed(1)
312+
d = {
313+
"a": torch.randn(4, 5, 6, 9, device=device),
314+
"b": torch.randn(4, 5, 6, 7, device=device),
315+
"c": torch.randn(4, 5, 6, device=device),
316+
}
317+
td1 = TensorDict(batch_size=(4, 5, 6), source=d)
318+
td2 = torch.permute(td1, dims=(2, 1, 0))
319+
td3 = torch.permute(td2, dims=(2, 1, 0))
320+
assert td3 is td1
321+
td1 = TensorDict(batch_size=(4, 5, 6), source=d)
322+
td2 = torch.permute(td1, dims=(2, 1, 0))
323+
td3 = torch.permute(td2, dims=(0, 1, 2))
324+
assert td3 is not td1
325+
326+
309327
@pytest.mark.parametrize("device", get_available_devices())
310328
def test_permute_exceptions(device):
311329
torch.manual_seed(1)
@@ -626,6 +644,18 @@ def permute_td(self, device):
626644
# batch_size=[3, 1, 2, 4],
627645
# ).permute(2, 0, 1, 3)
628646

647+
def test_permute_applied_twice(self, td_name, device):
648+
torch.manual_seed(0)
649+
tensordict = getattr(self, td_name)(device)
650+
for _ in range(10):
651+
p = torch.randperm(4)
652+
inv_p = p.argsort()
653+
other_p = inv_p
654+
while (other_p == inv_p).all():
655+
other_p = torch.randperm(4)
656+
assert tensordict.permute(*p).permute(*inv_p) is tensordict
657+
assert tensordict.permute(*p).permute(*other_p) is not tensordict
658+
629659
def unsqueezed_td(self, device):
630660
td = TensorDict(
631661
source={

torchrl/data/tensordict/tensordict.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,10 @@ def permute(
12911291
f"number of dims don't match in permute (got {len(dims_list)}, expected {len(self.shape)}"
12921292
)
12931293

1294+
if not len(dims_list) and not self.batch_dims:
1295+
return self
1296+
if np.array_equal(dims_list, range(self.batch_dims)):
1297+
return self
12941298
min_dim, max_dim = -self.batch_dims, self.batch_dims - 1
12951299
seen = [False for dim in range(max_dim + 1)]
12961300
for idx in dims_list:
@@ -4271,8 +4275,45 @@ def view(
42714275
class PermutedTensorDict(_CustomOpTensorDict):
42724276
"""
42734277
A lazy view on a TensorDict with the batch dimensions permuted.
4278+
4279+
When calling `tensordict.permute(dims_list, dim)`, a lazy view of this operation is
4280+
returned such that the following code snippet works without raising an
4281+
exception:
4282+
4283+
>>> assert tensordict.permute(dims_list, dim).permute(dims_list, dim) is tensordict
4284+
4285+
Examples:
4286+
>>> from torchrl.data import TensorDict
4287+
>>> import torch
4288+
>>> td = TensorDict({'a': torch.randn(4, 5, 6, 9)}, batch_size=[3])
4289+
>>> td_permute = td.permute(dims=(2, 1, 0))
4290+
>>> print(td_permute.shape)
4291+
torch.Size([6, 5, 4])
4292+
>>> print(td_permute.permute(dims=(2, 1, 0)) is td)
4293+
True
42744294
"""
42754295

4296+
def permute(
4297+
self,
4298+
*dims_list: int,
4299+
dims=None,
4300+
) -> _TensorDict:
4301+
if len(dims_list) == 0:
4302+
dims_list = dims
4303+
elif len(dims_list) == 1 and not isinstance(dims_list[0], int):
4304+
dims_list = dims_list[0]
4305+
if len(dims_list) != len(self.shape):
4306+
raise RuntimeError(
4307+
f"number of dims don't match in permute (got {len(dims_list)}, expected {len(self.shape)}"
4308+
)
4309+
if not len(dims_list) and not self.batch_dims:
4310+
return self
4311+
if np.array_equal(dims_list, range(self.batch_dims)):
4312+
return self
4313+
if np.array_equal(np.argsort(dims_list), self.inv_op_kwargs.get("dims")):
4314+
return self._source
4315+
return super().permute(*dims_list)
4316+
42764317
def add_missing_dims(self, num_dims: int, batch_dims: tuple) -> tuple:
42774318
dim_diff = num_dims - len(batch_dims)
42784319
all_dims = [i for i in range(num_dims)]

0 commit comments

Comments
 (0)