Skip to content

Commit 4e3a3a3

Browse files
committed
Refactored route function parameters inject functions
1 parent 741b1e8 commit 4e3a3a3

22 files changed

+806
-654
lines changed

ellar/common/routing/params.py

Lines changed: 17 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import typing as t
22

3-
from pydantic.error_wrappers import ErrorWrapper
43
from pydantic.fields import Undefined
54
from starlette.responses import Response
65

76
from ellar.core.connection import HTTPConnection, Request, WebSocket
87
from ellar.core.context import IExecutionContext
98
from ellar.core.params import params
10-
from ellar.core.params.resolvers import (
11-
BaseRequestRouteParameterResolver,
12-
NonFieldRouteParameterResolver,
13-
ParameterInjectable,
9+
from ellar.core.params.resolvers.non_parameter import (
10+
ConnectionParam,
11+
ExecutionContextParameter,
12+
HostRequestParam,
13+
ProviderParameterInjector,
14+
RequestParameter,
15+
ResponseRequestParam,
16+
SessionRequestParam,
17+
WebSocketParameter,
1418
)
1519
from ellar.types import T
1620

@@ -341,108 +345,44 @@ def WsBody(
341345
)
342346

343347

344-
class _RequestParameter(NonFieldRouteParameterResolver):
345-
async def resolve(
346-
self, ctx: IExecutionContext, **kwargs: t.Any
347-
) -> t.Tuple[t.Dict, t.List]:
348-
try:
349-
request = ctx.switch_to_http_connection().get_request()
350-
return {self.parameter_name: request}, []
351-
except Exception as ex:
352-
return {}, [ErrorWrapper(ex, loc=self.parameter_name or "request")]
353-
354-
355-
class _WebSocketParameter(NonFieldRouteParameterResolver):
356-
async def resolve(
357-
self, ctx: IExecutionContext, **kwargs: t.Any
358-
) -> t.Tuple[t.Dict, t.List]:
359-
try:
360-
websocket = ctx.switch_to_websocket().get_client()
361-
return {self.parameter_name: websocket}, []
362-
except Exception as ex:
363-
return {}, [ErrorWrapper(ex, loc=self.parameter_name or "websocket")]
364-
365-
366-
class _ExecutionContextParameter(NonFieldRouteParameterResolver):
367-
async def resolve(
368-
self, ctx: IExecutionContext, **kwargs: t.Any
369-
) -> t.Tuple[t.Dict, t.List]:
370-
return {self.parameter_name: ctx}, []
371-
372-
373-
class _HostRequestParam(BaseRequestRouteParameterResolver):
374-
lookup_connection_field = None
375-
376-
async def get_value(self, ctx: IExecutionContext) -> t.Any:
377-
connection = ctx.switch_to_http_connection().get_client()
378-
if connection.client:
379-
return connection.client.host
380-
381-
382-
class _SessionRequestParam(BaseRequestRouteParameterResolver):
383-
lookup_connection_field = "session"
384-
385-
386-
class _ConnectionParam(NonFieldRouteParameterResolver):
387-
async def resolve(
388-
self, ctx: IExecutionContext, **kwargs: t.Any
389-
) -> t.Tuple[t.Dict, t.List]:
390-
try:
391-
connection = ctx.switch_to_http_connection().get_client()
392-
return {self.parameter_name: connection}, []
393-
except Exception as ex:
394-
return {}, [ErrorWrapper(ex, loc=self.parameter_name or "connection")]
395-
396-
397-
class _ResponseRequestParam(NonFieldRouteParameterResolver):
398-
async def resolve(
399-
self, ctx: IExecutionContext, **kwargs: t.Any
400-
) -> t.Tuple[t.Dict, t.List]:
401-
try:
402-
response = ctx.switch_to_http_connection().get_response()
403-
return {self.parameter_name: response}, []
404-
except Exception as ex:
405-
return {}, [ErrorWrapper(ex, loc=self.parameter_name or "response")]
406-
407-
408348
def Http() -> HTTPConnection:
409349
"""
410350
Route Function Parameter for retrieving Current Request Instance
411351
:return: Request
412352
"""
413-
return t.cast(Request, _ConnectionParam())
353+
return t.cast(Request, ConnectionParam())
414354

415355

416356
def Req() -> Request:
417357
"""
418358
Route Function Parameter for retrieving Current Request Instance
419359
:return: Request
420360
"""
421-
return t.cast(Request, _RequestParameter())
361+
return t.cast(Request, RequestParameter())
422362

423363

424364
def Ws() -> WebSocket:
425365
"""
426366
Route Function Parameter for retrieving Current WebSocket Instance
427367
:return: WebSocket
428368
"""
429-
return t.cast(WebSocket, _WebSocketParameter())
369+
return t.cast(WebSocket, WebSocketParameter())
430370

431371

432372
def Context() -> IExecutionContext:
433373
"""
434374
Route Function Parameter for retrieving Current IExecutionContext Instance
435375
:return: IExecutionContext
436376
"""
437-
return t.cast(IExecutionContext, _ExecutionContextParameter())
377+
return t.cast(IExecutionContext, ExecutionContextParameter())
438378

439379

440380
def Provide(service: t.Optional[t.Type[T]] = None) -> T:
441381
"""
442382
Route Function Parameter for resolving registered Provider
443383
:return: T
444384
"""
445-
return t.cast(T, ParameterInjectable(service))
385+
return t.cast(T, ProviderParameterInjector(service))
446386

447387

448388
def Session() -> t.Dict:
@@ -451,20 +391,20 @@ def Session() -> t.Dict:
451391
Ensure SessionMiddleware is registered to application middlewares
452392
:return: Dict
453393
"""
454-
return t.cast(t.Dict, _SessionRequestParam())
394+
return t.cast(t.Dict, SessionRequestParam())
455395

456396

457397
def Host() -> str:
458398
"""
459399
Route Function Parameter for resolving registered `HTTPConnection.client.host`
460400
:return: str
461401
"""
462-
return t.cast(str, _HostRequestParam())
402+
return t.cast(str, HostRequestParam())
463403

464404

465405
def Res() -> Response:
466406
"""
467407
Route Function Parameter for resolving registered Response
468408
:return: Response
469409
"""
470-
return t.cast(Response, _ResponseRequestParam())
410+
return t.cast(Response, ResponseRequestParam())

ellar/core/params/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
)
77
from .params import Body, Cookie, File, Form, Header, Param, ParamTypes, Path, Query
88
from .resolvers import (
9-
BaseRequestRouteParameterResolver,
10-
BaseRouteParameterResolver,
11-
NonFieldRouteParameterResolver,
9+
BaseConnectionParameterResolver,
10+
IRouteParameterResolver,
11+
NonParameterResolver,
1212
)
1313

1414
__all__ = [
1515
"WebsocketEndpointArgsModel",
1616
"RequestEndpointArgsModel",
1717
"ExtraEndpointArg",
1818
"EndpointArgsModel",
19-
"NonFieldRouteParameterResolver",
20-
"BaseRequestRouteParameterResolver",
21-
"BaseRouteParameterResolver",
19+
"NonParameterResolver",
20+
"BaseConnectionParameterResolver",
21+
"IRouteParameterResolver",
2222
"Body",
2323
"Cookie",
2424
"File",

ellar/core/params/args/base.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,15 @@
2626
from ..helpers import is_scalar_field, is_scalar_sequence_field
2727
from ..resolvers import (
2828
BaseRouteParameterResolver,
29-
NonFieldRouteParameterResolver,
30-
ParameterInjectable,
31-
RouteParameterResolver,
29+
IRouteParameterResolver,
30+
NonParameterResolver,
31+
)
32+
from ..resolvers.non_parameter import (
33+
ConnectionParam,
34+
ExecutionContextParameter,
35+
RequestParameter,
36+
ResponseRequestParam,
37+
WebSocketParameter,
3238
)
3339
from .extra_args import ExtraEndpointArg
3440
from .factory import get_parameter_field
@@ -39,6 +45,24 @@
3945
QueryHeaderResolverGenerator,
4046
)
4147

48+
DEFAULT_RESOLVERS: t.Dict[t.Type, t.Type[NonParameterResolver]] = {
49+
Request: RequestParameter,
50+
StarletteRequest: RequestParameter,
51+
WebSocket: WebSocketParameter,
52+
StarletteWebSocket: WebSocketParameter,
53+
Response: ResponseRequestParam,
54+
HTTPConnection: ConnectionParam,
55+
StarletteHTTPConnection: ConnectionParam,
56+
IExecutionContext: ExecutionContextParameter,
57+
ExecutionContext: ExecutionContextParameter,
58+
}
59+
60+
61+
def add_default_resolver(
62+
type_identifier: t.Type, resolver_type: t.Type[NonParameterResolver]
63+
) -> None:
64+
DEFAULT_RESOLVERS.update({type_identifier: resolver_type})
65+
4266

4367
class EndpointArgsModel:
4468
_bulk_resolvers_generators = {
@@ -72,12 +96,14 @@ def __init__(
7296
self.path = path
7397
self.param_converters = param_converters
7498
self._computation_models: t.DefaultDict[
75-
str, t.List[BaseRouteParameterResolver]
99+
str, t.List[IRouteParameterResolver]
76100
] = defaultdict(list)
77101
self.path_param_names = self.get_path_param_names(path)
78102
self.endpoint_signature = self.get_typed_signature(endpoint)
79-
self.body_resolver: t.Optional[t.Union[t.Any, RouteParameterResolver]] = None
80-
self._route_models: t.List[BaseRouteParameterResolver] = []
103+
self.body_resolver: t.Optional[
104+
t.Union[t.Any, BaseRouteParameterResolver]
105+
] = None
106+
self._route_models: t.List[IRouteParameterResolver] = []
81107
self._extra_endpoint_args: t.List[ExtraEndpointArg] = (
82108
list(extra_endpoint_args) if extra_endpoint_args else []
83109
)
@@ -89,14 +115,14 @@ def get_resolver_generator(
89115
str(type(param)), BulkArgsResolverGenerator
90116
)
91117

92-
def get_route_models(self) -> t.List[BaseRouteParameterResolver]:
118+
def get_route_models(self) -> t.List[IRouteParameterResolver]:
93119
"""
94120
Returns all computed endpoint resolvers required for function execution
95121
:return: List[BaseRouteParameterResolver]
96122
"""
97123
return self._route_models
98124

99-
def get_all_models(self) -> t.List[BaseRouteParameterResolver]:
125+
def get_all_models(self) -> t.List[IRouteParameterResolver]:
100126
"""
101127
Returns all computed endpoint resolvers + omitted resolvers
102128
:return: List[BaseRouteParameterResolver]
@@ -150,7 +176,7 @@ def build_model(self) -> None:
150176
+ self._computation_models[params.Path.in_.value]
151177
+ self._computation_models[params.Query.in_.value]
152178
+ self._computation_models[params.Cookie.in_.value]
153-
+ self._computation_models[NonFieldRouteParameterResolver.in_]
179+
+ self._computation_models[NonParameterResolver.in_]
154180
)
155181

156182
def compute_route_parameter_list(
@@ -236,7 +262,7 @@ def _add_non_field_param_to_dependency(
236262
param_annotation: t.Optional[t.Type],
237263
key: str = None,
238264
) -> t.Optional[bool]:
239-
if isinstance(param_default, NonFieldRouteParameterResolver):
265+
if isinstance(param_default, NonParameterResolver):
240266
model = param_default(param_name, param_annotation) # type:ignore
241267
self._computation_models[key or model.in_].append(model)
242268
return True
@@ -349,22 +375,9 @@ def build_body_field(self) -> None:
349375
def _add_non_pydantic_field_to_dependency(
350376
self, param_name: str, param_default: t.Any, param_annotation: t.Any
351377
) -> bool:
352-
if (
353-
param_annotation
354-
in (
355-
Request,
356-
WebSocket,
357-
HTTPConnection,
358-
Response,
359-
StarletteRequest,
360-
StarletteHTTPConnection,
361-
IExecutionContext,
362-
ExecutionContext,
363-
StarletteWebSocket,
364-
)
365-
and param_default == inspect.Parameter.empty
366-
):
367-
_inject = ParameterInjectable()(param_name, param_annotation)
378+
resolver_class = DEFAULT_RESOLVERS.get(param_annotation)
379+
if resolver_class and param_default == inspect.Parameter.empty:
380+
_inject = resolver_class()(param_name, param_annotation)
368381
self._computation_models[_inject.in_].append(_inject)
369382
return True
370383
return False

ellar/core/params/args/extra_args.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,45 @@
55

66

77
class ExtraEndpointArg(t.Generic[T]):
8+
"""
9+
Add more route function parameters programmatically.
10+
For example:
11+
lets add `limit` and `offset` to a route function.
12+
13+
def limit(func):
14+
limit_args = ExtraEndpointArg(name='limit', annotation=int, default_value=10)
15+
16+
offset_args = ExtraEndpointArg(name='offset', annotation=int, default_value=0)
17+
18+
extra_args = [limit_args, offset_args]
19+
20+
set_metadata(EXTRA_ROUTE_ARGS_KEY, extra_args)(func)
21+
22+
@wraps(func)
23+
def _wrapper(*args, **kwargs):
24+
# RESOLVING EXTRA ARGS
25+
26+
resolved_limit_args = limit_args.resolve(kwargs)
27+
28+
resolved_offset_args = offset_args.resolve(kwargs)
29+
30+
response = func(*args, **kwargs)
31+
32+
response = response[resolved_offset_args: resolved_limit_args]
33+
34+
return response
35+
36+
return _wrapper
37+
38+
router = ModuleRouter('/testing')
39+
40+
@router.get('/list')
41+
@limit
42+
def route(request: Request):
43+
44+
return [i=1 for i in range(40)]
45+
"""
46+
847
__slots__ = ("name", "annotation", "default")
948

1049
empty = inspect.Parameter.empty

ellar/core/params/args/request_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ellar.helper.modelfield import create_model_field
99

1010
from .. import params
11-
from ..resolvers import RouteParameterResolver
11+
from ..resolvers import BaseRouteParameterResolver
1212
from .base import EndpointArgsModel
1313
from .extra_args import ExtraEndpointArg
1414

@@ -56,7 +56,7 @@ def build_body_field(self) -> None:
5656
elif body_resolvers:
5757
# if body_resolvers is more than one, we create a bulk_body_resolver instead
5858
_body_resolvers_model_fields = (
59-
t.cast(RouteParameterResolver, item).model_field
59+
t.cast(BaseRouteParameterResolver, item).model_field
6060
for item in body_resolvers
6161
)
6262
model_name = "body_" + self.operation_unique_id

0 commit comments

Comments
 (0)