Skip to content

Commit 09e148b

Browse files
Vincent Moensmatteobettini
andauthored
[Feature] Threaded collection and parallel envs (#1559)
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
1 parent 95773f7 commit 09e148b

File tree

7 files changed

+111
-8
lines changed

7 files changed

+111
-8
lines changed

test/_utils_internal.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from tensordict import tensorclass, TensorDict
2121
from torchrl._utils import implement_for, seed_generator
22+
from torchrl.data.utils import CloudpickleWrapper
2223

2324
from torchrl.envs import MultiThreadedEnv, ObservationNorm
2425
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
@@ -433,3 +434,11 @@ def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int):
433434
== td["nested_2", "observation"][~action_is_count]
434435
).all()
435436
assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all()
437+
438+
439+
def decorate_thread_sub_func(func, num_threads):
440+
def new_func(*args, **kwargs):
441+
assert torch.get_num_threads() == num_threads
442+
return func(*args, **kwargs)
443+
444+
return CloudpickleWrapper(new_func)

test/test_collector.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import argparse
7+
78
import sys
89

910
import numpy as np
1011
import pytest
1112
import torch
1213
from _utils_internal import (
1314
check_rollout_consistency_multikey_env,
15+
decorate_thread_sub_func,
1416
generate_seeds,
1517
PENDULUM_VERSIONED,
1618
PONG_VERSIONED,
@@ -1783,6 +1785,35 @@ def make_env():
17831785
collector.shutdown()
17841786

17851787

1788+
def test_num_threads():
1789+
from torchrl.collectors import collectors
1790+
1791+
_main_async_collector_saved = collectors._main_async_collector
1792+
collectors._main_async_collector = decorate_thread_sub_func(
1793+
collectors._main_async_collector, num_threads=3
1794+
)
1795+
num_threads = torch.get_num_threads()
1796+
try:
1797+
env = ContinuousActionVecMockEnv()
1798+
c = MultiSyncDataCollector(
1799+
[env],
1800+
policy=RandomPolicy(env.action_spec),
1801+
num_threads=7,
1802+
num_sub_threads=3,
1803+
total_frames=200,
1804+
frames_per_batch=200,
1805+
)
1806+
assert torch.get_num_threads() == 7
1807+
for _ in c:
1808+
pass
1809+
c.shutdown()
1810+
del c
1811+
finally:
1812+
# reset vals
1813+
collectors._main_async_collector = _main_async_collector_saved
1814+
torch.set_num_threads(num_threads)
1815+
1816+
17861817
if __name__ == "__main__":
17871818
args, unknown = argparse.ArgumentParser().parse_known_args()
17881819
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_env.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_make_envs,
1919
CARTPOLE_VERSIONED,
2020
check_rollout_consistency_multikey_env,
21+
decorate_thread_sub_func,
2122
get_default_devices,
2223
HALFCHEETAH_VERSIONED,
2324
PENDULUM_VERSIONED,
@@ -2088,6 +2089,29 @@ def test_mocking_envs(envclass):
20882089
check_env_specs(env, seed=100, return_contiguous=False)
20892090

20902091

2092+
def test_num_threads():
2093+
from torchrl.envs import batched_envs
2094+
2095+
_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
2096+
batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func(
2097+
batched_envs._run_worker_pipe_shared_mem, num_threads=3
2098+
)
2099+
num_threads = torch.get_num_threads()
2100+
try:
2101+
env = ParallelEnv(
2102+
2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7
2103+
)
2104+
# We could test that the number of threads isn't changed until we start the procs.
2105+
# Even though it's unlikely that we have 7 threads, we still disable this for safety
2106+
# assert torch.get_num_threads() != 7
2107+
env.rollout(3)
2108+
assert torch.get_num_threads() == 7
2109+
finally:
2110+
# reset vals
2111+
batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save
2112+
torch.set_num_threads(num_threads)
2113+
2114+
20912115
if __name__ == "__main__":
20922116
args, unknown = argparse.ArgumentParser().parse_known_args()
20932117
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
56

67
import collections
78

@@ -534,21 +535,23 @@ def get_trace():
534535

535536

536537
class _ProcessNoWarn(mp.Process):
537-
"""A private Process class that shuts down warnings on the subprocess."""
538+
"""A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess."""
538539

539540
@wraps(mp.Process.__init__)
540-
def __init__(self, *args, **kwargs):
541+
def __init__(self, *args, num_threads=None, **kwargs):
541542
import torchrl
542543

543-
if torchrl.filter_warnings_subprocess:
544-
self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
544+
self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
545+
self.num_threads = num_threads
545546
super().__init__(*args, **kwargs)
546547

547548
def run(self, *args, **kwargs):
549+
if self.num_threads is not None:
550+
torch.set_num_threads(self.num_threads)
548551
if self.filter_warnings_subprocess:
549552
import warnings
550553

551554
with warnings.catch_warnings():
552555
warnings.simplefilter("ignore")
553-
return mp.Process.run(self, *args, **kwargs)
556+
return mp.Process.run(self, *args, **kwargs)
554557
return mp.Process.run(self, *args, **kwargs)

torchrl/collectors/collectors.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,14 @@ class _MultiDataCollector(DataCollectorBase):
10981098
Defaults to ``False``.
10991099
preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
11001100
that will be allowed to finished collecting their rollout before the rest are forced to end early.
1101+
num_threads (int, optional): number of threads for this process.
1102+
Defaults to the number of workers.
1103+
num_sub_threads (int, optional): number of threads of the subprocesses.
1104+
Should be equal to one plus the number of processes launched within
1105+
each subprocess (or one if a single process is launched).
1106+
Defaults to 1 for safety: if none is indicated, launching multiple
1107+
workers may charge the cpu load too much and harm performance.
1108+
11011109
"""
11021110

11031111
def __init__(
@@ -1127,11 +1135,17 @@ def __init__(
11271135
update_at_each_batch: bool = False,
11281136
devices=None,
11291137
storing_devices=None,
1138+
num_threads: int = None,
1139+
num_sub_threads: int = 1,
11301140
):
11311141
exploration_type = _convert_exploration_type(
11321142
exploration_mode=exploration_mode, exploration_type=exploration_type
11331143
)
11341144
self.closed = True
1145+
if num_threads is None:
1146+
num_threads = len(create_env_fn) + 1 # 1 more thread for this proc
1147+
self.num_sub_threads = num_sub_threads
1148+
self.num_threads = num_threads
11351149
self.create_env_fn = create_env_fn
11361150
self.num_workers = len(create_env_fn)
11371151
self.create_env_kwargs = (
@@ -1308,6 +1322,7 @@ def _queue_len(self) -> int:
13081322
raise NotImplementedError
13091323

13101324
def _run_processes(self) -> None:
1325+
torch.set_num_threads(self.num_threads)
13111326
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
13121327
self.procs = []
13131328
self.pipes = []
@@ -1339,7 +1354,11 @@ def _run_processes(self) -> None:
13391354
"idx": i,
13401355
"interruptor": self.interruptor,
13411356
}
1342-
proc = _ProcessNoWarn(target=_main_async_collector, kwargs=kwargs)
1357+
proc = _ProcessNoWarn(
1358+
target=_main_async_collector,
1359+
num_threads=self.num_sub_threads,
1360+
kwargs=kwargs,
1361+
)
13431362
# proc.daemon can't be set as daemonic processes may be launched by the process itself
13441363
try:
13451364
proc.start()

torchrl/data/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,8 @@ def __setstate__(self, ob: bytes):
261261
self.fn, self.kwargs = pickle.loads(ob)
262262

263263
def __call__(self, *args, **kwargs) -> Any:
264-
kwargs = {k: item for k, item in kwargs.items()}
265264
kwargs.update(self.kwargs)
266-
return self.fn(**kwargs)
265+
return self.fn(*args, **kwargs)
267266

268267

269268
def _process_action_space_spec(action_space, spec):

torchrl/envs/batched_envs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ class _BatchedEnv(EnvBase):
121121
It is assumed that all environments will run on the same device as a common shared
122122
tensordict will be used to pass data from process to process. The device can be
123123
changed after instantiation using :obj:`env.to(device)`.
124+
num_threads (int, optional): number of threads for this process.
125+
Defaults to the number of workers.
126+
This parameter has no effect for the :class:`~SerialEnv` class.
127+
num_sub_threads (int, optional): number of threads of the subprocesses.
128+
Should be equal to one plus the number of processes launched within
129+
each subprocess (or one if a single process is launched).
130+
Defaults to 1 for safety: if none is indicated, launching multiple
131+
workers may charge the cpu load too much and harm performance.
132+
This parameter has no effect for the :class:`~SerialEnv` class.
124133
125134
"""
126135

@@ -144,6 +153,8 @@ def __init__(
144153
policy_proof: Optional[Callable] = None,
145154
device: Optional[DEVICE_TYPING] = None,
146155
allow_step_when_done: bool = False,
156+
num_threads: int = None,
157+
num_sub_threads: int = 1,
147158
):
148159
if device is not None:
149160
raise ValueError(
@@ -154,6 +165,10 @@ def __init__(
154165

155166
super().__init__(device=None)
156167
self.is_closed = True
168+
if num_threads is None:
169+
num_threads = num_workers + 1 # 1 more thread for this proc
170+
self.num_sub_threads = num_sub_threads
171+
self.num_threads = num_threads
157172
self._cache_in_keys = None
158173

159174
self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
@@ -692,6 +707,8 @@ class ParallelEnv(_BatchedEnv):
692707
def _start_workers(self) -> None:
693708
from torchrl.envs.env_creator import EnvCreator
694709

710+
torch.set_num_threads(self.num_threads)
711+
695712
ctx = mp.get_context("spawn")
696713

697714
_num_workers = self.num_workers
@@ -717,6 +734,7 @@ def _start_workers(self) -> None:
717734

718735
process = _ProcessNoWarn(
719736
target=_run_worker_pipe_shared_mem,
737+
num_threads=self.num_sub_threads,
720738
args=(
721739
parent_pipe,
722740
child_pipe,

0 commit comments

Comments
 (0)