Skip to content

Commit 4a81a6c

Browse files
authored
[BugFix] Sync with tensordict (meta-tensor deprecation) (#842)
1 parent e8e511d commit 4a81a6c

File tree

3 files changed

+8
-21
lines changed

3 files changed

+8
-21
lines changed

test/test_shared.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111
import torch
12-
from tensordict import SavedTensorDict, TensorDict
12+
from tensordict import TensorDict
1313
from torch import multiprocessing as mp
1414

1515

@@ -145,10 +145,10 @@ def test_shared(self, shared):
145145
)
146146

147147

148-
@pytest.mark.skipif(
149-
sys.platform == "win32",
150-
reason="RuntimeError from Torch serialization.py when creating td_saved on Windows",
151-
)
148+
# @pytest.mark.skipif(
149+
# sys.platform == "win32",
150+
# reason="RuntimeError from Torch serialization.py when creating td_saved on Windows",
151+
# )
152152
@pytest.mark.parametrize(
153153
"idx",
154154
[
@@ -180,7 +180,6 @@ def test_memmap(idx, dtype, large_scale=False):
180180

181181
td_sm = td.clone().share_memory_()
182182
td_memmap = td.clone().memmap_()
183-
td_saved = td.to(SavedTensorDict)
184183

185184
print("\nTesting reading from TD")
186185
for i in range(2):
@@ -194,11 +193,6 @@ def test_memmap(idx, dtype, large_scale=False):
194193
if i == 1:
195194
print(f"memmap: {time.time() - t0:4.4f} sec")
196195

197-
t0 = time.time()
198-
td_saved[idx].clone()
199-
if i == 1:
200-
print(f"saved td: {time.time() - t0:4.4f} sec")
201-
202196
td_to_copy = td[idx].contiguous()
203197
for k in td_to_copy.keys():
204198
td_to_copy.set_(k, torch.ones_like(td_to_copy.get(k)))
@@ -219,13 +213,6 @@ def test_memmap(idx, dtype, large_scale=False):
219213
print(f"memmap td: {time.time() - t0:4.4f} sec")
220214
torch.testing.assert_close(sub_td_sm.get("a")._tensor, td_to_copy.get("a"))
221215

222-
t0 = time.time()
223-
sub_td_sm = td_saved.get_sub_tensordict(idx)
224-
sub_td_sm.update_(td_to_copy)
225-
if i == 1:
226-
print(f"saved td: {time.time() - t0:4.4f} sec")
227-
torch.testing.assert_close(sub_td_sm.get("a"), td_to_copy.get("a"))
228-
229216

230217
if __name__ == "__main__":
231218
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/collectors/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase:
5959
torch.ones(
6060
out_split.shape,
6161
dtype=torch.bool,
62-
device=out_split._get_meta("done").device,
62+
device=out_split.get("done").device,
6363
),
6464
)
6565
MAX = max(*[out_split.shape[0] for out_split in out_splits])

torchrl/envs/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,13 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
380380
obs = tensordict_out.get(key)
381381
self.observation_spec.type_check(obs, key)
382382

383-
if tensordict_out._get_meta("reward").dtype is not self.reward_spec.dtype:
383+
if tensordict_out.get("reward").dtype is not self.reward_spec.dtype:
384384
raise TypeError(
385385
f"expected reward.dtype to be {self.reward_spec.dtype} "
386386
f"but got {tensordict_out.get('reward').dtype}"
387387
)
388388

389-
if tensordict_out._get_meta("done").dtype is not torch.bool:
389+
if tensordict_out.get("done").dtype is not torch.bool:
390390
raise TypeError(
391391
f"expected done.dtype to be torch.bool but got {tensordict_out.get('done').dtype}"
392392
)

0 commit comments

Comments
 (0)