73
73
from torchrl .modules import Actor , ActorCriticOperator , MLP , SafeModule , ValueOperator
74
74
from torchrl .modules .tensordict_module import WorldModelWrapper
75
75
76
+ pytestmark = [
77
+ pytest .mark .filterwarnings ("error" ),
78
+ pytest .mark .filterwarnings (
79
+ "ignore:Got multiple backends for torchrl.data.replay_buffers.storages"
80
+ ),
81
+ ]
82
+
76
83
gym_version = None
77
84
if _has_gym :
78
85
try :
@@ -232,7 +239,7 @@ def test_run_type_checks(self):
232
239
check_env_specs (env )
233
240
env ._run_type_checks = True
234
241
check_env_specs (env )
235
- env .output_spec .unlock_ ()
242
+ env .output_spec .unlock_ (recurse = True )
236
243
# check type check on done
237
244
env .output_spec ["full_done_spec" , "done" ].dtype = torch .int
238
245
with pytest .raises (TypeError , match = "expected done.dtype to" ):
@@ -292,8 +299,8 @@ def test_single_env_spec(self):
292
299
assert not env .output_spec_unbatched .shape
293
300
assert not env .full_reward_spec_unbatched .shape
294
301
295
- assert env .action_spec_unbatched .shape
296
- assert env .reward_spec_unbatched .shape
302
+ assert env .full_action_spec_unbatched [ env . action_key ] .shape
303
+ assert env .full_reward_spec_unbatched [ env . reward_key ] .shape
297
304
298
305
assert env .output_spec .is_in (env .output_spec_unbatched .zeros (env .shape ))
299
306
assert env .input_spec .is_in (env .input_spec_unbatched .zeros (env .shape ))
@@ -307,7 +314,10 @@ def forward(self, values):
307
314
return values .argmax (- 1 )
308
315
309
316
policy = nn .Sequential (
310
- nn .Linear (env .observation_spec ["observation" ].shape [- 1 ], env .action_spec .n ),
317
+ nn .Linear (
318
+ env .observation_spec ["observation" ].shape [- 1 ],
319
+ env .full_action_spec [env .action_key ].n ,
320
+ ),
311
321
ArgMaxModule (),
312
322
)
313
323
env .rollout (10 , policy )
@@ -507,7 +517,7 @@ def test_auto_cast_to_device(self, break_when_any_done):
507
517
policy = Actor (
508
518
nn .Linear (
509
519
env .observation_spec ["observation" ].shape [- 1 ],
510
- env .action_spec .shape [- 1 ],
520
+ env .full_action_spec [ env . action_key ] .shape [- 1 ],
511
521
device = "cuda:0" ,
512
522
),
513
523
in_keys = ["observation" ],
@@ -538,7 +548,7 @@ def test_auto_cast_to_device(self, break_when_any_done):
538
548
def test_env_seed (self , env_name , frame_skip , seed = 0 ):
539
549
env_name = env_name ()
540
550
env = GymEnv (env_name , frame_skip = frame_skip )
541
- action = env .action_spec .rand ()
551
+ action = env .full_action_spec [ env . action_key ] .rand ()
542
552
543
553
env .set_seed (seed )
544
554
td0a = env .reset ()
@@ -624,7 +634,7 @@ def test_env_base_reset_flag(self, batch_size, max_steps=3):
624
634
env = CountingEnv (max_steps = max_steps , batch_size = batch_size )
625
635
env .set_seed (1 )
626
636
627
- action = env .action_spec .rand ()
637
+ action = env .full_action_spec [ env . action_key ] .rand ()
628
638
action [:] = 1
629
639
630
640
for i in range (max_steps ):
@@ -695,7 +705,7 @@ def test_batch_locked(self, device):
695
705
with pytest .raises (RuntimeError , match = "batch_locked is a read-only property" ):
696
706
env .batch_locked = False
697
707
td = env .reset ()
698
- td ["action" ] = env .action_spec .rand ()
708
+ td ["action" ] = env .full_action_spec [ env . action_key ] .rand ()
699
709
td_expanded = td .expand (2 ).clone ()
700
710
_ = env .step (td )
701
711
@@ -712,7 +722,7 @@ def test_batch_unlocked(self, device):
712
722
with pytest .raises (RuntimeError , match = "batch_locked is a read-only property" ):
713
723
env .batch_locked = False
714
724
td = env .reset ()
715
- td ["action" ] = env .action_spec .rand ()
725
+ td ["action" ] = env .full_action_spec [ env . action_key ] .rand ()
716
726
td_expanded = td .expand (2 ).clone ()
717
727
td = env .step (td )
718
728
@@ -727,7 +737,7 @@ def test_batch_unlocked_with_batch_size(self, device):
727
737
env .batch_locked = False
728
738
729
739
td = env .reset ()
730
- td ["action" ] = env .action_spec .rand ()
740
+ td ["action" ] = env .full_action_spec [ env . action_key ] .rand ()
731
741
td_expanded = td .expand (2 , 2 ).reshape (- 1 ).to_tensordict ()
732
742
td = env .step (td )
733
743
@@ -803,7 +813,7 @@ def test_rollouts_chaining(self, max_steps, batch_size=(4,), epochs=4):
803
813
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
804
814
env = CountingEnv (max_steps = max_steps - 1 , batch_size = batch_size )
805
815
policy = CountingEnvCountPolicy (
806
- action_spec = env .action_spec , action_key = env .action_key
816
+ action_spec = env .full_action_spec [ env . action_key ] , action_key = env .action_key
807
817
)
808
818
809
819
input_td = env .reset ()
@@ -1010,7 +1020,7 @@ def test_mb_env_batch_lock(self, device, seed=0):
1010
1020
with pytest .raises (RuntimeError , match = "batch_locked is a read-only property" ):
1011
1021
mb_env .batch_locked = False
1012
1022
td = mb_env .reset ()
1013
- td ["action" ] = mb_env .action_spec .rand ()
1023
+ td ["action" ] = mb_env .full_action_spec [ mb_env . action_key ] .rand ()
1014
1024
td_expanded = td .unsqueeze (- 1 ).expand (10 , 2 ).reshape (- 1 ).to_tensordict ()
1015
1025
mb_env .step (td )
1016
1026
@@ -1028,7 +1038,7 @@ def test_mb_env_batch_lock(self, device, seed=0):
1028
1038
with pytest .raises (RuntimeError , match = "batch_locked is a read-only property" ):
1029
1039
mb_env .batch_locked = False
1030
1040
td = mb_env .reset ()
1031
- td ["action" ] = mb_env .action_spec .rand ()
1041
+ td ["action" ] = mb_env .full_action_spec [ mb_env . action_key ] .rand ()
1032
1042
td_expanded = td .expand (2 )
1033
1043
mb_env .step (td )
1034
1044
# we should be able to do a step with a tensordict that has been expended
@@ -1242,6 +1252,7 @@ def test_parallel_env(
1242
1252
N = N ,
1243
1253
)
1244
1254
td = TensorDict (source = {"action" : env0 .action_spec .rand ((N ,))}, batch_size = [N ])
1255
+ env_parallel .reset ()
1245
1256
td1 = env_parallel .step (td )
1246
1257
assert not td1 .is_shared ()
1247
1258
assert ("next" , "done" ) in td1 .keys (True )
@@ -1308,6 +1319,7 @@ def test_parallel_env_with_policy(
1308
1319
)
1309
1320
1310
1321
td = TensorDict (source = {"action" : env0 .action_spec .rand ((N ,))}, batch_size = [N ])
1322
+ env_parallel .reset ()
1311
1323
td1 = env_parallel .step (td )
1312
1324
assert not td1 .is_shared ()
1313
1325
assert ("next" , "done" ) in td1 .keys (True )
@@ -1715,7 +1727,7 @@ def test_parallel_env_reset_flag(
1715
1727
n_workers , lambda : CountingEnv (max_steps = max_steps , batch_size = batch_size )
1716
1728
)
1717
1729
env .set_seed (1 )
1718
- action = env .action_spec .rand ()
1730
+ action = env .full_action_spec [ env . action_key ] .rand ()
1719
1731
action [:] = 1
1720
1732
for i in range (max_steps ):
1721
1733
td = env .step (
@@ -1787,7 +1799,9 @@ def test_parallel_env_nested(
1787
1799
if not nested_done and not nested_reward and not nested_obs_action :
1788
1800
assert "data" not in td .keys ()
1789
1801
1790
- policy = CountingEnvCountPolicy (env .action_spec , env .action_key )
1802
+ policy = CountingEnvCountPolicy (
1803
+ env .full_action_spec [env .action_key ], env .action_key
1804
+ )
1791
1805
td = env .rollout (rollout_length , policy )
1792
1806
assert td .batch_size == (* batch_size , rollout_length )
1793
1807
if nested_done or nested_obs_action :
@@ -2558,6 +2572,7 @@ def main_collector(j, q=None):
2558
2572
total_frames = N * n_workers * 100 ,
2559
2573
storing_device = device ,
2560
2574
device = device ,
2575
+ trust_policy = True ,
2561
2576
cat_results = - 1 ,
2562
2577
)
2563
2578
single_collectors = [
@@ -2567,6 +2582,7 @@ def main_collector(j, q=None):
2567
2582
frames_per_batch = n_workers * 100 ,
2568
2583
total_frames = N * n_workers * 100 ,
2569
2584
storing_device = device ,
2585
+ trust_policy = True ,
2570
2586
device = device ,
2571
2587
)
2572
2588
for i in range (n_workers )
@@ -2662,18 +2678,24 @@ def test_nested_env(self, envclass):
2662
2678
else :
2663
2679
raise NotImplementedError
2664
2680
reset = env .reset ()
2665
- assert not isinstance (env .reward_spec , Composite )
2681
+ with pytest .warns (
2682
+ DeprecationWarning , match = "non-trivial"
2683
+ ) if envclass == "NestedCountingEnv" else contextlib .nullcontext ():
2684
+ assert not isinstance (env .reward_spec , Composite )
2666
2685
for done_key in env .done_keys :
2667
2686
assert (
2668
2687
env .full_done_spec [done_key ]
2669
2688
== env .output_spec [("full_done_spec" , * _unravel_key_to_tuple (done_key ))]
2670
2689
)
2671
- assert (
2672
- env .reward_spec
2673
- == env .output_spec [
2674
- ("full_reward_spec" , * _unravel_key_to_tuple (env .reward_key ))
2675
- ]
2676
- )
2690
+ with pytest .warns (
2691
+ DeprecationWarning , match = "non-trivial"
2692
+ ) if envclass == "NestedCountingEnv" else contextlib .nullcontext ():
2693
+ assert (
2694
+ env .reward_spec
2695
+ == env .output_spec [
2696
+ ("full_reward_spec" , * _unravel_key_to_tuple (env .reward_key ))
2697
+ ]
2698
+ )
2677
2699
if envclass == "NestedCountingEnv" :
2678
2700
for done_key in env .done_keys :
2679
2701
assert done_key in (("data" , "done" ), ("data" , "terminated" ))
@@ -2734,7 +2756,9 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3):
2734
2756
nested_dim ,
2735
2757
)
2736
2758
2737
- policy = CountingEnvCountPolicy (env .action_spec , env .action_key )
2759
+ policy = CountingEnvCountPolicy (
2760
+ env .full_action_spec [env .action_key ], env .action_key
2761
+ )
2738
2762
td = env .rollout (rollout_length , policy )
2739
2763
assert td .batch_size == (* batch_size , rollout_length )
2740
2764
assert td ["data" ].batch_size == (* batch_size , rollout_length , nested_dim )
@@ -2858,7 +2882,7 @@ class TestMultiKeyEnvs:
2858
2882
@pytest .mark .parametrize ("max_steps" , [2 , 5 ])
2859
2883
def test_rollout (self , batch_size , rollout_steps , max_steps , seed ):
2860
2884
env = MultiKeyCountingEnv (batch_size = batch_size , max_steps = max_steps )
2861
- policy = MultiKeyCountingEnvPolicy (full_action_spec = env .action_spec )
2885
+ policy = MultiKeyCountingEnvPolicy (full_action_spec = env .full_action_spec )
2862
2886
td = env .rollout (rollout_steps , policy = policy )
2863
2887
torch .manual_seed (seed )
2864
2888
check_rollout_consistency_multikey_env (td , max_steps = max_steps )
@@ -2924,11 +2948,17 @@ def test_parallel(
2924
2948
)
2925
2949
def test_mocking_envs (envclass ):
2926
2950
env = envclass ()
2927
- env .set_seed (100 )
2951
+ with pytest .warns (UserWarning , match = "model based" ) if isinstance (
2952
+ env , DummyModelBasedEnvBase
2953
+ ) else contextlib .nullcontext ():
2954
+ env .set_seed (100 )
2928
2955
reset = env .reset ()
2929
2956
_ = env .rand_step (reset )
2930
2957
r = env .rollout (3 )
2931
- check_env_specs (env , seed = 100 , return_contiguous = False )
2958
+ with pytest .warns (UserWarning , match = "model based" ) if isinstance (
2959
+ env , DummyModelBasedEnvBase
2960
+ ) else contextlib .nullcontext ():
2961
+ check_env_specs (env , seed = 100 , return_contiguous = False )
2932
2962
2933
2963
2934
2964
class TestTerminatedOrTruncated :
@@ -4019,7 +4049,7 @@ def test_parallel_partial_steps(
4019
4049
psteps [[1 , 3 ]] = True
4020
4050
td .set ("_step" , psteps )
4021
4051
4022
- td .set ("action" , penv .action_spec .one ())
4052
+ td .set ("action" , penv .full_action_spec [ penv . action_key ] .one ())
4023
4053
td = penv .step (td )
4024
4054
assert (td [0 ].get ("next" ) == 0 ).all ()
4025
4055
assert (td [1 ].get ("next" ) != 0 ).any ()
@@ -4042,7 +4072,7 @@ def test_parallel_partial_step_and_maybe_reset(
4042
4072
psteps [[1 , 3 ]] = True
4043
4073
td .set ("_step" , psteps )
4044
4074
4045
- td .set ("action" , penv .action_spec .one ())
4075
+ td .set ("action" , penv .full_action_spec [ penv . action_key ] .one ())
4046
4076
td , tdreset = penv .step_and_maybe_reset (td )
4047
4077
assert (td [0 ].get ("next" ) == 0 ).all ()
4048
4078
assert (td [1 ].get ("next" ) != 0 ).any ()
@@ -4063,7 +4093,7 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
4063
4093
psteps [[1 , 3 ]] = True
4064
4094
td .set ("_step" , psteps )
4065
4095
4066
- td .set ("action" , penv .action_spec .one ())
4096
+ td .set ("action" , penv .full_action_spec [ penv . action_key ] .one ())
4067
4097
td = penv .step (td )
4068
4098
assert (td [0 ].get ("next" ) == 0 ).all ()
4069
4099
assert (td [1 ].get ("next" ) != 0 ).any ()
@@ -4084,7 +4114,7 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
4084
4114
psteps [[1 , 3 ]] = True
4085
4115
td .set ("_step" , psteps )
4086
4116
4087
- td .set ("action" , penv .action_spec .one ())
4117
+ td .set ("action" , penv .full_action_spec [ penv . action_key ] .one ())
4088
4118
td = penv .step (td )
4089
4119
assert (td [0 ].get ("next" ) == 0 ).all ()
4090
4120
assert (td [1 ].get ("next" ) != 0 ).any ()
0 commit comments