Skip to content

Commit e679e71

Browse files
author
Vincent Moens
authored
[BugFix] Fix sequential step counts (#1838)
1 parent da7904e commit e679e71

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

test/test_transforms.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,6 +1477,22 @@ def test_step_count_gym(self):
14771477
env.rollout(1000)
14781478
check_env_specs(env)
14791479

1480+
@pytest.mark.skipif(not _has_gym, reason="no gym detected")
1481+
def test_step_count_gym_doublecount(self):
1482+
# tests that 2 truncations can be used together
1483+
env = TransformedEnv(
1484+
GymEnv(PENDULUM_VERSIONED),
1485+
Compose(
1486+
StepCounter(max_steps=2),
1487+
StepCounter(max_steps=3), # this one will be ignored
1488+
),
1489+
)
1490+
r = env.rollout(10, break_when_any_done=False)
1491+
assert (
1492+
r.get(("next", "truncated")).squeeze().nonzero().squeeze(-1)
1493+
== torch.arange(1, 10, 2)
1494+
).all()
1495+
14801496
@pytest.mark.skipif(not _has_dm_control, reason="no dm_control detected")
14811497
def test_step_count_dmc(self):
14821498
env = TransformedEnv(DMControlEnv("cheetah", "run"), StepCounter(max_steps=30))

torchrl/envs/transforms/transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5169,6 +5169,7 @@ def _reset(
51695169
tensordict_reset.set(step_count_key, step_count)
51705170
if self.max_steps is not None:
51715171
truncated = step_count >= self.max_steps
5172+
truncated = truncated | tensordict_reset.get(truncated_key, False)
51725173
if self.update_done:
51735174
# we assume no done after reset
51745175
tensordict_reset.set(done_key, truncated)
@@ -5187,8 +5188,10 @@ def _step(
51875188
step_count = tensordict.get(step_count_key)
51885189
next_step_count = step_count + 1
51895190
next_tensordict.set(step_count_key, next_step_count)
5191+
51905192
if self.max_steps is not None:
51915193
truncated = next_step_count >= self.max_steps
5194+
truncated = truncated | next_tensordict.get(truncated_key, False)
51925195
if self.update_done:
51935196
done = next_tensordict.get(done_key, None)
51945197
terminated = next_tensordict.get(terminated_key, None)

0 commit comments

Comments
 (0)