Skip to content

Commit 942a80b

Browse files
committed
Added dependency to inject swagger in the route.
1 parent b2e6742 commit 942a80b

File tree

6 files changed

+217
-13
lines changed

6 files changed

+217
-13
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,15 @@ async def my_handler(var: str = Depends(Path())):
357357
```
358358

359359

360-
## Overridiing dependencies
360+
## ExtraOpenAPI
361+
362+
This dependency is used to add additional swagger fields to the endpoint's swagger
363+
that is using this dependency. It might be even indirect dependency.
364+
365+
You can check how this thing can be used in our [examples/swagger_auth.py](https://github.com/taskiq-python/aiohttp-deps/tree/master/examples/swagger_auth.py).
366+
367+
368+
## Overriding dependencies
361369

362370
Sometimes for tests you don't want to calculate actual functions
363371
and you want to pass another functions instead.

aiohttp_deps/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from aiohttp_deps.keys import DEPENDENCY_OVERRIDES_KEY, VALUES_OVERRIDES_KEY
77
from aiohttp_deps.router import Router
88
from aiohttp_deps.swagger import extra_openapi, openapi_response, setup_swagger
9-
from aiohttp_deps.utils import Form, Header, Json, Path, Query
9+
from aiohttp_deps.utils import ExtraOpenAPI, Form, Header, Json, Path, Query
1010
from aiohttp_deps.view import View
1111

1212
__all__ = [
@@ -21,6 +21,7 @@
2121
"Query",
2222
"Form",
2323
"Path",
24+
"ExtraOpenAPI",
2425
"openapi_response",
2526
"DEPENDENCY_OVERRIDES_KEY",
2627
"VALUES_OVERRIDES_KEY",

aiohttp_deps/swagger.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Awaitable,
77
Callable,
88
Dict,
9+
List,
910
Optional,
1011
Tuple,
1112
TypeVar,
@@ -19,7 +20,7 @@
1920

2021
from aiohttp_deps.initializer import InjectableFuncHandler, InjectableViewHandler
2122
from aiohttp_deps.keys import SWAGGER_SCHEMA_KEY
22-
from aiohttp_deps.utils import Form, Header, Json, Path, Query
23+
from aiohttp_deps.utils import ExtraOpenAPI, Form, Header, Json, Path, Query
2324

2425
_T = TypeVar("_T")
2526

@@ -119,6 +120,7 @@ def _add_route_def( # noqa: C901
119120
openapi_schema["components"]["schemas"].update(extra_openapi_schemas)
120121

121122
params: Dict[Tuple[str, str], Any] = {}
123+
updaters: List[Callable[[Dict[str, Any]], None]] = []
122124

123125
def _insert_in_params(data: Dict[str, Any]) -> None:
124126
element = params.get((data["name"], data["in"]))
@@ -191,8 +193,18 @@ def _insert_in_params(data: Dict[str, Any]) -> None:
191193
"schema": schema,
192194
},
193195
)
196+
elif isinstance(dependency.dependency, ExtraOpenAPI):
197+
if dependency.dependency.updater is not None:
198+
updaters.append(dependency.dependency.updater)
199+
if dependency.dependency.extra_openapi is not None:
200+
extra_openapi = always_merger.merge(
201+
extra_openapi,
202+
dependency.dependency.extra_openapi,
203+
)
194204

195205
route_info["parameters"] = list(params.values())
206+
for updater in updaters:
207+
updater(route_info)
196208
openapi_schema["paths"][route.resource.canonical].update(
197209
{method.lower(): always_merger.merge(route_info, extra_openapi)},
198210
)
@@ -207,6 +219,7 @@ def setup_swagger( # noqa: C901
207219
title: str = "AioHTTP",
208220
description: Optional[str] = None,
209221
version: str = "1.0.0",
222+
extra_openapi: Optional[Dict[str, Any]] = None,
210223
) -> Callable[[web.Application], Awaitable[None]]:
211224
"""
212225
Add swagger documentation.
@@ -230,8 +243,11 @@ def setup_swagger( # noqa: C901
230243
:param title: Title of an application.
231244
:param description: description of an application.
232245
:param version: version of an application.
246+
:param extra_openapi: extra openAPI dict that will be merged with generated schema.
233247
:return: startup event handler.
234248
"""
249+
if extra_openapi is None:
250+
extra_openapi = {}
235251

236252
async def event_handler(app: web.Application) -> None: # noqa: C901
237253
openapi_schema = {
@@ -252,12 +268,12 @@ async def event_handler(app: web.Application) -> None: # noqa: C901
252268
if hide_options and route.method.upper() == "OPTIONS":
253269
continue
254270
if isinstance(route._handler, InjectableFuncHandler):
255-
extra_openapi = getattr(
271+
route_extra_openapi = getattr(
256272
route._handler.original_handler,
257273
"__extra_openapi__",
258274
{},
259275
)
260-
extra_schemas = getattr(
276+
route_extra_schemas = getattr(
261277
route._handler.original_handler,
262278
"__extra_openapi_schemas__",
263279
{},
@@ -268,8 +284,8 @@ async def event_handler(app: web.Application) -> None: # noqa: C901
268284
route, # type: ignore
269285
route.method,
270286
route._handler.graph,
271-
extra_openapi=extra_openapi,
272-
extra_openapi_schemas=extra_schemas,
287+
extra_openapi=route_extra_openapi,
288+
extra_openapi_schemas=route_extra_schemas,
273289
)
274290
except Exception as exc: # pragma: no cover
275291
logger.warn(
@@ -280,12 +296,12 @@ async def event_handler(app: web.Application) -> None: # noqa: C901
280296

281297
elif isinstance(route._handler, InjectableViewHandler):
282298
for key, graph in route._handler.graph_map.items():
283-
extra_openapi = getattr(
299+
route_extra_openapi = getattr(
284300
getattr(route._handler.original_handler, key),
285301
"__extra_openapi__",
286302
{},
287303
)
288-
extra_schemas = getattr(
304+
route_extra_schemas = getattr(
289305
getattr(route._handler.original_handler, key),
290306
"__extra_openapi_schemas__",
291307
{},
@@ -296,8 +312,8 @@ async def event_handler(app: web.Application) -> None: # noqa: C901
296312
route, # type: ignore
297313
key,
298314
graph,
299-
extra_openapi=extra_openapi,
300-
extra_openapi_schemas=extra_schemas,
315+
extra_openapi=route_extra_openapi,
316+
extra_openapi_schemas=route_extra_schemas,
301317
)
302318
except Exception as exc: # pragma: no cover
303319
logger.warn(
@@ -306,7 +322,7 @@ async def event_handler(app: web.Application) -> None: # noqa: C901
306322
exc_info=True,
307323
)
308324

309-
app[SWAGGER_SCHEMA_KEY] = openapi_schema
325+
app[SWAGGER_SCHEMA_KEY] = always_merger.merge(openapi_schema, extra_openapi)
310326

311327
app.router.add_get(
312328
schema_url,

aiohttp_deps/utils.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
import json
3-
from typing import Any, Optional, Union
3+
from typing import Any, Callable, Dict, Optional, Union
44

55
import pydantic
66
from aiohttp import web
@@ -344,3 +344,37 @@ def __call__(
344344
headers={"Content-Type": "application/json"},
345345
text=json.dumps(errors),
346346
) from err
347+
348+
349+
class ExtraOpenAPI:
350+
"""
351+
Update swagger for the endpoint.
352+
353+
You can use this dependency to add swagger to an endpoint from
354+
a dependency. It's useful when you want to add some extra swagger
355+
to the route when some specific dependency is used by it.
356+
"""
357+
358+
def __init__(
359+
self,
360+
extra_openapi: Optional[Dict[str, Any]] = None,
361+
swagger_updater: Optional[Callable[[Dict[str, Any]], None]] = None,
362+
) -> None:
363+
"""
364+
Initialize the dependency.
365+
366+
:param swagger_updater: function that takes final swagger endpoint and
367+
updates it.
368+
:param extra_swagger: extra swagger to add to the endpoint. This one might
369+
override other extra_swagger on the endpoint.
370+
"""
371+
self.updater = swagger_updater
372+
self.extra_openapi = extra_openapi
373+
374+
def __call__(self) -> None:
375+
"""
376+
This method is called when dependency is resolved.
377+
378+
It's empty, becuase it's used by the swagger function and
379+
there is no actual dependency.
380+
"""

examples/swagger_auth.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import base64
2+
3+
from aiohttp import web
4+
from pydantic import BaseModel
5+
6+
from aiohttp_deps import Depends, ExtraOpenAPI, Header, Router, init, setup_swagger
7+
8+
9+
class UserInfo(BaseModel):
10+
"""Abstract user model."""
11+
12+
id: int
13+
name: str
14+
password: str
15+
16+
17+
router = Router()
18+
19+
# Here we create a simple user storage.
20+
# In real-world applications, you would use a database.
21+
users = {
22+
"john": UserInfo(id=1, name="John Doe", password="123"), # noqa: S106
23+
"caren": UserInfo(id=2, name="Caren Doe", password="321"), # noqa: S106
24+
}
25+
26+
27+
def get_current_user(
28+
# Current auth header.
29+
authorization: str = Depends(Header()),
30+
# We don't need a name to this variable,
31+
# because it will only affect the API schema,
32+
# but won't be used in runtime.
33+
_: None = Depends(
34+
ExtraOpenAPI(
35+
extra_openapi={
36+
"security": [{"basicAuth": []}],
37+
},
38+
),
39+
),
40+
) -> UserInfo:
41+
"""This function checks if the user authorized."""
42+
# Here we check if the authorization header is present.
43+
if not authorization.startswith("Basic"):
44+
raise web.HTTPUnauthorized(reason="Unsupported authorization type")
45+
# We decode credentials from the header.
46+
# And check if the user exists.
47+
creds = base64.b64decode(authorization.split(" ")[1]).decode()
48+
username, password = creds.split(":")
49+
found_user = users.get(username)
50+
if found_user is None:
51+
raise web.HTTPUnauthorized(reason="User not found")
52+
if found_user.password != password:
53+
raise web.HTTPUnauthorized(reason="Invalid password")
54+
return found_user
55+
56+
57+
@router.get("/")
58+
async def index(current_user: UserInfo = Depends(get_current_user)) -> web.Response:
59+
"""Index handler returns current user."""
60+
return web.json_response(current_user.model_dump(mode="json"))
61+
62+
63+
app = web.Application()
64+
app.router.add_routes(router)
65+
app.on_startup.extend(
66+
[
67+
init,
68+
setup_swagger(
69+
# Here we add security schemes used
70+
# to authorize users.
71+
extra_openapi={
72+
"components": {
73+
"securitySchemes": {
74+
# We only support basic auth.
75+
"basicAuth": {
76+
"type": "http",
77+
"scheme": "basic",
78+
},
79+
},
80+
},
81+
},
82+
),
83+
],
84+
)
85+
86+
if __name__ == "__main__":
87+
web.run_app(app)

tests/test_swagger.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from aiohttp_deps import (
1010
Depends,
11+
ExtraOpenAPI,
1112
Form,
1213
Header,
1314
Json,
@@ -780,3 +781,60 @@ async def my_handler() -> None:
780781
schema = await response.json()
781782
assert "get" in schema["paths"]["/"]
782783
assert method.lower() not in schema["paths"]["/"]
784+
785+
786+
@pytest.mark.anyio
787+
async def test_extra_openapi_dep_func(
788+
my_app: web.Application,
789+
aiohttp_client: ClientGenerator,
790+
) -> None:
791+
openapi_url = "/my_api_def.json"
792+
my_app.on_startup.append(setup_swagger(schema_url=openapi_url))
793+
794+
async def dep(
795+
_: None = Depends(ExtraOpenAPI(extra_openapi={"responses": {"200": {}}})),
796+
) -> None:
797+
"""Test dep that adds swagger through a dependency."""
798+
799+
async def my_handler(_: None = Depends(dep)) -> None:
800+
"""Nothing."""
801+
802+
my_app.router.add_get("/a", my_handler)
803+
804+
client = await aiohttp_client(my_app)
805+
resp = await client.get(openapi_url)
806+
assert resp.status == 200
807+
resp_json = await resp.json()
808+
809+
handler_info = resp_json["paths"]["/a"]["get"]
810+
assert handler_info["responses"] == {"200": {}}
811+
812+
813+
@pytest.mark.anyio
814+
async def test_extra_openapi_dep_updater_func(
815+
my_app: web.Application,
816+
aiohttp_client: ClientGenerator,
817+
) -> None:
818+
openapi_url = "/my_api_def.json"
819+
my_app.on_startup.append(setup_swagger(schema_url=openapi_url))
820+
821+
def schema_updater(schema: Dict[str, Any]) -> None:
822+
schema["responses"] = {"200": {}}
823+
824+
async def dep(
825+
_: None = Depends(ExtraOpenAPI(swagger_updater=schema_updater)),
826+
) -> None:
827+
"""Test dep that adds swagger through a dependency."""
828+
829+
async def my_handler(_: None = Depends(dep)) -> None:
830+
"""Nothing."""
831+
832+
my_app.router.add_get("/a", my_handler)
833+
834+
client = await aiohttp_client(my_app)
835+
resp = await client.get(openapi_url)
836+
assert resp.status == 200
837+
resp_json = await resp.json()
838+
839+
handler_info = resp_json["paths"]["/a"]["get"]
840+
assert handler_info["responses"] == {"200": {}}

0 commit comments

Comments
 (0)