Skip to content

Commit 28c3c7a

Browse files
author
Vincent Moens
committed
[BugFix] Avoid calling reset during env init
ghstack-source-id: 5ab8281 Pull Request resolved: #2770 (cherry picked from commit 09e93c1)
1 parent 6b0d5b8 commit 28c3c7a

File tree

4 files changed

+108
-48
lines changed

4 files changed

+108
-48
lines changed

test/test_env.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@
7373
from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator
7474
from torchrl.modules.tensordict_module import WorldModelWrapper
7575

76+
pytestmark = [
77+
pytest.mark.filterwarnings("error"),
78+
pytest.mark.filterwarnings(
79+
"ignore:Got multiple backends for torchrl.data.replay_buffers.storages"
80+
),
81+
]
82+
7683
gym_version = None
7784
if _has_gym:
7885
try:
@@ -232,7 +239,7 @@ def test_run_type_checks(self):
232239
check_env_specs(env)
233240
env._run_type_checks = True
234241
check_env_specs(env)
235-
env.output_spec.unlock_()
242+
env.output_spec.unlock_(recurse=True)
236243
# check type check on done
237244
env.output_spec["full_done_spec", "done"].dtype = torch.int
238245
with pytest.raises(TypeError, match="expected done.dtype to"):
@@ -292,8 +299,8 @@ def test_single_env_spec(self):
292299
assert not env.output_spec_unbatched.shape
293300
assert not env.full_reward_spec_unbatched.shape
294301

295-
assert env.action_spec_unbatched.shape
296-
assert env.reward_spec_unbatched.shape
302+
assert env.full_action_spec_unbatched[env.action_key].shape
303+
assert env.full_reward_spec_unbatched[env.reward_key].shape
297304

298305
assert env.output_spec.is_in(env.output_spec_unbatched.zeros(env.shape))
299306
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))
@@ -307,7 +314,10 @@ def forward(self, values):
307314
return values.argmax(-1)
308315

309316
policy = nn.Sequential(
310-
nn.Linear(env.observation_spec["observation"].shape[-1], env.action_spec.n),
317+
nn.Linear(
318+
env.observation_spec["observation"].shape[-1],
319+
env.full_action_spec[env.action_key].n,
320+
),
311321
ArgMaxModule(),
312322
)
313323
env.rollout(10, policy)
@@ -507,7 +517,7 @@ def test_auto_cast_to_device(self, break_when_any_done):
507517
policy = Actor(
508518
nn.Linear(
509519
env.observation_spec["observation"].shape[-1],
510-
env.action_spec.shape[-1],
520+
env.full_action_spec[env.action_key].shape[-1],
511521
device="cuda:0",
512522
),
513523
in_keys=["observation"],
@@ -538,7 +548,7 @@ def test_auto_cast_to_device(self, break_when_any_done):
538548
def test_env_seed(self, env_name, frame_skip, seed=0):
539549
env_name = env_name()
540550
env = GymEnv(env_name, frame_skip=frame_skip)
541-
action = env.action_spec.rand()
551+
action = env.full_action_spec[env.action_key].rand()
542552

543553
env.set_seed(seed)
544554
td0a = env.reset()
@@ -624,7 +634,7 @@ def test_env_base_reset_flag(self, batch_size, max_steps=3):
624634
env = CountingEnv(max_steps=max_steps, batch_size=batch_size)
625635
env.set_seed(1)
626636

627-
action = env.action_spec.rand()
637+
action = env.full_action_spec[env.action_key].rand()
628638
action[:] = 1
629639

630640
for i in range(max_steps):
@@ -695,7 +705,7 @@ def test_batch_locked(self, device):
695705
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
696706
env.batch_locked = False
697707
td = env.reset()
698-
td["action"] = env.action_spec.rand()
708+
td["action"] = env.full_action_spec[env.action_key].rand()
699709
td_expanded = td.expand(2).clone()
700710
_ = env.step(td)
701711

@@ -712,7 +722,7 @@ def test_batch_unlocked(self, device):
712722
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
713723
env.batch_locked = False
714724
td = env.reset()
715-
td["action"] = env.action_spec.rand()
725+
td["action"] = env.full_action_spec[env.action_key].rand()
716726
td_expanded = td.expand(2).clone()
717727
td = env.step(td)
718728

@@ -727,7 +737,7 @@ def test_batch_unlocked_with_batch_size(self, device):
727737
env.batch_locked = False
728738

729739
td = env.reset()
730-
td["action"] = env.action_spec.rand()
740+
td["action"] = env.full_action_spec[env.action_key].rand()
731741
td_expanded = td.expand(2, 2).reshape(-1).to_tensordict()
732742
td = env.step(td)
733743

@@ -803,7 +813,7 @@ def test_rollouts_chaining(self, max_steps, batch_size=(4,), epochs=4):
803813
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
804814
env = CountingEnv(max_steps=max_steps - 1, batch_size=batch_size)
805815
policy = CountingEnvCountPolicy(
806-
action_spec=env.action_spec, action_key=env.action_key
816+
action_spec=env.full_action_spec[env.action_key], action_key=env.action_key
807817
)
808818

809819
input_td = env.reset()
@@ -1010,7 +1020,7 @@ def test_mb_env_batch_lock(self, device, seed=0):
10101020
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
10111021
mb_env.batch_locked = False
10121022
td = mb_env.reset()
1013-
td["action"] = mb_env.action_spec.rand()
1023+
td["action"] = mb_env.full_action_spec[mb_env.action_key].rand()
10141024
td_expanded = td.unsqueeze(-1).expand(10, 2).reshape(-1).to_tensordict()
10151025
mb_env.step(td)
10161026

@@ -1028,7 +1038,7 @@ def test_mb_env_batch_lock(self, device, seed=0):
10281038
with pytest.raises(RuntimeError, match="batch_locked is a read-only property"):
10291039
mb_env.batch_locked = False
10301040
td = mb_env.reset()
1031-
td["action"] = mb_env.action_spec.rand()
1041+
td["action"] = mb_env.full_action_spec[mb_env.action_key].rand()
10321042
td_expanded = td.expand(2)
10331043
mb_env.step(td)
10341044
# we should be able to do a step with a tensordict that has been expended
@@ -1242,6 +1252,7 @@ def test_parallel_env(
12421252
N=N,
12431253
)
12441254
td = TensorDict(source={"action": env0.action_spec.rand((N,))}, batch_size=[N])
1255+
env_parallel.reset()
12451256
td1 = env_parallel.step(td)
12461257
assert not td1.is_shared()
12471258
assert ("next", "done") in td1.keys(True)
@@ -1308,6 +1319,7 @@ def test_parallel_env_with_policy(
13081319
)
13091320

13101321
td = TensorDict(source={"action": env0.action_spec.rand((N,))}, batch_size=[N])
1322+
env_parallel.reset()
13111323
td1 = env_parallel.step(td)
13121324
assert not td1.is_shared()
13131325
assert ("next", "done") in td1.keys(True)
@@ -1715,7 +1727,7 @@ def test_parallel_env_reset_flag(
17151727
n_workers, lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size)
17161728
)
17171729
env.set_seed(1)
1718-
action = env.action_spec.rand()
1730+
action = env.full_action_spec[env.action_key].rand()
17191731
action[:] = 1
17201732
for i in range(max_steps):
17211733
td = env.step(
@@ -1787,7 +1799,9 @@ def test_parallel_env_nested(
17871799
if not nested_done and not nested_reward and not nested_obs_action:
17881800
assert "data" not in td.keys()
17891801

1790-
policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
1802+
policy = CountingEnvCountPolicy(
1803+
env.full_action_spec[env.action_key], env.action_key
1804+
)
17911805
td = env.rollout(rollout_length, policy)
17921806
assert td.batch_size == (*batch_size, rollout_length)
17931807
if nested_done or nested_obs_action:
@@ -2558,6 +2572,7 @@ def main_collector(j, q=None):
25582572
total_frames=N * n_workers * 100,
25592573
storing_device=device,
25602574
device=device,
2575+
trust_policy=True,
25612576
cat_results=-1,
25622577
)
25632578
single_collectors = [
@@ -2567,6 +2582,7 @@ def main_collector(j, q=None):
25672582
frames_per_batch=n_workers * 100,
25682583
total_frames=N * n_workers * 100,
25692584
storing_device=device,
2585+
trust_policy=True,
25702586
device=device,
25712587
)
25722588
for i in range(n_workers)
@@ -2662,18 +2678,24 @@ def test_nested_env(self, envclass):
26622678
else:
26632679
raise NotImplementedError
26642680
reset = env.reset()
2665-
assert not isinstance(env.reward_spec, Composite)
2681+
with pytest.warns(
2682+
DeprecationWarning, match="non-trivial"
2683+
) if envclass == "NestedCountingEnv" else contextlib.nullcontext():
2684+
assert not isinstance(env.reward_spec, Composite)
26662685
for done_key in env.done_keys:
26672686
assert (
26682687
env.full_done_spec[done_key]
26692688
== env.output_spec[("full_done_spec", *_unravel_key_to_tuple(done_key))]
26702689
)
2671-
assert (
2672-
env.reward_spec
2673-
== env.output_spec[
2674-
("full_reward_spec", *_unravel_key_to_tuple(env.reward_key))
2675-
]
2676-
)
2690+
with pytest.warns(
2691+
DeprecationWarning, match="non-trivial"
2692+
) if envclass == "NestedCountingEnv" else contextlib.nullcontext():
2693+
assert (
2694+
env.reward_spec
2695+
== env.output_spec[
2696+
("full_reward_spec", *_unravel_key_to_tuple(env.reward_key))
2697+
]
2698+
)
26772699
if envclass == "NestedCountingEnv":
26782700
for done_key in env.done_keys:
26792701
assert done_key in (("data", "done"), ("data", "terminated"))
@@ -2734,7 +2756,9 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3):
27342756
nested_dim,
27352757
)
27362758

2737-
policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
2759+
policy = CountingEnvCountPolicy(
2760+
env.full_action_spec[env.action_key], env.action_key
2761+
)
27382762
td = env.rollout(rollout_length, policy)
27392763
assert td.batch_size == (*batch_size, rollout_length)
27402764
assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim)
@@ -2858,7 +2882,7 @@ class TestMultiKeyEnvs:
28582882
@pytest.mark.parametrize("max_steps", [2, 5])
28592883
def test_rollout(self, batch_size, rollout_steps, max_steps, seed):
28602884
env = MultiKeyCountingEnv(batch_size=batch_size, max_steps=max_steps)
2861-
policy = MultiKeyCountingEnvPolicy(full_action_spec=env.action_spec)
2885+
policy = MultiKeyCountingEnvPolicy(full_action_spec=env.full_action_spec)
28622886
td = env.rollout(rollout_steps, policy=policy)
28632887
torch.manual_seed(seed)
28642888
check_rollout_consistency_multikey_env(td, max_steps=max_steps)
@@ -2924,11 +2948,17 @@ def test_parallel(
29242948
)
29252949
def test_mocking_envs(envclass):
29262950
env = envclass()
2927-
env.set_seed(100)
2951+
with pytest.warns(UserWarning, match="model based") if isinstance(
2952+
env, DummyModelBasedEnvBase
2953+
) else contextlib.nullcontext():
2954+
env.set_seed(100)
29282955
reset = env.reset()
29292956
_ = env.rand_step(reset)
29302957
r = env.rollout(3)
2931-
check_env_specs(env, seed=100, return_contiguous=False)
2958+
with pytest.warns(UserWarning, match="model based") if isinstance(
2959+
env, DummyModelBasedEnvBase
2960+
) else contextlib.nullcontext():
2961+
check_env_specs(env, seed=100, return_contiguous=False)
29322962

29332963

29342964
class TestTerminatedOrTruncated:
@@ -4019,7 +4049,7 @@ def test_parallel_partial_steps(
40194049
psteps[[1, 3]] = True
40204050
td.set("_step", psteps)
40214051

4022-
td.set("action", penv.action_spec.one())
4052+
td.set("action", penv.full_action_spec[penv.action_key].one())
40234053
td = penv.step(td)
40244054
assert (td[0].get("next") == 0).all()
40254055
assert (td[1].get("next") != 0).any()
@@ -4042,7 +4072,7 @@ def test_parallel_partial_step_and_maybe_reset(
40424072
psteps[[1, 3]] = True
40434073
td.set("_step", psteps)
40444074

4045-
td.set("action", penv.action_spec.one())
4075+
td.set("action", penv.full_action_spec[penv.action_key].one())
40464076
td, tdreset = penv.step_and_maybe_reset(td)
40474077
assert (td[0].get("next") == 0).all()
40484078
assert (td[1].get("next") != 0).any()
@@ -4063,7 +4093,7 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
40634093
psteps[[1, 3]] = True
40644094
td.set("_step", psteps)
40654095

4066-
td.set("action", penv.action_spec.one())
4096+
td.set("action", penv.full_action_spec[penv.action_key].one())
40674097
td = penv.step(td)
40684098
assert (td[0].get("next") == 0).all()
40694099
assert (td[1].get("next") != 0).any()
@@ -4084,7 +4114,7 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
40844114
psteps[[1, 3]] = True
40854115
td.set("_step", psteps)
40864116

4087-
td.set("action", penv.action_spec.one())
4117+
td.set("action", penv.full_action_spec[penv.action_key].one())
40884118
td = penv.step(td)
40894119
assert (td[0].get("next") == 0).all()
40904120
assert (td[1].get("next") != 0).any()

torchrl/envs/batched_envs.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import gc
1111

1212
import os
13+
import time
1314
import weakref
1415
from collections import OrderedDict
1516
from copy import deepcopy
@@ -1616,11 +1617,7 @@ def step_and_maybe_reset(
16161617
for i, _data in zip(workers_range, data):
16171618
self.parent_channels[i].send(("step_and_maybe_reset", _data))
16181619

1619-
for i in workers_range:
1620-
event = self._events[i]
1621-
event.wait(self.BATCHED_PIPE_TIMEOUT)
1622-
event.clear()
1623-
1620+
self._wait_for_workers(workers_range)
16241621
if self._non_tensor_keys:
16251622
non_tensor_tds = []
16261623
for i in workers_range:
@@ -1670,6 +1667,36 @@ def step_and_maybe_reset(
16701667

16711668
return tensordict, tensordict_
16721669

1670+
def _wait_for_workers(self, workers_range):
1671+
workers_range_consume = set(workers_range)
1672+
t0 = time.time()
1673+
while (
1674+
len(workers_range_consume)
1675+
and (time.time() - t0) < self.BATCHED_PIPE_TIMEOUT
1676+
):
1677+
for i in workers_range:
1678+
if i not in workers_range_consume:
1679+
continue
1680+
worker = self._workers[i]
1681+
if worker.is_alive():
1682+
event: mp.Event = self._events[i]
1683+
if event.is_set():
1684+
workers_range_consume.discard(i)
1685+
event.clear()
1686+
else:
1687+
continue
1688+
else:
1689+
try:
1690+
self._shutdown_workers()
1691+
finally:
1692+
raise RuntimeError(f"Cannot proceed, worker {i} dead.")
1693+
# event.wait(self.BATCHED_PIPE_TIMEOUT)
1694+
if len(workers_range_consume):
1695+
raise RuntimeError(
1696+
f"Failed to run all workers within the {self.BATCHED_PIPE_TIMEOUT} sec time limit. This "
1697+
f"threshold can be increased via the BATCHED_PIPE_TIMEOUT env variable."
1698+
)
1699+
16731700
def _step_no_buffers(
16741701
self, tensordict: TensorDictBase
16751702
) -> Tuple[TensorDictBase, TensorDictBase]:
@@ -1806,10 +1833,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
18061833
for i in workers_range:
18071834
self.parent_channels[i].send(("step", data[i]))
18081835

1809-
for i in workers_range:
1810-
event = self._events[i]
1811-
event.wait(self.BATCHED_PIPE_TIMEOUT)
1812-
event.clear()
1836+
self._wait_for_workers(workers_range)
18131837

18141838
if self._non_tensor_keys:
18151839
non_tensor_tds = []
@@ -1975,10 +1999,7 @@ def tentative_update(val, other):
19751999
for i, out in outs:
19762000
self.parent_channels[i].send(out)
19772001

1978-
for i, _ in outs:
1979-
event = self._events[i]
1980-
event.wait(self.BATCHED_PIPE_TIMEOUT)
1981-
event.clear()
2002+
self._wait_for_workers(list(zip(*outs))[0])
19822003

19832004
workers_nontensor = []
19842005
if self._non_tensor_keys:

0 commit comments

Comments
 (0)