20
20
21
21
import tensordict .tensordict
22
22
import torch
23
+ from tensordict .nn import WrapModule
23
24
24
25
from torchrl .collectors import MultiSyncDataCollector
25
26
@@ -13208,7 +13209,9 @@ def test_single_trans_env_check(self):
13208
13209
r = env .rollout (1000 , policy_odd , break_when_all_done = True )
13209
13210
assert r .shape [0 ] == 15
13210
13211
assert (r ["action" ] == 0 ).all ()
13211
- assert (r ["step_count" ] == torch .arange (1 , r .numel () * 2 , 2 ).unsqueeze (- 1 )).all ()
13212
+ assert (
13213
+ r ["step_count" ] == torch .arange (1 , r .numel () * 2 , 2 ).unsqueeze (- 1 )
13214
+ ).all ()
13212
13215
assert r ["next" , "done" ].any ()
13213
13216
13214
13217
# Player 1
@@ -13221,58 +13224,104 @@ def test_single_trans_env_check(self):
13221
13224
r = env .rollout (1000 , policy_even , break_when_all_done = True )
13222
13225
assert r .shape [0 ] == 16
13223
13226
assert (r ["action" ] == 1 ).all ()
13224
- assert (r ["step_count" ] == torch .arange (0 , r .numel () * 2 , 2 ).unsqueeze (- 1 )).all ()
13227
+ assert (
13228
+ r ["step_count" ] == torch .arange (0 , r .numel () * 2 , 2 ).unsqueeze (- 1 )
13229
+ ).all ()
13225
13230
assert r ["next" , "done" ].any ()
13226
13231
13232
+ def _create_policy_odd (self , base_env ):
13233
+ return WrapModule (
13234
+ lambda td , base_env = base_env : td .set (
13235
+ "action" , base_env .action_spec_unbatched .zero (td .shape )
13236
+ ),
13237
+ out_keys = ["action" ],
13238
+ )
13227
13239
13228
- def test_trans_serial_env_check (self ):
13229
- def make_env (max_count ):
13230
- def make ():
13231
- base_env = CountingEnv (max_steps = max_count )
13232
- transforms =
13233
- return base_env .append_transform (transforms )
13234
- return make
13235
-
13236
- base_env = SerialEnv (3 ,
13237
- [partial (CountingEnv , 6 ), partial (CountingEnv , 7 ), partial (CountingEnv , 8 )])
13238
- condition = lambda td : ((td .get ("step_count" ) % 2 ) == 0 )
13239
- policy_odd = lambda td , base_env = base_env : td .set ("action" , base_env .action_spec .zero ())
13240
- policy_even = lambda td , base_env = base_env : td .set ("action" , base_env .action_spec .one ())
13241
- env = base_env .append_transform (Compose (
13242
- StepCounter (),
13243
- ConditionalPolicySwitch (condition = condition , policy = policy_even ),
13244
- ))
13245
- r = env .rollout (100 , break_when_all_done = False )
13246
- print (r ["step_count" ].squeeze ())
13240
+ def _create_policy_even (self , base_env ):
13241
+ return WrapModule (
13242
+ lambda td , base_env = base_env : td .set (
13243
+ "action" , base_env .action_spec_unbatched .one (td .shape )
13244
+ ),
13245
+ out_keys = ["action" ],
13246
+ )
13247
+
13248
+ def _create_transforms (self , condition , policy_even ):
13249
+ return Compose (
13250
+ StepCounter (),
13251
+ ConditionalPolicySwitch (condition = condition , policy = policy_even ),
13252
+ )
13247
13253
13254
+ def _make_env (self , max_count , env_cls ):
13255
+ torch .manual_seed (0 )
13256
+ condition = lambda td : ((td .get ("step_count" ) % 2 ) == 0 ).squeeze (- 1 )
13257
+ base_env = env_cls (max_steps = max_count )
13258
+ policy_even = self ._create_policy_even (base_env )
13259
+ transforms = self ._create_transforms (condition , policy_even )
13260
+ return base_env .append_transform (transforms )
13261
+
13262
+ def _test_env (self , env , policy_odd ):
13263
+ env .check_env_specs ()
13264
+ env .set_seed (0 )
13265
+ r = env .rollout (100 , policy_odd , break_when_any_done = False )
13266
+ # Check results are independent: one reset / step in one env should not impact results in another
13267
+ r0 , r1 , r2 = r .unbind (0 )
13268
+ r0_split = r0 .split (6 )
13269
+ assert all (((r == r0_split [0 ][: r .numel ()]).all () for r in r0_split [1 :]))
13270
+ r1_split = r1 .split (7 )
13271
+ assert all (((r == r1_split [0 ][: r .numel ()]).all () for r in r1_split [1 :]))
13272
+ r2_split = r2 .split (8 )
13273
+ assert all (((r == r2_split [0 ][: r .numel ()]).all () for r in r2_split [1 :]))
13274
+
13275
+ def test_trans_serial_env_check (self ):
13276
+ torch .manual_seed (0 )
13277
+ base_env = SerialEnv (
13278
+ 3 ,
13279
+ [partial (CountingEnv , 6 ), partial (CountingEnv , 7 ), partial (CountingEnv , 8 )],
13280
+ batch_locked = False ,
13281
+ )
13282
+ condition = lambda td : ((td .get ("step_count" ) % 2 ) == 0 ).squeeze (- 1 )
13283
+ policy_odd = self ._create_policy_odd (base_env )
13284
+ policy_even = self ._create_policy_even (base_env )
13285
+ transforms = self ._create_transforms (condition , policy_even )
13286
+ env = base_env .append_transform (transforms )
13287
+ self ._test_env (env , policy_odd )
13248
13288
13249
13289
def test_trans_parallel_env_check (self ):
13250
- """tests that a transformed paprallel env (TransformedEnv(ParallelEnv(N, lambda: env()), transform)) passes the check_env_specs test."""
13251
- raise NotImplementedError
13290
+ torch .manual_seed (0 )
13291
+ base_env = ParallelEnv (
13292
+ 3 ,
13293
+ [partial (CountingEnv , 6 ), partial (CountingEnv , 7 ), partial (CountingEnv , 8 )],
13294
+ batch_locked = False ,
13295
+ mp_start_method = mp_ctx ,
13296
+ )
13297
+ condition = lambda td : ((td .get ("step_count" ) % 2 ) == 0 ).squeeze (- 1 )
13298
+ policy_odd = self ._create_policy_odd (base_env )
13299
+ policy_even = self ._create_policy_even (base_env )
13300
+ transforms = self ._create_transforms (condition , policy_even )
13301
+ env = base_env .append_transform (transforms )
13302
+ self ._test_env (env , policy_odd )
13252
13303
13253
13304
def test_serial_trans_env_check (self ):
13254
- condition = lambda td : ((td .get ("step_count" ) % 2 ) == 0 ).all ()
13255
- # Player 0
13256
- policy_odd = lambda td : td .set ("action" , env .action_spec .zero ())
13257
- policy_even = lambda td : td .set ("action" , env .action_spec .one ())
13305
+ condition = lambda td : ((td .get ("step_count" ) % 2 ) == 0 ).squeeze (- 1 )
13306
+ policy_odd = self ._create_policy_odd (CountingEnv ())
13307
+
13258
13308
def make_env (max_count ):
13259
- def make ():
13260
- base_env = CountingEnv (max_steps = max_count )
13261
- transforms = Compose (
13262
- StepCounter (),
13263
- ConditionalPolicySwitch (condition = condition , policy = policy_even ),
13264
- )
13265
- return base_env .append_transform (transforms )
13266
- return make
13309
+ return partial (self ._make_env , max_count , CountingEnv )
13267
13310
13268
- env = SerialEnv (3 ,
13269
- [make_env (6 ), make_env (7 ), make_env (8 )])
13270
- r = env .rollout (100 , break_when_all_done = False )
13271
- print (r ["step_count" ].squeeze ())
13311
+ env = SerialEnv (3 , [make_env (6 ), make_env (7 ), make_env (8 )])
13312
+ self ._test_env (env , policy_odd )
13272
13313
13273
13314
def test_parallel_trans_env_check (self ):
13274
- """tests that a parallel transformed env (ParallelEnv(N, lambda: TransformedEnv(env, transform))) passes the check_env_specs test."""
13275
- raise NotImplementedError
13315
+ condition = lambda td : ((td .get ("step_count" ) % 2 ) == 0 ).squeeze (- 1 )
13316
+ policy_odd = self ._create_policy_odd (CountingEnv ())
13317
+
13318
+ def make_env (max_count ):
13319
+ return partial (self ._make_env , max_count , CountingEnv )
13320
+
13321
+ env = ParallelEnv (
13322
+ 3 , [make_env (6 ), make_env (7 ), make_env (8 )], mp_start_method = mp_ctx
13323
+ )
13324
+ self ._test_env (env , policy_odd )
13276
13325
13277
13326
def test_transform_no_env (self ):
13278
13327
"""tests the transform on dummy data, without an env."""
0 commit comments