Skip to content

Commit 14e1e20

Browse files
authored
add jwt authentication middleware (#84)
* Add jwt authentication middleware * Fix branch conflicts
1 parent 227d76c commit 14e1e20

File tree

11 files changed

+108
-83
lines changed

11 files changed

+108
-83
lines changed

backend/app/api/v1/api.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# -*- coding: utf-8 -*-
33
from typing import Annotated
44

5-
from fastapi import APIRouter, Query
5+
from fastapi import APIRouter, Query, Request
66

77
from backend.app.common.casbin_rbac import DependsRBAC
8-
from backend.app.common.jwt import DependsUser, CurrentUser
8+
from backend.app.common.jwt import DependsJwtAuth
99
from backend.app.common.pagination import PageDepends, paging_data
1010
from backend.app.common.response.response_schema import response_base
1111
from backend.app.database.db_mysql import CurrentSession
@@ -15,13 +15,13 @@
1515
router = APIRouter()
1616

1717

18-
@router.get('/{pk}', summary='获取接口详情', dependencies=[DependsUser])
18+
@router.get('/{pk}', summary='获取接口详情', dependencies=[DependsJwtAuth])
1919
async def get_api(pk: int):
2020
api = await ApiService.get(pk=pk)
2121
return response_base.success(data=api)
2222

2323

24-
@router.get('', summary='(模糊条件)分页获取所有接口', dependencies=[DependsUser, PageDepends])
24+
@router.get('', summary='(模糊条件)分页获取所有接口', dependencies=[DependsJwtAuth, PageDepends])
2525
async def get_all_apis(
2626
db: CurrentSession,
2727
name: Annotated[str | None, Query()] = None,
@@ -34,14 +34,14 @@ async def get_all_apis(
3434

3535

3636
@router.post('', summary='创建接口', dependencies=[DependsRBAC])
37-
async def create_api(obj: CreateApi, user: CurrentUser):
38-
await ApiService.create(obj=obj, user_id=user.id)
37+
async def create_api(request: Request, obj: CreateApi):
38+
await ApiService.create(obj=obj, user_id=request.user.id)
3939
return response_base.success()
4040

4141

4242
@router.put('/{pk}', summary='更新接口', dependencies=[DependsRBAC])
43-
async def update_api(pk: int, obj: UpdateApi, user: CurrentUser):
44-
count = await ApiService.update(pk=pk, obj=obj, user_id=user.id)
43+
async def update_api(request: Request, pk: int, obj: UpdateApi):
44+
count = await ApiService.update(pk=pk, obj=obj, user_id=request.user.id)
4545
if count > 0:
4646
return response_base.success()
4747
return response_base.fail()

backend/app/api/v1/auth/auth.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fastapi_limiter.depends import RateLimiter
88
from starlette.background import BackgroundTasks
99

10-
from backend.app.common.jwt import DependsUser, CurrentUser
10+
from backend.app.common.jwt import DependsJwtAuth
1111
from backend.app.common.response.response_schema import response_base
1212
from backend.app.schemas.token import LoginToken, SwaggerToken, NewToken
1313
from backend.app.schemas.user import Auth
@@ -42,14 +42,14 @@ async def user_login(request: Request, obj: Auth, background_tasks: BackgroundTa
4242
return response_base.success(data=data)
4343

4444

45-
@router.post('/new_token', summary='创建新 token', dependencies=[DependsUser])
45+
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
4646
async def create_new_token(refresh_token: Annotated[str, Query(...)]):
4747
access_token, access_expire = await AuthService.new_token(refresh_token)
4848
data = NewToken(access_token=access_token, access_token_expire_time=access_expire)
4949
return response_base.success(data=data)
5050

5151

52-
@router.post('/logout', summary='用户登出')
53-
async def user_logout(request: Request, current_user: CurrentUser):
54-
await AuthService.logout(request=request, current_user=current_user)
52+
@router.post('/logout', summary='用户登出', dependencies=[DependsJwtAuth])
53+
async def user_logout(request: Request):
54+
await AuthService.logout(request)
5555
return response_base.success()

backend/app/api/v1/login_log.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi import APIRouter, Query
66

77
from backend.app.common.casbin_rbac import DependsRBAC
8-
from backend.app.common.jwt import DependsUser
8+
from backend.app.common.jwt import DependsJwtAuth
99
from backend.app.common.pagination import paging_data, PageDepends
1010
from backend.app.common.response.response_schema import response_base
1111
from backend.app.database.db_mysql import CurrentSession
@@ -15,7 +15,7 @@
1515
router = APIRouter()
1616

1717

18-
@router.get('', summary='获取所有登录日志', dependencies=[DependsUser, PageDepends])
18+
@router.get('', summary='获取所有登录日志', dependencies=[DependsJwtAuth, PageDepends])
1919
async def get_all_login_logs(db: CurrentSession):
2020
log_select = await LoginLogService.get_select()
2121
page_data = await paging_data(db, log_select, GetAllLoginLog)

backend/app/api/v1/user.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from fastapi import APIRouter, Query, Request
66

7-
from backend.app.common.jwt import DependsUser, CurrentUser, DependsSuperUser
7+
from backend.app.common.jwt import DependsJwtAuth
88
from backend.app.common.pagination import paging_data, PageDepends
99
from backend.app.common.response.response_schema import response_base
1010
from backend.app.database.db_mysql import CurrentSession
@@ -29,30 +29,30 @@ async def password_reset(obj: ResetPassword):
2929
return response_base.fail()
3030

3131

32-
@router.get('/{username}', summary='查看用户信息', dependencies=[DependsUser])
32+
@router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth])
3333
async def get_user(username: str):
3434
current_user = await UserService.get_userinfo(username)
3535
data = GetAllUserInfo(**select_to_json(current_user))
3636
return response_base.success(data=data)
3737

3838

39-
@router.put('/{username}', summary='更新用户信息')
40-
async def update_userinfo(username: str, obj: UpdateUser, current_user: CurrentUser):
41-
count = await UserService.update(username=username, current_user=current_user, obj=obj)
39+
@router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth])
40+
async def update_userinfo(request: Request, username: str, obj: UpdateUser):
41+
count = await UserService.update(request=request, username=username, obj=obj)
4242
if count > 0:
4343
return response_base.success()
4444
return response_base.fail()
4545

4646

47-
@router.put('/{username}/avatar', summary='更新头像')
48-
async def update_avatar(username: str, avatar: Avatar, current_user: CurrentUser):
49-
count = await UserService.update_avatar(username=username, current_user=current_user, avatar=avatar)
47+
@router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth])
48+
async def update_avatar(request: Request, username: str, avatar: Avatar):
49+
count = await UserService.update_avatar(request=request, username=username, avatar=avatar)
5050
if count > 0:
5151
return response_base.success()
5252
return response_base.fail()
5353

5454

55-
@router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsUser, PageDepends])
55+
@router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsJwtAuth, PageDepends])
5656
async def get_all_users(
5757
db: CurrentSession,
5858
username: Annotated[str | None, Query()] = None,
@@ -64,33 +64,38 @@ async def get_all_users(
6464
return response_base.success(data=page_data)
6565

6666

67-
@router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsSuperUser])
68-
async def super_set(pk: int):
69-
count = await UserService.update_permission(pk)
67+
@router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsJwtAuth])
68+
async def super_set(request: Request, pk: int):
69+
count = await UserService.update_permission(request=request, pk=pk)
7070
if count > 0:
7171
return response_base.success()
7272
return response_base.fail()
7373

7474

75-
@router.post('/{pk}/action', summary='修改用户状态', dependencies=[DependsSuperUser])
76-
async def active_set(pk: int):
77-
count = await UserService.update_active(pk)
75+
@router.post('/{pk}/action', summary='修改用户状态', dependencies=[DependsJwtAuth])
76+
async def active_set(request: Request, pk: int):
77+
count = await UserService.update_active(request=request, pk=pk)
7878
if count > 0:
7979
return response_base.success()
8080
return response_base.fail()
8181

8282

83-
@router.post('/{pk}/multi', summary='修改用户多点登录状态')
84-
async def multi_set(request: Request, pk: int, current_user: CurrentUser):
85-
count = await UserService.update_multi_login(request=request, pk=pk, current_user=current_user)
83+
@router.post('/{pk}/multi', summary='修改用户多点登录状态', dependencies=[DependsJwtAuth])
84+
async def multi_set(request: Request, pk: int):
85+
count = await UserService.update_multi_login(request=request, pk=pk)
8686
if count > 0:
8787
return response_base.success()
8888
return response_base.fail()
8989

9090

91-
@router.delete('/{username}', summary='用户注销', description='用户注销 != 用户退出,注销之后用户将从数据库删除')
92-
async def delete_user(username: str, current_user: CurrentUser):
93-
count = await UserService.delete(username=username, current_user=current_user)
91+
@router.delete(
92+
path='/{username}',
93+
summary='用户注销',
94+
description='用户注销 != 用户登出,注销之后用户将从数据库删除',
95+
dependencies=[DependsJwtAuth],
96+
)
97+
async def delete_user(request: Request, username: str):
98+
count = await UserService.delete(request=request, username=username)
9499
if count > 0:
95100
return response_base.success()
96101
return response_base.fail()

backend/app/common/casbin_rbac.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi import Request, Depends
66

77
from backend.app.common.exception.errors import AuthorizationError
8-
from backend.app.common.jwt import CurrentUser
8+
from backend.app.common.jwt import DependsJwtAuth
99
from backend.app.core.conf import settings
1010
from backend.app.core.path_conf import RBAC_MODEL_CONF
1111
from backend.app.database.db_mysql import async_engine
@@ -26,18 +26,18 @@ async def get_casbin_enforcer() -> casbin.Enforcer:
2626

2727
return enforcer
2828

29-
async def rbac_verify(self, request: Request, user: CurrentUser) -> None:
29+
async def rbac_verify(self, request: Request, _: str = DependsJwtAuth) -> None:
3030
"""
3131
权限校验,超级用户跳过校验,默认拥有所有权限
3232
3333
:param request:
34-
:param user:
34+
:param _:
3535
:return:
3636
"""
37-
user_uuid = user.user_uuid
38-
user_roles = user.roles
37+
user_uuid = request.user.user_uuid
38+
user_roles = request.user.roles
3939
role_data_scope = [role.data_scope for role in user_roles]
40-
super_user = user.is_superuser
40+
super_user = request.user.is_superuser
4141
path = request.url.path
4242
method = request.method
4343

backend/app/common/jwt.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
from jose import jwt
99
from passlib.context import CryptContext
1010
from pydantic import ValidationError
11-
from typing_extensions import Annotated
11+
from sqlalchemy.ext.asyncio import AsyncSession
1212

1313
from backend.app.common.exception.errors import AuthorizationError, TokenError
1414
from backend.app.common.redis import redis_client
1515
from backend.app.core.conf import settings
1616
from backend.app.crud.crud_user import UserDao
17-
from backend.app.database.db_mysql import CurrentSession
1817
from backend.app.models import User
1918

2019
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
@@ -139,7 +138,7 @@ def jwt_decode(token: str) -> tuple[int, list[int]]:
139138
return user_id, role_ids
140139

141140

142-
async def jwt_authentication(token: str = Depends(oauth2_schema)) -> dict[str, int]:
141+
async def jwt_authentication(token: str) -> dict[str, int]:
143142
"""
144143
JWT authentication
145144
@@ -154,9 +153,9 @@ async def jwt_authentication(token: str = Depends(oauth2_schema)) -> dict[str, i
154153
return {'sub': user_id}
155154

156155

157-
async def get_current_user(db: CurrentSession, data: dict = Depends(jwt_authentication)) -> User:
156+
async def get_current_user(db: AsyncSession, data: dict) -> User:
158157
"""
159-
Get the current user through tokens
158+
Get the current user through token
160159
161160
:param db:
162161
:param data:
@@ -169,24 +168,18 @@ async def get_current_user(db: CurrentSession, data: dict = Depends(jwt_authenti
169168
return user
170169

171170

172-
async def get_current_is_superuser(user: User = Depends(get_current_user)):
171+
async def superuser_verify(request: Request) -> bool:
173172
"""
174173
Verify the current user permissions through token
175174
176-
:param user:
175+
:param request:
177176
:return:
178177
"""
179-
is_superuser = user.is_superuser
178+
is_superuser = request.user.is_superuser
180179
if not is_superuser:
181180
raise AuthorizationError
182181
return is_superuser
183182

184183

185-
# User Annotated
186-
CurrentUser = Annotated[User, Depends(get_current_user)]
187-
CurrentSuperUser = Annotated[bool, Depends(get_current_is_superuser)]
188-
# Token dependency injection
189-
CurrentJwtAuth = Annotated[dict, Depends(jwt_authentication)]
190-
# Permission dependency injection
191-
DependsUser = Depends(get_current_user)
192-
DependsSuperUser = Depends(get_current_is_superuser)
184+
# Jwt verify dependency
185+
DependsJwtAuth = Depends(oauth2_schema)

backend/app/core/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def validator_api_url(cls, values):
9191
# Middleware
9292
MIDDLEWARE_CORS: bool = True
9393
MIDDLEWARE_GZIP: bool = True
94+
MIDDLEWARE_JWT_AUTH: bool = True
9495
MIDDLEWARE_ACCESS: bool = False
9596

9697
# Casbin

backend/app/core/registrar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from fastapi import FastAPI
66
from fastapi_limiter import FastAPILimiter
77
from fastapi_pagination import add_pagination
8+
from starlette.middleware.authentication import AuthenticationMiddleware
89

910
from backend.app.api.routers import v1
1011
from backend.app.common.exception.exception_handler import register_exception
@@ -105,6 +106,11 @@ def register_middleware(app: FastAPI):
105106
from fastapi.middleware.gzip import GZipMiddleware
106107

107108
app.add_middleware(GZipMiddleware)
109+
# JWT auth
110+
if settings.MIDDLEWARE_JWT_AUTH:
111+
from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware
112+
113+
app.add_middleware(AuthenticationMiddleware, backend=JwtAuthMiddleware())
108114
# Api access logs
109115
if settings.MIDDLEWARE_ACCESS:
110116
from backend.app.middleware.access_middleware import AccessMiddleware
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
from starlette.authentication import AuthenticationBackend
4+
from fastapi import Request
5+
6+
from backend.app.common import jwt
7+
from backend.app.database.db_mysql import async_db_session
8+
9+
10+
class JwtAuthMiddleware(AuthenticationBackend):
11+
"""JWT 认证中间件"""
12+
13+
async def authenticate(self, request: Request):
14+
auth = request.headers.get('Authorization')
15+
if not auth:
16+
return
17+
18+
scheme, token = auth.split()
19+
if scheme.lower() != 'bearer':
20+
return
21+
22+
sub = await jwt.jwt_authentication(token)
23+
24+
async with async_db_session() as db:
25+
user = await jwt.get_current_user(db, data=sub)
26+
27+
return auth, user

backend/app/services/auth_service.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010

1111
from backend.app.common import jwt
1212
from backend.app.common.exception import errors
13-
from backend.app.common.jwt import get_token, jwt_decode
13+
from backend.app.common.jwt import get_token
1414
from backend.app.common.redis import redis_client
1515
from backend.app.core.conf import settings
1616
from backend.app.crud.crud_user import UserDao
1717
from backend.app.database.db_mysql import async_db_session
18-
from backend.app.models import User
1918
from backend.app.schemas.user import Auth
2019
from backend.app.services.login_log_service import LoginLogService
2120

@@ -93,12 +92,11 @@ async def new_token(refresh_token: str) -> tuple[str, datetime]:
9392
return access_new_token, access_new_token_expire_time
9493

9594
@staticmethod
96-
async def logout(*, request: Request, current_user: User) -> NoReturn:
95+
async def logout(request: Request) -> NoReturn:
9796
token = get_token(request)
98-
user_id, _ = jwt_decode(token)
99-
if current_user.is_multi_login:
100-
key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token}'
97+
if request.user.is_multi_login:
98+
key = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:{token}'
10199
await redis_client.delete(key)
102100
else:
103-
prefix = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:'
101+
prefix = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:'
104102
await redis_client.delete_prefix(prefix)

0 commit comments

Comments
 (0)