Skip to content

Commit cdc6798

Browse files
authored
[Refactor] Refactor composite spec keys to match tensordict (#956)
1 parent 1d0c335 commit cdc6798

File tree

14 files changed

+96
-69
lines changed

14 files changed

+96
-69
lines changed

test/test_specs.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -521,16 +521,24 @@ def test_nested_composite_spec(self, is_complete, device, dtype):
521521
assert set(ts.keys()) == {
522522
"obs",
523523
"act",
524+
"nested_cp",
525+
}
526+
assert set(ts.keys(include_nested=True)) == {
527+
"obs",
528+
"act",
529+
"nested_cp",
524530
("nested_cp", "obs"),
525531
("nested_cp", "act"),
526532
}
527-
assert len(ts.keys()) == len(ts.keys(yield_nesting_keys=True)) - 1
528-
assert set(ts.keys(yield_nesting_keys=True)) == {
533+
assert set(ts.keys(include_nested=True, leaves_only=True)) == {
529534
"obs",
530535
"act",
531536
("nested_cp", "obs"),
532537
("nested_cp", "act"),
533-
"nested_cp",
538+
}
539+
assert set(ts.keys(leaves_only=True)) == {
540+
"obs",
541+
"act",
534542
}
535543
td = ts.rand()
536544
assert isinstance(td["nested_cp"], TensorDictBase)
@@ -577,9 +585,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
577585
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
578586
td2 = CompositeSpec(new=None)
579587
ts.update(td2)
580-
assert set(ts.keys()) == {
588+
assert set(ts.keys(include_nested=True)) == {
581589
"obs",
582590
"act",
591+
"nested_cp",
583592
("nested_cp", "obs"),
584593
("nested_cp", "act"),
585594
"new",
@@ -589,9 +598,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
589598
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
590599
td2 = CompositeSpec(nested_cp=CompositeSpec(new=None).to(device))
591600
ts.update(td2)
592-
assert set(ts.keys()) == {
601+
assert set(ts.keys(include_nested=True)) == {
593602
"obs",
594603
"act",
604+
"nested_cp",
595605
("nested_cp", "obs"),
596606
("nested_cp", "act"),
597607
("nested_cp", "new"),
@@ -601,9 +611,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
601611
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
602612
td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device))
603613
ts.update(td2)
604-
assert set(ts.keys()) == {
614+
assert set(ts.keys(include_nested=True)) == {
605615
"obs",
606616
"act",
617+
"nested_cp",
607618
("nested_cp", "obs"),
608619
("nested_cp", "act"),
609620
}
@@ -617,9 +628,10 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
617628
nested_cp=CompositeSpec(act=UnboundedContinuousTensorSpec(device=device))
618629
)
619630
ts.update(td2)
620-
assert set(ts.keys()) == {
631+
assert set(ts.keys(include_nested=True)) == {
621632
"obs",
622633
"act",
634+
"nested_cp",
623635
("nested_cp", "obs"),
624636
("nested_cp", "act"),
625637
}
@@ -629,7 +641,7 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
629641
def test_keys_to_empty_composite_spec():
630642
keys = [("key1", "out"), ("key1", "in"), "key2", ("key1", "subkey1", "subkey2")]
631643
composite = _keys_to_empty_composite_spec(keys)
632-
assert set(composite.keys()) == set(keys)
644+
assert set(composite.keys(True, True)) == set(keys)
633645

634646

635647
class TestEquality:

torchrl/collectors/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def __init__(
447447
hasattr(self.policy, "spec")
448448
and self.policy.spec is not None
449449
and all(v is not None for v in self.policy.spec.values())
450-
and set(self.policy.spec.keys()) == set(self.policy.out_keys)
450+
and set(self.policy.spec.keys(True, True)) == set(self.policy.out_keys)
451451
):
452452
# if policy spec is non-empty, all the values are not None and the keys
453453
# match the out_keys we assume the user has given all relevant information

torchrl/data/tensor_specs.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,13 +1862,12 @@ def type_check(
18621862
self._specs[_key].type_check(value[_key], _key)
18631863

18641864
def is_in(self, val: Union[dict, TensorDictBase]) -> bool:
1865-
return all(
1866-
[
1867-
item.is_in(val.get(key))
1868-
for (key, item) in self._specs.items()
1869-
if item is not None
1870-
]
1871-
)
1865+
for (key, item) in self._specs.items():
1866+
if item is None:
1867+
continue
1868+
if not item.is_in(val.get(key)):
1869+
return False
1870+
return True
18721871

18731872
def project(self, val: TensorDictBase) -> TensorDictBase:
18741873
for key, item in self.items():
@@ -1894,22 +1893,29 @@ def rand(self, shape=None) -> TensorDictBase:
18941893
)
18951894

18961895
def keys(
1897-
self, yield_nesting_keys: bool = False, nested_keys: bool = True
1896+
self,
1897+
include_nested: bool = False,
1898+
leaves_only: bool = False,
18981899
) -> KeysView:
18991900
"""Keys of the CompositeSpec.
19001901
1902+
The keys argument reflect those of :class:`tensordict.TensorDict`.
1903+
19011904
Args:
1902-
yield_nesting_keys (bool, optional): if :obj:`True`, the values returned
1903-
will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
1904-
will lead to the keys :obj:`["next", ("next", "obs")]`. Default is :obj:`False`, i.e.
1905-
only nested keys will be returned.
1906-
nested_keys (bool, optional): if :obj:`False`, the returned keys will not be nested. They will
1905+
include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will
19071906
represent only the immediate children of the root, and not the whole nested sequence, i.e. a
19081907
:obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys
1909-
:obj:`["next"]. Default is :obj:`True`, i.e. nested keys will be returned.
1908+
:obj:`["next"]. Default is ``False``, i.e. nested keys will not
1909+
be returned.
1910+
leaves_only (bool, optional): if :obj:`False`, the values returned
1911+
will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
1912+
will lead to the keys :obj:`["next", ("next", "obs")]`.
1913+
Default is ``False``.
19101914
"""
19111915
return _CompositeSpecKeysView(
1912-
self, _yield_nesting_keys=yield_nesting_keys, nested_keys=nested_keys
1916+
self,
1917+
include_nested=include_nested,
1918+
leaves_only=leaves_only,
19131919
)
19141920

19151921
def items(self) -> ItemsView:
@@ -2014,13 +2020,14 @@ def expand(self, *shape):
20142020

20152021

20162022
def _keys_to_empty_composite_spec(keys):
2023+
"""Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value."""
20172024
if not len(keys):
20182025
return
20192026
c = CompositeSpec()
20202027
for key in keys:
20212028
if isinstance(key, str):
20222029
c[key] = None
2023-
elif key[0] in c.keys(yield_nesting_keys=True):
2030+
elif key[0] in c.keys():
20242031
if c[key[0]] is None:
20252032
# if the value is None we just replace it
20262033
c[key[0]] = _keys_to_empty_composite_spec([key[1:]])
@@ -2042,28 +2049,34 @@ class _CompositeSpecKeysView:
20422049
def __init__(
20432050
self,
20442051
composite: CompositeSpec,
2045-
nested_keys: bool = True,
2046-
_yield_nesting_keys: bool = False,
2052+
include_nested,
2053+
leaves_only,
20472054
):
20482055
self.composite = composite
2049-
self._yield_nesting_keys = _yield_nesting_keys
2050-
self.nested_keys = nested_keys
2056+
self.leaves_only = leaves_only
2057+
self.include_nested = include_nested
20512058

20522059
def __iter__(
20532060
self,
20542061
):
20552062
for key, item in self.composite.items():
2056-
if self.nested_keys and isinstance(item, CompositeSpec):
2057-
for subkey in item.keys():
2058-
yield (key, *subkey) if isinstance(subkey, tuple) else (key, subkey)
2059-
if self._yield_nesting_keys:
2060-
yield key
2061-
else:
2062-
if not isinstance(item, CompositeSpec) or len(item):
2063+
if self.include_nested and isinstance(item, CompositeSpec):
2064+
for subkey in item.keys(
2065+
include_nested=True, leaves_only=self.leaves_only
2066+
):
2067+
if not isinstance(subkey, tuple):
2068+
subkey = (subkey,)
2069+
yield (key, *subkey)
2070+
if not self.leaves_only:
20632071
yield key
2072+
elif not isinstance(item, CompositeSpec) or not self.leaves_only:
2073+
yield key
20642074

20652075
def __len__(self):
20662076
i = 0
20672077
for _ in self:
20682078
i += 1
20692079
return i
2080+
2081+
def __repr__(self):
2082+
return f"_CompositeSpecKeysView(keys={list(self)})"

torchrl/envs/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
389389
)
390390
tensordict.unlock_()
391391

392-
obs_keys = self.observation_spec.keys(nested_keys=False)
392+
obs_keys = self.observation_spec.keys(True)
393393
# we deliberately do not update the input values, but we want to keep track of
394394
# new keys considered as "input" by inverse transforms.
395395
in_keys = self._get_in_keys_to_exclude(tensordict)

torchrl/envs/gym_like.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def read_obs(
174174
# when queried with and without pixels
175175
observations["observation"] = observations.pop("state")
176176
if not isinstance(observations, (TensorDict, dict)):
177-
(key,) = itertools.islice(self.observation_spec.keys(), 1)
177+
(key,) = itertools.islice(self.observation_spec.keys(True, True), 1)
178178
observations = {key: observations}
179179
observations = self.observation_spec.encode(observations)
180180
return observations

torchrl/envs/transforms/r3m.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
107107
if not isinstance(observation_spec, CompositeSpec):
108108
raise ValueError("_R3MNet can only infer CompositeSpec")
109109

110-
keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
110+
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
111111
device = observation_spec[keys[0]].device
112112
dim = observation_spec[keys[0]].shape[:-3]
113113

torchrl/envs/transforms/transforms.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def new_fun(self, observation_spec):
6161
if isinstance(observation_spec, CompositeSpec):
6262
d = observation_spec._specs
6363
for in_key, out_key in zip(self.in_keys, self.out_keys):
64-
if in_key in observation_spec.keys():
64+
if in_key in observation_spec.keys(True, True):
6565
d[out_key] = function(self, observation_spec[in_key].clone())
6666
return CompositeSpec(
6767
d, shape=observation_spec.shape, device=observation_spec.device
@@ -85,7 +85,7 @@ def new_fun(self, input_spec):
8585
if isinstance(input_spec, CompositeSpec):
8686
d = input_spec._specs
8787
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
88-
if in_key in input_spec.keys():
88+
if in_key in input_spec.keys(True, True):
8989
d[out_key] = function(self, input_spec[in_key].clone())
9090
return CompositeSpec(d, shape=input_spec.shape, device=input_spec.device)
9191
else:
@@ -2066,7 +2066,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
20662066
# by def, there must be only one key
20672067
return observation_spec
20682068

2069-
keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
2069+
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
20702070

20712071
sum_shape = sum(
20722072
[
@@ -2849,7 +2849,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
28492849
raise KeyError(
28502850
f"The key {in_key} was not found in the parent "
28512851
f"observation_spec with keys "
2852-
f"{list(self.parent.observation_spec.keys())}. "
2852+
f"{list(self.parent.observation_spec.keys(True))}. "
28532853
) from err
28542854

28552855
return tensordict
@@ -2880,7 +2880,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
28802880
episode_specs = {}
28812881
if isinstance(reward_spec, CompositeSpec):
28822882
# If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
2883-
if not all(k in reward_spec.keys() for k in self.in_keys):
2883+
if not all(k in reward_spec.keys(True, True) for k in self.in_keys):
28842884
raise KeyError("Not all in_keys are present in ´reward_spec´")
28852885

28862886
# Define episode specs for all out_keys
@@ -3042,7 +3042,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
30423042
return tensordict.exclude(*self.excluded_keys)
30433043

30443044
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
3045-
if any(key in observation_spec.keys() for key in self.excluded_keys):
3045+
if any(key in observation_spec.keys(True, True) for key in self.excluded_keys):
30463046
return CompositeSpec(
30473047
**{
30483048
key: value
@@ -3074,7 +3074,7 @@ def __init__(self, *selected_keys):
30743074

30753075
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
30763076
if self.parent:
3077-
input_keys = self.parent.input_spec.keys()
3077+
input_keys = self.parent.input_spec.keys(True, True)
30783078
else:
30793079
input_keys = []
30803080
return tensordict.select(
@@ -3085,7 +3085,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
30853085

30863086
def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
30873087
if self.parent:
3088-
input_keys = self.parent.input_spec.keys()
3088+
input_keys = self.parent.input_spec.keys(True, True)
30893089
else:
30903090
input_keys = []
30913091
return tensordict.select(

torchrl/envs/transforms/vip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
9696
if not isinstance(observation_spec, CompositeSpec):
9797
raise ValueError("_VIPNet can only infer CompositeSpec")
9898

99-
keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
99+
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
100100
device = observation_spec[keys[0]].device
101101
dim = observation_spec[keys[0]].shape[:-3]
102102

torchrl/envs/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,22 +242,24 @@ def _check_isin(key, value, obs_spec, input_spec):
242242
for _key, _value in value.items():
243243
_check_isin(_key, _value, obs_spec, input_spec)
244244
return
245-
elif key in input_spec.keys(yield_nesting_keys=True):
245+
elif key in input_spec.keys(True):
246246
if not input_spec[key].is_in(value):
247247
raise AssertionError(
248248
f"input_spec.is_in failed for key {key}. "
249249
f"Got input_spec={input_spec[key]} and real={value}."
250250
)
251251
return
252-
elif key in obs_spec.keys(yield_nesting_keys=True):
252+
elif key in obs_spec.keys(True):
253253
if not obs_spec[key].is_in(value):
254254
raise AssertionError(
255255
f"obs_spec.is_in failed for key {key}. "
256256
f"Got obs_spec={obs_spec[key]} and real={value}."
257257
)
258258
return
259259
else:
260-
raise KeyError(key)
260+
raise KeyError(
261+
f"key {key} was not found in input spec with keys {input_spec.keys(True)} or obs spec with keys {obs_spec.keys(True)}"
262+
)
261263

262264

263265
def _selective_unsqueeze(tensor: torch.Tensor, batch_size: torch.Size, dim: int = -1):

0 commit comments

Comments
 (0)