File tree Expand file tree Collapse file tree 5 files changed +15
-26
lines changed Expand file tree Collapse file tree 5 files changed +15
-26
lines changed Original file line number Diff line number Diff line change @@ -2361,7 +2361,7 @@ def make_env():
2361
2361
class TestLibThreading :
2362
2362
@pytest .mark .skipif (
2363
2363
IS_OSX ,
2364
- reason = "setting different threads across workeres can randomly fail on OSX." ,
2364
+ reason = "setting different threads across workers can randomly fail on OSX." ,
2365
2365
)
2366
2366
def test_num_threads (self ):
2367
2367
from torchrl .collectors import collectors
@@ -2396,7 +2396,7 @@ def test_num_threads(self):
2396
2396
2397
2397
@pytest .mark .skipif (
2398
2398
IS_OSX ,
2399
- reason = "setting different threads across workeres can randomly fail on OSX." ,
2399
+ reason = "setting different threads across workers can randomly fail on OSX." ,
2400
2400
)
2401
2401
def test_auto_num_threads (self ):
2402
2402
init_threads = torch .get_num_threads ()
Original file line number Diff line number Diff line change @@ -2337,7 +2337,7 @@ def test_terminated_or_truncated_spec(self):
2337
2337
class TestLibThreading :
2338
2338
@pytest .mark .skipif (
2339
2339
IS_OSX ,
2340
- reason = "setting different threads across workeres can randomly fail on OSX." ,
2340
+ reason = "setting different threads across workers can randomly fail on OSX." ,
2341
2341
)
2342
2342
def test_num_threads (self ):
2343
2343
from torchrl .envs import batched_envs
@@ -2363,18 +2363,18 @@ def test_num_threads(self):
2363
2363
2364
2364
@pytest .mark .skipif (
2365
2365
IS_OSX ,
2366
- reason = "setting different threads across workeres can randomly fail on OSX." ,
2366
+ reason = "setting different threads across workers can randomly fail on OSX." ,
2367
2367
)
2368
2368
def test_auto_num_threads (self ):
2369
2369
init_threads = torch .get_num_threads ()
2370
2370
2371
2371
try :
2372
- env3 = ParallelEnv (3 , lambda : GymEnv ( "Pendulum-v1" ) )
2372
+ env3 = ParallelEnv (3 , ContinuousActionVecMockEnv )
2373
2373
env3 .rollout (2 )
2374
2374
2375
2375
assert torch .get_num_threads () == max (1 , init_threads - 3 )
2376
2376
2377
- env2 = ParallelEnv (2 , lambda : GymEnv ( "Pendulum-v1" ) )
2377
+ env2 = ParallelEnv (2 , ContinuousActionVecMockEnv )
2378
2378
env2 .rollout (2 )
2379
2379
2380
2380
assert torch .get_num_threads () == max (1 , init_threads - 5 )
Original file line number Diff line number Diff line change 51
51
filter_warnings_subprocess = True
52
52
53
53
_THREAD_POOL_INIT = torch .get_num_threads ()
54
- _THREAD_POOL = torch .get_num_threads ()
Original file line number Diff line number Diff line change @@ -1607,18 +1607,12 @@ def _queue_len(self) -> int:
1607
1607
1608
1608
def _run_processes (self ) -> None :
1609
1609
if self .num_threads is None :
1610
- import torchrl
1611
-
1612
1610
total_workers = self ._total_workers_from_env (self .create_env_fn )
1613
1611
self .num_threads = max (
1614
- 1 , torchrl . _THREAD_POOL - total_workers
1612
+ 1 , torch . get_num_threads () - total_workers
1615
1613
) # 1 more thread for this proc
1616
1614
1617
1615
torch .set_num_threads (self .num_threads )
1618
- assert torch .get_num_threads () == self .num_threads
1619
- import torchrl
1620
-
1621
- torchrl ._THREAD_POOL = self .num_threads
1622
1616
queue_out = mp .Queue (self ._queue_len ) # sends data from proc to main
1623
1617
self .procs = []
1624
1618
self .pipes = []
@@ -1727,11 +1721,12 @@ def _shutdown_main(self) -> None:
1727
1721
finally :
1728
1722
import torchrl
1729
1723
1730
- torchrl . _THREAD_POOL = min (
1724
+ num_threads = min (
1731
1725
torchrl ._THREAD_POOL_INIT ,
1732
- torchrl ._THREAD_POOL + self ._total_workers_from_env (self .create_env_fn ),
1726
+ torch .get_num_threads ()
1727
+ + self ._total_workers_from_env (self .create_env_fn ),
1733
1728
)
1734
- torch .set_num_threads (torchrl . _THREAD_POOL )
1729
+ torch .set_num_threads (num_threads )
1735
1730
1736
1731
for proc in self .procs :
1737
1732
if proc .is_alive ():
Original file line number Diff line number Diff line change @@ -633,10 +633,10 @@ def close(self) -> None:
633
633
self .is_closed = True
634
634
import torchrl
635
635
636
- torchrl . _THREAD_POOL = min (
637
- torchrl ._THREAD_POOL_INIT , torchrl . _THREAD_POOL + self .num_workers
636
+ num_threads = min (
637
+ torchrl ._THREAD_POOL_INIT , torch . get_num_threads () + self .num_workers
638
638
)
639
- torch .set_num_threads (torchrl . _THREAD_POOL )
639
+ torch .set_num_threads (num_threads )
640
640
641
641
def _shutdown_workers (self ) -> None :
642
642
raise NotImplementedError
@@ -1015,16 +1015,11 @@ def _start_workers(self) -> None:
1015
1015
from torchrl .envs .env_creator import EnvCreator
1016
1016
1017
1017
if self .num_threads is None :
1018
- import torchrl
1019
-
1020
1018
self .num_threads = max (
1021
- 1 , torchrl . _THREAD_POOL - self .num_workers
1019
+ 1 , torch . get_num_threads () - self .num_workers
1022
1020
) # 1 more thread for this proc
1023
1021
1024
1022
torch .set_num_threads (self .num_threads )
1025
- import torchrl
1026
-
1027
- torchrl ._THREAD_POOL = self .num_threads
1028
1023
1029
1024
ctx = mp .get_context ("spawn" )
1030
1025
You can’t perform that action at this time.
0 commit comments