Skip to content

Commit 71f4212

Browse files
Fix typing of wrapper. (#391)
* Fix typing of wrapper. * Line wrapping. * Comma. Co-authored-by: Sam Bull <git@sambull.org>
1 parent 5eeeb7c commit 71f4212

File tree

2 files changed

+41
-29
lines changed

2 files changed

+41
-29
lines changed

aiohttp_jinja2/__init__.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Iterable,
1111
Mapping,
1212
Optional,
13+
TypeVar,
1314
Union,
1415
cast,
1516
overload,
@@ -33,35 +34,38 @@
3334

3435
_TemplateReturnType = Awaitable[Union[web.StreamResponse, Mapping[str, Any]]]
3536
_SimpleTemplateHandler = Callable[[web.Request], _TemplateReturnType]
36-
_MethodTemplateHandler = Callable[[Any, web.Request], _TemplateReturnType]
37-
_ViewTemplateHandler = Callable[[AbstractView], _TemplateReturnType]
38-
_TemplateHandler = Union[
39-
_SimpleTemplateHandler, _MethodTemplateHandler, _ViewTemplateHandler
40-
]
41-
4237
_ContextProcessor = Callable[[web.Request], Awaitable[Dict[str, Any]]]
4338

39+
_T = TypeVar("_T")
40+
_AbstractView = TypeVar("_AbstractView", bound=AbstractView)
41+
4442
if sys.version_info >= (3, 8):
4543
from typing import Protocol
4644

47-
class _TemplateWrapped(Protocol):
45+
class _TemplateWrapper(Protocol):
4846
@overload
49-
async def __call__(self, request: web.Request) -> web.StreamResponse:
47+
def __call__(
48+
self, func: _SimpleTemplateHandler
49+
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
5050
...
5151

5252
@overload
53-
async def __call__(self, view: AbstractView) -> web.StreamResponse:
53+
def __call__(
54+
self, func: Callable[[_AbstractView], _TemplateReturnType]
55+
) -> Callable[[_AbstractView], Awaitable[web.StreamResponse]]:
5456
...
5557

5658
@overload
57-
async def __call__(
58-
self, _self: Any, request: web.Request
59-
) -> web.StreamResponse:
59+
def __call__(
60+
self, func: Callable[[_T, web.Request], _TemplateReturnType]
61+
) -> Callable[[_T, web.Request], Awaitable[web.StreamResponse]]:
6062
...
6163

6264

6365
else:
64-
_TemplateWrapped = Callable[..., web.StreamResponse]
66+
_TemplateHandler = Callable[..., _TemplateReturnType]
67+
_WebHandler = Callable[..., Awaitable[web.StreamResponse]]
68+
_TemplateWrapper = Callable[[_TemplateHandler], _WebHandler]
6569

6670

6771
def setup(
@@ -151,20 +155,28 @@ def template(
151155
app_key: str = APP_KEY,
152156
encoding: str = "utf-8",
153157
status: int = 200,
154-
) -> Callable[[_TemplateHandler], _TemplateWrapped]:
155-
def wrapper(func: _TemplateHandler) -> _TemplateWrapped:
156-
@overload
157-
async def wrapped(request: web.Request) -> web.StreamResponse:
158-
...
159-
160-
@overload
161-
async def wrapped(view: AbstractView) -> web.StreamResponse:
162-
...
163-
164-
@overload
165-
async def wrapped(_self: Any, request: web.Request) -> web.StreamResponse:
166-
...
167-
158+
) -> _TemplateWrapper:
159+
@overload
160+
def wrapper(
161+
func: _SimpleTemplateHandler,
162+
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
163+
...
164+
165+
@overload
166+
def wrapper(
167+
func: Callable[[_AbstractView], _TemplateReturnType]
168+
) -> Callable[[_AbstractView], Awaitable[web.StreamResponse]]:
169+
...
170+
171+
@overload
172+
def wrapper(
173+
func: Callable[[_T, web.Request], _TemplateReturnType]
174+
) -> Callable[[_T, web.Request], Awaitable[web.StreamResponse]]:
175+
...
176+
177+
def wrapper(
178+
func: Callable[..., _TemplateReturnType]
179+
) -> Callable[..., Awaitable[web.StreamResponse]]:
168180
@functools.wraps(func)
169181
async def wrapped(*args: Any) -> web.StreamResponse:
170182
if asyncio.iscoroutinefunction(func):

tests/test_simple_renderer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ async def func(request: web.Request) -> web.Response:
201201

202202

203203
async def test_render_not_mapping():
204-
@aiohttp_jinja2.template("tmpl.jinja2")
205-
async def func(request):
204+
@aiohttp_jinja2.template("tmpl.jinja2") # type: ignore[arg-type]
205+
async def func(request: web.Request) -> int:
206206
return 123
207207

208208
app = web.Application()

0 commit comments

Comments
 (0)