Skip to content

Commit 46d5663

Browse files
committed
Refactor plugin into multiple files
1 parent 48180f2 commit 46d5663

12 files changed

+892
-886
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
def get_logical_cpus():
2+
try:
3+
import psutil
4+
except ImportError:
5+
pass
6+
else:
7+
process = psutil.Process()
8+
try:
9+
cpu_cores = process.cpu_affinity()
10+
return len(cpu_cores)
11+
except AttributeError:
12+
cpu_cores = psutil.cpu_count()
13+
if cpu_cores is not None:
14+
return cpu_cores
15+
16+
try:
17+
from os import process_cpu_count
18+
except ImportError:
19+
pass
20+
else:
21+
cpu_cores = process_cpu_count()
22+
if cpu_cores is not None:
23+
return cpu_cores
24+
25+
try:
26+
from os import sched_getaffinity
27+
except ImportError:
28+
pass
29+
else:
30+
cpu_cores = sched_getaffinity(0)
31+
if cpu_cores is not None:
32+
return len(cpu_cores)
33+
34+
from os import cpu_count
35+
36+
return cpu_count()

src/pytest_run_parallel/fixtures.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
3+
from pytest_run_parallel.thread_comparator import ThreadComparator
4+
from pytest_run_parallel.utils import get_num_workers
5+
6+
7+
@pytest.fixture
8+
def num_parallel_threads(request):
9+
return get_num_workers(request.config, request.node)
10+
11+
12+
@pytest.fixture
13+
def num_iterations(request):
14+
node = request.node
15+
n_iterations = request.config.option.iterations
16+
m = node.get_closest_marker("iterations")
17+
if m is not None:
18+
n_iterations = int(m.args[0])
19+
return n_iterations
20+
21+
22+
@pytest.fixture
23+
def thread_comp(num_parallel_threads):
24+
return ThreadComparator(num_parallel_threads)

src/pytest_run_parallel/plugin.py renamed to src/pytest_run_parallel/hooks.py

Lines changed: 5 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
import functools
21
import os
3-
import sys
4-
import threading
52
import warnings
63

7-
import _pytest.outcomes
84
import pytest
95

10-
from pytest_run_parallel.utils import (
11-
ThreadComparator,
12-
get_configured_num_workers,
13-
get_num_workers,
6+
from pytest_run_parallel.thread_unsafe_detection import (
7+
THREAD_UNSAFE_FIXTURES,
148
identify_thread_unsafe_nodes,
159
)
10+
from pytest_run_parallel.thread_wrapper import wrap_function_parallel
11+
from pytest_run_parallel.utils import get_configured_num_workers, get_num_workers
1612

1713

1814
def pytest_addoption(parser):
@@ -72,81 +68,6 @@ def pytest_configure(config):
7268
)
7369

7470

75-
def wrap_function_parallel(fn, n_workers, n_iterations):
76-
@functools.wraps(fn)
77-
def inner(*args, **kwargs):
78-
errors = []
79-
skip = None
80-
failed = None
81-
barrier = threading.Barrier(n_workers)
82-
original_switch = sys.getswitchinterval()
83-
new_switch = 1e-6
84-
for _ in range(3):
85-
try:
86-
sys.setswitchinterval(new_switch)
87-
break
88-
except ValueError:
89-
new_switch *= 10
90-
else:
91-
sys.setswitchinterval(original_switch)
92-
93-
try:
94-
95-
def closure(*args, **kwargs):
96-
for _ in range(n_iterations):
97-
barrier.wait()
98-
try:
99-
fn(*args, **kwargs)
100-
except Warning:
101-
pass
102-
except Exception as e:
103-
errors.append(e)
104-
except _pytest.outcomes.Skipped as s:
105-
nonlocal skip
106-
skip = s.msg
107-
except _pytest.outcomes.Failed as f:
108-
nonlocal failed
109-
failed = f
110-
111-
workers = []
112-
for _ in range(0, n_workers):
113-
worker_kwargs = kwargs
114-
workers.append(
115-
threading.Thread(target=closure, args=args, kwargs=worker_kwargs)
116-
)
117-
118-
num_completed = 0
119-
try:
120-
for worker in workers:
121-
worker.start()
122-
num_completed += 1
123-
finally:
124-
if num_completed < len(workers):
125-
barrier.abort()
126-
127-
for worker in workers:
128-
worker.join()
129-
130-
finally:
131-
sys.setswitchinterval(original_switch)
132-
133-
if skip is not None:
134-
pytest.skip(skip)
135-
elif failed is not None:
136-
raise failed
137-
elif errors:
138-
raise errors[0]
139-
140-
return inner
141-
142-
143-
_thread_unsafe_fixtures = {
144-
"capsys",
145-
"monkeypatch",
146-
"recwarn",
147-
}
148-
149-
15071
@pytest.hookimpl(trylast=True)
15172
def pytest_itemcollected(item):
15273
n_workers = get_num_workers(item.config, item)
@@ -207,7 +128,7 @@ def pytest_itemcollected(item):
207128
else:
208129
item.add_marker(pytest.mark.parallel_threads(1))
209130

210-
unsafe_fixtures = _thread_unsafe_fixtures | set(
131+
unsafe_fixtures = THREAD_UNSAFE_FIXTURES | set(
211132
item.config.getini("thread_unsafe_fixtures")
212133
)
213134

@@ -307,23 +228,3 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
307228
)
308229
if n_workers > 1 and num_serial == 0:
309230
terminalreporter.line("All tests were run in parallel! 🎉")
310-
311-
312-
@pytest.fixture
313-
def num_parallel_threads(request):
314-
return get_num_workers(request.config, request.node)
315-
316-
317-
@pytest.fixture
318-
def num_iterations(request):
319-
node = request.node
320-
n_iterations = request.config.option.iterations
321-
m = node.get_closest_marker("iterations")
322-
if m is not None:
323-
n_iterations = int(m.args[0])
324-
return n_iterations
325-
326-
327-
@pytest.fixture
328-
def thread_comp(num_parallel_threads):
329-
return ThreadComparator(num_parallel_threads)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import threading
2+
import types
3+
4+
try:
5+
import numpy as np
6+
7+
numpy_available = True
8+
except ImportError:
9+
numpy_available = False
10+
11+
12+
class ThreadComparator:
13+
def __init__(self, n_threads):
14+
self._barrier = threading.Barrier(n_threads)
15+
self._reset_evt = threading.Event()
16+
self._entry_barrier = threading.Barrier(n_threads)
17+
18+
self._thread_ids = []
19+
self._values = {}
20+
self._entry_lock = threading.Lock()
21+
self._entry_counter = 0
22+
23+
def __call__(self, **values):
24+
"""
25+
Compares a set of values across threads.
26+
27+
For each value, type equality as well as comparison takes place. If any
28+
of the values is a function, then address comparison is performed.
29+
Also, if any of the values is a `numpy.ndarray`, then approximate
30+
numerical comparison is performed.
31+
"""
32+
tid = id(threading.current_thread())
33+
self._entry_barrier.wait()
34+
with self._entry_lock:
35+
if self._entry_counter == 0:
36+
# Reset state before comparison
37+
self._barrier.reset()
38+
self._reset_evt.clear()
39+
self._thread_ids = []
40+
self._values = {}
41+
self._entry_barrier.reset()
42+
self._entry_counter += 1
43+
44+
self._values[tid] = values
45+
self._thread_ids.append(tid)
46+
self._barrier.wait()
47+
48+
if tid == self._thread_ids[0]:
49+
thread_ids = list(self._values)
50+
try:
51+
for value_name in values:
52+
for i in range(1, len(thread_ids)):
53+
tid_a = thread_ids[i - 1]
54+
tid_b = thread_ids[i]
55+
value_a = self._values[tid_a][value_name]
56+
value_b = self._values[tid_b][value_name]
57+
assert type(value_a) is type(value_b)
58+
if numpy_available and isinstance(value_a, np.ndarray):
59+
if len(value_a.shape) == 0:
60+
assert value_a == value_b
61+
else:
62+
assert np.allclose(value_a, value_b, equal_nan=True)
63+
elif isinstance(value_a, types.FunctionType):
64+
assert id(value_a) == id(value_b)
65+
elif value_a != value_a:
66+
assert value_b != value_b
67+
else:
68+
assert value_a == value_b
69+
finally:
70+
self._entry_counter = 0
71+
self._reset_evt.set()
72+
else:
73+
self._reset_evt.wait()

0 commit comments

Comments
 (0)