Skip to content

Commit e820346

Browse files
committed
Refactor plugin into multiple files
1 parent 48180f2 commit e820346

11 files changed

+869
-866
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
def get_logical_cpus():
2+
try:
3+
import psutil
4+
except ImportError:
5+
pass
6+
else:
7+
process = psutil.Process()
8+
try:
9+
cpu_cores = process.cpu_affinity()
10+
return len(cpu_cores)
11+
except AttributeError:
12+
cpu_cores = psutil.cpu_count()
13+
if cpu_cores is not None:
14+
return cpu_cores
15+
16+
try:
17+
from os import process_cpu_count
18+
except ImportError:
19+
pass
20+
else:
21+
cpu_cores = process_cpu_count()
22+
if cpu_cores is not None:
23+
return cpu_cores
24+
25+
try:
26+
from os import sched_getaffinity
27+
except ImportError:
28+
pass
29+
else:
30+
cpu_cores = sched_getaffinity(0)
31+
if cpu_cores is not None:
32+
return len(cpu_cores)
33+
34+
from os import cpu_count
35+
36+
return cpu_count()

src/pytest_run_parallel/plugin.py

Lines changed: 6 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
import functools
21
import os
3-
import sys
4-
import threading
52
import warnings
63

7-
import _pytest.outcomes
84
import pytest
95

10-
from pytest_run_parallel.utils import (
11-
ThreadComparator,
12-
get_configured_num_workers,
13-
get_num_workers,
6+
from pytest_run_parallel.thread_comparator import ThreadComparator
7+
from pytest_run_parallel.thread_unsafe_detection import (
8+
THREAD_UNSAFE_FIXTURES,
149
identify_thread_unsafe_nodes,
1510
)
11+
from pytest_run_parallel.thread_wrapper import wrap_function_parallel
12+
from pytest_run_parallel.utils import get_configured_num_workers, get_num_workers
1613

1714

1815
def pytest_addoption(parser):
@@ -72,81 +69,6 @@ def pytest_configure(config):
7269
)
7370

7471

75-
def wrap_function_parallel(fn, n_workers, n_iterations):
76-
@functools.wraps(fn)
77-
def inner(*args, **kwargs):
78-
errors = []
79-
skip = None
80-
failed = None
81-
barrier = threading.Barrier(n_workers)
82-
original_switch = sys.getswitchinterval()
83-
new_switch = 1e-6
84-
for _ in range(3):
85-
try:
86-
sys.setswitchinterval(new_switch)
87-
break
88-
except ValueError:
89-
new_switch *= 10
90-
else:
91-
sys.setswitchinterval(original_switch)
92-
93-
try:
94-
95-
def closure(*args, **kwargs):
96-
for _ in range(n_iterations):
97-
barrier.wait()
98-
try:
99-
fn(*args, **kwargs)
100-
except Warning:
101-
pass
102-
except Exception as e:
103-
errors.append(e)
104-
except _pytest.outcomes.Skipped as s:
105-
nonlocal skip
106-
skip = s.msg
107-
except _pytest.outcomes.Failed as f:
108-
nonlocal failed
109-
failed = f
110-
111-
workers = []
112-
for _ in range(0, n_workers):
113-
worker_kwargs = kwargs
114-
workers.append(
115-
threading.Thread(target=closure, args=args, kwargs=worker_kwargs)
116-
)
117-
118-
num_completed = 0
119-
try:
120-
for worker in workers:
121-
worker.start()
122-
num_completed += 1
123-
finally:
124-
if num_completed < len(workers):
125-
barrier.abort()
126-
127-
for worker in workers:
128-
worker.join()
129-
130-
finally:
131-
sys.setswitchinterval(original_switch)
132-
133-
if skip is not None:
134-
pytest.skip(skip)
135-
elif failed is not None:
136-
raise failed
137-
elif errors:
138-
raise errors[0]
139-
140-
return inner
141-
142-
143-
_thread_unsafe_fixtures = {
144-
"capsys",
145-
"monkeypatch",
146-
"recwarn",
147-
}
148-
149-
15072
@pytest.hookimpl(trylast=True)
15173
def pytest_itemcollected(item):
15274
n_workers = get_num_workers(item.config, item)
@@ -207,7 +129,7 @@ def pytest_itemcollected(item):
207129
else:
208130
item.add_marker(pytest.mark.parallel_threads(1))
209131

210-
unsafe_fixtures = _thread_unsafe_fixtures | set(
132+
unsafe_fixtures = THREAD_UNSAFE_FIXTURES | set(
211133
item.config.getini("thread_unsafe_fixtures")
212134
)
213135

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import threading
2+
import types
3+
4+
try:
5+
import numpy as np
6+
7+
numpy_available = True
8+
except ImportError:
9+
numpy_available = False
10+
11+
12+
class ThreadComparator:
13+
def __init__(self, n_threads):
14+
self._barrier = threading.Barrier(n_threads)
15+
self._reset_evt = threading.Event()
16+
self._entry_barrier = threading.Barrier(n_threads)
17+
18+
self._thread_ids = []
19+
self._values = {}
20+
self._entry_lock = threading.Lock()
21+
self._entry_counter = 0
22+
23+
def __call__(self, **values):
24+
"""
25+
Compares a set of values across threads.
26+
27+
For each value, type equality as well as comparison takes place. If any
28+
of the values is a function, then address comparison is performed.
29+
Also, if any of the values is a `numpy.ndarray`, then approximate
30+
numerical comparison is performed.
31+
"""
32+
tid = id(threading.current_thread())
33+
self._entry_barrier.wait()
34+
with self._entry_lock:
35+
if self._entry_counter == 0:
36+
# Reset state before comparison
37+
self._barrier.reset()
38+
self._reset_evt.clear()
39+
self._thread_ids = []
40+
self._values = {}
41+
self._entry_barrier.reset()
42+
self._entry_counter += 1
43+
44+
self._values[tid] = values
45+
self._thread_ids.append(tid)
46+
self._barrier.wait()
47+
48+
if tid == self._thread_ids[0]:
49+
thread_ids = list(self._values)
50+
try:
51+
for value_name in values:
52+
for i in range(1, len(thread_ids)):
53+
tid_a = thread_ids[i - 1]
54+
tid_b = thread_ids[i]
55+
value_a = self._values[tid_a][value_name]
56+
value_b = self._values[tid_b][value_name]
57+
assert type(value_a) is type(value_b)
58+
if numpy_available and isinstance(value_a, np.ndarray):
59+
if len(value_a.shape) == 0:
60+
assert value_a == value_b
61+
else:
62+
assert np.allclose(value_a, value_b, equal_nan=True)
63+
elif isinstance(value_a, types.FunctionType):
64+
assert id(value_a) == id(value_b)
65+
elif value_a != value_a:
66+
assert value_b != value_b
67+
else:
68+
assert value_a == value_b
69+
finally:
70+
self._entry_counter = 0
71+
self._reset_evt.set()
72+
else:
73+
self._reset_evt.wait()
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import ast
2+
import functools
3+
import inspect
4+
from textwrap import dedent
5+
6+
try:
7+
# added in hypothesis 6.131.0
8+
from hypothesis import is_hypothesis_test
9+
except ImportError:
10+
try:
11+
# hypothesis versions < 6.131.0
12+
from hypothesis.internal.detection import is_hypothesis_test
13+
except ImportError:
14+
# hypothesis isn't installed
15+
def is_hypothesis_test(fn):
16+
return False
17+
18+
19+
THREAD_UNSAFE_FIXTURES = {
20+
"capsys",
21+
"monkeypatch",
22+
"recwarn",
23+
}
24+
25+
26+
class ThreadUnsafeNodeVisitor(ast.NodeVisitor):
27+
def __init__(self, fn, skip_set, level=0):
28+
self.thread_unsafe = False
29+
self.thread_unsafe_reason = None
30+
self.blacklist = {
31+
("pytest", "warns"),
32+
("pytest", "deprecated_call"),
33+
("_pytest.recwarn", "warns"),
34+
("_pytest.recwarn", "deprecated_call"),
35+
("warnings", "catch_warnings"),
36+
("mock", "patch"), # unittest.mock
37+
} | set(skip_set)
38+
modules = {mod.split(".")[0] for mod, _ in self.blacklist}
39+
modules |= {mod for mod, _ in self.blacklist}
40+
41+
self.fn = fn
42+
self.skip_set = skip_set
43+
self.level = level
44+
self.modules_aliases = {}
45+
self.func_aliases = {}
46+
for var_name in getattr(fn, "__globals__", {}):
47+
value = fn.__globals__[var_name]
48+
if inspect.ismodule(value) and value.__name__ in modules:
49+
self.modules_aliases[var_name] = value.__name__
50+
elif inspect.isfunction(value):
51+
real_name = value.__name__
52+
for mod in modules:
53+
if mod == value.__module__:
54+
self.func_aliases[var_name] = (mod, real_name)
55+
break
56+
57+
super().__init__()
58+
59+
def visit_Call(self, node):
60+
if self.thread_unsafe:
61+
return
62+
63+
if isinstance(node.func, ast.Attribute):
64+
if isinstance(node.func.value, ast.Name):
65+
real_mod = node.func.value.id
66+
if real_mod in self.modules_aliases:
67+
real_mod = self.modules_aliases[real_mod]
68+
if (real_mod, node.func.attr) in self.blacklist:
69+
self.thread_unsafe = True
70+
self.thread_unsafe_reason = (
71+
"calls thread-unsafe function: " f"{real_mod}.{node.func.attr}"
72+
)
73+
elif self.level < 2:
74+
if node.func.value.id in getattr(self.fn, "__globals__", {}):
75+
mod = self.fn.__globals__[node.func.value.id]
76+
child_fn = getattr(mod, node.func.attr, None)
77+
if child_fn is not None:
78+
self.thread_unsafe, self.thread_unsafe_reason = (
79+
identify_thread_unsafe_nodes(
80+
child_fn, self.skip_set, self.level + 1
81+
)
82+
)
83+
elif isinstance(node.func, ast.Name):
84+
recurse = True
85+
if node.func.id in self.func_aliases:
86+
if self.func_aliases[node.func.id] in self.blacklist:
87+
self.thread_unsafe = True
88+
self.thread_unsafe_reason = (
89+
f"calls thread-unsafe function: {node.func.id}"
90+
)
91+
recurse = False
92+
if recurse and self.level < 2:
93+
if node.func.id in getattr(self.fn, "__globals__", {}):
94+
child_fn = self.fn.__globals__[node.func.id]
95+
self.thread_unsafe, self.thread_unsafe_reason = (
96+
identify_thread_unsafe_nodes(
97+
child_fn, self.skip_set, self.level + 1
98+
)
99+
)
100+
101+
def visit_Assign(self, node):
102+
if self.thread_unsafe:
103+
return
104+
105+
if len(node.targets) == 1:
106+
name_node = node.targets[0]
107+
value_node = node.value
108+
if getattr(name_node, "id", None) == "__thread_safe__":
109+
self.thread_unsafe = not bool(value_node.value)
110+
self.thread_unsafe_reason = (
111+
f"calls thread-unsafe function: f{name_node} "
112+
"(inferred via func.__thread_safe__ == False)"
113+
)
114+
else:
115+
self.generic_visit(node)
116+
117+
118+
def _identify_thread_unsafe_nodes(fn, skip_set, level=0):
119+
if is_hypothesis_test(fn):
120+
return True, "uses hypothesis"
121+
try:
122+
src = inspect.getsource(fn)
123+
tree = ast.parse(dedent(src))
124+
except Exception:
125+
return False, None
126+
visitor = ThreadUnsafeNodeVisitor(fn, skip_set, level=level)
127+
visitor.visit(tree)
128+
return visitor.thread_unsafe, visitor.thread_unsafe_reason
129+
130+
131+
cached_thread_unsafe_identify = functools.lru_cache(_identify_thread_unsafe_nodes)
132+
133+
134+
def identify_thread_unsafe_nodes(fn, skip_set, level=0):
135+
try:
136+
return cached_thread_unsafe_identify(fn, skip_set, level=level)
137+
except TypeError:
138+
return _identify_thread_unsafe_nodes(fn, skip_set, level=level)

0 commit comments

Comments
 (0)