Skip to content

Commit 05b74d7

Browse files
authored
Simplify OAuth2 model and optimize auth service (#655)
1 parent 10c0c69 commit 05b74d7

File tree

7 files changed

+74
-52
lines changed

7 files changed

+74
-52
lines changed

backend/app/admin/crud/crud_user.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ async def update(self, db: AsyncSession, input_user: User, obj: UpdateUserParam)
114114
input_user.roles = roles.scalars().all()
115115
return count
116116

117+
async def update_avatar(self, db: AsyncSession, user_id: int, avatar: str) -> int:
118+
"""
119+
更新用户头像
120+
121+
:param db: 数据库会话
122+
:param user_id: 用户 ID
123+
:param avatar: 头像地址
124+
:return:
125+
"""
126+
return await self.update_model(db, user_id, {'avatar': avatar})
127+
117128
async def delete(self, db: AsyncSession, user_id: int) -> int:
118129
"""
119130
删除用户

backend/app/admin/schema/user.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class AddOAuth2UserParam(AuthSchemaBase):
3838

3939
nickname: str | None = Field(None, description='昵称')
4040
email: EmailStr = Field(description='邮箱')
41+
avatar: HttpUrl | None = Field(None, description='头像地址')
4142

4243

4344
class ResetPasswordParam(SchemaBase):
@@ -54,7 +55,7 @@ class UserInfoSchemaBase(SchemaBase):
5455
dept_id: int | None = Field(None, description='部门 ID')
5556
username: str = Field(description='用户名')
5657
nickname: str = Field(description='昵称')
57-
avatar: HttpUrl | None = Field(None, description='头像')
58+
avatar: HttpUrl | None = Field(None, description='头像地址')
5859

5960

6061
class UpdateUserParam(UserInfoSchemaBase):

backend/plugin/oauth2/crud/crud_user_social.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
class CRUDUserSocial(CRUDPlus[UserSocial]):
1111
"""用户社交账号数据库操作类"""
1212

13-
async def get(self, db: AsyncSession, pk: int, source: str) -> UserSocial | None:
13+
async def check_binding(self, db: AsyncSession, pk: int, source: str) -> UserSocial | None:
1414
"""
15-
获取用户社交账号绑定详情
15+
检查系统用户社交账号绑定
1616
1717
:param db: 数据库会话
1818
:param pk: 用户 ID
@@ -21,6 +21,17 @@ async def get(self, db: AsyncSession, pk: int, source: str) -> UserSocial | None
2121
"""
2222
return await self.select_model_by_column(db, user_id=pk, source=source)
2323

24+
async def get_by_sid(self, db: AsyncSession, sid: str, source: str) -> UserSocial | None:
25+
"""
26+
通过 UUID 获取社交用户
27+
28+
:param db: 数据库会话
29+
:param sid: 第三方 UUID
30+
:param source: 社交账号类型
31+
:return:
32+
"""
33+
return await self.select_model_by_column(db, sid=sid, source=source)
34+
2435
async def create(self, db: AsyncSession, obj: CreateUserSocialParam) -> None:
2536
"""
2637
创建用户社交账号绑定

backend/plugin/oauth2/model/user_social.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,9 @@ class UserSocial(Base):
1919
__tablename__ = 'sys_user_social'
2020

2121
id: Mapped[id_key] = mapped_column(init=False)
22+
sid: Mapped[str] = mapped_column(String(20), comment='第三方用户 ID')
2223
source: Mapped[str] = mapped_column(String(20), comment='第三方用户来源')
23-
open_id: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户 open id')
24-
sid: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户 ID')
25-
union_id: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户 union id')
26-
scope: Mapped[str | None] = mapped_column(String(120), default=None, comment='第三方用户授予的权限')
27-
code: Mapped[str | None] = mapped_column(String(50), default=None, comment='用户的授权 code')
2824

2925
# 用户社交信息一对多
30-
user_id: Mapped[int | None] = mapped_column(
31-
ForeignKey('sys_user.id', ondelete='SET NULL'), default=None, comment='用户关联ID'
32-
)
26+
user_id: Mapped[int] = mapped_column(ForeignKey('sys_user.id', ondelete='CASCADE'), comment='用户关联ID')
3327
user: Mapped[User | None] = relationship(init=False, backref='socials')

backend/plugin/oauth2/plugin.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[plugin]
22
summary = 'OAuth 2.0'
3-
version = '0.0.1'
3+
version = '0.0.2'
44
description = '通过 OAuth 2.0 的方式登录系统'
55
author = 'wu-clan'
66

backend/plugin/oauth2/schema/user_social.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,8 @@
99
class UserSocialSchemaBase(SchemaBase):
1010
"""用户社交基础模型"""
1111

12+
sid: str = Field(description='第三方用户 ID')
1213
source: UserSocialType = Field(description='社交平台')
13-
open_id: str | None = Field(None, description='开放平台 ID')
14-
sid: str | None = Field(None, description='第三方用户 ID')
15-
union_id: str | None = Field(None, description='开放平台唯一 ID')
16-
scope: str | None = Field(None, description='授权范围')
17-
code: str | None = Field(None, description='授权码')
1814

1915

2016
class CreateUserSocialParam(UserSocialSchemaBase):

backend/plugin/oauth2/service/oauth2_service.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
from backend.app.admin.crud.crud_user import user_dao
99
from backend.app.admin.schema.token import GetLoginToken
10-
from backend.app.admin.schema.user import RegisterUserParam
10+
from backend.app.admin.schema.user import AddOAuth2UserParam
1111
from backend.app.admin.service.login_log_service import login_log_service
1212
from backend.common.enums import LoginLogStatusType, UserSocialType
13-
from backend.common.exception.errors import AuthorizationError
1413
from backend.common.security import jwt
1514
from backend.core.conf import settings
1615
from backend.database.db import async_db_session
@@ -43,55 +42,65 @@ async def create_with_login(
4342
:return:
4443
"""
4544
async with async_db_session.begin() as db:
46-
# 获取 OAuth2 平台用户信息
47-
social_id = user.get('id')
48-
social_nickname = user.get('name')
45+
sid = user.get('uuid')
46+
username = user.get('username')
47+
nickname = user.get('nickname')
48+
email = user.get('email')
49+
avatar = user.get('avatar_url')
4950

50-
social_username = user.get('username')
5151
if social == UserSocialType.github:
52-
social_username = user.get('login')
52+
sid = user.get('id')
53+
username = user.get('login')
54+
nickname = user.get('name')
5355

54-
social_email = user.get('email')
5556
if social == UserSocialType.linux_do:
56-
social_email = f'{social_username}@linux.do'
57-
if not social_email:
58-
raise AuthorizationError(msg=f'授权失败,{social.value} 账户未绑定邮箱')
57+
sid = user.get('id')
58+
nickname = user.get('name')
5959

60-
# 创建系统用户
61-
sys_user = await user_dao.check_email(db, social_email)
62-
if not sys_user:
63-
sys_user = await user_dao.get_by_username(db, social_username)
64-
if sys_user:
65-
social_username = f'{social_username}#{text_captcha(5)}'
66-
sys_user = await user_dao.get_by_nickname(db, social_nickname)
67-
if sys_user:
68-
social_username = f'{social_nickname}#{text_captcha(5)}'
69-
new_sys_user = RegisterUserParam(
70-
username=social_username, password=None, nickname=social_username, email=social_email
71-
)
72-
await user_dao.create(db, new_sys_user, social=True)
73-
await db.flush()
74-
sys_user = await user_dao.check_email(db, social_email)
75-
# 绑定社交用户
76-
sys_user_id = sys_user.id
77-
user_social = await user_social_dao.get(db, sys_user_id, social.value)
60+
sys_user = None
61+
user_social = await user_social_dao.get_by_sid(db, str(sid), str(social.value))
7862
if not user_social:
79-
new_user_social = CreateUserSocialParam(source=social.value, sid=str(social_id), user_id=sys_user_id)
80-
await user_social_dao.create(db, new_user_social)
63+
if email:
64+
sys_user = await user_dao.check_email(db, email)
65+
66+
# 创建系统用户
67+
if not sys_user:
68+
while await user_dao.get_by_username(db, username):
69+
username = f'{username}_{text_captcha(5)}'
70+
new_sys_user = AddOAuth2UserParam(
71+
username=username,
72+
password='123456', # 默认密码,可修改系统用户表进行默认密码检测并配合前端进行修改密码提示
73+
nickname=nickname,
74+
email=email,
75+
avatar=avatar,
76+
)
77+
await user_dao.add_by_oauth2(db, new_sys_user)
78+
await db.flush()
79+
sys_user = await user_dao.get_by_username(db, username)
80+
81+
# 绑定社交用户
82+
new_user_social = CreateUserSocialParam(sid=str(sid), source=social.value, user_id=sys_user.id)
83+
await user_social_dao.create(db, new_user_social)
84+
85+
if not sys_user:
86+
sys_user = await user_dao.get(db, user_social.user_id)
87+
if avatar:
88+
await user_dao.update_avatar(db, sys_user.id, avatar)
89+
8190
# 创建 token
8291
access_token = await jwt.create_access_token(
83-
str(sys_user_id),
92+
str(sys_user.id),
8493
sys_user.is_multi_login,
8594
# extra info
8695
username=sys_user.username,
87-
nickname=sys_user.nickname,
96+
nickname=sys_user.nickname or f'#{text_captcha(5)}',
8897
last_login_time=timezone.t_str(timezone.now()),
8998
ip=request.state.ip,
9099
os=request.state.os,
91100
browser=request.state.browser,
92101
device=request.state.device,
93102
)
94-
refresh_token = await jwt.create_refresh_token(str(sys_user_id), multi_login=sys_user.is_multi_login)
103+
refresh_token = await jwt.create_refresh_token(str(sys_user.id), multi_login=sys_user.is_multi_login)
95104
await user_dao.update_login_time(db, sys_user.username)
96105
await db.refresh(sys_user)
97106
login_log = dict(
@@ -115,8 +124,8 @@ async def create_with_login(
115124
data = GetLoginToken(
116125
access_token=access_token.access_token,
117126
access_token_expire_time=access_token.access_token_expire_time,
118-
user=sys_user, # type: ignore
119127
session_uuid=access_token.session_uuid,
128+
user=sys_user, # type: ignore
120129
)
121130
return data
122131

0 commit comments

Comments
 (0)