34
34
MultiKeyCountingEnvPolicy ,
35
35
NestedCountingEnv ,
36
36
)
37
- from tensordict .nn import TensorDictModule
37
+ from tensordict .nn import TensorDictModule , TensorDictSequential
38
38
from tensordict .tensordict import assert_allclose_td , TensorDict
39
39
40
40
from torch import nn
@@ -939,17 +939,22 @@ def create_env():
939
939
[MultiSyncDataCollector , MultiaSyncDataCollector , SyncDataCollector ],
940
940
)
941
941
@pytest .mark .parametrize ("exclude" , [True , False ])
942
- def test_excluded_keys (collector_class , exclude ):
942
+ @pytest .mark .parametrize ("out_key" , ["_dummy" , ("out" , "_dummy" ), ("_out" , "dummy" )])
943
+ def test_excluded_keys (collector_class , exclude , out_key ):
943
944
if not exclude and collector_class is not SyncDataCollector :
944
945
pytest .skip ("defining _exclude_private_keys is not possible" )
945
946
946
947
def make_env ():
947
- return ContinuousActionVecMockEnv ()
948
+ return TransformedEnv ( ContinuousActionVecMockEnv (), InitTracker () )
948
949
949
950
dummy_env = make_env ()
950
951
obs_spec = dummy_env .observation_spec ["observation" ]
951
952
policy_module = nn .Linear (obs_spec .shape [- 1 ], dummy_env .action_spec .shape [- 1 ])
952
- policy = Actor (policy_module , spec = dummy_env .action_spec )
953
+ policy = TensorDictModule (
954
+ policy_module , in_keys = ["observation" ], out_keys = ["action" ]
955
+ )
956
+ copier = TensorDictModule (lambda x : x , in_keys = ["observation" ], out_keys = [out_key ])
957
+ policy = TensorDictSequential (policy , copier )
953
958
policy_explore = OrnsteinUhlenbeckProcessWrapper (policy )
954
959
955
960
collector_kwargs = {
@@ -966,11 +971,13 @@ def make_env():
966
971
collector = collector_class (** collector_kwargs )
967
972
collector ._exclude_private_keys = exclude
968
973
for b in collector :
969
- keys = b .keys ()
974
+ keys = set ( b .keys () )
970
975
if exclude :
971
976
assert not any (key .startswith ("_" ) for key in keys )
977
+ assert out_key not in b .keys (True , True )
972
978
else :
973
979
assert any (key .startswith ("_" ) for key in keys )
980
+ assert out_key in b .keys (True , True )
974
981
break
975
982
collector .shutdown ()
976
983
dummy_env .close ()
0 commit comments