Skip to content

Commit c961ebf

Browse files
cleanup patches after test in pytest plugin (#1148) (#1164)
* cleanup patches after test in pytest plugin (#1148) * pin dependency in docs causing problems with breaking change * simplify Generator annotations; use mapping proxy to ensure immutable module constant
1 parent 15cbb83 commit c961ebf

File tree

4 files changed

+123
-130
lines changed

4 files changed

+123
-130
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ See [0Ver](https://0ver.org/).
1010
### Bugfixes
1111

1212
- Fixes `__slots__` not being set properly in containers and their base classes
13+
- Fixes patching of containers in pytest plugin not undone after each test
1314

1415
## 0.17.0
1516

docs/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ hypothesis==6.30.1
2020
# TODO: Remove this lock when we found and fix the route case.
2121
# See: https://github.com/typlog/sphinx-typlog-theme/issues/22
2222
jinja2==3.0.3
23+
24+
# TODO: Remove this lock when this dependency issue is resolved.
25+
# See: https://github.com/miyakogi/m2r/issues/66
26+
mistune<2.0.0

returns/contrib/pytest/plugin.py

Lines changed: 112 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
11
import inspect
22
import sys
3-
from contextlib import contextmanager
3+
from contextlib import ExitStack, contextmanager
44
from functools import partial, wraps
5-
from types import FrameType
5+
from types import FrameType, MappingProxyType
66
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, TypeVar, Union
7+
from unittest import mock
78

89
import pytest
910
from typing_extensions import Final, final
1011

1112
if TYPE_CHECKING:
1213
from returns.interfaces.specific.result import ResultLikeN
1314

14-
_ERROR_HANDLERS: Final = (
15-
'lash',
16-
)
17-
_ERRORS_COPIERS: Final = (
18-
'map',
19-
'alt',
20-
)
21-
2215
# We keep track of errors handled by keeping a mapping of <object id>: object.
2316
# If an error is handled, it is in the mapping.
2417
# If it isn't in the mapping, the error is not handled.
@@ -28,7 +21,7 @@
2821
# Also, the object itself cannot be (in) the key because
2922
# (1) we cannot always assume hashability and
3023
# (2) we need to track the object identity, not its value
31-
_ERRORS_HANDLED: Final[Dict[int, Any]] = {} # noqa: WPS407
24+
_ErrorsHandled = Dict[int, Any]
3225

3326
_FunctionType = TypeVar('_FunctionType', bound=Callable)
3427
_ReturnsResultType = TypeVar(
@@ -41,7 +34,11 @@
4134
class ReturnsAsserts(object):
4235
"""Class with helpers assertions to check containers."""
4336

44-
__slots__ = ()
37+
__slots__ = ('_errors_handled', )
38+
39+
def __init__(self, errors_handled: _ErrorsHandled) -> None:
40+
"""Constructor for this type."""
41+
self._errors_handled = errors_handled
4542

4643
@staticmethod # noqa: WPS602
4744
def assert_equal( # noqa: WPS602
@@ -55,10 +52,9 @@ def assert_equal( # noqa: WPS602
5552
from returns.primitives.asserts import assert_equal
5653
assert_equal(first, second, deps=deps, backend=backend)
5754

58-
@staticmethod # noqa: WPS602
59-
def is_error_handled(container) -> bool: # noqa: WPS602
55+
def is_error_handled(self, container) -> bool:
6056
"""Ensures that container has its error handled in the end."""
61-
return id(container) in _ERRORS_HANDLED
57+
return id(container) in self._errors_handled
6258

6359
@staticmethod # noqa: WPS602
6460
@contextmanager
@@ -86,59 +82,6 @@ def assert_trace( # noqa: WPS602
8682
sys.settrace(old_tracer)
8783

8884

89-
@pytest.fixture(scope='session')
90-
def returns(_patch_containers) -> ReturnsAsserts:
91-
"""Returns our own class with helpers assertions to check containers."""
92-
return ReturnsAsserts()
93-
94-
95-
@pytest.fixture(autouse=True)
96-
def _clear_errors_handled():
97-
"""Ensures the 'errors handled' registry doesn't leak memory."""
98-
yield
99-
_ERRORS_HANDLED.clear()
100-
101-
102-
def pytest_configure(config) -> None:
103-
"""
104-
Hook to be executed on import.
105-
106-
We use it define custom markers.
107-
"""
108-
config.addinivalue_line(
109-
'markers',
110-
(
111-
'returns_lawful: all tests under `check_all_laws` ' +
112-
'is marked this way, ' +
113-
'use `-m "not returns_lawful"` to skip them.'
114-
),
115-
)
116-
117-
118-
@pytest.fixture(scope='session')
119-
def _patch_containers() -> None:
120-
"""
121-
Fixture to add test specifics into our containers.
122-
123-
Currently we inject:
124-
125-
- Error handling state, this is required to test that ``Result``-based
126-
containers do handle errors
127-
128-
Even more things to come!
129-
"""
130-
_patch_error_handling(_ERROR_HANDLERS, _PatchedContainer.error_handler)
131-
_patch_error_handling(_ERRORS_COPIERS, _PatchedContainer.copy_handler)
132-
133-
134-
def _patch_error_handling(methods, patch_handler) -> None:
135-
for container in _PatchedContainer.containers_to_patch():
136-
for method in methods:
137-
original = getattr(container, method, None)
138-
if original:
139-
setattr(container, method, patch_handler(original))
140-
141-
14285
def _trace_function(
14386
trace_type: _ReturnsResultType,
14487
function_to_search: _FunctionType,
@@ -166,65 +109,107 @@ def _trace_function(
166109
raise _DesiredFunctionFound()
167110

168111

169-
@final
170-
class _PatchedContainer(object):
171-
"""Class with helper methods to patched containers."""
172-
173-
__slots__ = ()
174-
175-
@classmethod
176-
def containers_to_patch(cls) -> tuple:
177-
"""We need this method so coverage will work correctly."""
178-
from returns.context import (
179-
RequiresContextFutureResult,
180-
RequiresContextIOResult,
181-
RequiresContextResult,
182-
)
183-
from returns.future import FutureResult
184-
from returns.io import IOFailure, IOSuccess
185-
from returns.result import Failure, Success
186-
187-
return (
188-
Success,
189-
Failure,
190-
IOSuccess,
191-
IOFailure,
192-
RequiresContextResult,
193-
RequiresContextIOResult,
194-
RequiresContextFutureResult,
195-
FutureResult,
196-
)
112+
class _DesiredFunctionFound(BaseException): # noqa: WPS418
113+
"""Exception to raise when expected function is found."""
197114

198-
@classmethod
199-
def error_handler(cls, original):
200-
if inspect.iscoroutinefunction(original):
201-
async def factory(self, *args, **kwargs):
202-
original_result = await original(self, *args, **kwargs)
203-
_ERRORS_HANDLED[id(original_result)] = original_result
204-
return original_result
205-
else:
206-
def factory(self, *args, **kwargs):
207-
original_result = original(self, *args, **kwargs)
208-
_ERRORS_HANDLED[id(original_result)] = original_result
209-
return original_result
210-
return wraps(original)(factory)
211-
212-
@classmethod
213-
def copy_handler(cls, original):
214-
if inspect.iscoroutinefunction(original):
215-
async def factory(self, *args, **kwargs):
216-
original_result = await original(self, *args, **kwargs)
217-
if id(self) in _ERRORS_HANDLED:
218-
_ERRORS_HANDLED[id(original_result)] = original_result
219-
return original_result
220-
else:
221-
def factory(self, *args, **kwargs):
222-
original_result = original(self, *args, **kwargs)
223-
if id(self) in _ERRORS_HANDLED:
224-
_ERRORS_HANDLED[id(original_result)] = original_result
225-
return original_result
226-
return wraps(original)(factory)
227115

116+
def pytest_configure(config) -> None:
117+
"""
118+
Hook to be executed on import.
228119
229-
class _DesiredFunctionFound(BaseException): # noqa: WPS418
230-
"""Exception to raise when expected function is found."""
120+
We use it define custom markers.
121+
"""
122+
config.addinivalue_line(
123+
'markers',
124+
(
125+
'returns_lawful: all tests under `check_all_laws` ' +
126+
'is marked this way, ' +
127+
'use `-m "not returns_lawful"` to skip them.'
128+
),
129+
)
130+
131+
132+
@pytest.fixture()
133+
def returns() -> Iterator[ReturnsAsserts]:
134+
"""Returns class with helpers assertions to check containers."""
135+
with _spy_error_handling() as errors_handled:
136+
yield ReturnsAsserts(errors_handled)
137+
138+
139+
@contextmanager
140+
def _spy_error_handling() -> Iterator[_ErrorsHandled]:
141+
"""Track error handling of containers."""
142+
errs: _ErrorsHandled = {}
143+
with ExitStack() as cleanup:
144+
for container in _containers_to_patch():
145+
for method, patch in _ERROR_HANDLING_PATCHERS.items():
146+
cleanup.enter_context(mock.patch.object(
147+
container,
148+
method,
149+
patch(getattr(container, method), errs=errs),
150+
))
151+
yield errs
152+
153+
154+
# delayed imports are needed to prevent messing up coverage
155+
def _containers_to_patch() -> tuple:
156+
from returns.context import (
157+
RequiresContextFutureResult,
158+
RequiresContextIOResult,
159+
RequiresContextResult,
160+
)
161+
from returns.future import FutureResult
162+
from returns.io import IOFailure, IOSuccess
163+
from returns.result import Failure, Success
164+
165+
return (
166+
Success,
167+
Failure,
168+
IOSuccess,
169+
IOFailure,
170+
RequiresContextResult,
171+
RequiresContextIOResult,
172+
RequiresContextFutureResult,
173+
FutureResult,
174+
)
175+
176+
177+
def _patched_error_handler(
178+
original: _FunctionType, errs: _ErrorsHandled,
179+
) -> _FunctionType:
180+
if inspect.iscoroutinefunction(original):
181+
async def wrapper(self, *args, **kwargs):
182+
original_result = await original(self, *args, **kwargs)
183+
errs[id(original_result)] = original_result
184+
return original_result
185+
else:
186+
def wrapper(self, *args, **kwargs):
187+
original_result = original(self, *args, **kwargs)
188+
errs[id(original_result)] = original_result
189+
return original_result
190+
return wraps(original)(wrapper) # type: ignore
191+
192+
193+
def _patched_error_copier(
194+
original: _FunctionType, errs: _ErrorsHandled,
195+
) -> _FunctionType:
196+
if inspect.iscoroutinefunction(original):
197+
async def wrapper(self, *args, **kwargs):
198+
original_result = await original(self, *args, **kwargs)
199+
if id(self) in errs:
200+
errs[id(original_result)] = original_result
201+
return original_result
202+
else:
203+
def wrapper(self, *args, **kwargs):
204+
original_result = original(self, *args, **kwargs)
205+
if id(self) in errs:
206+
errs[id(original_result)] = original_result
207+
return original_result
208+
return wraps(original)(wrapper) # type: ignore
209+
210+
211+
_ERROR_HANDLING_PATCHERS: Final = MappingProxyType({
212+
'lash': _patched_error_handler,
213+
'map': _patched_error_copier,
214+
'alt': _patched_error_copier,
215+
})

tests/test_contrib/test_pytest/test_plugin_error_handler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
RequiresContextResult,
77
)
88
from returns.contrib.pytest import ReturnsAsserts
9-
from returns.contrib.pytest.plugin import _ERRORS_HANDLED
109
from returns.functions import identity
1110
from returns.future import FutureResult
1211
from returns.io import IOFailure, IOSuccess
@@ -42,13 +41,15 @@ def _under_test(
4241
])
4342
def test_error_handled(returns: ReturnsAsserts, container, kwargs):
4443
"""Demo on how to use ``pytest`` helpers to work with error handling."""
45-
assert not _ERRORS_HANDLED
44+
assert not returns._errors_handled # noqa: WPS437
4645
error_handled = _under_test(container, **kwargs)
4746

4847
assert returns.is_error_handled(error_handled)
4948
assert returns.is_error_handled(error_handled.map(identity))
5049
assert returns.is_error_handled(error_handled.alt(identity))
5150

51+
assert returns._errors_handled # noqa: WPS437
52+
5253

5354
@pytest.mark.parametrize('container', [
5455
Success(1),
@@ -64,14 +65,16 @@ def test_error_handled(returns: ReturnsAsserts, container, kwargs):
6465
])
6566
def test_error_not_handled(returns: ReturnsAsserts, container):
6667
"""Demo on how to use ``pytest`` helpers to work with error handling."""
67-
assert not _ERRORS_HANDLED
68+
assert not returns._errors_handled # noqa: WPS437
6869
error_handled = _under_test(container)
6970

7071
assert not returns.is_error_handled(container)
7172
assert not returns.is_error_handled(error_handled)
7273
assert not returns.is_error_handled(error_handled.map(identity))
7374
assert not returns.is_error_handled(error_handled.alt(identity))
7475

76+
assert not returns._errors_handled # noqa: WPS437
77+
7578

7679
@pytest.mark.anyio()
7780
@pytest.mark.parametrize('container', [

0 commit comments

Comments
 (0)