File tree Expand file tree Collapse file tree 2 files changed +19
-0
lines changed Expand file tree Collapse file tree 2 files changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -1477,6 +1477,22 @@ def test_step_count_gym(self):
1477
1477
env .rollout (1000 )
1478
1478
check_env_specs (env )
1479
1479
1480
+ @pytest .mark .skipif (not _has_gym , reason = "no gym detected" )
1481
+ def test_step_count_gym_doublecount (self ):
1482
+ # tests that 2 truncations can be used together
1483
+ env = TransformedEnv (
1484
+ GymEnv (PENDULUM_VERSIONED ),
1485
+ Compose (
1486
+ StepCounter (max_steps = 2 ),
1487
+ StepCounter (max_steps = 3 ), # this one will be ignored
1488
+ ),
1489
+ )
1490
+ r = env .rollout (10 , break_when_any_done = False )
1491
+ assert (
1492
+ r .get (("next" , "truncated" )).squeeze ().nonzero ().squeeze (- 1 )
1493
+ == torch .arange (1 , 10 , 2 )
1494
+ ).all ()
1495
+
1480
1496
@pytest .mark .skipif (not _has_dm_control , reason = "no dm_control detected" )
1481
1497
def test_step_count_dmc (self ):
1482
1498
env = TransformedEnv (DMControlEnv ("cheetah" , "run" ), StepCounter (max_steps = 30 ))
Original file line number Diff line number Diff line change @@ -5169,6 +5169,7 @@ def _reset(
5169
5169
tensordict_reset .set (step_count_key , step_count )
5170
5170
if self .max_steps is not None :
5171
5171
truncated = step_count >= self .max_steps
5172
+ truncated = truncated | tensordict_reset .get (truncated_key , False )
5172
5173
if self .update_done :
5173
5174
# we assume no done after reset
5174
5175
tensordict_reset .set (done_key , truncated )
@@ -5187,8 +5188,10 @@ def _step(
5187
5188
step_count = tensordict .get (step_count_key )
5188
5189
next_step_count = step_count + 1
5189
5190
next_tensordict .set (step_count_key , next_step_count )
5191
+
5190
5192
if self .max_steps is not None :
5191
5193
truncated = next_step_count >= self .max_steps
5194
+ truncated = truncated | next_tensordict .get (truncated_key , False )
5192
5195
if self .update_done :
5193
5196
done = next_tensordict .get (done_key , None )
5194
5197
terminated = next_tensordict .get (terminated_key , None )
You can’t perform that action at this time.
0 commit comments