Skip to content

Commit 86b8918

Browse files
author
Vincent Moens
authored
[BugFix] thread setting bug (#1852)
1 parent 017bcd0 commit 86b8918

File tree

5 files changed

+15
-26
lines changed

5 files changed

+15
-26
lines changed

test/test_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,7 +2361,7 @@ def make_env():
23612361
class TestLibThreading:
23622362
@pytest.mark.skipif(
23632363
IS_OSX,
2364-
reason="setting different threads across workeres can randomly fail on OSX.",
2364+
reason="setting different threads across workers can randomly fail on OSX.",
23652365
)
23662366
def test_num_threads(self):
23672367
from torchrl.collectors import collectors
@@ -2396,7 +2396,7 @@ def test_num_threads(self):
23962396

23972397
@pytest.mark.skipif(
23982398
IS_OSX,
2399-
reason="setting different threads across workeres can randomly fail on OSX.",
2399+
reason="setting different threads across workers can randomly fail on OSX.",
24002400
)
24012401
def test_auto_num_threads(self):
24022402
init_threads = torch.get_num_threads()

test/test_env.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,7 +2337,7 @@ def test_terminated_or_truncated_spec(self):
23372337
class TestLibThreading:
23382338
@pytest.mark.skipif(
23392339
IS_OSX,
2340-
reason="setting different threads across workeres can randomly fail on OSX.",
2340+
reason="setting different threads across workers can randomly fail on OSX.",
23412341
)
23422342
def test_num_threads(self):
23432343
from torchrl.envs import batched_envs
@@ -2363,18 +2363,18 @@ def test_num_threads(self):
23632363

23642364
@pytest.mark.skipif(
23652365
IS_OSX,
2366-
reason="setting different threads across workeres can randomly fail on OSX.",
2366+
reason="setting different threads across workers can randomly fail on OSX.",
23672367
)
23682368
def test_auto_num_threads(self):
23692369
init_threads = torch.get_num_threads()
23702370

23712371
try:
2372-
env3 = ParallelEnv(3, lambda: GymEnv("Pendulum-v1"))
2372+
env3 = ParallelEnv(3, ContinuousActionVecMockEnv)
23732373
env3.rollout(2)
23742374

23752375
assert torch.get_num_threads() == max(1, init_threads - 3)
23762376

2377-
env2 = ParallelEnv(2, lambda: GymEnv("Pendulum-v1"))
2377+
env2 = ParallelEnv(2, ContinuousActionVecMockEnv)
23782378
env2.rollout(2)
23792379

23802380
assert torch.get_num_threads() == max(1, init_threads - 5)

torchrl/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,3 @@
5151
filter_warnings_subprocess = True
5252

5353
_THREAD_POOL_INIT = torch.get_num_threads()
54-
_THREAD_POOL = torch.get_num_threads()

torchrl/collectors/collectors.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,18 +1607,12 @@ def _queue_len(self) -> int:
16071607

16081608
def _run_processes(self) -> None:
16091609
if self.num_threads is None:
1610-
import torchrl
1611-
16121610
total_workers = self._total_workers_from_env(self.create_env_fn)
16131611
self.num_threads = max(
1614-
1, torchrl._THREAD_POOL - total_workers
1612+
1, torch.get_num_threads() - total_workers
16151613
) # 1 more thread for this proc
16161614

16171615
torch.set_num_threads(self.num_threads)
1618-
assert torch.get_num_threads() == self.num_threads
1619-
import torchrl
1620-
1621-
torchrl._THREAD_POOL = self.num_threads
16221616
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
16231617
self.procs = []
16241618
self.pipes = []
@@ -1727,11 +1721,12 @@ def _shutdown_main(self) -> None:
17271721
finally:
17281722
import torchrl
17291723

1730-
torchrl._THREAD_POOL = min(
1724+
num_threads = min(
17311725
torchrl._THREAD_POOL_INIT,
1732-
torchrl._THREAD_POOL + self._total_workers_from_env(self.create_env_fn),
1726+
torch.get_num_threads()
1727+
+ self._total_workers_from_env(self.create_env_fn),
17331728
)
1734-
torch.set_num_threads(torchrl._THREAD_POOL)
1729+
torch.set_num_threads(num_threads)
17351730

17361731
for proc in self.procs:
17371732
if proc.is_alive():

torchrl/envs/batched_envs.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -633,10 +633,10 @@ def close(self) -> None:
633633
self.is_closed = True
634634
import torchrl
635635

636-
torchrl._THREAD_POOL = min(
637-
torchrl._THREAD_POOL_INIT, torchrl._THREAD_POOL + self.num_workers
636+
num_threads = min(
637+
torchrl._THREAD_POOL_INIT, torch.get_num_threads() + self.num_workers
638638
)
639-
torch.set_num_threads(torchrl._THREAD_POOL)
639+
torch.set_num_threads(num_threads)
640640

641641
def _shutdown_workers(self) -> None:
642642
raise NotImplementedError
@@ -1015,16 +1015,11 @@ def _start_workers(self) -> None:
10151015
from torchrl.envs.env_creator import EnvCreator
10161016

10171017
if self.num_threads is None:
1018-
import torchrl
1019-
10201018
self.num_threads = max(
1021-
1, torchrl._THREAD_POOL - self.num_workers
1019+
1, torch.get_num_threads() - self.num_workers
10221020
) # 1 more thread for this proc
10231021

10241022
torch.set_num_threads(self.num_threads)
1025-
import torchrl
1026-
1027-
torchrl._THREAD_POOL = self.num_threads
10281023

10291024
ctx = mp.get_context("spawn")
10301025

0 commit comments

Comments
 (0)