Skip to content

Commit 82a40dc

Browse files
Make @safe((TypeError, ValueError)) variant (#1199)
* Changes `safe` decorator adding a overload that accepts a tuple of exceptions to handle * Updates `CHANGELOG.md` * Fixes naming * Fixes CI * Fixes CI * Fixes CI Co-authored-by: sobolevn <mail@sobolevn.me>
1 parent b5f7c18 commit 82a40dc

File tree

5 files changed

+225
-10
lines changed

5 files changed

+225
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ See [0Ver](https://0ver.org/).
3030
- Enables Pattern Matching support for `IOResult` container
3131
- Improves `hypothesis` plugin, now we detect
3232
when type cannot be constructed and give a clear error message
33+
- Adds the option to pass what exceptions `@safe` will handle
3334

3435

3536
## 0.16.0

docs/pages/result.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,22 @@ use :func:`future_safe <returns.future.future_safe>` instead.
9999
>>> str(divide(0))
100100
'<Failure: division by zero>'
101101
102+
If you want to `safe` handle only a set of exceptions:
103+
104+
.. code:: python
105+
106+
>>> @safe(exceptions=(ZeroDivisionError,)) # Other exceptions will be raised
107+
... def divide(number: int) -> float:
108+
... if number > 10:
109+
... raise ValueError('Too big')
110+
... return number / number
111+
112+
>>> assert divide(5) == Success(1.0)
113+
>>> assert divide(0).failure()
114+
>>> divide(15)
115+
Traceback (most recent call last):
116+
...
117+
ValueError: Too big
102118
103119
FAQ
104120
---

returns/result.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
List,
1010
NoReturn,
1111
Optional,
12+
Tuple,
1213
Type,
1314
TypeVar,
1415
Union,
16+
overload,
1517
)
1618

1719
from typing_extensions import ParamSpec, final
@@ -442,9 +444,33 @@ def failure(self) -> NoReturn:
442444

443445
# Decorators:
444446

447+
@overload
445448
def safe(
446449
function: Callable[_FuncParams, _ValueType],
447450
) -> Callable[_FuncParams, ResultE[_ValueType]]:
451+
"""Decorator to convert exception-throwing for any kind of Exception."""
452+
453+
454+
@overload
455+
def safe(
456+
exceptions: Tuple[Type[Exception], ...],
457+
) -> Callable[
458+
[Callable[_FuncParams, _ValueType]],
459+
Callable[_FuncParams, ResultE[_ValueType]],
460+
]:
461+
"""Decorator to convert exception-throwing just for a set of Exceptions."""
462+
463+
464+
def safe( # type: ignore # noqa: WPS234, C901
465+
function: Optional[Callable[_FuncParams, _ValueType]] = None,
466+
exceptions: Optional[Tuple[Type[Exception], ...]] = None,
467+
) -> Union[
468+
Callable[_FuncParams, ResultE[_ValueType]],
469+
Callable[
470+
[Callable[_FuncParams, _ValueType]],
471+
Callable[_FuncParams, ResultE[_ValueType]],
472+
],
473+
]:
448474
"""
449475
Decorator to convert exception-throwing function to ``Result`` container.
450476
@@ -466,16 +492,40 @@ def safe(
466492
>>> assert might_raise(1) == Success(1.0)
467493
>>> assert isinstance(might_raise(0), Result.failure_type)
468494
495+
You can also use it with explicit exception types as the first argument:
496+
497+
.. code:: python
498+
499+
>>> from returns.result import Result, Success, safe
500+
501+
>>> @safe(exceptions=(ZeroDivisionError,))
502+
... def might_raise(arg: int) -> float:
503+
... return 1 / arg
504+
505+
>>> assert might_raise(1) == Success(1.0)
506+
>>> assert isinstance(might_raise(0), Result.failure_type)
507+
508+
In this case, only exceptions that are explicitly
509+
listed are going to be caught.
510+
469511
Similar to :func:`returns.io.impure_safe`
470512
and :func:`returns.future.future_safe` decorators.
471513
"""
472-
@wraps(function)
473-
def decorator(
474-
*args: _FuncParams.args,
475-
**kwargs: _FuncParams.kwargs,
476-
) -> ResultE[_ValueType]:
477-
try:
478-
return Success(function(*args, **kwargs))
479-
except Exception as exc:
480-
return Failure(exc)
481-
return decorator
514+
def factory(
515+
inner_function: Callable[_FuncParams, _ValueType],
516+
inner_exceptions: Tuple[Type[Exception], ...],
517+
) -> Callable[_FuncParams, ResultE[_ValueType]]:
518+
@wraps(inner_function)
519+
def decorator(*args: _FuncParams.args, **kwargs: _FuncParams.kwargs):
520+
try:
521+
return Success(inner_function(*args, **kwargs))
522+
except inner_exceptions as exc:
523+
return Failure(exc)
524+
return decorator
525+
526+
if callable(function):
527+
return factory(function, (Exception,))
528+
if isinstance(function, tuple):
529+
exceptions = function # type: ignore
530+
function = None
531+
return lambda function: factory(function, exceptions) # type: ignore

tests/test_result/test_result_functions/test_safe.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Union
2+
3+
import pytest
14

25
from returns.result import Success, safe
36

@@ -7,6 +10,18 @@ def _function(number: int) -> float:
710
return number / number
811

912

13+
@safe(exceptions=(ZeroDivisionError,))
14+
def _function_two(number: Union[int, str]) -> float:
15+
assert isinstance(number, int)
16+
return number / number
17+
18+
19+
@safe((ZeroDivisionError,)) # no name
20+
def _function_three(number: Union[int, str]) -> float:
21+
assert isinstance(number, int)
22+
return number / number
23+
24+
1025
def test_safe_success():
1126
"""Ensures that safe decorator works correctly for Success case."""
1227
assert _function(1) == Success(1.0)
@@ -16,3 +31,18 @@ def test_safe_failure():
1631
"""Ensures that safe decorator works correctly for Failure case."""
1732
failed = _function(0)
1833
assert isinstance(failed.failure(), ZeroDivisionError)
34+
35+
36+
def test_safe_failure_with_expected_error():
37+
"""Ensures that safe decorator works correctly for Failure case."""
38+
failed = _function_two(0)
39+
assert isinstance(failed.failure(), ZeroDivisionError)
40+
41+
failed2 = _function_three(0)
42+
assert isinstance(failed2.failure(), ZeroDivisionError)
43+
44+
45+
def test_safe_failure_with_non_expected_error():
46+
"""Ensures that safe decorator works correctly for Failure case."""
47+
with pytest.raises(AssertionError):
48+
_function_two('0')

typesafety/test_result/test_safe.yml

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@
1010
reveal_type(test) # N: Revealed type is "def () -> returns.result.Result[builtins.int, builtins.Exception]"
1111
1212
13+
- case: safe_decorator_passing_exceptions_no_params
14+
disable_cache: false
15+
main: |
16+
from returns.result import safe
17+
18+
@safe((ValueError,))
19+
def test() -> int:
20+
return 1
21+
22+
reveal_type(test) # N: Revealed type is "def () -> returns.result.Result[builtins.int, builtins.Exception]"
23+
24+
@safe(exceptions=(ValueError,))
25+
def test2() -> int:
26+
return 1
27+
28+
reveal_type(test2) # N: Revealed type is "def () -> returns.result.Result[builtins.int, builtins.Exception]"
29+
30+
1331
- case: safe_composition_no_params
1432
disable_cache: false
1533
main: |
@@ -21,6 +39,17 @@
2139
reveal_type(safe(test)) # N: Revealed type is "def () -> returns.result.Result[builtins.int, builtins.Exception]"
2240
2341
42+
- case: safe_composition_passing_exceptions_no_params
43+
disable_cache: false
44+
main: |
45+
from returns.result import safe
46+
47+
def test() -> int:
48+
return 1
49+
50+
reveal_type(safe((EOFError,))(test)) # N: Revealed type is "def () -> returns.result.Result[builtins.int, builtins.Exception]"
51+
52+
2453
- case: safe_decorator_with_args
2554
disable_cache: false
2655
main: |
@@ -34,6 +63,19 @@
3463
reveal_type(test) # N: Revealed type is "def (first: builtins.int, second: Union[builtins.str, None] =, *, kw: builtins.bool =) -> returns.result.Result[builtins.int, builtins.Exception]"
3564
3665
66+
- case: safe_decorator_passing_exceptions_with_args
67+
disable_cache: false
68+
main: |
69+
from typing import Optional
70+
from returns.result import safe
71+
72+
@safe((ValueError, EOFError))
73+
def test(first: int, second: Optional[str] = None, *, kw: bool = True) -> int:
74+
return 1
75+
76+
reveal_type(test) # N: Revealed type is "def (first: builtins.int, second: Union[builtins.str, None] =, *, kw: builtins.bool =) -> returns.result.Result[builtins.int, builtins.Exception]"
77+
78+
3779
- case: safe_composition_with_args
3880
disable_cache: false
3981
main: |
@@ -46,6 +88,18 @@
4688
reveal_type(safe(test)) # N: Revealed type is "def (first: builtins.int, second: Union[builtins.str, None] =, *, kw: builtins.bool =) -> returns.result.Result[builtins.int, builtins.Exception]"
4789
4890
91+
- case: safe_composition_passing_exceptions_with_args
92+
disable_cache: false
93+
main: |
94+
from typing import Optional
95+
from returns.result import safe
96+
97+
def test(first: int, second: Optional[str] = None, *, kw: bool = True) -> int:
98+
return 1
99+
100+
reveal_type(safe((ValueError,))(test)) # N: Revealed type is "def (first: builtins.int, second: Union[builtins.str, None] =, *, kw: builtins.bool =) -> returns.result.Result[builtins.int, builtins.Exception]"
101+
102+
49103
- case: safe_regression333
50104
disable_cache: false
51105
main: |
@@ -59,6 +113,19 @@
59113
reveal_type(send) # N: Revealed type is "def (text: builtins.str) -> returns.result.Result[Any, builtins.Exception]"
60114
61115
116+
- case: safe_passing_exceptions_regression333
117+
disable_cache: false
118+
main: |
119+
from returns.result import safe
120+
from typing import Any
121+
122+
@safe((Exception,))
123+
def send(text: str) -> Any:
124+
return "test"
125+
126+
reveal_type(send) # N: Revealed type is "def (text: builtins.str) -> returns.result.Result[Any, builtins.Exception]"
127+
128+
62129
- case: safe_regression641
63130
disable_cache: false
64131
main: |
@@ -72,6 +139,19 @@
72139
reveal_type(safe(tap(Response.raise_for_status))) # N: Revealed type is "def (main.Response*) -> returns.result.Result[main.Response, builtins.Exception]"
73140
74141
142+
- case: safe_passing_exceptions_regression641
143+
disable_cache: false
144+
main: |
145+
from returns.result import safe
146+
from returns.functions import tap
147+
148+
class Response(object):
149+
def raise_for_status(self) -> None:
150+
...
151+
152+
reveal_type(safe((EOFError,))(tap(Response.raise_for_status))) # N: Revealed type is "def (main.Response*) -> returns.result.Result[main.Response, builtins.Exception]"
153+
154+
75155
- case: safe_decorator_with_args_kwargs
76156
disable_cache: false
77157
main: |
@@ -84,6 +164,18 @@
84164
reveal_type(test) # N: Revealed type is "def (*args: Any, **kwargs: Any) -> returns.result.Result[builtins.int, builtins.Exception]"
85165
86166
167+
- case: safe_decorator_passing_exceptions_with_args_kwargs
168+
disable_cache: false
169+
main: |
170+
from returns.result import safe
171+
172+
@safe((EOFError,))
173+
def test(*args, **kwargs) -> int:
174+
return 1
175+
176+
reveal_type(test) # N: Revealed type is "def (*args: Any, **kwargs: Any) -> returns.result.Result[builtins.int, builtins.Exception]"
177+
178+
87179
- case: safe_decorator_with_args_kwargs
88180
disable_cache: false
89181
main: |
@@ -96,6 +188,18 @@
96188
reveal_type(test) # N: Revealed type is "def (*args: builtins.int, **kwargs: builtins.str) -> returns.result.Result[builtins.int, builtins.Exception]"
97189
98190
191+
- case: safe_decorator_passing_exceptions_with_args_kwargs
192+
disable_cache: false
193+
main: |
194+
from returns.result import safe
195+
196+
@safe((Exception,))
197+
def test(*args: int, **kwargs: str) -> int:
198+
return 1
199+
200+
reveal_type(test) # N: Revealed type is "def (*args: builtins.int, **kwargs: builtins.str) -> returns.result.Result[builtins.int, builtins.Exception]"
201+
202+
99203
- case: safe_decorator_composition
100204
disable_cache: false
101205
main: |
@@ -108,3 +212,17 @@
108212
return 1
109213
110214
reveal_type(test) # N: Revealed type is "def (*args: builtins.int, **kwargs: builtins.str) -> returns.io.IO[returns.result.Result*[builtins.int*, builtins.Exception]]"
215+
216+
217+
- case: safe_decorator_passing_exceptions_composition
218+
disable_cache: false
219+
main: |
220+
from returns.io import impure
221+
from returns.result import safe
222+
223+
@impure
224+
@safe((ValueError,))
225+
def test(*args: int, **kwargs: str) -> int:
226+
return 1
227+
228+
reveal_type(test) # N: Revealed type is "def (*args: builtins.int, **kwargs: builtins.str) -> returns.io.IO[returns.result.Result*[builtins.int*, builtins.Exception]]"

0 commit comments

Comments
 (0)