Skip to content

Commit 69c6122

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 075e82b commit 69c6122

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

test/test_specs.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2):
14021402
assert spec2.zero().shape == spec2.shape
14031403

14041404
def test_non_tensor(self):
1405-
spec = NonTensor((3, 4), device="cpu")
1405+
spec = NonTensor((3, 4), device="cpu", example_data="example_data")
14061406
assert (
14071407
spec.expand(2, 3, 4)
14081408
== spec.expand((2, 3, 4))
1409-
== NonTensor((2, 3, 4), device="cpu")
1409+
== NonTensor((2, 3, 4), device="cpu", example_data="example_data")
14101410
)
1411+
assert spec.expand(2, 3, 4).example_data == "example_data"
14111412

14121413
@pytest.mark.parametrize("shape1", [None, (), (5,)])
14131414
@pytest.mark.parametrize("shape2", [(), (10,)])
@@ -1607,9 +1608,10 @@ def test_multionehot(
16071608
assert spec is not spec.clone()
16081609

16091610
def test_non_tensor(self):
1610-
spec = NonTensor(shape=(3, 4), device="cpu")
1611+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
16111612
assert spec.clone() == spec
16121613
assert spec.clone() is not spec
1614+
assert spec.clone().example_data == "example_data"
16131615

16141616
@pytest.mark.parametrize("shape1", [None, (), (5,)])
16151617
def test_onehot(
@@ -1840,9 +1842,10 @@ def test_multionehot(
18401842
spec.unbind(-1)
18411843

18421844
def test_non_tensor(self):
1843-
spec = NonTensor(shape=(3, 4), device="cpu")
1845+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
18441846
assert spec.unbind(1)[0] == spec[:, 0]
18451847
assert spec.unbind(1)[0] is not spec[:, 0]
1848+
assert spec.unbind(1)[0].example_data == "example_data"
18461849

18471850
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
18481851
def test_onehot(
@@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device):
20012004
assert spec.to(device).device == device
20022005

20032006
def test_non_tensor(self, device):
2004-
spec = NonTensor(shape=(3, 4), device="cpu")
2007+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
20052008
assert spec.to(device).device == device
2009+
assert spec.to(device).example_data == "example_data"
20062010

20072011
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
20082012
def test_onehot(self, shape1, device):
@@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
22622266
assert r.shape == c.shape
22632267

22642268
def test_stack_non_tensor(self, shape, stack_dim):
2265-
spec0 = NonTensor(shape=shape, device="cpu")
2266-
spec1 = NonTensor(shape=shape, device="cpu")
2269+
spec0 = NonTensor(shape=shape, device="cpu", example_data="example_data")
2270+
spec1 = NonTensor(shape=shape, device="cpu", example_data="example_data")
22672271
new_spec = torch.stack([spec0, spec1], stack_dim)
22682272
shape_insert = list(shape)
22692273
shape_insert.insert(stack_dim, 2)
22702274
assert new_spec.shape == torch.Size(shape_insert)
22712275
assert new_spec.device == torch.device("cpu")
2276+
assert new_spec.example_data == "example_data"
22722277

22732278
def test_stack_onehot(self, shape, stack_dim):
22742279
n = 5
@@ -3642,10 +3647,18 @@ def test_expand(self):
36423647

36433648
class TestNonTensorSpec:
36443649
def test_sample(self):
3645-
nts = NonTensor(shape=(3, 4))
3650+
nts = NonTensor(shape=(3, 4), example_data="example_data")
36463651
assert nts.one((2,)).shape == (2, 3, 4)
36473652
assert nts.rand((2,)).shape == (2, 3, 4)
36483653
assert nts.zero((2,)).shape == (2, 3, 4)
3654+
assert nts.one((2,)).data == "example_data"
3655+
assert nts.rand((2,)).data == "example_data"
3656+
assert nts.zero((2,)).data == "example_data"
3657+
3658+
def test_example_data_ineq(self):
3659+
nts0 = NonTensor(shape=(3, 4), example_data="example_data")
3660+
nts1 = NonTensor(shape=(3, 4), example_data="example_data 2")
3661+
assert nts0 != nts1
36493662

36503663

36513664
@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")

torchrl/data/tensor_specs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,6 +2452,8 @@ class NonTensor(TensorSpec):
24522452
(same will go for :meth:`.zero` and :meth:`.one`).
24532453
"""
24542454

2455+
example_data: Any = None
2456+
24552457
def __init__(
24562458
self,
24572459
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
@@ -2470,6 +2472,11 @@ def __init__(
24702472
)
24712473
self.example_data = example_data
24722474

2475+
def __eq__(self, other):
2476+
eq = super().__eq__(other)
2477+
eq = eq & (self.example_data == getattr(other, "example_data", None))
2478+
return eq
2479+
24732480
def cardinality(self) -> Any:
24742481
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
24752482

@@ -2555,6 +2562,16 @@ def expand(self, *shape):
25552562
shape=shape, device=self.device, dtype=None, example_data=self.example_data
25562563
)
25572564

2565+
def unsqueeze(self, dim: int) -> NonTensor:
2566+
unsq = super().unsqueeze(dim=dim)
2567+
unsq.example_data = self.example_data
2568+
return unsq
2569+
2570+
def squeeze(self, dim: int | None = None) -> NonTensor:
2571+
sq = super().squeeze(dim=dim)
2572+
sq.example_data = self.example_data
2573+
return sq
2574+
25582575
def _reshape(self, shape):
25592576
return self.__class__(
25602577
shape=shape,

0 commit comments

Comments
 (0)