Skip to content

Commit 1ab51a2

Browse files
authored
Add trampolines support, refs #413 (#1668)
* Add trampolines support, refs #413 * More tests
1 parent dff119a commit 1ab51a2

File tree

9 files changed

+685
-528
lines changed

9 files changed

+685
-528
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@ on:
1111
- 'docs/requirements.txt'
1212
workflow_dispatch:
1313

14+
permissions:
15+
contents: read
16+
1417
concurrency:
1518
group: test-${{ github.head_ref || github.run_id }}
1619
cancel-in-progress: true
1720

1821
jobs:
1922
build:
20-
runs-on: ubuntu-20.04
23+
runs-on: ubuntu-latest
2124
strategy:
2225
fail-fast: false
2326
matrix:
24-
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
27+
python-version: ['3.8', '3.9', '3.10', '3.11']
2528
task: ['tests', 'typesafety']
2629

2730
steps:

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@ incremental in minor, bugfixes only are patches.
66
See [0Ver](https://0ver.org/).
77

88

9+
## 0.22.0 WIP
10+
11+
### Features
12+
13+
- *Breaking*: Drops `python3.7` support
14+
- Adds `trampolines` support
15+
16+
917
## 0.21.0
1018

1119
### Features

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Contents
3939
pages/do-notation.rst
4040
pages/functions.rst
4141
pages/curry.rst
42+
pages/trampolines.rst
4243
pages/types.rst
4344

4445
.. toctree::

docs/pages/trampolines.rst

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
.. _trampolines:
2+
3+
Trampolines
4+
===========
5+
6+
Python does not support TCO (tail call optimization), so any recursion-based
7+
algorithms become dangerous.
8+
9+
We cannot be sure that
10+
they won't cause ``RecursionError`` on deeply nested data.
11+
12+
Here's why we need trampolines: they allow to replicate tail call optimization
13+
by wrapping function calls into :class:`returns.trampolines.Trampoline` objects,
14+
making recursion-based function *always* safe.
15+
16+
Example:
17+
18+
.. code:: python
19+
20+
>>> from typing import Union, List
21+
>>> from returns.trampolines import Trampoline, trampoline
22+
23+
>>> @trampoline
24+
... def accumulate(
25+
... numbers: List[int],
26+
... acc: int = 0,
27+
... ) -> Union[int, Trampoline[int]]:
28+
... if not numbers:
29+
... return acc
30+
... number = number = numbers.pop()
31+
... return Trampoline(accumulate, numbers, acc + number)
32+
33+
>>> assert accumulate([1, 2]) == 3
34+
>>> assert accumulate([1, 2, 3]) == 6
35+
36+
The following function is still fully type-safe:
37+
- ``Trampoline`` object uses ``ParamSpec`` to be sure that passed arguments are correct
38+
- Final return type of the function is narrowed to contain only an original type (without ``Trampoline`` implementation detail)
39+
40+
API Reference
41+
-------------
42+
43+
.. automodule:: returns.trampolines
44+
:members:

poetry.lock

Lines changed: 435 additions & 525 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ _ = "returns.contrib.hypothesis._entrypoint"
4646

4747

4848
[tool.poetry.dependencies]
49-
python = "^3.7"
49+
python = "^3.8"
5050

5151
typing-extensions = ">=4.0,<5.0"
5252
mypy = { version = "^1.4.0", optional = true }

returns/trampolines.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from functools import wraps
2+
from typing import Callable, Generic, TypeVar, Union
3+
4+
from typing_extensions import ParamSpec, final
5+
6+
_ReturnType = TypeVar('_ReturnType')
7+
_FuncParams = ParamSpec('_FuncParams')
8+
9+
10+
@final
11+
class Trampoline(Generic[_ReturnType]):
12+
"""
13+
Represents a wrapped function call.
14+
15+
Primitive to convert recursion into an actual object.
16+
"""
17+
18+
__slots__ = ('func', 'args', 'kwargs')
19+
20+
def __init__( # noqa: WPS451
21+
self,
22+
func: Callable[_FuncParams, _ReturnType],
23+
/, # We use pos-only here to be able to store `kwargs` correctly.
24+
*args: _FuncParams.args,
25+
**kwargs: _FuncParams.kwargs,
26+
) -> None:
27+
"""Save function and given arguments."""
28+
self.func = getattr(func, '_orig_func', func)
29+
self.args = args
30+
self.kwargs = kwargs
31+
32+
def __call__(self) -> _ReturnType:
33+
"""Call wrapped function with given arguments."""
34+
return self.func(*self.args, **self.kwargs)
35+
36+
37+
def trampoline(
38+
func: Callable[_FuncParams, Union[_ReturnType, Trampoline[_ReturnType]]],
39+
) -> Callable[_FuncParams, _ReturnType]:
40+
"""
41+
Convert functions using recursion to regular functions.
42+
43+
Trampolines allow to unwrap recursion into a regular ``while`` loop,
44+
which does not raise any ``RecursionError`` ever.
45+
46+
Since python does not have TCO (tail call optimization),
47+
we have to provide this helper.
48+
49+
This is done by wrapping real function calls into
50+
:class:`returns.trampolines.Trampoline` objects:
51+
52+
.. code:: python
53+
54+
>>> from typing import Union
55+
>>> from returns.trampolines import Trampoline, trampoline
56+
57+
>>> @trampoline
58+
... def get_factorial(
59+
... for_number: int,
60+
... current_number: int = 0,
61+
... acc: int = 1,
62+
... ) -> Union[int, Trampoline[int]]:
63+
... assert for_number >= 0
64+
... if for_number <= current_number:
65+
... return acc
66+
... return Trampoline(
67+
... get_factorial,
68+
... for_number,
69+
... current_number=current_number + 1,
70+
... acc=acc * (current_number + 1),
71+
... )
72+
73+
>>> assert get_factorial(0) == 1
74+
>>> assert get_factorial(3) == 6
75+
>>> assert get_factorial(4) == 24
76+
77+
See also:
78+
- eli.thegreenplace.net/2017/on-recursion-continuations-and-trampolines
79+
- https://en.wikipedia.org/wiki/Tail_call
80+
81+
"""
82+
83+
@wraps(func)
84+
def decorator(
85+
*args: _FuncParams.args,
86+
**kwargs: _FuncParams.kwargs,
87+
) -> _ReturnType:
88+
trampoline_result = func(*args, **kwargs)
89+
while isinstance(trampoline_result, Trampoline):
90+
trampoline_result = trampoline_result()
91+
return trampoline_result
92+
93+
decorator._orig_func = func # type: ignore[attr-defined] # noqa: WPS437
94+
return decorator
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import sys
2+
from typing import Callable, Iterator, Union
3+
4+
import pytest
5+
6+
from returns.trampolines import Trampoline, trampoline
7+
8+
9+
@trampoline
10+
def _accumulate(
11+
numbers: Iterator[int],
12+
acc: int = 0,
13+
) -> Union[int, Trampoline[int]]:
14+
number = next(numbers, None)
15+
if number is None:
16+
return acc
17+
return Trampoline(_accumulate, numbers, acc + number)
18+
19+
20+
@trampoline
21+
def _with_func_kwarg(
22+
numbers: Iterator[int],
23+
func: int = 0, # we need this name to match `Trampoline` constructor
24+
) -> Union[int, Trampoline[int]]:
25+
number = next(numbers, None)
26+
if number is None:
27+
return func
28+
return Trampoline(_with_func_kwarg, numbers, func=func + number)
29+
30+
31+
@pytest.mark.parametrize('trampoline_func', [
32+
_accumulate,
33+
_with_func_kwarg,
34+
])
35+
@pytest.mark.parametrize('given_range', [
36+
range(0),
37+
range(1),
38+
range(2),
39+
range(5),
40+
range(sys.getrecursionlimit()),
41+
range(sys.getrecursionlimit() + 1),
42+
])
43+
def test_recursion_limit(
44+
trampoline_func: Callable[[Iterator[int]], int],
45+
given_range: range,
46+
) -> None:
47+
"""Test that accumulation is correct and no ``RecursionError`` happens."""
48+
accumulated = trampoline_func(iter(given_range))
49+
50+
assert accumulated == sum(given_range)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
- case: trampoline_missing_args
2+
disable_cache: false
3+
main: |
4+
from typing import List, Union
5+
from returns.trampolines import Trampoline, trampoline
6+
7+
@trampoline
8+
def _accumulate(
9+
numbers: List[int],
10+
acc: int = 0,
11+
) -> Union[int, Trampoline[int]]:
12+
return Trampoline(_accumulate)
13+
out: |
14+
main:9: error: Missing positional argument "numbers" in call to "Trampoline" [call-arg]
15+
16+
17+
- case: trampoline_wrong_args
18+
disable_cache: false
19+
main: |
20+
from typing import List, Union
21+
from returns.trampolines import Trampoline, trampoline
22+
23+
@trampoline
24+
def _accumulate(
25+
numbers: List[int],
26+
acc: int = 0,
27+
) -> Union[int, Trampoline[int]]:
28+
return Trampoline(_accumulate, ['a'], 'b')
29+
out: |
30+
main:9: error: List item 0 has incompatible type "str"; expected "int" [list-item]
31+
main:9: error: Argument 3 to "Trampoline" has incompatible type "str"; expected "int" [arg-type]
32+
33+
34+
- case: trampoline_return_type
35+
disable_cache: false
36+
main: |
37+
from typing import List, Union
38+
from returns.trampolines import Trampoline, trampoline
39+
40+
@trampoline
41+
def _accumulate(
42+
numbers: List[int],
43+
acc: int = 0,
44+
) -> Union[int, Trampoline[int]]:
45+
return Trampoline(_accumulate, [1], 2)
46+
47+
reveal_type(_accumulate([1, 2])) # N: Revealed type is "builtins.int"

0 commit comments

Comments
 (0)