From e820346991e406a61e1a3279dc467dbfeee23c07 Mon Sep 17 00:00:00 2001 From: Lysandros Nikolaou Date: Wed, 11 Jun 2025 11:53:09 +0200 Subject: [PATCH 1/2] Refactor plugin into multiple files --- src/pytest_run_parallel/cpu_detection.py | 36 ++ src/pytest_run_parallel/plugin.py | 90 +-- src/pytest_run_parallel/thread_comparator.py | 73 +++ .../thread_unsafe_detection.py | 138 +++++ src/pytest_run_parallel/thread_wrapper.py | 74 +++ src/pytest_run_parallel/utils.py | 243 +------- tests/test_cpu_detection.py | 160 ++++++ tests/test_run_parallel.py | 539 ------------------ tests/test_thread_comparator.py | 54 ++ tests/test_thread_unsafe_detection.py | 326 +++++++++++ uv.lock | 2 +- 11 files changed, 869 insertions(+), 866 deletions(-) create mode 100644 src/pytest_run_parallel/cpu_detection.py create mode 100644 src/pytest_run_parallel/thread_comparator.py create mode 100644 src/pytest_run_parallel/thread_unsafe_detection.py create mode 100644 src/pytest_run_parallel/thread_wrapper.py create mode 100644 tests/test_cpu_detection.py create mode 100644 tests/test_thread_comparator.py create mode 100644 tests/test_thread_unsafe_detection.py diff --git a/src/pytest_run_parallel/cpu_detection.py b/src/pytest_run_parallel/cpu_detection.py new file mode 100644 index 0000000..936e99f --- /dev/null +++ b/src/pytest_run_parallel/cpu_detection.py @@ -0,0 +1,36 @@ +def get_logical_cpus(): + try: + import psutil + except ImportError: + pass + else: + process = psutil.Process() + try: + cpu_cores = process.cpu_affinity() + return len(cpu_cores) + except AttributeError: + cpu_cores = psutil.cpu_count() + if cpu_cores is not None: + return cpu_cores + + try: + from os import process_cpu_count + except ImportError: + pass + else: + cpu_cores = process_cpu_count() + if cpu_cores is not None: + return cpu_cores + + try: + from os import sched_getaffinity + except ImportError: + pass + else: + cpu_cores = sched_getaffinity(0) + if cpu_cores is not None: + return len(cpu_cores) + + from os import cpu_count + + return cpu_count() diff --git a/src/pytest_run_parallel/plugin.py b/src/pytest_run_parallel/plugin.py index 35020bb..b9d97dc 100644 --- a/src/pytest_run_parallel/plugin.py +++ b/src/pytest_run_parallel/plugin.py @@ -1,18 +1,15 @@ -import functools import os -import sys -import threading import warnings -import _pytest.outcomes import pytest -from pytest_run_parallel.utils import ( - ThreadComparator, - get_configured_num_workers, - get_num_workers, +from pytest_run_parallel.thread_comparator import ThreadComparator +from pytest_run_parallel.thread_unsafe_detection import ( + THREAD_UNSAFE_FIXTURES, identify_thread_unsafe_nodes, ) +from pytest_run_parallel.thread_wrapper import wrap_function_parallel +from pytest_run_parallel.utils import get_configured_num_workers, get_num_workers def pytest_addoption(parser): @@ -72,81 +69,6 @@ def pytest_configure(config): ) -def wrap_function_parallel(fn, n_workers, n_iterations): - @functools.wraps(fn) - def inner(*args, **kwargs): - errors = [] - skip = None - failed = None - barrier = threading.Barrier(n_workers) - original_switch = sys.getswitchinterval() - new_switch = 1e-6 - for _ in range(3): - try: - sys.setswitchinterval(new_switch) - break - except ValueError: - new_switch *= 10 - else: - sys.setswitchinterval(original_switch) - - try: - - def closure(*args, **kwargs): - for _ in range(n_iterations): - barrier.wait() - try: - fn(*args, **kwargs) - except Warning: - pass - except Exception as e: - errors.append(e) - except _pytest.outcomes.Skipped as s: - nonlocal skip - skip = s.msg - except _pytest.outcomes.Failed as f: - nonlocal failed - failed = f - - workers = [] - for _ in range(0, n_workers): - worker_kwargs = kwargs - workers.append( - threading.Thread(target=closure, args=args, kwargs=worker_kwargs) - ) - - num_completed = 0 - try: - for worker in workers: - worker.start() - num_completed += 1 - finally: - if num_completed < len(workers): - barrier.abort() - - for worker in workers: - worker.join() - - finally: - sys.setswitchinterval(original_switch) - - if skip is not None: - pytest.skip(skip) - elif failed is not None: - raise failed - elif errors: - raise errors[0] - - return inner - - -_thread_unsafe_fixtures = { - "capsys", - "monkeypatch", - "recwarn", -} - - @pytest.hookimpl(trylast=True) def pytest_itemcollected(item): n_workers = get_num_workers(item.config, item) @@ -207,7 +129,7 @@ def pytest_itemcollected(item): else: item.add_marker(pytest.mark.parallel_threads(1)) - unsafe_fixtures = _thread_unsafe_fixtures | set( + unsafe_fixtures = THREAD_UNSAFE_FIXTURES | set( item.config.getini("thread_unsafe_fixtures") ) diff --git a/src/pytest_run_parallel/thread_comparator.py b/src/pytest_run_parallel/thread_comparator.py new file mode 100644 index 0000000..75e2a08 --- /dev/null +++ b/src/pytest_run_parallel/thread_comparator.py @@ -0,0 +1,73 @@ +import threading +import types + +try: + import numpy as np + + numpy_available = True +except ImportError: + numpy_available = False + + +class ThreadComparator: + def __init__(self, n_threads): + self._barrier = threading.Barrier(n_threads) + self._reset_evt = threading.Event() + self._entry_barrier = threading.Barrier(n_threads) + + self._thread_ids = [] + self._values = {} + self._entry_lock = threading.Lock() + self._entry_counter = 0 + + def __call__(self, **values): + """ + Compares a set of values across threads. + + For each value, type equality as well as comparison takes place. If any + of the values is a function, then address comparison is performed. + Also, if any of the values is a `numpy.ndarray`, then approximate + numerical comparison is performed. + """ + tid = id(threading.current_thread()) + self._entry_barrier.wait() + with self._entry_lock: + if self._entry_counter == 0: + # Reset state before comparison + self._barrier.reset() + self._reset_evt.clear() + self._thread_ids = [] + self._values = {} + self._entry_barrier.reset() + self._entry_counter += 1 + + self._values[tid] = values + self._thread_ids.append(tid) + self._barrier.wait() + + if tid == self._thread_ids[0]: + thread_ids = list(self._values) + try: + for value_name in values: + for i in range(1, len(thread_ids)): + tid_a = thread_ids[i - 1] + tid_b = thread_ids[i] + value_a = self._values[tid_a][value_name] + value_b = self._values[tid_b][value_name] + assert type(value_a) is type(value_b) + if numpy_available and isinstance(value_a, np.ndarray): + if len(value_a.shape) == 0: + assert value_a == value_b + else: + assert np.allclose(value_a, value_b, equal_nan=True) + elif isinstance(value_a, types.FunctionType): + assert id(value_a) == id(value_b) + elif value_a != value_a: + assert value_b != value_b + else: + assert value_a == value_b + finally: + self._entry_counter = 0 + self._reset_evt.set() + else: + self._reset_evt.wait() diff --git a/src/pytest_run_parallel/thread_unsafe_detection.py b/src/pytest_run_parallel/thread_unsafe_detection.py new file mode 100644 index 0000000..2687a44 --- /dev/null +++ b/src/pytest_run_parallel/thread_unsafe_detection.py @@ -0,0 +1,138 @@ +import ast +import functools +import inspect +from textwrap import dedent + +try: + # added in hypothesis 6.131.0 + from hypothesis import is_hypothesis_test +except ImportError: + try: + # hypothesis versions < 6.131.0 + from hypothesis.internal.detection import is_hypothesis_test + except ImportError: + # hypothesis isn't installed + def is_hypothesis_test(fn): + return False + + +THREAD_UNSAFE_FIXTURES = { + "capsys", + "monkeypatch", + "recwarn", +} + + +class ThreadUnsafeNodeVisitor(ast.NodeVisitor): + def __init__(self, fn, skip_set, level=0): + self.thread_unsafe = False + self.thread_unsafe_reason = None + self.blacklist = { + ("pytest", "warns"), + ("pytest", "deprecated_call"), + ("_pytest.recwarn", "warns"), + ("_pytest.recwarn", "deprecated_call"), + ("warnings", "catch_warnings"), + ("mock", "patch"), # unittest.mock + } | set(skip_set) + modules = {mod.split(".")[0] for mod, _ in self.blacklist} + modules |= {mod for mod, _ in self.blacklist} + + self.fn = fn + self.skip_set = skip_set + self.level = level + self.modules_aliases = {} + self.func_aliases = {} + for var_name in getattr(fn, "__globals__", {}): + value = fn.__globals__[var_name] + if inspect.ismodule(value) and value.__name__ in modules: + self.modules_aliases[var_name] = value.__name__ + elif inspect.isfunction(value): + real_name = value.__name__ + for mod in modules: + if mod == value.__module__: + self.func_aliases[var_name] = (mod, real_name) + break + + super().__init__() + + def visit_Call(self, node): + if self.thread_unsafe: + return + + if isinstance(node.func, ast.Attribute): + if isinstance(node.func.value, ast.Name): + real_mod = node.func.value.id + if real_mod in self.modules_aliases: + real_mod = self.modules_aliases[real_mod] + if (real_mod, node.func.attr) in self.blacklist: + self.thread_unsafe = True + self.thread_unsafe_reason = ( + "calls thread-unsafe function: " f"{real_mod}.{node.func.attr}" + ) + elif self.level < 2: + if node.func.value.id in getattr(self.fn, "__globals__", {}): + mod = self.fn.__globals__[node.func.value.id] + child_fn = getattr(mod, node.func.attr, None) + if child_fn is not None: + self.thread_unsafe, self.thread_unsafe_reason = ( + identify_thread_unsafe_nodes( + child_fn, self.skip_set, self.level + 1 + ) + ) + elif isinstance(node.func, ast.Name): + recurse = True + if node.func.id in self.func_aliases: + if self.func_aliases[node.func.id] in self.blacklist: + self.thread_unsafe = True + self.thread_unsafe_reason = ( + f"calls thread-unsafe function: {node.func.id}" + ) + recurse = False + if recurse and self.level < 2: + if node.func.id in getattr(self.fn, "__globals__", {}): + child_fn = self.fn.__globals__[node.func.id] + self.thread_unsafe, self.thread_unsafe_reason = ( + identify_thread_unsafe_nodes( + child_fn, self.skip_set, self.level + 1 + ) + ) + + def visit_Assign(self, node): + if self.thread_unsafe: + return + + if len(node.targets) == 1: + name_node = node.targets[0] + value_node = node.value + if getattr(name_node, "id", None) == "__thread_safe__": + self.thread_unsafe = not bool(value_node.value) + self.thread_unsafe_reason = ( + f"calls thread-unsafe function: f{name_node} " + "(inferred via func.__thread_safe__ == False)" + ) + else: + self.generic_visit(node) + + +def _identify_thread_unsafe_nodes(fn, skip_set, level=0): + if is_hypothesis_test(fn): + return True, "uses hypothesis" + try: + src = inspect.getsource(fn) + tree = ast.parse(dedent(src)) + except Exception: + return False, None + visitor = ThreadUnsafeNodeVisitor(fn, skip_set, level=level) + visitor.visit(tree) + return visitor.thread_unsafe, visitor.thread_unsafe_reason + + +cached_thread_unsafe_identify = functools.lru_cache(_identify_thread_unsafe_nodes) + + +def identify_thread_unsafe_nodes(fn, skip_set, level=0): + try: + return cached_thread_unsafe_identify(fn, skip_set, level=level) + except TypeError: + return _identify_thread_unsafe_nodes(fn, skip_set, level=level) diff --git a/src/pytest_run_parallel/thread_wrapper.py b/src/pytest_run_parallel/thread_wrapper.py new file mode 100644 index 0000000..525bdfb --- /dev/null +++ b/src/pytest_run_parallel/thread_wrapper.py @@ -0,0 +1,74 @@ +import functools +import sys +import threading + +import _pytest.outcomes +import pytest + + +def wrap_function_parallel(fn, n_workers, n_iterations): + @functools.wraps(fn) + def inner(*args, **kwargs): + errors = [] + skip = None + failed = None + barrier = threading.Barrier(n_workers) + original_switch = sys.getswitchinterval() + new_switch = 1e-6 + for _ in range(3): + try: + sys.setswitchinterval(new_switch) + break + except ValueError: + new_switch *= 10 + else: + sys.setswitchinterval(original_switch) + + try: + + def closure(*args, **kwargs): + for _ in range(n_iterations): + barrier.wait() + try: + fn(*args, **kwargs) + except Warning: + pass + except Exception as e: + errors.append(e) + except _pytest.outcomes.Skipped as s: + nonlocal skip + skip = s.msg + except _pytest.outcomes.Failed as f: + nonlocal failed + failed = f + + workers = [] + for _ in range(0, n_workers): + worker_kwargs = kwargs + workers.append( + threading.Thread(target=closure, args=args, kwargs=worker_kwargs) + ) + + num_completed = 0 + try: + for worker in workers: + worker.start() + num_completed += 1 + finally: + if num_completed < len(workers): + barrier.abort() + + for worker in workers: + worker.join() + + finally: + sys.setswitchinterval(original_switch) + + if skip is not None: + pytest.skip(skip) + elif failed is not None: + raise failed + elif errors: + raise errors[0] + + return inner diff --git a/src/pytest_run_parallel/utils.py b/src/pytest_run_parallel/utils.py index 8ad5410..29dea62 100644 --- a/src/pytest_run_parallel/utils.py +++ b/src/pytest_run_parallel/utils.py @@ -1,245 +1,4 @@ -import ast -import functools -import inspect -import threading -import types -from textwrap import dedent - -try: - import numpy as np - - numpy_available = True -except ImportError: - numpy_available = False - -try: - # added in hypothesis 6.131.0 - from hypothesis import is_hypothesis_test -except ImportError: - try: - # hypothesis versions < 6.131.0 - from hypothesis.internal.detection import is_hypothesis_test - except ImportError: - # hypothesis isn't installed - def is_hypothesis_test(fn): - return False - - -class ThreadUnsafeNodeVisitor(ast.NodeVisitor): - def __init__(self, fn, skip_set, level=0): - self.thread_unsafe = False - self.thread_unsafe_reason = None - self.blacklist = { - ("pytest", "warns"), - ("pytest", "deprecated_call"), - ("_pytest.recwarn", "warns"), - ("_pytest.recwarn", "deprecated_call"), - ("warnings", "catch_warnings"), - ("mock", "patch"), # unittest.mock - } | set(skip_set) - modules = {mod.split(".")[0] for mod, _ in self.blacklist} - modules |= {mod for mod, _ in self.blacklist} - - self.fn = fn - self.skip_set = skip_set - self.level = level - self.modules_aliases = {} - self.func_aliases = {} - for var_name in getattr(fn, "__globals__", {}): - value = fn.__globals__[var_name] - if inspect.ismodule(value) and value.__name__ in modules: - self.modules_aliases[var_name] = value.__name__ - elif inspect.isfunction(value): - real_name = value.__name__ - for mod in modules: - if mod == value.__module__: - self.func_aliases[var_name] = (mod, real_name) - break - - super().__init__() - - def visit_Call(self, node): - if self.thread_unsafe: - return - - if isinstance(node.func, ast.Attribute): - if isinstance(node.func.value, ast.Name): - real_mod = node.func.value.id - if real_mod in self.modules_aliases: - real_mod = self.modules_aliases[real_mod] - if (real_mod, node.func.attr) in self.blacklist: - self.thread_unsafe = True - self.thread_unsafe_reason = ( - "calls thread-unsafe function: " f"{real_mod}.{node.func.attr}" - ) - elif self.level < 2: - if node.func.value.id in getattr(self.fn, "__globals__", {}): - mod = self.fn.__globals__[node.func.value.id] - child_fn = getattr(mod, node.func.attr, None) - if child_fn is not None: - self.thread_unsafe, self.thread_unsafe_reason = ( - identify_thread_unsafe_nodes( - child_fn, self.skip_set, self.level + 1 - ) - ) - elif isinstance(node.func, ast.Name): - recurse = True - if node.func.id in self.func_aliases: - if self.func_aliases[node.func.id] in self.blacklist: - self.thread_unsafe = True - self.thread_unsafe_reason = ( - f"calls thread-unsafe function: {node.func.id}" - ) - recurse = False - if recurse and self.level < 2: - if node.func.id in getattr(self.fn, "__globals__", {}): - child_fn = self.fn.__globals__[node.func.id] - self.thread_unsafe, self.thread_unsafe_reason = ( - identify_thread_unsafe_nodes( - child_fn, self.skip_set, self.level + 1 - ) - ) - - def visit_Assign(self, node): - if self.thread_unsafe: - return - - if len(node.targets) == 1: - name_node = node.targets[0] - value_node = node.value - if getattr(name_node, "id", None) == "__thread_safe__": - self.thread_unsafe = not bool(value_node.value) - self.thread_unsafe_reason = ( - f"calls thread-unsafe function: f{name_node} " - "(inferred via func.__thread_safe__ == False)" - ) - else: - self.generic_visit(node) - - -def _identify_thread_unsafe_nodes(fn, skip_set, level=0): - if is_hypothesis_test(fn): - return True, "uses hypothesis" - try: - src = inspect.getsource(fn) - tree = ast.parse(dedent(src)) - except Exception: - return False, None - visitor = ThreadUnsafeNodeVisitor(fn, skip_set, level=level) - visitor.visit(tree) - return visitor.thread_unsafe, visitor.thread_unsafe_reason - - -cached_thread_unsafe_identify = functools.lru_cache(_identify_thread_unsafe_nodes) - - -def identify_thread_unsafe_nodes(fn, skip_set, level=0): - try: - return cached_thread_unsafe_identify(fn, skip_set, level=level) - except TypeError: - return _identify_thread_unsafe_nodes(fn, skip_set, level=level) - - -class ThreadComparator: - def __init__(self, n_threads): - self._barrier = threading.Barrier(n_threads) - self._reset_evt = threading.Event() - self._entry_barrier = threading.Barrier(n_threads) - - self._thread_ids = [] - self._values = {} - self._entry_lock = threading.Lock() - self._entry_counter = 0 - - def __call__(self, **values): - """ - Compares a set of values across threads. - - For each value, type equality as well as comparison takes place. If any - of the values is a function, then address comparison is performed. - Also, if any of the values is a `numpy.ndarray`, then approximate - numerical comparison is performed. - """ - tid = id(threading.current_thread()) - self._entry_barrier.wait() - with self._entry_lock: - if self._entry_counter == 0: - # Reset state before comparison - self._barrier.reset() - self._reset_evt.clear() - self._thread_ids = [] - self._values = {} - self._entry_barrier.reset() - self._entry_counter += 1 - - self._values[tid] = values - self._thread_ids.append(tid) - self._barrier.wait() - - if tid == self._thread_ids[0]: - thread_ids = list(self._values) - try: - for value_name in values: - for i in range(1, len(thread_ids)): - tid_a = thread_ids[i - 1] - tid_b = thread_ids[i] - value_a = self._values[tid_a][value_name] - value_b = self._values[tid_b][value_name] - assert type(value_a) is type(value_b) - if numpy_available and isinstance(value_a, np.ndarray): - if len(value_a.shape) == 0: - assert value_a == value_b - else: - assert np.allclose(value_a, value_b, equal_nan=True) - elif isinstance(value_a, types.FunctionType): - assert id(value_a) == id(value_b) - elif value_a != value_a: - assert value_b != value_b - else: - assert value_a == value_b - finally: - self._entry_counter = 0 - self._reset_evt.set() - else: - self._reset_evt.wait() - - -def get_logical_cpus(): - try: - import psutil - except ImportError: - pass - else: - process = psutil.Process() - try: - cpu_cores = process.cpu_affinity() - return len(cpu_cores) - except AttributeError: - cpu_cores = psutil.cpu_count() - if cpu_cores is not None: - return cpu_cores - - try: - from os import process_cpu_count - except ImportError: - pass - else: - cpu_cores = process_cpu_count() - if cpu_cores is not None: - return cpu_cores - - try: - from os import sched_getaffinity - except ImportError: - pass - else: - cpu_cores = sched_getaffinity(0) - if cpu_cores is not None: - return len(cpu_cores) - - from os import cpu_count - - return cpu_count() +from pytest_run_parallel.cpu_detection import get_logical_cpus def get_configured_num_workers(config): diff --git a/tests/test_cpu_detection.py b/tests/test_cpu_detection.py new file mode 100644 index 0000000..b2f08c1 --- /dev/null +++ b/tests/test_cpu_detection.py @@ -0,0 +1,160 @@ +from contextlib import suppress + +import pytest + +try: + import psutil +except ImportError: + psutil = None + +try: + from os import process_cpu_count +except ImportError: + process_cpu_count = None + +try: + from os import sched_getaffinity +except ImportError: + sched_getaffinity = None + + +@pytest.mark.skipif(psutil is None, reason="psutil needs to be installed") +def test_auto_detect_cpus_psutil_affinity( + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch +) -> None: + import psutil + + monkeypatch.setattr( + psutil.Process, "cpu_affinity", lambda self: list(range(10)), raising=False + ) + + pytester.makepyfile(""" + def test_auto_detect_cpus(num_parallel_threads): + assert num_parallel_threads == 10 + """) + + # run pytest with the following cmd args + result = pytester.runpytest("--parallel-threads=auto", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*::test_auto_detect_cpus PARALLEL PASSED*", + ] + ) + + +@pytest.mark.skipif(psutil is None, reason="psutil needs to be installed") +def test_auto_detect_cpus_psutil_cpu_count( + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch +) -> None: + import psutil + + monkeypatch.delattr(psutil.Process, "cpu_affinity", raising=False) + monkeypatch.setattr(psutil, "cpu_count", lambda: 10) + + pytester.makepyfile(""" + def test_auto_detect_cpus(num_parallel_threads): + assert num_parallel_threads == 10 + """) + + # run pytest with the following cmd args + result = pytester.runpytest("--parallel-threads=auto", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*::test_auto_detect_cpus PARALLEL PASSED*", + ] + ) + + +@pytest.mark.skipif( + process_cpu_count is None, reason="process_cpu_count is available in >=3.13" +) +def test_auto_detect_process_cpu_count( + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch +) -> None: + with suppress(ImportError): + import psutil + + monkeypatch.delattr(psutil.Process, "cpu_affinity", raising=False) + monkeypatch.setattr(psutil, "cpu_count", lambda: None) + + monkeypatch.setattr("os.process_cpu_count", lambda: 10) + + pytester.makepyfile(""" + def test_auto_detect_cpus(num_parallel_threads): + assert num_parallel_threads == 10 + """) + + # run pytest with the following cmd args + result = pytester.runpytest("--parallel-threads=auto", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*::test_auto_detect_cpus PARALLEL PASSED*", + ] + ) + + +@pytest.mark.skipif( + sched_getaffinity is None, + reason="sched_getaffinity is available certain platforms only", +) +def test_auto_detect_sched_getaffinity( + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch +) -> None: + with suppress(ImportError): + import psutil + + monkeypatch.delattr(psutil.Process, "cpu_affinity", raising=False) + monkeypatch.setattr(psutil, "cpu_count", lambda: None) + + monkeypatch.setattr("os.process_cpu_count", lambda: None, raising=False) + monkeypatch.setattr("os.sched_getaffinity", lambda pid: list(range(10))) + + pytester.makepyfile(""" + def test_auto_detect_cpus(num_parallel_threads): + assert num_parallel_threads == 10 + """) + + # run pytest with the following cmd args + result = pytester.runpytest("--parallel-threads=auto", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*::test_auto_detect_cpus PARALLEL PASSED*", + ] + ) + + +def test_auto_detect_cpu_count( + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch +) -> None: + with suppress(ImportError): + import psutil + + monkeypatch.delattr(psutil.Process, "cpu_affinity", raising=False) + monkeypatch.setattr(psutil, "cpu_count", lambda: None) + + monkeypatch.setattr("os.process_cpu_count", lambda: None, raising=False) + monkeypatch.setattr("os.sched_getaffinity", lambda pid: None, raising=False) + monkeypatch.setattr("os.cpu_count", lambda: 10) + + pytester.makepyfile(""" + def test_auto_detect_cpus(num_parallel_threads): + assert num_parallel_threads == 10 + """) + + # run pytest with the following cmd args + result = pytester.runpytest("--parallel-threads=auto", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*::test_auto_detect_cpus PARALLEL PASSED*", + ] + ) diff --git a/tests/test_run_parallel.py b/tests/test_run_parallel.py index ea51386..b3c023e 100644 --- a/tests/test_run_parallel.py +++ b/tests/test_run_parallel.py @@ -1,27 +1,4 @@ import os -from contextlib import suppress - -import pytest - -try: - import psutil -except ImportError: - psutil = None - -try: - import hypothesis -except ImportError: - hypothesis = None - -try: - from os import process_cpu_count -except ImportError: - process_cpu_count = None - -try: - from os import sched_getaffinity -except ImportError: - sched_getaffinity = None def test_default_threads(pytester): @@ -321,62 +298,6 @@ def test_single_threaded(num_parallel_threads): ) -def test_thread_comp_fixture(pytester): - """Test that ThreadComparator works as expected.""" - - # create a temporary pytest test module - pytester.makepyfile(""" - import threading - import pytest - - class Counter: - def __init__(self): - self._value = 0 - self._lock = threading.Lock() - - def get_value_and_increment(self): - with self._lock: - value = int(self._value) - self._value += 1 - return value - - def test_value_comparison(num_parallel_threads, thread_comp): - assert num_parallel_threads == 10 - a = 1 - b = [2, 'string', 1.0] - c = {'a': -4, 'b': 'str'} - d = float('nan') - e = float('inf') - f = {'a', 'b', '#'} - thread_comp(a=a, b=b, c=c, d=d, e=e, f=f) - - # Ensure that the comparator can be used again - thread_comp(g=4) - - @pytest.fixture - def counter(num_parallel_threads): - return Counter() - - def test_comparison_fail(thread_comp, counter): - a = 4 - pos = counter.get_value_and_increment() - if pos % 2 == 0: - a = -1 - thread_comp(a=a) - """) - - # run pytest with the following cmd args - result = pytester.runpytest("--parallel-threads=10", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*::test_value_comparison PARALLEL PASSED*", - "*::test_comparison_fail PARALLEL FAILED*", - ] - ) - - def test_iterations_marker_one_thread(pytester): # create a temporary pytest test module pytester.makepyfile(""" @@ -564,294 +485,6 @@ def test_should_skip(): ) -def test_thread_unsafe_marker(pytester): - # create a temporary pytest test module - pytester.makepyfile(""" - import pytest - - @pytest.mark.thread_unsafe - def test_should_run_single(num_parallel_threads): - assert num_parallel_threads == 1 - - @pytest.mark.thread_unsafe(reason='this is thread-unsafe') - def test_should_run_single_2(num_parallel_threads): - assert num_parallel_threads == 1 - """) - - # run pytest with the following cmd args - result = pytester.runpytest("--parallel-threads=10", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*::test_should_run_single PASSED*", - "*::test_should_run_single_2 PASSED *thread-unsafe*: this is thread-unsafe*", - ] - ) - - # check that skipping works too - result = pytester.runpytest( - "--parallel-threads=10", "--skip-thread-unsafe=True", "-v" - ) - - result.stdout.fnmatch_lines( - ["*::test_should_run_single SKIPPED*", "*::test_should_run_single_2 SKIPPED*"] - ) - - -def test_pytest_warns_detection(pytester): - # create a temporary pytest test module - pytester.makepyfile(""" - import pytest - import warnings - import pytest as pyt - import warnings as w - from pytest import warns, deprecated_call - from warnings import catch_warnings - - warns_alias = warns - - def test_single_thread_warns_1(num_parallel_threads): - with pytest.warns(UserWarning): - warnings.warn('example', UserWarning) - assert num_parallel_threads == 1 - - def test_single_thread_warns_2(num_parallel_threads): - with warns(UserWarning): - warnings.warn('example', UserWarning) - assert num_parallel_threads == 1 - - def test_single_thread_warns_3(num_parallel_threads): - with pyt.warns(UserWarning): - warnings.warn('example', UserWarning) - assert num_parallel_threads == 1 - - def test_single_thread_warns_4(num_parallel_threads): - with warns_alias(UserWarning): - warnings.warn('example', UserWarning) - assert num_parallel_threads == 1 - """) - - # run pytest with the following cmd args - result = pytester.runpytest("--parallel-threads=10", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*::test_single_thread_warns_1 PASSED*", - "*::test_single_thread_warns_2 PASSED*", - "*::test_single_thread_warns_3 PASSED*", - "*::test_single_thread_warns_4 PASSED*", - ] - ) - - # check that skipping works too - result = pytester.runpytest( - "--parallel-threads=10", "--skip-thread-unsafe=True", "-v" - ) - - result.stdout.fnmatch_lines( - [ - "*::test_single_thread_warns_1 SKIPPED*", - "*::test_single_thread_warns_2 SKIPPED*", - "*::test_single_thread_warns_3 SKIPPED*", - "*::test_single_thread_warns_4 SKIPPED*", - ] - ) - - -@pytest.mark.skipif(psutil is None, reason="psutil needs to be installed") -def test_auto_detect_cpus_psutil_affinity( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch -) -> None: - import psutil - - monkeypatch.setattr( - psutil.Process, "cpu_affinity", lambda self: list(range(10)), raising=False - ) - - pytester.makepyfile(""" - def test_auto_detect_cpus(num_parallel_threads): - assert num_parallel_threads == 10 - """) - - # run pytest with the following cmd args - result = pytester.runpytest("--parallel-threads=auto", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*::test_auto_detect_cpus PARALLEL PASSED*", - ] - ) - - -@pytest.mark.skipif(psutil is None, reason="psutil needs to be installed") -def test_auto_detect_cpus_psutil_cpu_count( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch -) -> None: - import psutil - - monkeypatch.delattr(psutil.Process, "cpu_affinity", raising=False) - monkeypatch.setattr(psutil, "cpu_count", lambda: 10) - - pytester.makepyfile(""" - def test_auto_detect_cpus(num_parallel_threads): - assert num_parallel_threads == 10 - """) - - # run pytest with the following cmd args - result = pytester.runpytest("--parallel-threads=auto", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*::test_auto_detect_cpus PARALLEL PASSED*", - ] - ) - - -@pytest.mark.skipif( - process_cpu_count is None, reason="process_cpu_count is available in >=3.13" -) -def test_auto_detect_process_cpu_count( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch -) -> None: - with suppress(ImportError): - import psutil - - monkeypatch.delattr(psutil.Process, "cpu_affinity", raising=False) - monkeypatch.setattr(psutil, "cpu_count", lambda: None) - - monkeypatch.setattr("os.process_cpu_count", lambda: 10) - - pytester.makepyfile(""" - def test_auto_detect_cpus(num_parallel_threads): - assert num_parallel_threads == 10 - """) - - # run pytest with the following cmd args - result = pytester.runpytest("--parallel-threads=auto", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*::test_auto_detect_cpus PARALLEL PASSED*", - ] - ) - - -@pytest.mark.skipif( - sched_getaffinity is None, - reason="sched_getaffinity is available certain platforms only", -) -def test_auto_detect_sched_getaffinity( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch -) -> None: - with suppress(ImportError): - import psutil - - monkeypatch.delattr(psutil.Process, "cpu_affinity", raising=False) - monkeypatch.setattr(psutil, "cpu_count", lambda: None) - - monkeypatch.setattr("os.process_cpu_count", lambda: None, raising=False) - monkeypatch.setattr("os.sched_getaffinity", lambda pid: list(range(10))) - - pytester.makepyfile(""" - def test_auto_detect_cpus(num_parallel_threads): - assert num_parallel_threads == 10 - """) - - # run pytest with the following cmd args - result = pytester.runpytest("--parallel-threads=auto", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*::test_auto_detect_cpus PARALLEL PASSED*", - ] - ) - - -def test_auto_detect_cpu_count( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch -) -> None: - with suppress(ImportError): - import psutil - - monkeypatch.delattr(psutil.Process, "cpu_affinity", raising=False) - monkeypatch.setattr(psutil, "cpu_count", lambda: None) - - monkeypatch.setattr("os.process_cpu_count", lambda: None, raising=False) - monkeypatch.setattr("os.sched_getaffinity", lambda pid: None, raising=False) - monkeypatch.setattr("os.cpu_count", lambda: 10) - - pytester.makepyfile(""" - def test_auto_detect_cpus(num_parallel_threads): - assert num_parallel_threads == 10 - """) - - # run pytest with the following cmd args - result = pytester.runpytest("--parallel-threads=auto", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*::test_auto_detect_cpus PARALLEL PASSED*", - ] - ) - - -def test_thread_unsafe_fixtures(pytester): - # create a temporary pytest test module - pytester.makepyfile(""" - import pytest - - @pytest.fixture - def my_unsafe_fixture(): - pass - - @pytest.fixture - def my_unsafe_fixture_2(): - pass - - def test_capsys(capsys, num_parallel_threads): - assert num_parallel_threads == 1 - - def test_monkeypatch(monkeypatch, num_parallel_threads): - assert num_parallel_threads == 1 - - def test_recwarn(recwarn, num_parallel_threads): - assert num_parallel_threads == 1 - - def test_custom_fixture_skip(my_unsafe_fixture, num_parallel_threads): - assert num_parallel_threads == 1 - - def test_custom_fixture_skip_2(my_unsafe_fixture_2, num_parallel_threads): - assert num_parallel_threads == 1 - """) - - pytester.makeini(""" - [pytest] - thread_unsafe_fixtures = - my_unsafe_fixture - my_unsafe_fixture_2 - """) - - # run pytest with the following cmd args - result = pytester.runpytest("--parallel-threads=10", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*::test_capsys PASSED *thread-unsafe*: uses thread-unsafe fixture*", - "*::test_recwarn PASSED *thread-unsafe*: uses thread-unsafe fixture*", - "*::test_custom_fixture_skip PASSED *thread-unsafe*: uses thread-unsafe fixture*", - "*::test_custom_fixture_skip_2 PASSED *thread-unsafe*: uses thread-unsafe fixture*", - ] - ) - - def test_incompatible_test_item(pytester): pytester.makeconftest(""" import inspect @@ -929,140 +562,6 @@ def test_incompatible_item(): assert "warnings" not in result.parseoutcomes().keys() -def test_thread_unsafe_function_attr(pytester): - pytester.makepyfile( - mod_1=""" - def to_skip(): - __thread_safe__ = False - - def not_to_skip(): - __thread_safe__ = True - """ - ) - - pytester.makepyfile( - mod_2=""" - import mod_1 - from mod_1 import not_to_skip - - def some_fn_calls_skip(): - mod_1.to_skip() - - def some_fn_should_not_skip(): - not_to_skip() - - def marked_for_skip(): - pass - """ - ) - - pytester.makepyfile(""" - import mod_2 - from mod_2 import some_fn_calls_skip - - def test_should_be_marked_1(num_parallel_threads): - mod_2.some_fn_calls_skip() - assert num_parallel_threads == 1 - - def test_should_not_be_marked(num_parallel_threads): - mod_2.some_fn_should_not_skip() - assert num_parallel_threads == 10 - - def test_should_be_marked_2(num_parallel_threads): - mod_2.marked_for_skip() - assert num_parallel_threads == 1 - - def test_should_be_marked_3(num_parallel_threads): - some_fn_calls_skip() - assert num_parallel_threads == 1 - """) - - pytester.makeini(""" - [pytest] - thread_unsafe_functions = - mod_2.marked_for_skip - """) - - # run pytest with the following cmd args - orig = os.environ.get("PYTEST_RUN_PARALLEL_VERBOSE", "0") - os.environ["PYTEST_RUN_PARALLEL_VERBOSE"] = "0" - result = pytester.runpytest("--parallel-threads=10", "-v") - - # fnmatch_lines does an assertion internally - result.stdout.fnmatch_lines( - [ - "*Collected 1 items to run in parallel*", - "*::test_should_be_marked_1 PASSED *thread-unsafe*inferred via func.__thread_safe__*", - "*::test_should_not_be_marked PARALLEL PASSED*", - "*::test_should_be_marked_2 PASSED *thread-unsafe*marked_for_skip*", - "*::test_should_be_marked_3 PASSED *thread-unsafe*inferred via func.__thread_safe__*", - ] - ) - - result.stdout.fnmatch_lines( - [ - "*3 tests were not run in parallel because of use of thread-unsafe " - "functionality, to list the tests that were skipped, " - "re-run while setting PYTEST_RUN_PARALLEL_VERBOSE=1 in your " - "shell environment*", - ] - ) - - # re-run with PYTEST_RUN_PARALLEL_VERBOSE=1 - os.environ["PYTEST_RUN_PARALLEL_VERBOSE"] = "1" - result = pytester.runpytest("--parallel-threads=10", "-v") - os.environ["PYTEST_RUN_PARALLEL_VERBOSE"] = orig - - result.stdout.fnmatch_lines( - [ - "*Collected 1 items to run in parallel*", - "*::test_should_be_marked_1 PASSED *thread-unsafe*: calls thread-unsafe function*", - "*::test_should_not_be_marked PARALLEL PASSED*", - "*::test_should_be_marked_2 PASSED*", - "*::test_should_be_marked_3 PASSED*", - "*::test_should_be_marked_1*", - "*::test_should_be_marked_2*", - "*::test_should_be_marked_3*", - ] - ) - - -@pytest.mark.skipif(hypothesis is None, reason="hypothesis needs to be installed") -def test_detect_hypothesis(pytester): - pytester.makepyfile(""" - from hypothesis import given, strategies as st, settings, HealthCheck - - @given(a=st.none()) - @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) - def test_uses_hypothesis(a, num_parallel_threads): - assert num_parallel_threads == 1 - """) - result = pytester.runpytest("--parallel-threads=10", "-v") - result.stdout.fnmatch_lines( - [ - "*::test_uses_hypothesis PASSED*", - ] - ) - - -def test_detect_unittest_mock(pytester): - pytester.makepyfile(""" - import sys - from unittest import mock - - @mock.patch("sys.platform", "VAX") - def test_uses_mock(num_parallel_threads): - assert sys.platform == "VAX" - assert num_parallel_threads == 1 - """) - result = pytester.runpytest("--parallel-threads=10", "-v") - result.stdout.fnmatch_lines( - [ - r"*::test_uses_mock PASSED*" r"calls thread-unsafe function: mock.patch*", - ] - ) - - def test_all_tests_in_parallel(pytester): pytester.makepyfile(""" def test_parallel_1(num_parallel_threads): @@ -1090,41 +589,3 @@ def test_parallel_2(num_parallel_threads): "*All tests were run in parallel! 🎉*", ] ) - - -def test_recurse_assign(pytester): - pytester.makepyfile(""" - import pytest - - def test_function_recurse_on_assign(num_parallel_threads): - w = pytest.warns(UserWarning) - assert num_parallel_threads == 1 - """) - - result = pytester.runpytest("--parallel-threads=10", "-v") - result.stdout.fnmatch_lines( - [ - "*::test_function_recurse_on_assign PASSED*", - ] - ) - - -def test_failed_thread_unsafe(pytester): - pytester.makepyfile(""" - import pytest - - @pytest.mark.thread_unsafe - def test1(): - assert False - """) - - result = pytester.runpytest("--parallel-threads=10", "-v") - assert result.ret == 1 - print(result.stdout) - result.stdout.fnmatch_lines( - [ - "*::test1 FAILED *thread-unsafe*: uses thread_unsafe marker*", - "* FAILURES *", - "*1 failed*", - ] - ) diff --git a/tests/test_thread_comparator.py b/tests/test_thread_comparator.py new file mode 100644 index 0000000..8c9b790 --- /dev/null +++ b/tests/test_thread_comparator.py @@ -0,0 +1,54 @@ +def test_thread_comp_fixture(pytester): + """Test that ThreadComparator works as expected.""" + + # create a temporary pytest test module + pytester.makepyfile(""" + import threading + import pytest + + class Counter: + def __init__(self): + self._value = 0 + self._lock = threading.Lock() + + def get_value_and_increment(self): + with self._lock: + value = int(self._value) + self._value += 1 + return value + + def test_value_comparison(num_parallel_threads, thread_comp): + assert num_parallel_threads == 10 + a = 1 + b = [2, 'string', 1.0] + c = {'a': -4, 'b': 'str'} + d = float('nan') + e = float('inf') + f = {'a', 'b', '#'} + thread_comp(a=a, b=b, c=c, d=d, e=e, f=f) + + # Ensure that the comparator can be used again + thread_comp(g=4) + + @pytest.fixture + def counter(num_parallel_threads): + return Counter() + + def test_comparison_fail(thread_comp, counter): + a = 4 + pos = counter.get_value_and_increment() + if pos % 2 == 0: + a = -1 + thread_comp(a=a) + """) + + # run pytest with the following cmd args + result = pytester.runpytest("--parallel-threads=10", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*::test_value_comparison PARALLEL PASSED*", + "*::test_comparison_fail PARALLEL FAILED*", + ] + ) diff --git a/tests/test_thread_unsafe_detection.py b/tests/test_thread_unsafe_detection.py new file mode 100644 index 0000000..54e2a05 --- /dev/null +++ b/tests/test_thread_unsafe_detection.py @@ -0,0 +1,326 @@ +import os + +import pytest + +try: + import hypothesis +except ImportError: + hypothesis = None + + +def test_thread_unsafe_marker(pytester): + # create a temporary pytest test module + pytester.makepyfile(""" + import pytest + + @pytest.mark.thread_unsafe + def test_should_run_single(num_parallel_threads): + assert num_parallel_threads == 1 + + @pytest.mark.thread_unsafe(reason='this is thread-unsafe') + def test_should_run_single_2(num_parallel_threads): + assert num_parallel_threads == 1 + """) + + # run pytest with the following cmd args + result = pytester.runpytest("--parallel-threads=10", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*::test_should_run_single PASSED*", + "*::test_should_run_single_2 PASSED *thread-unsafe*: this is thread-unsafe*", + ] + ) + + # check that skipping works too + result = pytester.runpytest( + "--parallel-threads=10", "--skip-thread-unsafe=True", "-v" + ) + + result.stdout.fnmatch_lines( + ["*::test_should_run_single SKIPPED*", "*::test_should_run_single_2 SKIPPED*"] + ) + + +def test_pytest_warns_detection(pytester): + # create a temporary pytest test module + pytester.makepyfile(""" + import pytest + import warnings + import pytest as pyt + import warnings as w + from pytest import warns, deprecated_call + from warnings import catch_warnings + + warns_alias = warns + + def test_single_thread_warns_1(num_parallel_threads): + with pytest.warns(UserWarning): + warnings.warn('example', UserWarning) + assert num_parallel_threads == 1 + + def test_single_thread_warns_2(num_parallel_threads): + with warns(UserWarning): + warnings.warn('example', UserWarning) + assert num_parallel_threads == 1 + + def test_single_thread_warns_3(num_parallel_threads): + with pyt.warns(UserWarning): + warnings.warn('example', UserWarning) + assert num_parallel_threads == 1 + + def test_single_thread_warns_4(num_parallel_threads): + with warns_alias(UserWarning): + warnings.warn('example', UserWarning) + assert num_parallel_threads == 1 + """) + + # run pytest with the following cmd args + result = pytester.runpytest("--parallel-threads=10", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*::test_single_thread_warns_1 PASSED*", + "*::test_single_thread_warns_2 PASSED*", + "*::test_single_thread_warns_3 PASSED*", + "*::test_single_thread_warns_4 PASSED*", + ] + ) + + # check that skipping works too + result = pytester.runpytest( + "--parallel-threads=10", "--skip-thread-unsafe=True", "-v" + ) + + result.stdout.fnmatch_lines( + [ + "*::test_single_thread_warns_1 SKIPPED*", + "*::test_single_thread_warns_2 SKIPPED*", + "*::test_single_thread_warns_3 SKIPPED*", + "*::test_single_thread_warns_4 SKIPPED*", + ] + ) + + +def test_thread_unsafe_fixtures(pytester): + # create a temporary pytest test module + pytester.makepyfile(""" + import pytest + + @pytest.fixture + def my_unsafe_fixture(): + pass + + @pytest.fixture + def my_unsafe_fixture_2(): + pass + + def test_capsys(capsys, num_parallel_threads): + assert num_parallel_threads == 1 + + def test_monkeypatch(monkeypatch, num_parallel_threads): + assert num_parallel_threads == 1 + + def test_recwarn(recwarn, num_parallel_threads): + assert num_parallel_threads == 1 + + def test_custom_fixture_skip(my_unsafe_fixture, num_parallel_threads): + assert num_parallel_threads == 1 + + def test_custom_fixture_skip_2(my_unsafe_fixture_2, num_parallel_threads): + assert num_parallel_threads == 1 + """) + + pytester.makeini(""" + [pytest] + thread_unsafe_fixtures = + my_unsafe_fixture + my_unsafe_fixture_2 + """) + + # run pytest with the following cmd args + result = pytester.runpytest("--parallel-threads=10", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*::test_capsys PASSED *thread-unsafe*: uses thread-unsafe fixture*", + "*::test_recwarn PASSED *thread-unsafe*: uses thread-unsafe fixture*", + "*::test_custom_fixture_skip PASSED *thread-unsafe*: uses thread-unsafe fixture*", + "*::test_custom_fixture_skip_2 PASSED *thread-unsafe*: uses thread-unsafe fixture*", + ] + ) + + +def test_thread_unsafe_function_attr(pytester): + pytester.makepyfile( + mod_1=""" + def to_skip(): + __thread_safe__ = False + + def not_to_skip(): + __thread_safe__ = True + """ + ) + + pytester.makepyfile( + mod_2=""" + import mod_1 + from mod_1 import not_to_skip + + def some_fn_calls_skip(): + mod_1.to_skip() + + def some_fn_should_not_skip(): + not_to_skip() + + def marked_for_skip(): + pass + """ + ) + + pytester.makepyfile(""" + import mod_2 + from mod_2 import some_fn_calls_skip + + def test_should_be_marked_1(num_parallel_threads): + mod_2.some_fn_calls_skip() + assert num_parallel_threads == 1 + + def test_should_not_be_marked(num_parallel_threads): + mod_2.some_fn_should_not_skip() + assert num_parallel_threads == 10 + + def test_should_be_marked_2(num_parallel_threads): + mod_2.marked_for_skip() + assert num_parallel_threads == 1 + + def test_should_be_marked_3(num_parallel_threads): + some_fn_calls_skip() + assert num_parallel_threads == 1 + """) + + pytester.makeini(""" + [pytest] + thread_unsafe_functions = + mod_2.marked_for_skip + """) + + # run pytest with the following cmd args + orig = os.environ.get("PYTEST_RUN_PARALLEL_VERBOSE", "0") + os.environ["PYTEST_RUN_PARALLEL_VERBOSE"] = "0" + result = pytester.runpytest("--parallel-threads=10", "-v") + + # fnmatch_lines does an assertion internally + result.stdout.fnmatch_lines( + [ + "*Collected 1 items to run in parallel*", + "*::test_should_be_marked_1 PASSED *thread-unsafe*inferred via func.__thread_safe__*", + "*::test_should_not_be_marked PARALLEL PASSED*", + "*::test_should_be_marked_2 PASSED *thread-unsafe*marked_for_skip*", + "*::test_should_be_marked_3 PASSED *thread-unsafe*inferred via func.__thread_safe__*", + ] + ) + + result.stdout.fnmatch_lines( + [ + "*3 tests were not run in parallel because of use of thread-unsafe " + "functionality, to list the tests that were skipped, " + "re-run while setting PYTEST_RUN_PARALLEL_VERBOSE=1 in your " + "shell environment*", + ] + ) + + # re-run with PYTEST_RUN_PARALLEL_VERBOSE=1 + os.environ["PYTEST_RUN_PARALLEL_VERBOSE"] = "1" + result = pytester.runpytest("--parallel-threads=10", "-v") + os.environ["PYTEST_RUN_PARALLEL_VERBOSE"] = orig + + result.stdout.fnmatch_lines( + [ + "*Collected 1 items to run in parallel*", + "*::test_should_be_marked_1 PASSED *thread-unsafe*: calls thread-unsafe function*", + "*::test_should_not_be_marked PARALLEL PASSED*", + "*::test_should_be_marked_2 PASSED*", + "*::test_should_be_marked_3 PASSED*", + "*::test_should_be_marked_1*", + "*::test_should_be_marked_2*", + "*::test_should_be_marked_3*", + ] + ) + + +@pytest.mark.skipif(hypothesis is None, reason="hypothesis needs to be installed") +def test_detect_hypothesis(pytester): + pytester.makepyfile(""" + from hypothesis import given, strategies as st, settings, HealthCheck + + @given(a=st.none()) + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_uses_hypothesis(a, num_parallel_threads): + assert num_parallel_threads == 1 + """) + result = pytester.runpytest("--parallel-threads=10", "-v") + result.stdout.fnmatch_lines( + [ + "*::test_uses_hypothesis PASSED*", + ] + ) + + +def test_detect_unittest_mock(pytester): + pytester.makepyfile(""" + import sys + from unittest import mock + + @mock.patch("sys.platform", "VAX") + def test_uses_mock(num_parallel_threads): + assert sys.platform == "VAX" + assert num_parallel_threads == 1 + """) + result = pytester.runpytest("--parallel-threads=10", "-v") + result.stdout.fnmatch_lines( + [ + r"*::test_uses_mock PASSED*" r"calls thread-unsafe function: mock.patch*", + ] + ) + + +def test_recurse_assign(pytester): + pytester.makepyfile(""" + import pytest + + def test_function_recurse_on_assign(num_parallel_threads): + w = pytest.warns(UserWarning) + assert num_parallel_threads == 1 + """) + + result = pytester.runpytest("--parallel-threads=10", "-v") + result.stdout.fnmatch_lines( + [ + "*::test_function_recurse_on_assign PASSED*", + ] + ) + + +def test_failed_thread_unsafe(pytester): + pytester.makepyfile(""" + import pytest + + @pytest.mark.thread_unsafe + def test1(): + assert False + """) + + result = pytester.runpytest("--parallel-threads=10", "-v") + assert result.ret == 1 + print(result.stdout) + result.stdout.fnmatch_lines( + [ + "*::test1 FAILED *thread-unsafe*: uses thread_unsafe marker*", + "* FAILURES *", + "*1 failed*", + ] + ) diff --git a/uv.lock b/uv.lock index 52cf219..d41b383 100644 --- a/uv.lock +++ b/uv.lock @@ -295,7 +295,7 @@ wheels = [ [[package]] name = "pytest-run-parallel" -version = "0.4.4.dev0" +version = "0.4.5.dev0" source = { editable = "." } dependencies = [ { name = "pytest" }, From 0cc5de1f4454de6beab7f9ef302faf0d1cc3b254 Mon Sep 17 00:00:00 2001 From: Lysandros Nikolaou Date: Wed, 11 Jun 2025 17:43:48 +0200 Subject: [PATCH 2/2] Move wrapper into plugin.py --- src/pytest_run_parallel/plugin.py | 73 +++++++++++++++++++++- src/pytest_run_parallel/thread_wrapper.py | 74 ----------------------- 2 files changed, 72 insertions(+), 75 deletions(-) delete mode 100644 src/pytest_run_parallel/thread_wrapper.py diff --git a/src/pytest_run_parallel/plugin.py b/src/pytest_run_parallel/plugin.py index b9d97dc..5924f6d 100644 --- a/src/pytest_run_parallel/plugin.py +++ b/src/pytest_run_parallel/plugin.py @@ -1,6 +1,10 @@ +import functools import os +import sys +import threading import warnings +import _pytest.outcomes import pytest from pytest_run_parallel.thread_comparator import ThreadComparator @@ -8,7 +12,6 @@ THREAD_UNSAFE_FIXTURES, identify_thread_unsafe_nodes, ) -from pytest_run_parallel.thread_wrapper import wrap_function_parallel from pytest_run_parallel.utils import get_configured_num_workers, get_num_workers @@ -69,6 +72,74 @@ def pytest_configure(config): ) +def wrap_function_parallel(fn, n_workers, n_iterations): + @functools.wraps(fn) + def inner(*args, **kwargs): + errors = [] + skip = None + failed = None + barrier = threading.Barrier(n_workers) + original_switch = sys.getswitchinterval() + new_switch = 1e-6 + for _ in range(3): + try: + sys.setswitchinterval(new_switch) + break + except ValueError: + new_switch *= 10 + else: + sys.setswitchinterval(original_switch) + + try: + + def closure(*args, **kwargs): + for _ in range(n_iterations): + barrier.wait() + try: + fn(*args, **kwargs) + except Warning: + pass + except Exception as e: + errors.append(e) + except _pytest.outcomes.Skipped as s: + nonlocal skip + skip = s.msg + except _pytest.outcomes.Failed as f: + nonlocal failed + failed = f + + workers = [] + for _ in range(0, n_workers): + worker_kwargs = kwargs + workers.append( + threading.Thread(target=closure, args=args, kwargs=worker_kwargs) + ) + + num_completed = 0 + try: + for worker in workers: + worker.start() + num_completed += 1 + finally: + if num_completed < len(workers): + barrier.abort() + + for worker in workers: + worker.join() + + finally: + sys.setswitchinterval(original_switch) + + if skip is not None: + pytest.skip(skip) + elif failed is not None: + raise failed + elif errors: + raise errors[0] + + return inner + + @pytest.hookimpl(trylast=True) def pytest_itemcollected(item): n_workers = get_num_workers(item.config, item) diff --git a/src/pytest_run_parallel/thread_wrapper.py b/src/pytest_run_parallel/thread_wrapper.py deleted file mode 100644 index 525bdfb..0000000 --- a/src/pytest_run_parallel/thread_wrapper.py +++ /dev/null @@ -1,74 +0,0 @@ -import functools -import sys -import threading - -import _pytest.outcomes -import pytest - - -def wrap_function_parallel(fn, n_workers, n_iterations): - @functools.wraps(fn) - def inner(*args, **kwargs): - errors = [] - skip = None - failed = None - barrier = threading.Barrier(n_workers) - original_switch = sys.getswitchinterval() - new_switch = 1e-6 - for _ in range(3): - try: - sys.setswitchinterval(new_switch) - break - except ValueError: - new_switch *= 10 - else: - sys.setswitchinterval(original_switch) - - try: - - def closure(*args, **kwargs): - for _ in range(n_iterations): - barrier.wait() - try: - fn(*args, **kwargs) - except Warning: - pass - except Exception as e: - errors.append(e) - except _pytest.outcomes.Skipped as s: - nonlocal skip - skip = s.msg - except _pytest.outcomes.Failed as f: - nonlocal failed - failed = f - - workers = [] - for _ in range(0, n_workers): - worker_kwargs = kwargs - workers.append( - threading.Thread(target=closure, args=args, kwargs=worker_kwargs) - ) - - num_completed = 0 - try: - for worker in workers: - worker.start() - num_completed += 1 - finally: - if num_completed < len(workers): - barrier.abort() - - for worker in workers: - worker.join() - - finally: - sys.setswitchinterval(original_switch) - - if skip is not None: - pytest.skip(skip) - elif failed is not None: - raise failed - elif errors: - raise errors[0] - - return inner