1
1
import inspect
2
2
import sys
3
- from contextlib import contextmanager
3
+ from contextlib import ExitStack , contextmanager
4
4
from functools import partial , wraps
5
- from types import FrameType
5
+ from types import FrameType , MappingProxyType
6
6
from typing import TYPE_CHECKING , Any , Callable , Dict , Iterator , TypeVar , Union
7
+ from unittest import mock
7
8
8
9
import pytest
9
10
from typing_extensions import Final , final
10
11
11
12
if TYPE_CHECKING :
12
13
from returns .interfaces .specific .result import ResultLikeN
13
14
14
- _ERROR_HANDLERS : Final = (
15
- 'lash' ,
16
- )
17
- _ERRORS_COPIERS : Final = (
18
- 'map' ,
19
- 'alt' ,
20
- )
21
-
22
15
# We keep track of errors handled by keeping a mapping of <object id>: object.
23
16
# If an error is handled, it is in the mapping.
24
17
# If it isn't in the mapping, the error is not handled.
28
21
# Also, the object itself cannot be (in) the key because
29
22
# (1) we cannot always assume hashability and
30
23
# (2) we need to track the object identity, not its value
31
- _ERRORS_HANDLED : Final [ Dict [int , Any ]] = {} # noqa: WPS407
24
+ _ErrorsHandled = Dict [int , Any ]
32
25
33
26
_FunctionType = TypeVar ('_FunctionType' , bound = Callable )
34
27
_ReturnsResultType = TypeVar (
41
34
class ReturnsAsserts (object ):
42
35
"""Class with helpers assertions to check containers."""
43
36
44
- __slots__ = ()
37
+ __slots__ = ('_errors_handled' , )
38
+
39
+ def __init__ (self , errors_handled : _ErrorsHandled ) -> None :
40
+ """Constructor for this type."""
41
+ self ._errors_handled = errors_handled
45
42
46
43
@staticmethod # noqa: WPS602
47
44
def assert_equal ( # noqa: WPS602
@@ -55,10 +52,9 @@ def assert_equal( # noqa: WPS602
55
52
from returns .primitives .asserts import assert_equal
56
53
assert_equal (first , second , deps = deps , backend = backend )
57
54
58
- @staticmethod # noqa: WPS602
59
- def is_error_handled (container ) -> bool : # noqa: WPS602
55
+ def is_error_handled (self , container ) -> bool :
60
56
"""Ensures that container has its error handled in the end."""
61
- return id (container ) in _ERRORS_HANDLED
57
+ return id (container ) in self . _errors_handled
62
58
63
59
@staticmethod # noqa: WPS602
64
60
@contextmanager
@@ -86,59 +82,6 @@ def assert_trace( # noqa: WPS602
86
82
sys .settrace (old_tracer )
87
83
88
84
89
- @pytest .fixture (scope = 'session' )
90
- def returns (_patch_containers ) -> ReturnsAsserts :
91
- """Returns our own class with helpers assertions to check containers."""
92
- return ReturnsAsserts ()
93
-
94
-
95
- @pytest .fixture (autouse = True )
96
- def _clear_errors_handled ():
97
- """Ensures the 'errors handled' registry doesn't leak memory."""
98
- yield
99
- _ERRORS_HANDLED .clear ()
100
-
101
-
102
- def pytest_configure (config ) -> None :
103
- """
104
- Hook to be executed on import.
105
-
106
- We use it define custom markers.
107
- """
108
- config .addinivalue_line (
109
- 'markers' ,
110
- (
111
- 'returns_lawful: all tests under `check_all_laws` ' +
112
- 'is marked this way, ' +
113
- 'use `-m "not returns_lawful"` to skip them.'
114
- ),
115
- )
116
-
117
-
118
- @pytest .fixture (scope = 'session' )
119
- def _patch_containers () -> None :
120
- """
121
- Fixture to add test specifics into our containers.
122
-
123
- Currently we inject:
124
-
125
- - Error handling state, this is required to test that ``Result``-based
126
- containers do handle errors
127
-
128
- Even more things to come!
129
- """
130
- _patch_error_handling (_ERROR_HANDLERS , _PatchedContainer .error_handler )
131
- _patch_error_handling (_ERRORS_COPIERS , _PatchedContainer .copy_handler )
132
-
133
-
134
- def _patch_error_handling (methods , patch_handler ) -> None :
135
- for container in _PatchedContainer .containers_to_patch ():
136
- for method in methods :
137
- original = getattr (container , method , None )
138
- if original :
139
- setattr (container , method , patch_handler (original ))
140
-
141
-
142
85
def _trace_function (
143
86
trace_type : _ReturnsResultType ,
144
87
function_to_search : _FunctionType ,
@@ -166,65 +109,107 @@ def _trace_function(
166
109
raise _DesiredFunctionFound ()
167
110
168
111
169
- @final
170
- class _PatchedContainer (object ):
171
- """Class with helper methods to patched containers."""
172
-
173
- __slots__ = ()
174
-
175
- @classmethod
176
- def containers_to_patch (cls ) -> tuple :
177
- """We need this method so coverage will work correctly."""
178
- from returns .context import (
179
- RequiresContextFutureResult ,
180
- RequiresContextIOResult ,
181
- RequiresContextResult ,
182
- )
183
- from returns .future import FutureResult
184
- from returns .io import IOFailure , IOSuccess
185
- from returns .result import Failure , Success
186
-
187
- return (
188
- Success ,
189
- Failure ,
190
- IOSuccess ,
191
- IOFailure ,
192
- RequiresContextResult ,
193
- RequiresContextIOResult ,
194
- RequiresContextFutureResult ,
195
- FutureResult ,
196
- )
112
+ class _DesiredFunctionFound (BaseException ): # noqa: WPS418
113
+ """Exception to raise when expected function is found."""
197
114
198
- @classmethod
199
- def error_handler (cls , original ):
200
- if inspect .iscoroutinefunction (original ):
201
- async def factory (self , * args , ** kwargs ):
202
- original_result = await original (self , * args , ** kwargs )
203
- _ERRORS_HANDLED [id (original_result )] = original_result
204
- return original_result
205
- else :
206
- def factory (self , * args , ** kwargs ):
207
- original_result = original (self , * args , ** kwargs )
208
- _ERRORS_HANDLED [id (original_result )] = original_result
209
- return original_result
210
- return wraps (original )(factory )
211
-
212
- @classmethod
213
- def copy_handler (cls , original ):
214
- if inspect .iscoroutinefunction (original ):
215
- async def factory (self , * args , ** kwargs ):
216
- original_result = await original (self , * args , ** kwargs )
217
- if id (self ) in _ERRORS_HANDLED :
218
- _ERRORS_HANDLED [id (original_result )] = original_result
219
- return original_result
220
- else :
221
- def factory (self , * args , ** kwargs ):
222
- original_result = original (self , * args , ** kwargs )
223
- if id (self ) in _ERRORS_HANDLED :
224
- _ERRORS_HANDLED [id (original_result )] = original_result
225
- return original_result
226
- return wraps (original )(factory )
227
115
116
+ def pytest_configure (config ) -> None :
117
+ """
118
+ Hook to be executed on import.
228
119
229
- class _DesiredFunctionFound (BaseException ): # noqa: WPS418
230
- """Exception to raise when expected function is found."""
120
+ We use it define custom markers.
121
+ """
122
+ config .addinivalue_line (
123
+ 'markers' ,
124
+ (
125
+ 'returns_lawful: all tests under `check_all_laws` ' +
126
+ 'is marked this way, ' +
127
+ 'use `-m "not returns_lawful"` to skip them.'
128
+ ),
129
+ )
130
+
131
+
132
+ @pytest .fixture ()
133
+ def returns () -> Iterator [ReturnsAsserts ]:
134
+ """Returns class with helpers assertions to check containers."""
135
+ with _spy_error_handling () as errors_handled :
136
+ yield ReturnsAsserts (errors_handled )
137
+
138
+
139
+ @contextmanager
140
+ def _spy_error_handling () -> Iterator [_ErrorsHandled ]:
141
+ """Track error handling of containers."""
142
+ errs : _ErrorsHandled = {}
143
+ with ExitStack () as cleanup :
144
+ for container in _containers_to_patch ():
145
+ for method , patch in _ERROR_HANDLING_PATCHERS .items ():
146
+ cleanup .enter_context (mock .patch .object (
147
+ container ,
148
+ method ,
149
+ patch (getattr (container , method ), errs = errs ),
150
+ ))
151
+ yield errs
152
+
153
+
154
+ # delayed imports are needed to prevent messing up coverage
155
+ def _containers_to_patch () -> tuple :
156
+ from returns .context import (
157
+ RequiresContextFutureResult ,
158
+ RequiresContextIOResult ,
159
+ RequiresContextResult ,
160
+ )
161
+ from returns .future import FutureResult
162
+ from returns .io import IOFailure , IOSuccess
163
+ from returns .result import Failure , Success
164
+
165
+ return (
166
+ Success ,
167
+ Failure ,
168
+ IOSuccess ,
169
+ IOFailure ,
170
+ RequiresContextResult ,
171
+ RequiresContextIOResult ,
172
+ RequiresContextFutureResult ,
173
+ FutureResult ,
174
+ )
175
+
176
+
177
+ def _patched_error_handler (
178
+ original : _FunctionType , errs : _ErrorsHandled ,
179
+ ) -> _FunctionType :
180
+ if inspect .iscoroutinefunction (original ):
181
+ async def wrapper (self , * args , ** kwargs ):
182
+ original_result = await original (self , * args , ** kwargs )
183
+ errs [id (original_result )] = original_result
184
+ return original_result
185
+ else :
186
+ def wrapper (self , * args , ** kwargs ):
187
+ original_result = original (self , * args , ** kwargs )
188
+ errs [id (original_result )] = original_result
189
+ return original_result
190
+ return wraps (original )(wrapper ) # type: ignore
191
+
192
+
193
+ def _patched_error_copier (
194
+ original : _FunctionType , errs : _ErrorsHandled ,
195
+ ) -> _FunctionType :
196
+ if inspect .iscoroutinefunction (original ):
197
+ async def wrapper (self , * args , ** kwargs ):
198
+ original_result = await original (self , * args , ** kwargs )
199
+ if id (self ) in errs :
200
+ errs [id (original_result )] = original_result
201
+ return original_result
202
+ else :
203
+ def wrapper (self , * args , ** kwargs ):
204
+ original_result = original (self , * args , ** kwargs )
205
+ if id (self ) in errs :
206
+ errs [id (original_result )] = original_result
207
+ return original_result
208
+ return wraps (original )(wrapper ) # type: ignore
209
+
210
+
211
+ _ERROR_HANDLING_PATCHERS : Final = MappingProxyType ({
212
+ 'lash' : _patched_error_handler ,
213
+ 'map' : _patched_error_copier ,
214
+ 'alt' : _patched_error_copier ,
215
+ })
0 commit comments