Skip to content

Commit 3e42e7a

Browse files
author
Vincent Moens
committed
[BugFix] Fix update shape mismatch in _skip_tensordict
ghstack-source-id: 27e7d44 Pull Request resolved: #2792
1 parent 76aa9bc commit 3e42e7a

File tree

7 files changed

+80
-22
lines changed

7 files changed

+80
-22
lines changed

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13784,7 +13784,7 @@ def policy(td):
1378413784
assert r1["before_count"].max() == 18
1378513785
assert r1["after_count"].max() == 6
1378613786
finally:
13787-
env.close()
13787+
env.close(raise_if_closed=False)
1378813788

1378913789
@pytest.mark.parametrize("bwad", [False, True])
1379013790
def test_serial_trans_env_check(self, bwad):

torchrl/envs/batched_envs.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
clear_mpi_env_vars,
5858
)
5959

60+
_CONSOLIDATE_ERR_CAPTURE = (
61+
"TensorDict.consolidate failed. You can deactivate the tensordict consolidation via the "
62+
"`consolidate` keyword argument of the ParallelEnv constructor."
63+
)
64+
6065

6166
def _check_start(fun):
6267
def decorated_fun(self: BatchedEnvBase, *args, **kwargs):
@@ -307,6 +312,7 @@ def __init__(
307312
non_blocking: bool = False,
308313
mp_start_method: str = None,
309314
use_buffers: bool = None,
315+
consolidate: bool = True,
310316
):
311317
super().__init__(device=device)
312318
self.serial_for_single = serial_for_single
@@ -315,6 +321,7 @@ def __init__(
315321
self.num_threads = num_threads
316322
self._cache_in_keys = None
317323
self._use_buffers = use_buffers
324+
self.consolidate = consolidate
318325

319326
self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
320327
if callable(create_env_fn):
@@ -841,9 +848,12 @@ def __repr__(self) -> str:
841848
f"\n\tbatch_size={self.batch_size})"
842849
)
843850

844-
def close(self) -> None:
851+
def close(self, *, raise_if_closed: bool = True) -> None:
845852
if self.is_closed:
846-
raise RuntimeError("trying to close a closed environment")
853+
if raise_if_closed:
854+
raise RuntimeError("trying to close a closed environment")
855+
else:
856+
return
847857
if self._verbose:
848858
torchrl_logger.info(f"closing {self.__class__.__name__}")
849859

@@ -1470,6 +1480,12 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
14701480
"_non_tensor_keys": self._non_tensor_keys,
14711481
}
14721482
)
1483+
else:
1484+
kwargs[idx].update(
1485+
{
1486+
"consolidate": self.consolidate,
1487+
}
1488+
)
14731489
process = proc_fun(target=func, kwargs=kwargs[idx])
14741490
process.daemon = True
14751491
process.start()
@@ -1526,7 +1542,16 @@ def _step_and_maybe_reset_no_buffers(
15261542
else:
15271543
workers_range = range(self.num_workers)
15281544

1529-
td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
1545+
if self.consolidate:
1546+
try:
1547+
td = tensordict.consolidate(
1548+
share_memory=True, inplace=True, num_threads=1
1549+
)
1550+
except Exception as err:
1551+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
1552+
else:
1553+
td = tensordict
1554+
15301555
for i in workers_range:
15311556
# We send the same td multiple times as it is in shared mem and we just need to index it
15321557
# in each process.
@@ -1804,7 +1829,16 @@ def _step_no_buffers(
18041829
else:
18051830
workers_range = range(self.num_workers)
18061831

1807-
data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
1832+
if self.consolidate:
1833+
try:
1834+
data = tensordict.consolidate(
1835+
share_memory=True, inplace=True, num_threads=1
1836+
)
1837+
except Exception as err:
1838+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
1839+
else:
1840+
data = tensordict
1841+
18081842
for i, local_data in zip(workers_range, data.unbind(0)):
18091843
self.parent_channels[i].send(("step", local_data))
18101844
# for i in range(data.shape[0]):
@@ -2026,9 +2060,14 @@ def _reset_no_buffers(
20262060
) -> Tuple[TensorDictBase, TensorDictBase]:
20272061
if is_tensor_collection(tensordict):
20282062
# tensordict = tensordict.consolidate(share_memory=True, num_threads=1)
2029-
tensordict = tensordict.consolidate(
2030-
share_memory=True, num_threads=1
2031-
).unbind(0)
2063+
if self.consolidate:
2064+
try:
2065+
tensordict = tensordict.consolidate(
2066+
share_memory=True, num_threads=1
2067+
)
2068+
except Exception as err:
2069+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
2070+
tensordict = tensordict.unbind(0)
20322071
else:
20332072
tensordict = [None] * self.num_workers
20342073
out_tds = [None] * self.num_workers
@@ -2545,6 +2584,7 @@ def _run_worker_pipe_direct(
25452584
has_lazy_inputs: bool = False,
25462585
verbose: bool = False,
25472586
num_threads: int | None = None, # for fork start method
2587+
consolidate: bool = True,
25482588
) -> None:
25492589
if num_threads is not None:
25502590
torch.set_num_threads(num_threads)
@@ -2634,9 +2674,18 @@ def _run_worker_pipe_direct(
26342674
event.record()
26352675
event.synchronize()
26362676
mp_event.set()
2637-
child_pipe.send(
2638-
cur_td.consolidate(share_memory=True, inplace=True, num_threads=1)
2639-
)
2677+
if consolidate:
2678+
try:
2679+
child_pipe.send(
2680+
cur_td.consolidate(
2681+
share_memory=True, inplace=True, num_threads=1
2682+
)
2683+
)
2684+
except Exception as err:
2685+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
2686+
else:
2687+
child_pipe.send(cur_td)
2688+
26402689
del cur_td
26412690

26422691
elif cmd == "step":
@@ -2650,9 +2699,18 @@ def _run_worker_pipe_direct(
26502699
event.record()
26512700
event.synchronize()
26522701
mp_event.set()
2653-
child_pipe.send(
2654-
next_td.consolidate(share_memory=True, inplace=True, num_threads=1)
2655-
)
2702+
if consolidate:
2703+
try:
2704+
child_pipe.send(
2705+
next_td.consolidate(
2706+
share_memory=True, inplace=True, num_threads=1
2707+
)
2708+
)
2709+
except Exception as err:
2710+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
2711+
else:
2712+
child_pipe.send(next_td)
2713+
26562714
del next_td
26572715

26582716
elif cmd == "step_and_maybe_reset":

torchrl/envs/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3651,7 +3651,7 @@ def _select_observation_keys(self, tensordict: TensorDictBase) -> Iterator[str]:
36513651
if key.rfind("observation") >= 0:
36523652
yield key
36533653

3654-
def close(self):
3654+
def close(self, *, raise_if_closed: bool = True):
36553655
self.is_closed = True
36563656

36573657
def __del__(self):
@@ -3843,7 +3843,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
38433843
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
38443844
raise NotImplementedError
38453845

3846-
def close(self) -> None:
3846+
def close(self, *, raise_if_closed: bool = True) -> None:
38473847
"""Closes the contained environment if possible."""
38483848
self.is_closed = True
38493849
try:

torchrl/envs/libs/pettingzoo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ def _update_agent_mask(self, td):
837837
if agent not in agents_acting:
838838
group_mask[index] = False
839839

840-
def close(self) -> None:
840+
def close(self, *, raise_if_closed: bool = True) -> None:
841841
self._env.close()
842842

843843

torchrl/envs/libs/smacv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _update_action_mask(self):
410410
self.action_spec.update_mask(mask)
411411
return mask
412412

413-
def close(self):
413+
def close(self, *, raise_if_closed: bool = True):
414414
# Closes StarCraft II
415415
self._env.close()
416416

torchrl/envs/libs/unity_mlagents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ def _reset(
497497
self._env.reset()
498498
return self._make_td_out(tensordict, is_reset=True)
499499

500-
def close(self):
500+
def close(self, *, raise_if_closed: bool = True):
501501
self._env.close()
502502

503503
@_classproperty

torchrl/envs/transforms/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
10261026
tensordict = tensordict.select(
10271027
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
10281028
)
1029-
tensordict = self.transform._reset_env_preprocess(tensordict)
1029+
tensordict = self.transform._reset_env_preprocess(tensordict)
10301030
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
10311031
if tensordict is None:
10321032
# make sure all transforms see a source tensordict
@@ -1083,8 +1083,8 @@ def is_closed(self) -> bool:
10831083
def is_closed(self, value: bool):
10841084
self.base_env.is_closed = value
10851085

1086-
def close(self):
1087-
self.base_env.close()
1086+
def close(self, *, raise_if_closed: bool = True):
1087+
self.base_env.close(raise_if_closed=raise_if_closed)
10881088
self.is_closed = True
10891089

10901090
def empty_cache(self):

0 commit comments

Comments
 (0)