Skip to content

Commit fa4d207

Browse files
committed
Added more test
1 parent 75a55ec commit fa4d207

File tree

18 files changed

+134
-73
lines changed

18 files changed

+134
-73
lines changed

ellar/common/commands/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
hidden: bool = Default(False),
5252
deprecated: bool = Default(False),
5353
add_completion: bool = True,
54-
) -> None:
54+
) -> None: # pragma: no cover
5555
assert name is not None and name != "", "Typer name is required"
5656
super().__init__(
5757
name=name,

ellar/common/commands/decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def command(
2020
no_args_is_help: bool = False,
2121
hidden: bool = False,
2222
deprecated: bool = False,
23-
) -> t.Callable[[CommandFunctionType], CommandFunctionType]:
23+
) -> t.Callable[[CommandFunctionType], CommandFunctionType]: # pragma: no cover
2424
"""
2525
========= FUNCTION DECORATOR ==============
2626

ellar/common/routing/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def get_controller_type(self) -> t.Type:
7171
:return: a type that wraps the operation
7272
"""
7373
if not self._controller_type:
74-
_control_type = reflect.get_metadata(CONTROLLER_CLASS_KEY, self.endpoint)
75-
if _control_type is None:
74+
_controller_type = reflect.get_metadata(CONTROLLER_CLASS_KEY, self.endpoint)
75+
if _controller_type is None or not isinstance(_controller_type, type):
7676
raise Exception("Operation must have a single control type.")
77-
self._controller_type = t.cast(t.Type, _control_type)
77+
self._controller_type = t.cast(t.Type, _controller_type)
7878

7979
return self._controller_type
8080

ellar/common/routing/controller/route.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from starlette.concurrency import run_in_threadpool
44

5-
from ellar.common.exceptions import RequestValidationError
65
from ellar.common.interfaces import IExecutionContext
76
from ellar.common.routing.route import RouteOperation
87

@@ -12,18 +11,9 @@
1211
class ControllerRouteOperation(ControllerRouteOperationBase, RouteOperation):
1312
methods: t.Set[str]
1413

15-
async def handle_request(self, context: IExecutionContext) -> t.Any:
14+
async def run(self, context: IExecutionContext, **kwargs: t.Any) -> t.Any:
1615
controller_instance = self._get_controller_instance(ctx=context)
17-
18-
func_kwargs, errors = await self.endpoint_parameter_model.resolve_dependencies(
19-
ctx=context
20-
)
21-
if errors:
22-
raise RequestValidationError(errors)
23-
2416
if self._is_coroutine:
25-
return await self.endpoint(controller_instance, **func_kwargs)
17+
return await self.endpoint(controller_instance, **kwargs)
2618
else:
27-
return await run_in_threadpool(
28-
self.endpoint, controller_instance, **func_kwargs
29-
)
19+
return await run_in_threadpool(self.endpoint, controller_instance, **kwargs)
Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
import typing as t
22

3-
from starlette.status import WS_1008_POLICY_VIOLATION
4-
from starlette.websockets import WebSocketState
5-
6-
from ellar.common.exceptions import WebSocketRequestValidationError
73
from ellar.common.interfaces import IExecutionContext
84

95
from ...websocket import WebsocketRouteOperation
@@ -20,22 +16,8 @@ class ControllerWebsocketRouteOperation(
2016
def get_websocket_handler(cls) -> t.Type[ControllerWebSocketExtraHandler]:
2117
return ControllerWebSocketExtraHandler
2218

23-
async def handle_request(self, context: IExecutionContext) -> t.Any:
19+
async def run(self, context: IExecutionContext, **kwargs: t.Any) -> t.Any:
2420
controller_instance = self._get_controller_instance(ctx=context)
25-
func_kwargs, errors = await self.endpoint_parameter_model.resolve_dependencies(
26-
ctx=context
27-
)
28-
if errors:
29-
websocket = context.switch_to_websocket().get_client()
30-
exc = WebSocketRequestValidationError(errors)
31-
if websocket.client_state == WebSocketState.CONNECTING:
32-
await websocket.accept()
33-
await websocket.send_json(
34-
dict(code=WS_1008_POLICY_VIOLATION, errors=exc.errors())
35-
)
36-
await websocket.close(code=WS_1008_POLICY_VIOLATION)
37-
raise exc
38-
3921
if self._use_extra_handler:
4022
ws_extra_handler_type = (
4123
self._extra_handler_type or self.get_websocket_handler()
@@ -45,6 +27,6 @@ async def handle_request(self, context: IExecutionContext) -> t.Any:
4527
controller_instance=controller_instance,
4628
**self._handlers_kwargs,
4729
)
48-
return await ws_extra_handler.dispatch(context=context, **func_kwargs)
30+
return await ws_extra_handler.dispatch(context=context, **kwargs)
4931
else:
50-
return await self.endpoint(controller_instance, **func_kwargs)
32+
return await self.endpoint(controller_instance, **kwargs)

ellar/common/routing/route.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _load_model(self) -> None:
6969
extra_route_args: t.Union[t.List["ExtraEndpointArg"], "ExtraEndpointArg"] = (
7070
reflect.get_metadata(EXTRA_ROUTE_ARGS_KEY, self.endpoint) or []
7171
)
72-
if not isinstance(extra_route_args, list):
72+
if not isinstance(extra_route_args, list): # pragma: no cover
7373
extra_route_args = [extra_route_args]
7474

7575
if self.endpoint_parameter_model is NOT_SET:
@@ -106,16 +106,20 @@ def get_operation_unique_id(
106106
name=self.name, path=self.path_format, methods=_methods
107107
)
108108

109+
async def run(self, context: IExecutionContext, **kwargs: t.Any) -> t.Any:
110+
if self._is_coroutine:
111+
return await self.endpoint(**kwargs)
112+
else:
113+
return await run_in_threadpool(self.endpoint, **kwargs)
114+
109115
async def handle_request(self, context: IExecutionContext) -> t.Any:
110116
func_kwargs, errors = await self.endpoint_parameter_model.resolve_dependencies(
111117
ctx=context
112118
)
113119
if errors:
114120
raise RequestValidationError(errors)
115-
if self._is_coroutine:
116-
return await self.endpoint(**func_kwargs)
117-
else:
118-
return await run_in_threadpool(self.endpoint, **func_kwargs)
121+
122+
return await self.run(context, **func_kwargs)
119123

120124
async def handle_response(
121125
self, context: IExecutionContext, response_obj: t.Any

ellar/common/routing/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ def validate_methods(cls, value: t.Any) -> t.List[str]:
5656

5757
@validator("endpoint")
5858
def validate_endpoint(cls, value: t.Any) -> t.Any:
59-
if not callable(value):
59+
if not callable(value): # pragma: no cover
6060
raise ValueError("An endpoint must be a callable")
6161
return value
6262

6363
@root_validator
6464
def validate_root(cls, values: t.Any) -> t.Any:
65-
if "response" not in values:
65+
if "response" not in values: # pragma: no cover
6666
raise ValueError(
6767
"Expected "
6868
"IResponseModel | Dict[int, Any | Type[BaseModel] | "

ellar/common/routing/websocket/route.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ def add_websocket_handler(self, handler_name: str, handler: t.Callable) -> None:
8484
)
8585
self._handlers_kwargs.update({handler_name: handler})
8686

87+
async def run(self, context: IExecutionContext, **kwargs: t.Any) -> t.Any:
88+
if self._use_extra_handler:
89+
ws_extra_handler_type = (
90+
self._extra_handler_type or self.get_websocket_handler()
91+
)
92+
ws_extra_handler = ws_extra_handler_type(
93+
route_parameter_model=self.endpoint_parameter_model,
94+
**self._handlers_kwargs,
95+
)
96+
return await ws_extra_handler.dispatch(context=context, **kwargs)
97+
else:
98+
return await self.endpoint(**kwargs)
99+
87100
async def handle_request(self, context: IExecutionContext) -> t.Any:
88101
func_kwargs, errors = await self.endpoint_parameter_model.resolve_dependencies(
89102
ctx=context
@@ -99,17 +112,7 @@ async def handle_request(self, context: IExecutionContext) -> t.Any:
99112
await websocket.close(code=WS_1008_POLICY_VIOLATION)
100113
raise exc
101114

102-
if self._use_extra_handler:
103-
ws_extra_handler_type = (
104-
self._extra_handler_type or self.get_websocket_handler()
105-
)
106-
ws_extra_handler = ws_extra_handler_type(
107-
route_parameter_model=self.endpoint_parameter_model,
108-
**self._handlers_kwargs,
109-
)
110-
return await ws_extra_handler.dispatch(context=context, **func_kwargs)
111-
else:
112-
return await self.endpoint(**func_kwargs)
115+
return await self.run(context, **func_kwargs)
113116

114117
async def handle_response(
115118
self, context: IExecutionContext, response_obj: t.Any

ellar/common/serializer/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
__pydantic_root__ = "__root__"
1313

1414

15+
@t.no_type_check
1516
def get_dataclass_pydantic_model(
1617
dataclass_type: t.Type,
1718
) -> t.Optional[t.Type[BaseModel]]:
1819

1920
if hasattr(dataclass_type, __pydantic_model__):
2021
return t.cast(t.Type[BaseModel], dataclass_type.__dict__[__pydantic_model__])
21-
return None
2222

2323

2424
class SerializerConfig(BaseConfig):

ellar/core/modules/ref.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,6 @@ def routes(self) -> t.List[BaseRoute]:
230230
# self._flatten_routes.append(router)
231231

232232
def _register_module(self) -> None:
233-
if not is_decorated_with_injectable(self.module):
234-
self._module_type = injectable()(self.module)
235233
self.container.register(
236234
self.module, ModuleProvider(self.module, **self._init_kwargs)
237235
)

0 commit comments

Comments
 (0)