Skip to content

Commit 8302b2e

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents fc163e4 + f67cbc7 commit 8302b2e

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

test/test_env.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections import defaultdict
1717
from functools import partial
1818
from sys import platform
19-
from typing import Optional
19+
from typing import Any, Optional
2020

2121
import numpy as np
2222
import pytest
@@ -33,7 +33,7 @@
3333
TensorDictBase,
3434
)
3535
from tensordict.nn import TensorDictModuleBase
36-
from tensordict.tensorclass import NonTensorStack
36+
from tensordict.tensorclass import NonTensorStack, TensorClass
3737
from tensordict.utils import _unravel_key_to_tuple
3838
from torch import nn
3939

@@ -340,7 +340,8 @@ def forward(self, values):
340340
)
341341
env.rollout(10, policy)
342342

343-
def test_make_spec_from_td(self):
343+
@pytest.mark.parametrize("dynamic_shape", [True, False])
344+
def test_make_spec_from_td(self, dynamic_shape):
344345
data = TensorDict(
345346
{
346347
"obs": torch.randn(3),
@@ -353,10 +354,44 @@ def test_make_spec_from_td(self):
353354
},
354355
[],
355356
)
356-
spec = make_composite_from_td(data)
357+
spec = make_composite_from_td(data, dynamic_shape=dynamic_shape)
357358
assert (spec.zero() == data.zero_()).all()
358359
for key, val in data.items(True, True):
359360
assert val.dtype is spec[key].dtype
361+
if dynamic_shape:
362+
assert all(s.shape[-1] == -1 for s in spec.values(True, True))
363+
364+
def test_make_spec_from_tc(self):
365+
class Scratch(TensorClass):
366+
obs: torch.Tensor
367+
string: str
368+
some_object: Any
369+
370+
class Whatever:
371+
...
372+
373+
td = TensorDict(
374+
a=Scratch(
375+
obs=torch.ones(5, 3),
376+
string="another string!",
377+
some_object=Whatever(),
378+
batch_size=(5,),
379+
),
380+
b="a string!",
381+
batch_size=(5,),
382+
)
383+
spec = make_composite_from_td(td)
384+
assert isinstance(spec, Composite)
385+
assert isinstance(spec["a"], Composite)
386+
assert isinstance(spec["b"], NonTensor)
387+
assert spec["b"].example_data == "a string!", spec["b"].example_data
388+
assert spec["a", "string"].example_data == "another string!"
389+
one = spec.one()
390+
assert isinstance(one["a"], Scratch)
391+
assert isinstance(one["b"], str)
392+
assert isinstance(one["a"].string, str)
393+
assert isinstance(one["a"].some_object, Whatever)
394+
assert (one == td).all()
360395

361396
def test_env_that_does_nothing(self):
362397
env = EnvThatDoesNothing()

torchrl/envs/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -940,9 +940,9 @@ def make_shape(shape):
940940
unsqueeze_null_shapes=unsqueeze_null_shapes,
941941
dynamic_shape=dynamic_shape,
942942
)
943-
if isinstance(tensor, TensorDictBase)
943+
if is_tensor_collection(tensor) and not is_non_tensor(tensor)
944944
else NonTensor(
945-
shape=data.shape, example_data=data.data, device=tensor.device
945+
shape=tensor.shape, example_data=tensor.data, example_data=data.data, device=tensor.device
946946
)
947947
if is_non_tensor(tensor)
948948
else Unbounded(
@@ -951,6 +951,7 @@ def make_shape(shape):
951951
for key, tensor in data.items()
952952
},
953953
shape=data.shape,
954+
data_cls=type(data),
954955
)
955956
return composite
956957

0 commit comments

Comments
 (0)