61
61
from torchrl .modules import Actor , LSTMNet , OrnsteinUhlenbeckProcessWrapper , SafeModule
62
62
63
63
# torch.set_default_dtype(torch.double)
64
- _os_is_windows = sys .platform == "win32"
65
- _python_is_3_10 = sys .version_info . major == 3 and sys . version_info . minor == 10
66
- _python_is_3_7 = sys .version_info .major == 3 and sys .version_info .minor == 7
67
- _os_is_osx = sys .platform == "darwin"
64
+ IS_WINDOWS = sys .platform == "win32"
65
+ IS_OSX = sys .platform == "darwin"
66
+ PYTHON_3_10 = sys .version_info .major == 3 and sys .version_info .minor == 10
67
+ PYTHON_3_7 = sys .version_info . major == 3 and sys . version_info . minor == 7
68
68
69
69
70
70
class WrappablePolicy (nn .Module ):
@@ -172,7 +172,7 @@ def _is_consistent_device_type(
172
172
173
173
174
174
@pytest .mark .skipif (
175
- _os_is_windows and _python_is_3_10 ,
175
+ IS_WINDOWS and PYTHON_3_10 ,
176
176
reason = "Windows Access Violation in torch.multiprocessing / BrokenPipeError in multiprocessing.connection" ,
177
177
)
178
178
@pytest .mark .parametrize ("num_env" , [2 ])
@@ -187,7 +187,7 @@ def test_output_device_consistency(
187
187
) and not torch .cuda .is_available ():
188
188
pytest .skip ("cuda is not available" )
189
189
190
- if _os_is_windows and _python_is_3_7 :
190
+ if IS_WINDOWS and PYTHON_3_7 :
191
191
if device == "cuda" and policy_device == "cuda" and device is None :
192
192
pytest .skip (
193
193
"BrokenPipeError in multiprocessing.connection with Python 3.7 on Windows"
@@ -509,7 +509,7 @@ def test_collector_batch_size(
509
509
num_env , env_name , seed = 100 , num_workers = 2 , frames_per_batch = 20
510
510
):
511
511
"""Tests that there are 'frames_per_batch' frames in each batch of a collection."""
512
- if num_env == 3 and _os_is_windows :
512
+ if num_env == 3 and IS_WINDOWS :
513
513
pytest .skip ("Test timeout (> 10 min) on CI pipeline Windows machine with GPU" )
514
514
if num_env == 1 :
515
515
@@ -1053,12 +1053,7 @@ def test_collector_output_keys(
1053
1053
@pytest .mark .parametrize ("storing_device" , ["cuda" , "cpu" ])
1054
1054
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "no cuda device found" )
1055
1055
def test_collector_device_combinations (device , storing_device ):
1056
- if (
1057
- _os_is_windows
1058
- and _python_is_3_10
1059
- and storing_device == "cuda"
1060
- and device == "cuda"
1061
- ):
1056
+ if IS_WINDOWS and PYTHON_3_10 and storing_device == "cuda" and device == "cuda" :
1062
1057
pytest .skip ("Windows fatal exception: access violation in torch.storage" )
1063
1058
1064
1059
def env_fn (seed ):
@@ -1274,7 +1269,7 @@ def weight_reset(m):
1274
1269
m .reset_parameters ()
1275
1270
1276
1271
1277
- @pytest .mark .skipif (_os_is_osx , reason = "Queue.qsize does not work on osx." )
1272
+ @pytest .mark .skipif (IS_OSX , reason = "Queue.qsize does not work on osx." )
1278
1273
class TestPreemptiveThreshold :
1279
1274
@pytest .mark .parametrize ("env_name" , ["conv" , "vec" ])
1280
1275
def test_sync_collector_interruptor_mechanism (self , env_name , seed = 100 ):
@@ -1785,6 +1780,9 @@ def make_env():
1785
1780
collector .shutdown ()
1786
1781
1787
1782
1783
+ @pytest .mark .skipif (
1784
+ IS_OSX , reason = "setting different threads across workeres can randomly fail on OSX."
1785
+ )
1788
1786
def test_num_threads ():
1789
1787
from torchrl .collectors import collectors
1790
1788
0 commit comments