Skip to content

Commit c00b62a

Browse files
author
Vincent Moens
authored
[Test] Skip threading tests in OSX (#1571)
1 parent f016fa0 commit c00b62a

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

test/test_collector.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@
6161
from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule
6262

6363
# torch.set_default_dtype(torch.double)
64-
_os_is_windows = sys.platform == "win32"
65-
_python_is_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10
66-
_python_is_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
67-
_os_is_osx = sys.platform == "darwin"
64+
IS_WINDOWS = sys.platform == "win32"
65+
IS_OSX = sys.platform == "darwin"
66+
PYTHON_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10
67+
PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
6868

6969

7070
class WrappablePolicy(nn.Module):
@@ -172,7 +172,7 @@ def _is_consistent_device_type(
172172

173173

174174
@pytest.mark.skipif(
175-
_os_is_windows and _python_is_3_10,
175+
IS_WINDOWS and PYTHON_3_10,
176176
reason="Windows Access Violation in torch.multiprocessing / BrokenPipeError in multiprocessing.connection",
177177
)
178178
@pytest.mark.parametrize("num_env", [2])
@@ -187,7 +187,7 @@ def test_output_device_consistency(
187187
) and not torch.cuda.is_available():
188188
pytest.skip("cuda is not available")
189189

190-
if _os_is_windows and _python_is_3_7:
190+
if IS_WINDOWS and PYTHON_3_7:
191191
if device == "cuda" and policy_device == "cuda" and device is None:
192192
pytest.skip(
193193
"BrokenPipeError in multiprocessing.connection with Python 3.7 on Windows"
@@ -509,7 +509,7 @@ def test_collector_batch_size(
509509
num_env, env_name, seed=100, num_workers=2, frames_per_batch=20
510510
):
511511
"""Tests that there are 'frames_per_batch' frames in each batch of a collection."""
512-
if num_env == 3 and _os_is_windows:
512+
if num_env == 3 and IS_WINDOWS:
513513
pytest.skip("Test timeout (> 10 min) on CI pipeline Windows machine with GPU")
514514
if num_env == 1:
515515

@@ -1053,12 +1053,7 @@ def test_collector_output_keys(
10531053
@pytest.mark.parametrize("storing_device", ["cuda", "cpu"])
10541054
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found")
10551055
def test_collector_device_combinations(device, storing_device):
1056-
if (
1057-
_os_is_windows
1058-
and _python_is_3_10
1059-
and storing_device == "cuda"
1060-
and device == "cuda"
1061-
):
1056+
if IS_WINDOWS and PYTHON_3_10 and storing_device == "cuda" and device == "cuda":
10621057
pytest.skip("Windows fatal exception: access violation in torch.storage")
10631058

10641059
def env_fn(seed):
@@ -1274,7 +1269,7 @@ def weight_reset(m):
12741269
m.reset_parameters()
12751270

12761271

1277-
@pytest.mark.skipif(_os_is_osx, reason="Queue.qsize does not work on osx.")
1272+
@pytest.mark.skipif(IS_OSX, reason="Queue.qsize does not work on osx.")
12781273
class TestPreemptiveThreshold:
12791274
@pytest.mark.parametrize("env_name", ["conv", "vec"])
12801275
def test_sync_collector_interruptor_mechanism(self, env_name, seed=100):
@@ -1785,6 +1780,9 @@ def make_env():
17851780
collector.shutdown()
17861781

17871782

1783+
@pytest.mark.skipif(
1784+
IS_OSX, reason="setting different threads across workeres can randomly fail on OSX."
1785+
)
17881786
def test_num_threads():
17891787
from torchrl.collectors import collectors
17901788

test/test_env.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
from collections import defaultdict
1010
from functools import partial
11+
from sys import platform
1112

1213
import numpy as np
1314
import pytest
@@ -91,6 +92,7 @@
9192
_atari_found = False
9293
atari_confs = defaultdict(lambda: "")
9394

95+
IS_OSX = platform == "darwin"
9496

9597
## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between
9698
## the serial and parallel batched envs
@@ -2089,6 +2091,9 @@ def test_mocking_envs(envclass):
20892091
check_env_specs(env, seed=100, return_contiguous=False)
20902092

20912093

2094+
@pytest.mark.skipif(
2095+
IS_OSX, reason="setting different threads across workeres can randomly fail on OSX."
2096+
)
20922097
def test_num_threads():
20932098
from torchrl.envs import batched_envs
20942099

0 commit comments

Comments
 (0)