Skip to content

Commit ccadb67

Browse files
author
Vincent Moens
committed
[BugFix] Fix encode overrides
ghstack-source-id: 2b7710a Pull-Request-resolved: #2943
1 parent 3bad905 commit ccadb67

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

test/test_specs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3939,6 +3939,12 @@ def test_device_cast(self):
39393939
comp["nontensor"] = NonTensor(device="cpu")
39403940
assert comp["nontensor"].device == torch.device("cpu")
39413941

3942+
def test_encode(self):
3943+
comp = Composite(device="cpu")
3944+
comp["nontensor"] = NonTensor(shape=())
3945+
r = comp.encode({"nontensor": "a string"})
3946+
assert isinstance(r["nontensor"], str)
3947+
39423948

39433949
@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")
39443950
def test_device_ordinal():

torchrl/data/tensor_specs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,7 +1609,7 @@ def space(self):
16091609
def _project(self, val: TensorDictBase) -> TensorDictBase:
16101610
raise NOT_IMPLEMENTED_ERROR
16111611

1612-
def encode(
1612+
def _encode_eager(
16131613
self, val: np.ndarray | torch.Tensor, *, ignore_device=False
16141614
) -> torch.Tensor:
16151615
if self.dim != 0 and not isinstance(val, tuple):
@@ -2912,7 +2912,7 @@ def to_numpy(
29122912
) -> np.ndarray | dict:
29132913
return val
29142914

2915-
def encode(
2915+
def _encode_eager(
29162916
self,
29172917
val: np.ndarray | torch.Tensor | TensorDictBase,
29182918
*,
@@ -6357,7 +6357,7 @@ def empty(self):
63576357
[spec.empty() for spec in self._specs], dim=self.stack_dim
63586358
)
63596359

6360-
def encode(
6360+
def _encode_eager(
63616361
self, vals: dict[str, Any], ignore_device: bool = False
63626362
) -> dict[str, torch.Tensor]:
63636363
raise NOT_IMPLEMENTED_ERROR

0 commit comments

Comments
 (0)