Skip to content

Commit 100e9d4

Browse files
committed
Fixed test due to refactor
1 parent a193cee commit 100e9d4

File tree

9 files changed

+210
-42
lines changed

9 files changed

+210
-42
lines changed

tests/test_application/test_application_functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import typing as t
33

4-
from starlette.responses import JSONResponse, Response
4+
from starlette.responses import JSONResponse, PlainTextResponse, Response
55

66
from ellar.common import Module, get, template_filter, template_global
77
from ellar.compatible import asynccontextmanager
@@ -196,6 +196,19 @@ async def homepage(request: Request):
196196

197197

198198
class TestEllarApp:
199+
def test_ellar_as_asgi_app(self):
200+
@get("/")
201+
async def homepage(request: Request, ctx: IExecutionContext):
202+
res = PlainTextResponse("Ellar Route Handler as an ASGI app")
203+
await res(*ctx.get_args())
204+
205+
app = AppFactory.create_app()
206+
app.router.append(homepage)
207+
client = TestClient(app)
208+
response = client.get("/")
209+
assert response.status_code == 200
210+
assert response.text == "Ellar Route Handler as an ASGI app"
211+
199212
def test_app_staticfiles_route(self, tmpdir):
200213
path = os.path.join(tmpdir, "example.txt")
201214
with open(path, "w") as file:

tests/test_application/test_functional_middleware.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from ellar.common import Module, ModuleRouter, middleware
55
from ellar.core import TestClientFactory
6-
from ellar.core.context import IExecutionContext
6+
from ellar.core.context import IHostContext
77

88
mr = ModuleRouter()
99

@@ -21,22 +21,22 @@ def homepage(request: Request):
2121
@Module(routers=[mr])
2222
class ModuleMiddleware:
2323
@middleware()
24-
async def middleware_modify_response(cls, context: IExecutionContext, call_next):
25-
response = context.get_response()
24+
async def middleware_modify_response(cls, context: IHostContext, call_next):
25+
response = context.switch_to_http_connection().get_response()
2626
response.headers.setdefault("modified-header", "Ellar")
2727
await call_next()
2828

2929
@middleware()
30-
async def middleware_modify_request(cls, context: IExecutionContext, call_next):
31-
request = context.switch_to_request()
30+
async def middleware_modify_request(cls, context: IHostContext, call_next):
31+
request = context.switch_to_http_connection().get_request()
3232
request.state.user = None
3333
if request.headers.get("set-user"):
3434
request.state.user = dict(username="Ellar")
3535
await call_next()
3636

3737
@middleware()
38-
async def middleware_return_response(cls, context: IExecutionContext, call_next):
39-
request = context.switch_to_request()
38+
async def middleware_return_response(cls, context: IHostContext, call_next):
39+
request = context.switch_to_http_connection().get_request()
4040
if request.headers.get("ellar"):
4141
return PlainTextResponse("middleware_return_response returned a response")
4242
await call_next()
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from ellar.common import Controller
2+
from ellar.constants import CONTROLLER_METADATA, NOT_SET
3+
from ellar.reflect import reflect
4+
5+
6+
@Controller(
7+
prefix="/decorator",
8+
description="Some description",
9+
external_doc_description="external",
10+
guards=[],
11+
version=("v1",),
12+
tag="dec",
13+
external_doc_url="https://example.com",
14+
name="test",
15+
)
16+
class ControllerDecorationTest:
17+
pass
18+
19+
20+
@Controller
21+
class ControllerDefaultTest:
22+
pass
23+
24+
25+
def test_controller_decoration_default():
26+
assert (
27+
reflect.get_metadata(CONTROLLER_METADATA.NAME, ControllerDefaultTest)
28+
== "defaulttest"
29+
)
30+
31+
assert reflect.get_metadata(CONTROLLER_METADATA.OPENAPI, ControllerDefaultTest) == {
32+
"tag": NOT_SET,
33+
"description": None,
34+
"external_doc_description": None,
35+
"external_doc_url": None,
36+
}
37+
assert reflect.get_metadata(CONTROLLER_METADATA.GUARDS, ControllerDefaultTest) == []
38+
assert (
39+
reflect.get_metadata(CONTROLLER_METADATA.VERSION, ControllerDefaultTest)
40+
== set()
41+
)
42+
assert (
43+
reflect.get_metadata(CONTROLLER_METADATA.PATH, ControllerDefaultTest)
44+
== "/defaulttest"
45+
)
46+
assert (
47+
reflect.get_metadata(
48+
CONTROLLER_METADATA.INCLUDE_IN_SCHEMA, ControllerDefaultTest
49+
)
50+
is True
51+
)
52+
53+
54+
def test_controller_decoration_test():
55+
assert (
56+
reflect.get_metadata(CONTROLLER_METADATA.NAME, ControllerDecorationTest)
57+
== "test"
58+
)
59+
60+
assert reflect.get_metadata(
61+
CONTROLLER_METADATA.OPENAPI, ControllerDecorationTest
62+
) == {
63+
"tag": "dec",
64+
"description": "Some description",
65+
"external_doc_description": "external",
66+
"external_doc_url": "https://example.com",
67+
}
68+
assert (
69+
reflect.get_metadata(CONTROLLER_METADATA.GUARDS, ControllerDecorationTest) == []
70+
)
71+
assert reflect.get_metadata(
72+
CONTROLLER_METADATA.VERSION, ControllerDecorationTest
73+
) == {
74+
"v1",
75+
}
76+
assert (
77+
reflect.get_metadata(CONTROLLER_METADATA.PATH, ControllerDecorationTest)
78+
== "/decorator"
79+
)
80+
assert (
81+
reflect.get_metadata(
82+
CONTROLLER_METADATA.INCLUDE_IN_SCHEMA, ControllerDecorationTest
83+
)
84+
is True
85+
)

tests/test_di/test_middleware.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ellar.constants import SCOPE_SERVICE_PROVIDER
66
from ellar.core.connection import HTTPConnection, Request, WebSocket
7-
from ellar.core.context import IExecutionContext
7+
from ellar.core.context import IHostContext
88
from ellar.core.middleware import RequestServiceProviderMiddleware
99
from ellar.core.response import Response
1010
from ellar.di import EllarInjector
@@ -39,17 +39,23 @@ async def assert_iexecute_context_app(scope, receive, send):
3939
assert scope[SCOPE_SERVICE_PROVIDER]
4040

4141
service_provider = scope[SCOPE_SERVICE_PROVIDER]
42-
execution_context: IExecutionContext = service_provider.get(IExecutionContext)
42+
host_context: IHostContext = service_provider.get(IHostContext)
4343
assert (
4444
service_provider.get(HTTPConnection)
45-
is execution_context.switch_to_http_connection()
45+
is host_context.switch_to_http_connection().get_client()
4646
)
47-
assert service_provider.get(Request) is execution_context.switch_to_request()
48-
assert service_provider.get(Response) is execution_context.get_response()
49-
assert service_provider is execution_context.get_service_provider()
47+
assert (
48+
service_provider.get(Request)
49+
is host_context.switch_to_http_connection().get_request()
50+
)
51+
assert (
52+
service_provider.get(Response)
53+
is host_context.switch_to_http_connection().get_response()
54+
)
55+
assert service_provider is host_context.get_service_provider()
5056

5157
with pytest.raises(Exception):
52-
execution_context.switch_to_websocket()
58+
host_context.switch_to_websocket()
5359

5460
with pytest.raises(Exception):
5561
service_provider.get(WebSocket)
@@ -69,6 +75,32 @@ async def assert_iexecute_context_app(scope, receive, send):
6975
)
7076

7177

78+
async def assert_iexecute_context_app_websocket(scope, receive, send):
79+
assert scope[SCOPE_SERVICE_PROVIDER]
80+
81+
service_provider = scope[SCOPE_SERVICE_PROVIDER]
82+
host_context: IHostContext = service_provider.get(IHostContext)
83+
84+
websocket = host_context.switch_to_websocket().get_client()
85+
assert service_provider.get(WebSocket) is websocket
86+
assert service_provider is host_context.get_service_provider()
87+
88+
assert (
89+
service_provider.get(HTTPConnection)
90+
is host_context.switch_to_http_connection().get_client()
91+
)
92+
93+
with pytest.raises(Exception):
94+
service_provider.get(Request)
95+
96+
with pytest.raises(Exception):
97+
service_provider.get(Response)
98+
99+
await websocket.accept()
100+
await websocket.send_text("Hello, world!")
101+
await websocket.close()
102+
103+
72104
def test_di_middleware(test_client_factory):
73105
injector_ = EllarInjector()
74106
injector_.container.install(DummyModule)
@@ -96,3 +128,14 @@ def test_di_middleware_execution_context_initialization(test_client_factory):
96128
assert response.status_code == 200
97129
data = response.json()
98130
assert data["message"] == "execution context work"
131+
132+
133+
def test_di_middleware_execution_context_initialization_websocket(test_client_factory):
134+
asgi_app = RequestServiceProviderMiddleware(
135+
assert_iexecute_context_app_websocket, debug=False, injector=EllarInjector()
136+
)
137+
138+
client = test_client_factory(asgi_app)
139+
with client.websocket_connect("/") as session:
140+
text = session.receive_text()
141+
assert text == "Hello, world!"

tests/test_exceptions/test_custom_exceptions.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ellar.common import get
1010
from ellar.core import Config, TestClientFactory
11-
from ellar.core.context import IExecutionContext
11+
from ellar.core.context import IHostContext
1212
from ellar.core.exceptions.callable_exceptions import CallableExceptionHandler
1313
from ellar.core.exceptions.handlers import APIException, APIExceptionHandler
1414
from ellar.core.exceptions.interfaces import IExceptionHandler
@@ -28,14 +28,14 @@ class NewExceptionHandler(IExceptionHandler):
2828
exception_type_or_code = NewException
2929

3030
async def catch(
31-
self, ctx: IExecutionContext, exc: t.Union[t.Any, Exception]
31+
self, ctx: IHostContext, exc: t.Union[t.Any, Exception]
3232
) -> t.Union[Response, t.Any]:
3333
return JSONResponse({"detail": str(exc)}, status_code=400)
3434

3535

3636
class OverrideAPIExceptionHandler(APIExceptionHandler):
3737
async def catch(
38-
self, ctx: IExecutionContext, exc: t.Union[t.Any, Exception]
38+
self, ctx: IHostContext, exc: t.Union[t.Any, Exception]
3939
) -> t.Union[Response, t.Any]:
4040
return JSONResponse({"detail": str(exc)}, status_code=404)
4141

@@ -44,7 +44,7 @@ class OverrideHTTPException(IExceptionHandler):
4444
exception_type_or_code = HTTPException
4545

4646
async def catch(
47-
self, ctx: IExecutionContext, exc: t.Union[t.Any, Exception]
47+
self, ctx: IHostContext, exc: t.Union[t.Any, Exception]
4848
) -> t.Union[Response, t.Any]:
4949
return JSONResponse({"detail": "HttpException Override"}, status_code=400)
5050

@@ -53,13 +53,13 @@ class ServerErrorHandler(IExceptionHandler):
5353
exception_type_or_code = 500
5454

5555
async def catch(
56-
self, ctx: IExecutionContext, exc: t.Union[t.Any, Exception]
56+
self, ctx: IHostContext, exc: t.Union[t.Any, Exception]
5757
) -> t.Union[Response, t.Any]:
5858
return JSONResponse({"detail": "Server Error"}, status_code=500)
5959

6060

61-
def error_500(ctx: IExecutionContext, exc: Exception):
62-
assert isinstance(ctx, IExecutionContext)
61+
def error_500(ctx: IHostContext, exc: Exception):
62+
assert isinstance(ctx, IHostContext)
6363
return JSONResponse({"detail": "Server Error"}, status_code=500)
6464

6565

@@ -92,9 +92,7 @@ def test_invalid_exception_handler_setup_raise_exception():
9292
with pytest.raises(AssertionError) as ex:
9393

9494
class InvalidExceptionSetup(IExceptionHandler):
95-
def catch(
96-
self, ctx: IExecutionContext, exc: t.Any
97-
) -> t.Union[Response, t.Any]:
95+
def catch(self, ctx: IHostContext, exc: t.Any) -> t.Union[Response, t.Any]:
9896
pass
9997

10098
assert "'exception_type_or_code' must be defined" in str(ex.value)
@@ -106,9 +104,7 @@ def test_invalid_exception_type_setup_raise_exception():
106104
class InvalidExceptionSetup(IExceptionHandler):
107105
exception_type_or_code = ""
108106

109-
def catch(
110-
self, ctx: IExecutionContext, exc: t.Any
111-
) -> t.Union[Response, t.Any]:
107+
def catch(self, ctx: IHostContext, exc: t.Any) -> t.Union[Response, t.Any]:
112108
pass
113109

114110
assert "'exception_type_or_code' must be defined" in str(ex.value)
@@ -118,9 +114,7 @@ def catch(
118114
class InvalidExceptionSetup2(IExceptionHandler):
119115
exception_type_or_code = InvalidExceptionHandler
120116

121-
def catch(
122-
self, ctx: IExecutionContext, exc: t.Any
123-
) -> t.Union[Response, t.Any]:
117+
def catch(self, ctx: IHostContext, exc: t.Any) -> t.Union[Response, t.Any]:
124118
pass
125119

126120
assert "'exception_type_or_code' is not a valid type" in str(ex.value)

0 commit comments

Comments
 (0)