Skip to content

Commit 7deff86

Browse files
author
Vincent Moens
committed
[BugFix] NonTensor.encode must return a NonTensorData
ghstack-source-id: 36fc87c Pull-Request-resolved: #2944
1 parent ccadb67 commit 7deff86

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

test/test_specs.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3940,10 +3940,17 @@ def test_device_cast(self):
39403940
assert comp["nontensor"].device == torch.device("cpu")
39413941

39423942
def test_encode(self):
3943+
nt = NonTensor(shape=(1,))
3944+
r = nt.encode("a string")
3945+
assert isinstance(r, NonTensorData)
3946+
assert r.shape == nt.shape
3947+
39433948
comp = Composite(device="cpu")
3944-
comp["nontensor"] = NonTensor(shape=())
3949+
comp["nontensor"] = nt
39453950
r = comp.encode({"nontensor": "a string"})
3946-
assert isinstance(r["nontensor"], str)
3951+
assert isinstance(r, TensorDict)
3952+
assert isinstance(r.get("nontensor"), NonTensorData)
3953+
assert r.get("nontensor").shape == (1,)
39473954

39483955

39493956
@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")

torchrl/data/tensor_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2918,7 +2918,7 @@ def _encode_eager(
29182918
*,
29192919
ignore_device: bool = False,
29202920
) -> torch.Tensor | TensorDictBase:
2921-
return val
2921+
return NonTensorData(val, device=self.device, batch_size=self.shape)
29222922

29232923

29242924
class _UnboundedMeta(abc.ABCMeta):

0 commit comments

Comments
 (0)