Skip to content

Commit da6f8c1

Browse files
committed
Moved route function parameters to type annotation and still support assignment pattern
1 parent 2e9f8f2 commit da6f8c1

37 files changed

+429
-330
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ fmt format:clean ## Run code formatters
2626
black ellar tests examples
2727
ruff check --fix ellar tests examples
2828

29-
test: ## Run tests
29+
test:clean ## Run tests
3030
pytest
3131

3232
test-cov: ## Run tests with coverage

ellar/common/__init__.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,13 @@
6060
)
6161
from .params.decorators import (
6262
Body,
63-
Context,
6463
Cookie,
6564
File,
6665
Form,
6766
Header,
68-
Host,
69-
Http,
67+
Inject,
7068
Path,
71-
Provide,
7269
Query,
73-
Req,
74-
Res,
75-
Session,
76-
Ws,
7770
WsBody,
7871
)
7972
from .params.params import ParamFieldInfo as Param
@@ -130,6 +123,7 @@
130123
"ModuleRouter",
131124
"render",
132125
"Module",
126+
"Inject",
133127
"UseGuards",
134128
"Param",
135129
"ParamTypes",
@@ -154,19 +148,11 @@
154148
"Path",
155149
"Query",
156150
"WsBody",
157-
"Context",
158-
"Provide",
159-
"Req",
160-
"Ws",
161151
"middleware",
162152
"exception_handler",
163153
"serializer_filter",
164154
"template_filter",
165155
"template_global",
166-
"Res",
167-
"Session",
168-
"Host",
169-
"Http",
170156
"UploadFile",
171157
"file",
172158
"extra_args",

ellar/common/params/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
RequestEndpointArgsModel,
55
WebsocketEndpointArgsModel,
66
)
7+
from .decorators import add_default_resolver
78
from .resolvers import (
89
BaseConnectionParameterResolver,
910
IRouteParameterResolver,
1011
NonParameterResolver,
1112
)
1213

1314
__all__ = [
15+
"add_default_resolver",
1416
"WebsocketEndpointArgsModel",
1517
"RequestEndpointArgsModel",
1618
"ExtraEndpointArg",

ellar/common/params/args/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import EndpointArgsModel, add_default_resolver
1+
from .base import EndpointArgsModel
22
from .extra_args import ExtraEndpointArg
33
from .request_model import RequestEndpointArgsModel
44
from .websocket_model import WebsocketEndpointArgsModel
@@ -8,5 +8,4 @@
88
"RequestEndpointArgsModel",
99
"WebsocketEndpointArgsModel",
1010
"EndpointArgsModel",
11-
"add_default_resolver",
1211
]

ellar/common/params/args/base.py

Lines changed: 48 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,8 @@
1515
from pydantic.fields import FieldInfo, ModelField
1616
from pydantic.typing import ForwardRef, evaluate_forwardref # type:ignore
1717
from pydantic.utils import Representation, lenient_issubclass
18-
from starlette.background import BackgroundTasks
1918
from starlette.convertors import Convertor
20-
from starlette.requests import (
21-
HTTPConnection as StarletteHTTPConnection,
22-
)
23-
from starlette.requests import (
24-
Request as StarletteRequest,
25-
)
26-
from starlette.responses import Response
27-
from starlette.websockets import WebSocket as StarletteWebSocket
19+
from typing_extensions import Annotated, get_args, get_origin
2820

2921
from .. import params
3022
from ..helpers import is_scalar_field, is_scalar_sequence_field
@@ -33,14 +25,6 @@
3325
IRouteParameterResolver,
3426
NonParameterResolver,
3527
)
36-
from ..resolvers.non_parameter import (
37-
BackgroundTasksParameter,
38-
ConnectionParam,
39-
ExecutionContextParameter,
40-
RequestParameter,
41-
ResponseRequestParam,
42-
WebSocketParameter,
43-
)
4428
from .extra_args import ExtraEndpointArg
4529
from .factory import get_parameter_field
4630
from .resolver_generators import (
@@ -50,21 +34,6 @@
5034
QueryHeaderResolverGenerator,
5135
)
5236

53-
DEFAULT_RESOLVERS: t.Dict[t.Type, t.Type[NonParameterResolver]] = {
54-
StarletteRequest: RequestParameter,
55-
StarletteWebSocket: WebSocketParameter,
56-
Response: ResponseRequestParam,
57-
StarletteHTTPConnection: ConnectionParam,
58-
IExecutionContext: ExecutionContextParameter,
59-
BackgroundTasks: BackgroundTasksParameter,
60-
}
61-
62-
63-
def add_default_resolver(
64-
type_identifier: t.Type, resolver_type: t.Type[NonParameterResolver]
65-
) -> None: # pragma: no cover
66-
DEFAULT_RESOLVERS.update({type_identifier: resolver_type})
67-
6837

6938
class EndpointArgsModel:
7039
_bulk_resolvers_generators = {
@@ -185,27 +154,30 @@ def compute_route_parameter_list(
185154
self, body_field_class: t.Type[FieldInfo] = params.BodyFieldInfo
186155
) -> None:
187156
for param_name, param in self.endpoint_signature.parameters.items():
157+
param_name, param_default, param_annotation, param_kind = (
158+
param.name,
159+
param.default,
160+
param.annotation,
161+
param.kind,
162+
)
163+
param_annotation, param_default = self._get_annotation_type_and_default(
164+
param_annotation, param_default
165+
)
166+
188167
if (
189-
param.kind == param.VAR_KEYWORD
190-
or param.kind == param.VAR_POSITIONAL
168+
param_kind == param.VAR_KEYWORD
169+
or param_kind == param.VAR_POSITIONAL
191170
or (
192-
param.name == "self" and param.annotation == inspect.Parameter.empty
171+
param_name == "self" and param_annotation == inspect.Parameter.empty
193172
)
194173
):
195174
# Skipping **kwargs, *args, self
196175
continue
197176

198-
if self._add_non_pydantic_field_to_dependency(
199-
param_name=param.name,
200-
param_default=param.default,
201-
param_annotation=param.annotation,
202-
):
203-
continue
204-
205177
if self._add_non_field_param_to_dependency(
206-
param_name=param.name,
207-
param_default=param.default,
208-
param_annotation=param.annotation,
178+
param_name=param_name,
179+
param_default=param_default,
180+
param_annotation=param_annotation,
209181
):
210182
continue
211183

@@ -215,8 +187,8 @@ def compute_route_parameter_list(
215187
else:
216188
ignore_default = True
217189
param_field = get_parameter_field(
218-
param_default=param.default,
219-
param_annotation=param.annotation,
190+
param_default=param_default,
191+
param_annotation=param_annotation,
220192
param_name=param_name,
221193
default_field_info=params.PathFieldInfo,
222194
ignore_default=ignore_default,
@@ -233,8 +205,8 @@ def compute_route_parameter_list(
233205
else params.QueryFieldInfo,
234206
)
235207
param_field = get_parameter_field(
236-
param_default=param.default,
237-
param_annotation=param.annotation,
208+
param_default=param_default,
209+
param_annotation=param_annotation,
238210
default_field_info=default_field_info,
239211
param_name=param_name,
240212
body_field_class=body_field_class,
@@ -332,14 +304,38 @@ async def resolve_dependencies(
332304
def compute_extra_route_args(self) -> None:
333305
self._add_extra_route_args(*self._extra_endpoint_args)
334306

307+
def _get_annotation_type_and_default(
308+
self, param_annotation: t.Any, default_param_default: t.Any
309+
) -> t.Tuple:
310+
if get_origin(param_annotation) is Annotated:
311+
annotated_args = get_args(param_annotation)
312+
if len(annotated_args) == 2:
313+
return annotated_args
314+
else:
315+
raise ImproperConfiguration(
316+
f"Cannot specify multiple `Annotated` Ellar arguments for {annotated_args!r}"
317+
)
318+
319+
return param_annotation, default_param_default
320+
335321
def _add_extra_route_args(
336322
self, *extra_operation_args: ExtraEndpointArg, key: t.Optional[str] = None
337323
) -> None:
338324
for param in extra_operation_args:
325+
param_name, param_default, param_annotation = (
326+
param.name,
327+
param.default,
328+
param.annotation,
329+
)
330+
331+
param_annotation, param_default = self._get_annotation_type_and_default(
332+
param_annotation, param_default
333+
)
334+
339335
if self._add_non_field_param_to_dependency(
340-
param_name=param.name,
341-
param_default=param.default,
342-
param_annotation=param.annotation,
336+
param_name=param_name,
337+
param_default=param_default,
338+
param_annotation=param_annotation,
343339
key=key,
344340
):
345341
continue
@@ -379,13 +375,3 @@ def __copy__(
379375

380376
def build_body_field(self) -> None: # pragma: no cover
381377
raise NotImplementedError
382-
383-
def _add_non_pydantic_field_to_dependency(
384-
self, param_name: str, param_default: t.Any, param_annotation: t.Any
385-
) -> bool:
386-
resolver_class = DEFAULT_RESOLVERS.get(param_annotation)
387-
if resolver_class and param_default == inspect.Parameter.empty:
388-
_inject = resolver_class()(param_name, param_annotation)
389-
self._computation_models[_inject.in_].append(_inject)
390-
return True
391-
return False
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import typing as t
2+
3+
from typing_extensions import Annotated
4+
5+
from . import models as param_functions
6+
from .inject import InjectShortcut, add_default_resolver
7+
8+
__all__ = [
9+
"add_default_resolver",
10+
"Body",
11+
"Cookie",
12+
"File",
13+
"Form",
14+
"Header",
15+
"Path",
16+
"Query",
17+
"WsBody",
18+
"Inject",
19+
]
20+
21+
22+
class _ParamShortcut:
23+
def __init__(self, base_func: t.Callable) -> None:
24+
self._base_func = base_func
25+
26+
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
27+
return self._base_func(*args, **kwargs)
28+
29+
def __getitem__(self, args: t.Any) -> t.Any:
30+
if isinstance(args, tuple):
31+
return Annotated[args[0], self._base_func(**args[1])]
32+
return Annotated[args, self._base_func()]
33+
34+
@classmethod
35+
def P(
36+
cls,
37+
default: t.Any = ...,
38+
*,
39+
alias: t.Optional[str] = None,
40+
title: t.Optional[str] = None,
41+
description: t.Optional[str] = None,
42+
gt: t.Optional[float] = None,
43+
ge: t.Optional[float] = None,
44+
lt: t.Optional[float] = None,
45+
le: t.Optional[float] = None,
46+
min_length: t.Optional[int] = None,
47+
max_length: t.Optional[int] = None,
48+
regex: t.Optional[str] = None,
49+
example: t.Any = None,
50+
examples: t.Optional[t.Dict[str, t.Any]] = None,
51+
deprecated: t.Optional[bool] = None,
52+
include_in_schema: bool = True,
53+
**extra: t.Any,
54+
) -> t.Dict[str, t.Any]:
55+
"""Arguments for Body, Query, Header, Cookie, etc."""
56+
return dict(
57+
default=default,
58+
alias=alias,
59+
title=title,
60+
description=description,
61+
gt=gt,
62+
ge=ge,
63+
lt=lt,
64+
le=le,
65+
min_length=min_length,
66+
max_length=max_length,
67+
regex=regex,
68+
example=example,
69+
examples=examples,
70+
deprecated=deprecated,
71+
include_in_schema=include_in_schema,
72+
**extra,
73+
)
74+
75+
76+
if t.TYPE_CHECKING: # pragma: nocover
77+
# mypy cheats
78+
T = t.TypeVar("T")
79+
Body = Annotated[T, param_functions.Body()]
80+
Cookie = Annotated[T, param_functions.Cookie()]
81+
File = Annotated[T, param_functions.File()]
82+
Form = Annotated[T, param_functions.Form()]
83+
Header = Annotated[T, param_functions.Header()]
84+
Path = Annotated[T, param_functions.Path()]
85+
Query = Annotated[T, param_functions.Query()]
86+
WsBody = Annotated[T, param_functions.WsBody()]
87+
Inject = Annotated[T, t.Any]
88+
89+
else:
90+
Body = _ParamShortcut(param_functions.Body)
91+
Cookie = _ParamShortcut(param_functions.Cookie)
92+
File = _ParamShortcut(param_functions.File)
93+
Form = _ParamShortcut(param_functions.Form)
94+
Header = _ParamShortcut(param_functions.Header)
95+
Path = _ParamShortcut(param_functions.Path)
96+
Query = _ParamShortcut(param_functions.Query)
97+
WsBody = _ParamShortcut(param_functions.WsBody)
98+
Inject = InjectShortcut

0 commit comments

Comments
 (0)