@@ -13206,28 +13206,7 @@ def test_single_trans_env_check(self):
13206
13206
ConditionalPolicySwitch (condition = condition , policy = policy_even ),
13207
13207
)
13208
13208
env = base_env .append_transform (transforms )
13209
- r = env .rollout (1000 , policy_odd , break_when_all_done = True )
13210
- assert r .shape [0 ] == 15
13211
- assert (r ["action" ] == 0 ).all ()
13212
- assert (
13213
- r ["step_count" ] == torch .arange (1 , r .numel () * 2 , 2 ).unsqueeze (- 1 )
13214
- ).all ()
13215
- assert r ["next" , "done" ].any ()
13216
-
13217
- # Player 1
13218
- condition = lambda td : ((td .get ("step_count" ) % 2 ) == 1 ).all ()
13219
- transforms = Compose (
13220
- StepCounter (),
13221
- ConditionalPolicySwitch (condition = condition , policy = policy_odd ),
13222
- )
13223
- env = base_env .append_transform (transforms )
13224
- r = env .rollout (1000 , policy_even , break_when_all_done = True )
13225
- assert r .shape [0 ] == 16
13226
- assert (r ["action" ] == 1 ).all ()
13227
- assert (
13228
- r ["step_count" ] == torch .arange (0 , r .numel () * 2 , 2 ).unsqueeze (- 1 )
13229
- ).all ()
13230
- assert r ["next" , "done" ].any ()
13209
+ env .check_env_specs ()
13231
13210
13232
13211
def _create_policy_odd (self , base_env ):
13233
13212
return WrapModule (
@@ -13324,43 +13303,95 @@ def make_env(max_count):
13324
13303
self ._test_env (env , policy_odd )
13325
13304
13326
13305
def test_transform_no_env (self ):
13327
- """tests the transform on dummy data, without an env."""
13328
- raise NotImplementedError
13306
+ policy_odd = lambda td : td
13307
+ policy_even = lambda td : td
13308
+ condition = lambda td : True
13309
+ transforms = ConditionalPolicySwitch (condition = condition , policy = policy_even )
13310
+ with pytest .raises (
13311
+ RuntimeError ,
13312
+ match = "ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional." ,
13313
+ ):
13314
+ transforms (TensorDict ())
13329
13315
13330
13316
def test_transform_compose (self ):
13331
- """tests the transform on dummy data, without an env but inside a Compose."""
13332
- raise NotImplementedError
13317
+ policy_odd = lambda td : td
13318
+ policy_even = lambda td : td
13319
+ condition = lambda td : True
13320
+ transforms = Compose (
13321
+ ConditionalPolicySwitch (condition = condition , policy = policy_even ),
13322
+ )
13323
+ with pytest .raises (
13324
+ RuntimeError ,
13325
+ match = "ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional." ,
13326
+ ):
13327
+ transforms (TensorDict ())
13333
13328
13334
13329
def test_transform_env (self ):
13335
- """tests the transform on a real env.
13336
-
13337
- If possible, do not use a mock env, as bugs may go unnoticed if the dynamic is too
13338
- simplistic. A call to reset() and step() should be tested independently, ie
13339
- a check that reset produces the desired output and that step() does too.
13330
+ base_env = CountingEnv (max_steps = 15 )
13331
+ condition = lambda td : ((td .get ("step_count" ) % 2 ) == 0 ).all ()
13332
+ # Player 0
13333
+ policy_odd = lambda td : td .set ("action" , env .action_spec .zero ())
13334
+ policy_even = lambda td : td .set ("action" , env .action_spec .one ())
13335
+ transforms = Compose (
13336
+ StepCounter (),
13337
+ ConditionalPolicySwitch (condition = condition , policy = policy_even ),
13338
+ )
13339
+ env = base_env .append_transform (transforms )
13340
+ env .check_env_specs ()
13341
+ r = env .rollout (1000 , policy_odd , break_when_all_done = True )
13342
+ assert r .shape [0 ] == 15
13343
+ assert (r ["action" ] == 0 ).all ()
13344
+ assert (
13345
+ r ["step_count" ] == torch .arange (1 , r .numel () * 2 , 2 ).unsqueeze (- 1 )
13346
+ ).all ()
13347
+ assert r ["next" , "done" ].any ()
13340
13348
13341
- """
13342
- raise NotImplementedError
13349
+ # Player 1
13350
+ condition = lambda td : ((td .get ("step_count" ) % 2 ) == 1 ).all ()
13351
+ transforms = Compose (
13352
+ StepCounter (),
13353
+ ConditionalPolicySwitch (condition = condition , policy = policy_odd ),
13354
+ )
13355
+ env = base_env .append_transform (transforms )
13356
+ r = env .rollout (1000 , policy_even , break_when_all_done = True )
13357
+ assert r .shape [0 ] == 16
13358
+ assert (r ["action" ] == 1 ).all ()
13359
+ assert (
13360
+ r ["step_count" ] == torch .arange (0 , r .numel () * 2 , 2 ).unsqueeze (- 1 )
13361
+ ).all ()
13362
+ assert r ["next" , "done" ].any ()
13343
13363
13344
13364
def test_transform_model (self ):
13345
- """tests the transform before an nn.Module that reads the output."""
13346
- raise NotImplementedError
13347
-
13348
- def test_transform_rb (self ):
13349
- """tests the transform when used with a replay buffer.
13350
-
13351
- If your transform is not supposed to work with a replay buffer, test that
13352
- an error will be raised when called or appended to a RB.
13365
+ policy_odd = lambda td : td
13366
+ policy_even = lambda td : td
13367
+ condition = lambda td : True
13368
+ transforms = nn .Sequential (
13369
+ ConditionalPolicySwitch (condition = condition , policy = policy_even ),
13370
+ )
13371
+ with pytest .raises (
13372
+ RuntimeError ,
13373
+ match = "ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional." ,
13374
+ ):
13375
+ transforms (TensorDict ())
13353
13376
13354
- """
13355
- raise NotImplementedError
13377
+ @pytest .mark .parametrize ("rbclass" , [ReplayBuffer , TensorDictReplayBuffer ])
13378
+ def test_transform_rb (self , rbclass ):
13379
+ policy_odd = lambda td : td
13380
+ policy_even = lambda td : td
13381
+ condition = lambda td : True
13382
+ rb = rbclass (storage = LazyTensorStorage (10 ))
13383
+ rb .append_transform (
13384
+ ConditionalPolicySwitch (condition = condition , policy = policy_even )
13385
+ )
13386
+ rb .extend (TensorDict (batch_size = [2 ]))
13387
+ with pytest .raises (
13388
+ RuntimeError ,
13389
+ match = "ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional." ,
13390
+ ):
13391
+ rb .sample (2 )
13356
13392
13357
13393
def test_transform_inverse (self ):
13358
- """tests the inverse transform. If not applicable, simply skip this test.
13359
-
13360
- If your transform is not supposed to work offline, test that
13361
- an error will be raised when called in a nn.Module.
13362
- """
13363
- raise NotImplementedError
13394
+ return
13364
13395
13365
13396
13366
13397
if __name__ == "__main__" :
0 commit comments