Skip to content

Commit e7630f1

Browse files
author
Vincent Moens
authored
[Feature] Exclude all private keys in collectors (#1644)
1 parent 2e32c10 commit e7630f1

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

test/test_collector.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
MultiKeyCountingEnvPolicy,
3535
NestedCountingEnv,
3636
)
37-
from tensordict.nn import TensorDictModule
37+
from tensordict.nn import TensorDictModule, TensorDictSequential
3838
from tensordict.tensordict import assert_allclose_td, TensorDict
3939

4040
from torch import nn
@@ -939,17 +939,22 @@ def create_env():
939939
[MultiSyncDataCollector, MultiaSyncDataCollector, SyncDataCollector],
940940
)
941941
@pytest.mark.parametrize("exclude", [True, False])
942-
def test_excluded_keys(collector_class, exclude):
942+
@pytest.mark.parametrize("out_key", ["_dummy", ("out", "_dummy"), ("_out", "dummy")])
943+
def test_excluded_keys(collector_class, exclude, out_key):
943944
if not exclude and collector_class is not SyncDataCollector:
944945
pytest.skip("defining _exclude_private_keys is not possible")
945946

946947
def make_env():
947-
return ContinuousActionVecMockEnv()
948+
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())
948949

949950
dummy_env = make_env()
950951
obs_spec = dummy_env.observation_spec["observation"]
951952
policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1])
952-
policy = Actor(policy_module, spec=dummy_env.action_spec)
953+
policy = TensorDictModule(
954+
policy_module, in_keys=["observation"], out_keys=["action"]
955+
)
956+
copier = TensorDictModule(lambda x: x, in_keys=["observation"], out_keys=[out_key])
957+
policy = TensorDictSequential(policy, copier)
953958
policy_explore = OrnsteinUhlenbeckProcessWrapper(policy)
954959

955960
collector_kwargs = {
@@ -966,11 +971,13 @@ def make_env():
966971
collector = collector_class(**collector_kwargs)
967972
collector._exclude_private_keys = exclude
968973
for b in collector:
969-
keys = b.keys()
974+
keys = set(b.keys())
970975
if exclude:
971976
assert not any(key.startswith("_") for key in keys)
977+
assert out_key not in b.keys(True, True)
972978
else:
973979
assert any(key.startswith("_") for key in keys)
980+
assert out_key in b.keys(True, True)
974981
break
975982
collector.shutdown()
976983
dummy_env.close()

torchrl/collectors/collectors.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,8 +758,18 @@ def iterator(self) -> Iterator[TensorDictBase]:
758758
if self.postproc is not None:
759759
tensordict_out = self.postproc(tensordict_out)
760760
if self._exclude_private_keys:
761+
762+
def is_private(key):
763+
if isinstance(key, str) and key.startswith("_"):
764+
return True
765+
if isinstance(key, tuple) and any(
766+
_key.startswith("_") for _key in key
767+
):
768+
return True
769+
return False
770+
761771
excluded_keys = [
762-
key for key in tensordict_out.keys() if key.startswith("_")
772+
key for key in tensordict_out.keys(True) if is_private(key)
763773
]
764774
tensordict_out = tensordict_out.exclude(
765775
*excluded_keys, inplace=True

0 commit comments

Comments
 (0)