Skip to content

Commit 726dc42

Browse files
authored
[BugFix] Fix CompositeSpec.to_numpy method (#931)
1 parent 5bd4b4f commit 726dc42

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

test/test_specs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,14 @@ def test_is_in(self, is_complete, device, dtype):
446446
r = ts.rand()
447447
assert ts.is_in(r)
448448

449+
def test_to_numpy(self, is_complete, device, dtype):
450+
ts = self._composite_spec(is_complete, device, dtype)
451+
for _ in range(100):
452+
r = ts.rand()
453+
for key, value in ts.to_numpy(r).items():
454+
spec = ts[key]
455+
assert (spec.to_numpy(r[key]) == value).all()
456+
449457
@pytest.mark.parametrize("shape", [[], [3]])
450458
def test_project(self, is_complete, device, dtype, shape):
451459
ts = self._composite_spec(is_complete, device, dtype)

torchrl/data/tensor_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1947,7 +1947,7 @@ def clone(self) -> CompositeSpec:
19471947
)
19481948

19491949
def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
1950-
return {key: self[key]._to_numpy(val) for key, val in val.items()}
1950+
return {key: self[key].to_numpy(val) for key, val in val.items()}
19511951

19521952
def zero(self, shape=None) -> TensorDictBase:
19531953
if shape is None:

0 commit comments

Comments
 (0)