Skip to content

Commit e94d734

Browse files
committed
Fixed OpenApi for multiple Module Routers
1 parent 6fffac4 commit e94d734

File tree

3 files changed

+49
-63
lines changed

3 files changed

+49
-63
lines changed

ellar/openapi/builder.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,19 @@ def _get_openapi_route_document_models(self, app: App) -> t.List[OpenAPIRoute]:
117117
openapi_route_models: t.List = []
118118
reflector = app.injector.get(Reflector)
119119
for route in app.routes:
120-
if isinstance(route, Mount) and len(route.routes) > 0:
120+
if (
121+
isinstance(route, ModuleMount)
122+
and len(route.routes) > 0
123+
and route.include_in_schema
124+
):
121125
openapi: t.Dict = dict()
122126
guards = app.get_guards()
127+
123128
openapi_route_models.append(
124129
OpenAPIMountDocumentation(
125130
mount=route, global_guards=guards, **openapi
126131
)
127132
)
128-
continue
129133
elif (
130134
isinstance(route, (RouteOperation, ControllerRouteOperation))
131135
and route.include_in_schema
@@ -203,11 +207,10 @@ def build_document(self, app: App) -> OpenAPI:
203207
if definitions:
204208
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
205209

206-
for mount in mounts:
207-
if isinstance(mount, ModuleMount):
208-
data = mount.get_tag()
209-
if data and mount.include_in_schema:
210-
self._build.setdefault("tags", []).append(data)
210+
for route_model in openapi_route_models:
211+
if isinstance(route_model, OpenAPIMountDocumentation):
212+
data = route_model.get_tag()
213+
self._build.setdefault("tags", []).append(data)
211214
if components:
212215
self._build.setdefault("components", {}).update(components)
213216
return OpenAPI(**self._build)

ellar/openapi/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def setup_swagger_doc(
5454
title: str = "Ellar Swagger Doc",
5555
swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js",
5656
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css",
57-
swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
57+
swagger_favicon_url: str = "https://eadwincode.github.io/ellar/img/Icon.svg",
5858
) -> None:
5959
self._setup_docs(
6060
template_name="swagger",
@@ -72,7 +72,7 @@ def setup_redocs(
7272
path: str = "redoc",
7373
title: str = "Ellar Redoc",
7474
redoc_js_url: str = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js",
75-
redoc_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
75+
redoc_favicon_url: str = "https://eadwincode.github.io/ellar/img/Icon.svg",
7676
with_google_fonts: bool = True,
7777
) -> None:
7878
self._setup_docs(

ellar/openapi/route_doc_models.py

Lines changed: 37 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,13 @@
55
from pydantic import BaseModel
66
from pydantic.fields import ModelField, Undefined
77
from pydantic.schema import field_schema
8-
from starlette.convertors import (
9-
FloatConvertor,
10-
IntegerConvertor,
11-
PathConvertor,
12-
StringConvertor,
13-
UUIDConvertor,
14-
)
158
from starlette.routing import Mount, compile_path
169
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
1710

18-
from ellar.compatible import cached_property
11+
from ellar.compatible import AttributeDict, cached_property
1912
from ellar.constants import GUARDS_KEY, METHODS_WITH_BODY, OPENAPI_KEY, REF_PREFIX
2013
from ellar.core.guard import BaseAuthGuard
14+
from ellar.core.params.args import EndpointArgsModel
2115
from ellar.core.params.params import Body, Param
2216
from ellar.core.params.resolvers import (
2317
BodyParameterResolver,
@@ -33,15 +27,6 @@
3327
from ellar.core.guard import GuardCanActivate
3428

3529

36-
CONVERTOR_TYPES = {
37-
StringConvertor: dict(title="title", type="string"),
38-
PathConvertor: dict(title="title", type="string"),
39-
IntegerConvertor: dict(title="title", type="integer"),
40-
FloatConvertor: dict(title="title", type="number"),
41-
UUIDConvertor: dict(title="title", type="string", format="uuid"),
42-
}
43-
44-
4530
class OpenAPIRoute(ABC):
4631
@abstractmethod
4732
def get_route_models(self) -> t.List[t.Union[ModelField, RouteParameterModelField]]:
@@ -70,29 +55,24 @@ def __init__(
7055
t.Union[t.Type["GuardCanActivate"], "GuardCanActivate"]
7156
] = None,
7257
) -> None:
73-
self.tag = tag
74-
self.description = description
75-
self.external_doc_description = external_doc_description
76-
self.external_doc_url = external_doc_url
58+
meta = mount.get_meta() if isinstance(mount, ModuleMount) else AttributeDict()
59+
self.tag = tag or meta.tag # type: ignore
60+
self.description = description or meta.description # type: ignore
61+
self.external_doc_description = (
62+
external_doc_description or meta.external_doc_description # type: ignore
63+
)
64+
self.external_doc_url = external_doc_url or meta.external_doc_url # type: ignore
7765
self.mount = mount
7866
self.path_regex, self.path_format, self.param_convertors = compile_path(
7967
self.mount.path
8068
)
81-
self.global_route_parameters = []
69+
# if there is some convertor on ModuleMount Object, then we need to convert it to ModelField
70+
self.global_route_parameters: t.List[ModelField] = [
71+
EndpointArgsModel.get_convertor_model_field(name, convertor)
72+
for name, convertor in self.param_convertors.items()
73+
]
8274
self.global_guards = global_guards or []
8375

84-
if self.param_convertors:
85-
for name, converter in self.param_convertors.items():
86-
schema = CONVERTOR_TYPES[converter.__class__]
87-
schema.update(title=name)
88-
parameter = {
89-
"name": name,
90-
"in": "path",
91-
"required": True,
92-
"schema": schema,
93-
}
94-
self.global_route_parameters.append(parameter)
95-
9676
self._routes: t.List["OpenAPIRouteDocumentation"] = self._build_routes()
9777

9878
def get_tag(self) -> t.Dict:
@@ -116,11 +96,13 @@ def _build_routes(self) -> t.List["OpenAPIRouteDocumentation"]:
11696
if isinstance(route, RouteOperation) and route.include_in_schema:
11797
openapi = reflector.get(OPENAPI_KEY, route.endpoint) or dict()
11898
guards = reflector.get(GUARDS_KEY, route.endpoint)
119-
self._routes.append(
99+
100+
_routes.append(
120101
OpenAPIRouteDocumentation(
121102
route=route,
122103
global_route_parameters=self.global_route_parameters,
123104
guards=guards or self.global_guards,
105+
tags=[self.tag],
124106
**openapi,
125107
)
126108
)
@@ -150,17 +132,12 @@ def get_openapi_path(
150132
else self.path_format
151133
)
152134
for openapi_route in self._routes:
153-
path, _security_schemes = openapi_route.get_child_openapi_path(
154-
model_name_map=model_name_map
135+
openapi_route.get_openapi_path(
136+
model_name_map=model_name_map,
137+
paths=paths,
138+
security_schemes=security_schemes,
139+
path_prefix=path_prefix,
155140
)
156-
if path:
157-
route_path = (
158-
normalize_path(f"{path_prefix}/{openapi_route.route.path_format}")
159-
if path_prefix
160-
else openapi_route.route.path_format
161-
)
162-
paths.setdefault(route_path, {}).update(path)
163-
security_schemes.update(_security_schemes)
164141

165142

166143
class OpenAPIRouteDocumentation(OpenAPIRoute):
@@ -173,7 +150,7 @@ def __init__(
173150
description: t.Optional[str] = None,
174151
tags: t.Optional[t.List[str]] = None,
175152
deprecated: t.Optional[bool] = None,
176-
global_route_parameters: t.List[t.Dict] = None,
153+
global_route_parameters: t.List[ModelField] = None,
177154
guards: t.List[t.Union["GuardCanActivate", t.Type["GuardCanActivate"]]] = None,
178155
) -> None:
179156
self.operation_id = operation_id
@@ -198,7 +175,11 @@ def _openapi_models(self) -> t.List[t.Union[ModelField, RouteParameterModelField
198175

199176
@cached_property
200177
def input_fields(self) -> t.List[ModelField]:
201-
_models: t.List[ModelField] = []
178+
omitted_path_parameter_fields = (
179+
self.route.endpoint_parameter_model.get_omitted_prefix()
180+
)
181+
_models: t.List[ModelField] = self.global_route_parameters
182+
202183
for item in self.route.endpoint_parameter_model.get_all_models():
203184
if isinstance(item, BodyParameterResolver):
204185
continue
@@ -210,6 +191,12 @@ def input_fields(self) -> t.List[ModelField]:
210191
if isinstance(item, RouteParameterResolver):
211192
_models.append(item.model_field)
212193

194+
already_existing_parameter_names = [model.name for model in _models]
195+
for omitted_path_parameter_field in omitted_path_parameter_fields:
196+
if omitted_path_parameter_field.name in already_existing_parameter_names:
197+
continue
198+
_models.append(omitted_path_parameter_field)
199+
213200
return _models
214201

215202
@cached_property
@@ -282,7 +269,7 @@ def get_openapi_operation_parameters(
282269
if field_info.deprecated:
283270
parameter["deprecated"] = field_info.deprecated
284271
parameters.append(parameter)
285-
return parameters + self.global_route_parameters
272+
return parameters
286273

287274
def get_openapi_operation_request_body(
288275
self,
@@ -345,14 +332,10 @@ def _get_openapi_path_object(
345332
operation_responses = operation.setdefault("responses", {})
346333
for status, response_model in self.route.response_model.models.items():
347334
operation_responses_status = operation_responses.setdefault(status, {})
348-
operation_responses_status["description"] = getattr(
349-
response_model, "description", ""
350-
)
335+
operation_responses_status["description"] = response_model.description
351336

352337
content = operation_responses_status.setdefault("content", {})
353-
media_type = content.setdefault(
354-
getattr(response_model, "media_type", "text/plain"), {}
355-
)
338+
media_type = content.setdefault(response_model.media_type, {})
356339
media_type.setdefault("schema", {"type": "string"})
357340

358341
model_field = response_model.get_model_field()

0 commit comments

Comments
 (0)