Skip to content

Commit 967bad2

Browse files
author
Vincent Moens
authored
[Feature, BugFix] Better thread control in penv and collectors (#1848)
1 parent b1cc796 commit 967bad2

File tree

6 files changed

+188
-64
lines changed

6 files changed

+188
-64
lines changed

test/test_collector.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import argparse
8+
import gc
89
import logging
910

1011
import sys
@@ -2357,39 +2358,79 @@ def make_env():
23572358
del collector
23582359

23592360

2360-
@pytest.mark.skipif(
2361-
IS_OSX, reason="setting different threads across workeres can randomly fail on OSX."
2362-
)
2363-
def test_num_threads():
2364-
from torchrl.collectors import collectors
2365-
2366-
_main_async_collector_saved = collectors._main_async_collector
2367-
collectors._main_async_collector = decorate_thread_sub_func(
2368-
collectors._main_async_collector, num_threads=3
2361+
class TestLibThreading:
2362+
@pytest.mark.skipif(
2363+
IS_OSX,
2364+
reason="setting different threads across workeres can randomly fail on OSX.",
23692365
)
2370-
num_threads = torch.get_num_threads()
2371-
try:
2372-
env = ContinuousActionVecMockEnv()
2373-
c = MultiSyncDataCollector(
2374-
[env],
2375-
policy=RandomPolicy(env.action_spec),
2376-
num_threads=7,
2377-
num_sub_threads=3,
2378-
total_frames=200,
2379-
frames_per_batch=200,
2366+
def test_num_threads(self):
2367+
from torchrl.collectors import collectors
2368+
2369+
_main_async_collector_saved = collectors._main_async_collector
2370+
collectors._main_async_collector = decorate_thread_sub_func(
2371+
collectors._main_async_collector, num_threads=3
23802372
)
2381-
assert torch.get_num_threads() == 7
2382-
for _ in c:
2383-
pass
2384-
finally:
2373+
num_threads = torch.get_num_threads()
2374+
try:
2375+
env = ContinuousActionVecMockEnv()
2376+
c = MultiSyncDataCollector(
2377+
[env],
2378+
policy=RandomPolicy(env.action_spec),
2379+
num_threads=7,
2380+
num_sub_threads=3,
2381+
total_frames=200,
2382+
frames_per_batch=200,
2383+
)
2384+
assert torch.get_num_threads() == 7
2385+
for _ in c:
2386+
pass
2387+
finally:
2388+
try:
2389+
c.shutdown()
2390+
del c
2391+
except Exception:
2392+
logging.info("Failed to shut down collector")
2393+
# reset vals
2394+
collectors._main_async_collector = _main_async_collector_saved
2395+
torch.set_num_threads(num_threads)
2396+
2397+
@pytest.mark.skipif(
2398+
IS_OSX,
2399+
reason="setting different threads across workeres can randomly fail on OSX.",
2400+
)
2401+
def test_auto_num_threads(self):
2402+
init_threads = torch.get_num_threads()
2403+
try:
2404+
collector = MultiSyncDataCollector(
2405+
[ContinuousActionVecMockEnv],
2406+
RandomPolicy(ContinuousActionVecMockEnv().full_action_spec),
2407+
frames_per_batch=3,
2408+
)
2409+
for _ in collector:
2410+
assert torch.get_num_threads() == init_threads - 1
2411+
break
2412+
collector.shutdown()
2413+
assert torch.get_num_threads() == init_threads
2414+
del collector
2415+
gc.collect()
2416+
finally:
2417+
torch.set_num_threads(init_threads)
2418+
23852419
try:
2386-
c.shutdown()
2387-
del c
2388-
except Exception:
2389-
logging.info("Failed to shut down collector")
2390-
# reset vals
2391-
collectors._main_async_collector = _main_async_collector_saved
2392-
torch.set_num_threads(num_threads)
2420+
collector = MultiSyncDataCollector(
2421+
[ParallelEnv(2, ContinuousActionVecMockEnv)],
2422+
RandomPolicy(ContinuousActionVecMockEnv().full_action_spec.expand(2)),
2423+
frames_per_batch=3,
2424+
)
2425+
for _ in collector:
2426+
assert torch.get_num_threads() == init_threads - 2
2427+
break
2428+
collector.shutdown()
2429+
assert torch.get_num_threads() == init_threads
2430+
del collector
2431+
gc.collect()
2432+
finally:
2433+
torch.set_num_threads(init_threads)
23932434

23942435

23952436
if __name__ == "__main__":

test/test_env.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import argparse
7+
import gc
78
import os.path
89
import re
910
from collections import defaultdict
@@ -2333,30 +2334,64 @@ def test_terminated_or_truncated_spec(self):
23332334
assert not data["nested", "_reset"].any()
23342335

23352336

2336-
@pytest.mark.skipif(
2337-
IS_OSX, reason="setting different threads across workeres can randomly fail on OSX."
2338-
)
2339-
def test_num_threads():
2340-
from torchrl.envs import batched_envs
2341-
2342-
_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
2343-
batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func(
2344-
batched_envs._run_worker_pipe_shared_mem, num_threads=3
2337+
class TestLibThreading:
2338+
@pytest.mark.skipif(
2339+
IS_OSX,
2340+
reason="setting different threads across workeres can randomly fail on OSX.",
23452341
)
2346-
num_threads = torch.get_num_threads()
2347-
try:
2348-
env = ParallelEnv(
2349-
2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7
2342+
def test_num_threads(self):
2343+
from torchrl.envs import batched_envs
2344+
2345+
_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
2346+
batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func(
2347+
batched_envs._run_worker_pipe_shared_mem, num_threads=3
23502348
)
2351-
# We could test that the number of threads isn't changed until we start the procs.
2352-
# Even though it's unlikely that we have 7 threads, we still disable this for safety
2353-
# assert torch.get_num_threads() != 7
2354-
env.rollout(3)
2355-
assert torch.get_num_threads() == 7
2356-
finally:
2357-
# reset vals
2358-
batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save
2359-
torch.set_num_threads(num_threads)
2349+
num_threads = torch.get_num_threads()
2350+
try:
2351+
env = ParallelEnv(
2352+
2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7
2353+
)
2354+
# We could test that the number of threads isn't changed until we start the procs.
2355+
# Even though it's unlikely that we have 7 threads, we still disable this for safety
2356+
# assert torch.get_num_threads() != 7
2357+
env.rollout(3)
2358+
assert torch.get_num_threads() == 7
2359+
finally:
2360+
# reset vals
2361+
batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save
2362+
torch.set_num_threads(num_threads)
2363+
2364+
@pytest.mark.skipif(
2365+
IS_OSX,
2366+
reason="setting different threads across workeres can randomly fail on OSX.",
2367+
)
2368+
def test_auto_num_threads(self):
2369+
init_threads = torch.get_num_threads()
2370+
2371+
try:
2372+
env3 = ParallelEnv(3, lambda: GymEnv("Pendulum-v1"))
2373+
env3.rollout(2)
2374+
2375+
assert torch.get_num_threads() == max(1, init_threads - 3)
2376+
2377+
env2 = ParallelEnv(2, lambda: GymEnv("Pendulum-v1"))
2378+
env2.rollout(2)
2379+
2380+
assert torch.get_num_threads() == max(1, init_threads - 5)
2381+
2382+
env2.close()
2383+
del env2
2384+
gc.collect()
2385+
2386+
assert torch.get_num_threads() == max(1, init_threads - 3)
2387+
2388+
env3.close()
2389+
del env3
2390+
gc.collect()
2391+
2392+
assert torch.get_num_threads() == init_threads
2393+
finally:
2394+
torch.set_num_threads(init_threads)
23602395

23612396

23622397
def test_run_type_checks():

torchrl/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,6 @@
4949
# Filter warnings in subprocesses: True by default given the multiple optional
5050
# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.
5151
filter_warnings_subprocess = True
52+
53+
_THREAD_POOL_INIT = torch.get_num_threads()
54+
_THREAD_POOL = torch.get_num_threads()

torchrl/collectors/collectors.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,12 +1368,11 @@ def __init__(
13681368
exploration_mode=exploration_mode, exploration_type=exploration_type
13691369
)
13701370
self.closed = True
1371-
if num_threads is None:
1372-
num_threads = len(create_env_fn) + 1 # 1 more thread for this proc
1371+
self.num_workers = len(create_env_fn)
1372+
13731373
self.num_sub_threads = num_sub_threads
13741374
self.num_threads = num_threads
13751375
self.create_env_fn = create_env_fn
1376-
self.num_workers = len(create_env_fn)
13771376
self.create_env_kwargs = (
13781377
create_env_kwargs
13791378
if create_env_kwargs is not None
@@ -1521,6 +1520,18 @@ def _get_weight_fn(weights=policy_weights):
15211520
self._frames = 0
15221521
self._iter = -1
15231522

1523+
@classmethod
1524+
def _total_workers_from_env(cls, env_creators):
1525+
if isinstance(env_creators, (tuple, list)):
1526+
return sum(
1527+
cls._total_workers_from_env(env_creator) for env_creator in env_creators
1528+
)
1529+
from torchrl.envs import ParallelEnv
1530+
1531+
if isinstance(env_creators, ParallelEnv):
1532+
return env_creators.num_workers
1533+
return 1
1534+
15241535
def _get_devices(
15251536
self,
15261537
*,
@@ -1595,7 +1606,19 @@ def _queue_len(self) -> int:
15951606
raise NotImplementedError
15961607

15971608
def _run_processes(self) -> None:
1609+
if self.num_threads is None:
1610+
import torchrl
1611+
1612+
total_workers = self._total_workers_from_env(self.create_env_fn)
1613+
self.num_threads = max(
1614+
1, torchrl._THREAD_POOL - total_workers
1615+
) # 1 more thread for this proc
1616+
15981617
torch.set_num_threads(self.num_threads)
1618+
assert torch.get_num_threads() == self.num_threads
1619+
import torchrl
1620+
1621+
torchrl._THREAD_POOL = self.num_threads
15991622
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
16001623
self.procs = []
16011624
self.pipes = []
@@ -1702,6 +1725,14 @@ def _shutdown_main(self) -> None:
17021725
for proc in self.procs:
17031726
proc.join(1.0)
17041727
finally:
1728+
import torchrl
1729+
1730+
torchrl._THREAD_POOL = min(
1731+
torchrl._THREAD_POOL_INIT,
1732+
torchrl._THREAD_POOL + self._total_workers_from_env(self.create_env_fn),
1733+
)
1734+
torch.set_num_threads(torchrl._THREAD_POOL)
1735+
17051736
for proc in self.procs:
17061737
if proc.is_alive():
17071738
proc.terminate()

torchrl/envs/batched_envs.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,6 @@ def __init__(
270270
super().__init__(device=device)
271271
self.serial_for_single = serial_for_single
272272
self.is_closed = True
273-
if num_threads is None:
274-
num_threads = num_workers + 1 # 1 more thread for this proc
275273
self.num_sub_threads = num_sub_threads
276274
self.num_threads = num_threads
277275
self._cache_in_keys = None
@@ -633,6 +631,12 @@ def close(self) -> None:
633631

634632
self._shutdown_workers()
635633
self.is_closed = True
634+
import torchrl
635+
636+
torchrl._THREAD_POOL = min(
637+
torchrl._THREAD_POOL_INIT, torchrl._THREAD_POOL + self.num_workers
638+
)
639+
torch.set_num_threads(torchrl._THREAD_POOL)
636640

637641
def _shutdown_workers(self) -> None:
638642
raise NotImplementedError
@@ -1010,7 +1014,17 @@ class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta):
10101014
def _start_workers(self) -> None:
10111015
from torchrl.envs.env_creator import EnvCreator
10121016

1017+
if self.num_threads is None:
1018+
import torchrl
1019+
1020+
self.num_threads = max(
1021+
1, torchrl._THREAD_POOL - self.num_workers
1022+
) # 1 more thread for this proc
1023+
10131024
torch.set_num_threads(self.num_threads)
1025+
import torchrl
1026+
1027+
torchrl._THREAD_POOL = self.num_threads
10141028

10151029
ctx = mp.get_context("spawn")
10161030

torchrl/envs/libs/gym.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -306,16 +306,16 @@ def _gym_to_torchrl_spec_transform(
306306
shape = torch.Size([1])
307307
if dtype is None:
308308
dtype = numpy_to_torch_dtype_dict[spec.dtype]
309-
low = torch.tensor(spec.low, device=device, dtype=dtype)
310-
high = torch.tensor(spec.high, device=device, dtype=dtype)
309+
low = torch.as_tensor(spec.low, device=device, dtype=dtype)
310+
high = torch.as_tensor(spec.high, device=device, dtype=dtype)
311311
is_unbounded = low.isinf().all() and high.isinf().all()
312312

313313
minval, maxval = _minmax_dtype(dtype)
314314
minval = torch.as_tensor(minval).to(low.device, dtype)
315315
maxval = torch.as_tensor(maxval).to(low.device, dtype)
316316
is_unbounded = is_unbounded or (
317-
torch.isclose(low, torch.tensor(minval, dtype=dtype)).all()
318-
and torch.isclose(high, torch.tensor(maxval, dtype=dtype)).all()
317+
torch.isclose(low, torch.as_tensor(minval, dtype=dtype)).all()
318+
and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all()
319319
)
320320
return (
321321
UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype)
@@ -1480,7 +1480,7 @@ def _read_obs(self, obs, key, tensor, index):
14801480
# Simplest case: there is one observation,
14811481
# presented as a np.ndarray. The key should be pixels or observation.
14821482
# We just write that value at its location in the tensor
1483-
tensor[index] = torch.tensor(obs, device=tensor.device)
1483+
tensor[index] = torch.as_tensor(obs, device=tensor.device)
14841484
elif isinstance(obs, dict):
14851485
if key not in obs:
14861486
raise KeyError(
@@ -1491,13 +1491,13 @@ def _read_obs(self, obs, key, tensor, index):
14911491
# if the obs is a dict, we expect that the key points also to
14921492
# a value in the obs. We retrieve this value and write it in the
14931493
# tensor
1494-
tensor[index] = torch.tensor(subobs, device=tensor.device)
1494+
tensor[index] = torch.as_tensor(subobs, device=tensor.device)
14951495

14961496
elif isinstance(obs, (list, tuple)):
14971497
# tuples are stacked along the first dimension when passing gym spaces
14981498
# to torchrl specs. As such, we can simply stack the tuple and set it
14991499
# at the relevant index (assuming stacking can be achieved)
1500-
tensor[index] = torch.tensor(obs, device=tensor.device)
1500+
tensor[index] = torch.as_tensor(obs, device=tensor.device)
15011501
else:
15021502
raise NotImplementedError(
15031503
f"Observations of type {type(obs)} are not supported yet."

0 commit comments

Comments
 (0)