@@ -70,13 +70,13 @@ class Auth(Generic[UserModelT]):
70
70
backend : AuthBackend [UserModelT ] = None
71
71
72
72
def __init__ (
73
- self ,
74
- db : Union [AsyncDatabase , Database ],
75
- * ,
76
- token_store : BaseTokenStore = None ,
77
- user_model : Type [UserModelT ] = User ,
78
- pwd_context : CryptContext = CryptContext (schemes = ["bcrypt" ], deprecated = "auto" ),
79
- enforcer : AsyncEnforcer = None ,
73
+ self ,
74
+ db : Union [AsyncDatabase , Database ],
75
+ * ,
76
+ token_store : BaseTokenStore = None ,
77
+ user_model : Type [UserModelT ] = User ,
78
+ pwd_context : CryptContext = CryptContext (schemes = ["bcrypt" ], deprecated = "auto" ),
79
+ enforcer : AsyncEnforcer = None ,
80
80
):
81
81
self .user_model = user_model or self .user_model
82
82
assert self .user_model , "user_model is None"
@@ -152,15 +152,16 @@ async def get_current_user(self, request: Request) -> Optional[UserModelT]:
152
152
if "user" in request .scope : # 防止重复授权
153
153
return request .scope ["user" ]
154
154
token_info = await self ._get_token_info (request )
155
- request .scope ["user" ]: UserModelT = await self .db .async_get (self .user_model , token_info .id ) if token_info else None
155
+ request .scope ["user" ]: UserModelT = await self .db .async_get (self .user_model ,
156
+ token_info .id ) if token_info else None
156
157
return request .scope ["user" ]
157
158
158
159
def requires (
159
- self ,
160
- roles : Union [str , Sequence [str ]] = None ,
161
- status_code : int = 403 ,
162
- redirect : str = None ,
163
- response : Union [bool , Response ] = None ,
160
+ self ,
161
+ roles : Union [str , Sequence [str ]] = None ,
162
+ status_code : int = 403 ,
163
+ redirect : str = None ,
164
+ response : Union [bool , Response ] = None ,
164
165
) -> Callable : # sourcery no-metrics
165
166
# todo 优化
166
167
roles_ = (roles ,) if not roles or isinstance (roles , str ) else tuple (roles )
@@ -173,8 +174,8 @@ async def has_requires(user: UserModelT) -> bool:
173
174
return await self .has_role_for_user (user .username , roles_ )
174
175
175
176
async def depend (
176
- request : Request ,
177
- user : UserModelT = Depends (self .get_current_user ),
177
+ request : Request ,
178
+ user : UserModelT = Depends (self .get_current_user ),
178
179
) -> Union [bool , Response ]:
179
180
user_auth = request .scope .get ("__user_auth__" , None )
180
181
if user_auth is None :
@@ -289,29 +290,31 @@ async def create_role_user(self, role_key: str = "root", commit: bool = True) ->
289
290
await self .db .async_commit ()
290
291
return user
291
292
292
- async def request_login (self , request : Request , response : Response , username : str , password : str ) -> BaseApiOut [UserLoginOut ]:
293
+ async def request_login (self , request : Request , response : Response , username : str , password : str ) -> BaseApiOut [
294
+ UserLoginOut ]:
293
295
if request .scope .get ("user" ):
294
296
return BaseApiOut (code = 1 , msg = _ ("User logged in!" ), data = UserLoginOut .parse_obj (request .user ))
295
297
user = await request .auth .authenticate_user (username = username , password = password )
296
298
# 保存登录记录
297
299
ip = request .client .host # 获取真实ip
298
300
# 获取代理ip
299
- ips = [request .headers .get (key , "" ).strip () for key in ["x-forwarded-for" , "x-real-ip" , "x-client-ip" , "remote-host" ]]
301
+ ips = [request .headers .get (key , "" ).strip () for key in
302
+ ["x-forwarded-for" , "x-real-ip" , "x-client-ip" , "remote-host" ]]
300
303
forwarded_for = "," .join ([i for i in set (ips ) if i and i != ip ])
301
304
history = LoginHistory (
302
305
user_id = user .id if user else None ,
303
306
login_name = username ,
304
307
ip = request .client .host ,
305
308
user_agent = request .headers .get ("user-agent" ),
306
- login_status = "登录成功" ,
309
+ login_status = "Login successful" , # 登录成功
307
310
forwarded_for = forwarded_for ,
308
311
)
309
312
self .db .add (history )
310
313
if not user :
311
- history .login_status = "密码错误"
314
+ history .login_status = "Wrong password" # 密码错误
312
315
return BaseApiOut (status = - 1 , msg = _ ("Incorrect username or password!" ))
313
316
if not user .is_active :
314
- history .login_status = "用户未激活"
317
+ history .login_status = "User is not activated" # 用户未激活
315
318
return BaseApiOut (status = - 2 , msg = _ ("Inactive user status!" ))
316
319
request .scope ["user" ] = user
317
320
token_info = UserLoginOut .parse_obj (request .user )
@@ -359,7 +362,8 @@ def __init__(self, auth: Auth = None):
359
362
)
360
363
# oauth2
361
364
if self .route_gettoken :
362
- self .router .dependencies .append (Depends (self .OAuth2 (tokenUrl = f"{ self .router_path } /gettoken" , auto_error = False )))
365
+ self .router .dependencies .append (
366
+ Depends (self .OAuth2 (tokenUrl = f"{ self .router_path } /gettoken" , auto_error = False )))
363
367
self .router .add_api_route (
364
368
"/gettoken" ,
365
369
self .route_gettoken ,
@@ -395,7 +399,8 @@ async def user_logout(request: Request):
395
399
396
400
@property
397
401
def route_gettoken (self ):
398
- async def oauth_token (request : Request , response : Response , username : str = Form (...), password : str = Form (...)):
402
+ async def oauth_token (request : Request , response : Response , username : str = Form (...),
403
+ password : str = Form (...)):
399
404
return await self .auth .request_login (request , response , username , password )
400
405
401
406
return oauth_token
0 commit comments