Skip to content

Commit a8a83dd

Browse files
committed
Added test for functional based middleware
1 parent a80b8e4 commit a8a83dd

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from starlette.requests import Request
2+
from starlette.responses import PlainTextResponse
3+
4+
from ellar.common import Module, ModuleRouter, middleware
5+
from ellar.core import TestClientFactory
6+
from ellar.core.context import IExecutionContext
7+
8+
mr = ModuleRouter()
9+
10+
11+
@mr.get()
12+
def homepage(request: Request):
13+
if request.headers.get("modified_header"):
14+
return "homepage modified_header"
15+
16+
if request.state.user:
17+
return request.state.user
18+
return "homepage"
19+
20+
21+
@Module(routers=[mr])
22+
class ModuleMiddleware:
23+
@middleware()
24+
async def middleware_modify_response(cls, context: IExecutionContext, call_next):
25+
response = context.get_response()
26+
response.headers.setdefault("modified-header", "Ellar")
27+
await call_next()
28+
29+
@middleware()
30+
async def middleware_modify_request(cls, context: IExecutionContext, call_next):
31+
request = context.switch_to_request()
32+
request.state.user = None
33+
if request.headers.get("set-user"):
34+
request.state.user = dict(username="Ellar")
35+
await call_next()
36+
37+
@middleware()
38+
async def middleware_return_response(cls, context: IExecutionContext, call_next):
39+
request = context.switch_to_request()
40+
if request.headers.get("ellar"):
41+
return PlainTextResponse("middleware_return_response returned a response")
42+
await call_next()
43+
44+
45+
def test_middleware_modifying_response():
46+
tm = TestClientFactory.create_test_module_from_module(ModuleMiddleware)
47+
client = tm.get_client()
48+
49+
response = client.get("/")
50+
assert response.status_code == 200
51+
assert response.headers["modified-header"] == "Ellar"
52+
53+
54+
def test_middleware_modifying_request():
55+
tm = TestClientFactory.create_test_module_from_module(ModuleMiddleware)
56+
client = tm.get_client()
57+
58+
response = client.get("/", headers={"set-user": "set"})
59+
assert response.status_code == 200
60+
assert response.json() == {"username": "Ellar"}
61+
62+
63+
def test_middleware_returns_response():
64+
tm = TestClientFactory.create_test_module_from_module(ModuleMiddleware)
65+
client = tm.get_client()
66+
67+
response = client.get("/", headers={"ellar": "set"})
68+
assert response.status_code == 200
69+
assert response.text == "middleware_return_response returned a response"

0 commit comments

Comments
 (0)