5
5
from pydantic import BaseModel
6
6
from pydantic .fields import ModelField , Undefined
7
7
from pydantic .schema import field_schema
8
- from starlette .convertors import (
9
- FloatConvertor ,
10
- IntegerConvertor ,
11
- PathConvertor ,
12
- StringConvertor ,
13
- UUIDConvertor ,
14
- )
15
8
from starlette .routing import Mount , compile_path
16
9
from starlette .status import HTTP_422_UNPROCESSABLE_ENTITY
17
10
18
- from ellar .compatible import cached_property
11
+ from ellar .compatible import AttributeDict , cached_property
19
12
from ellar .constants import GUARDS_KEY , METHODS_WITH_BODY , OPENAPI_KEY , REF_PREFIX
20
13
from ellar .core .guard import BaseAuthGuard
14
+ from ellar .core .params .args import EndpointArgsModel
21
15
from ellar .core .params .params import Body , Param
22
16
from ellar .core .params .resolvers import (
23
17
BodyParameterResolver ,
33
27
from ellar .core .guard import GuardCanActivate
34
28
35
29
36
- CONVERTOR_TYPES = {
37
- StringConvertor : dict (title = "title" , type = "string" ),
38
- PathConvertor : dict (title = "title" , type = "string" ),
39
- IntegerConvertor : dict (title = "title" , type = "integer" ),
40
- FloatConvertor : dict (title = "title" , type = "number" ),
41
- UUIDConvertor : dict (title = "title" , type = "string" , format = "uuid" ),
42
- }
43
-
44
-
45
30
class OpenAPIRoute (ABC ):
46
31
@abstractmethod
47
32
def get_route_models (self ) -> t .List [t .Union [ModelField , RouteParameterModelField ]]:
@@ -70,29 +55,24 @@ def __init__(
70
55
t .Union [t .Type ["GuardCanActivate" ], "GuardCanActivate" ]
71
56
] = None ,
72
57
) -> None :
73
- self .tag = tag
74
- self .description = description
75
- self .external_doc_description = external_doc_description
76
- self .external_doc_url = external_doc_url
58
+ meta = mount .get_meta () if isinstance (mount , ModuleMount ) else AttributeDict ()
59
+ self .tag = tag or meta .tag # type: ignore
60
+ self .description = description or meta .description # type: ignore
61
+ self .external_doc_description = (
62
+ external_doc_description or meta .external_doc_description # type: ignore
63
+ )
64
+ self .external_doc_url = external_doc_url or meta .external_doc_url # type: ignore
77
65
self .mount = mount
78
66
self .path_regex , self .path_format , self .param_convertors = compile_path (
79
67
self .mount .path
80
68
)
81
- self .global_route_parameters = []
69
+ # if there is some convertor on ModuleMount Object, then we need to convert it to ModelField
70
+ self .global_route_parameters : t .List [ModelField ] = [
71
+ EndpointArgsModel .get_convertor_model_field (name , convertor )
72
+ for name , convertor in self .param_convertors .items ()
73
+ ]
82
74
self .global_guards = global_guards or []
83
75
84
- if self .param_convertors :
85
- for name , converter in self .param_convertors .items ():
86
- schema = CONVERTOR_TYPES [converter .__class__ ]
87
- schema .update (title = name )
88
- parameter = {
89
- "name" : name ,
90
- "in" : "path" ,
91
- "required" : True ,
92
- "schema" : schema ,
93
- }
94
- self .global_route_parameters .append (parameter )
95
-
96
76
self ._routes : t .List ["OpenAPIRouteDocumentation" ] = self ._build_routes ()
97
77
98
78
def get_tag (self ) -> t .Dict :
@@ -116,11 +96,13 @@ def _build_routes(self) -> t.List["OpenAPIRouteDocumentation"]:
116
96
if isinstance (route , RouteOperation ) and route .include_in_schema :
117
97
openapi = reflector .get (OPENAPI_KEY , route .endpoint ) or dict ()
118
98
guards = reflector .get (GUARDS_KEY , route .endpoint )
119
- self ._routes .append (
99
+
100
+ _routes .append (
120
101
OpenAPIRouteDocumentation (
121
102
route = route ,
122
103
global_route_parameters = self .global_route_parameters ,
123
104
guards = guards or self .global_guards ,
105
+ tags = [self .tag ],
124
106
** openapi ,
125
107
)
126
108
)
@@ -150,17 +132,12 @@ def get_openapi_path(
150
132
else self .path_format
151
133
)
152
134
for openapi_route in self ._routes :
153
- path , _security_schemes = openapi_route .get_child_openapi_path (
154
- model_name_map = model_name_map
135
+ openapi_route .get_openapi_path (
136
+ model_name_map = model_name_map ,
137
+ paths = paths ,
138
+ security_schemes = security_schemes ,
139
+ path_prefix = path_prefix ,
155
140
)
156
- if path :
157
- route_path = (
158
- normalize_path (f"{ path_prefix } /{ openapi_route .route .path_format } " )
159
- if path_prefix
160
- else openapi_route .route .path_format
161
- )
162
- paths .setdefault (route_path , {}).update (path )
163
- security_schemes .update (_security_schemes )
164
141
165
142
166
143
class OpenAPIRouteDocumentation (OpenAPIRoute ):
@@ -173,7 +150,7 @@ def __init__(
173
150
description : t .Optional [str ] = None ,
174
151
tags : t .Optional [t .List [str ]] = None ,
175
152
deprecated : t .Optional [bool ] = None ,
176
- global_route_parameters : t .List [t . Dict ] = None ,
153
+ global_route_parameters : t .List [ModelField ] = None ,
177
154
guards : t .List [t .Union ["GuardCanActivate" , t .Type ["GuardCanActivate" ]]] = None ,
178
155
) -> None :
179
156
self .operation_id = operation_id
@@ -198,7 +175,11 @@ def _openapi_models(self) -> t.List[t.Union[ModelField, RouteParameterModelField
198
175
199
176
@cached_property
200
177
def input_fields (self ) -> t .List [ModelField ]:
201
- _models : t .List [ModelField ] = []
178
+ omitted_path_parameter_fields = (
179
+ self .route .endpoint_parameter_model .get_omitted_prefix ()
180
+ )
181
+ _models : t .List [ModelField ] = self .global_route_parameters
182
+
202
183
for item in self .route .endpoint_parameter_model .get_all_models ():
203
184
if isinstance (item , BodyParameterResolver ):
204
185
continue
@@ -210,6 +191,12 @@ def input_fields(self) -> t.List[ModelField]:
210
191
if isinstance (item , RouteParameterResolver ):
211
192
_models .append (item .model_field )
212
193
194
+ already_existing_parameter_names = [model .name for model in _models ]
195
+ for omitted_path_parameter_field in omitted_path_parameter_fields :
196
+ if omitted_path_parameter_field .name in already_existing_parameter_names :
197
+ continue
198
+ _models .append (omitted_path_parameter_field )
199
+
213
200
return _models
214
201
215
202
@cached_property
@@ -282,7 +269,7 @@ def get_openapi_operation_parameters(
282
269
if field_info .deprecated :
283
270
parameter ["deprecated" ] = field_info .deprecated
284
271
parameters .append (parameter )
285
- return parameters + self . global_route_parameters
272
+ return parameters
286
273
287
274
def get_openapi_operation_request_body (
288
275
self ,
@@ -345,14 +332,10 @@ def _get_openapi_path_object(
345
332
operation_responses = operation .setdefault ("responses" , {})
346
333
for status , response_model in self .route .response_model .models .items ():
347
334
operation_responses_status = operation_responses .setdefault (status , {})
348
- operation_responses_status ["description" ] = getattr (
349
- response_model , "description" , ""
350
- )
335
+ operation_responses_status ["description" ] = response_model .description
351
336
352
337
content = operation_responses_status .setdefault ("content" , {})
353
- media_type = content .setdefault (
354
- getattr (response_model , "media_type" , "text/plain" ), {}
355
- )
338
+ media_type = content .setdefault (response_model .media_type , {})
356
339
media_type .setdefault ("schema" , {"type" : "string" })
357
340
358
341
model_field = response_model .get_model_field ()
0 commit comments