@@ -62,71 +62,71 @@ class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
62
62
63
63
:param user_db_model: Pydantic model of a DB representation of a user.
64
64
:param session: SQLAlchemy session instance.
65
- :param user_model : SQLAlchemy user model.
66
- :param oauth_account_model : Optional SQLAlchemy OAuth accounts model.
65
+ :param user_table : SQLAlchemy user model.
66
+ :param oauth_account_table : Optional SQLAlchemy OAuth accounts model.
67
67
"""
68
68
69
69
session : AsyncSession
70
- user_model : Type [SQLAlchemyBaseUserTable ]
71
- oauth_account_model : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]]
70
+ user_table : Type [SQLAlchemyBaseUserTable ]
71
+ oauth_account_table : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]]
72
72
73
73
def __init__ (
74
74
self ,
75
75
user_db_model : Type [UD ],
76
76
session : AsyncSession ,
77
- user_model : Type [SQLAlchemyBaseUserTable ],
78
- oauth_account_model : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]] = None ,
77
+ user_table : Type [SQLAlchemyBaseUserTable ],
78
+ oauth_account_table : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]] = None ,
79
79
):
80
80
super ().__init__ (user_db_model )
81
81
self .session = session
82
- self .user_model = user_model
83
- self .oauth_account_model = oauth_account_model
82
+ self .user_table = user_table
83
+ self .oauth_account_table = oauth_account_table
84
84
85
85
async def get (self , id : UUID4 ) -> Optional [UD ]:
86
- statement = select (self .user_model ).where (self .user_model .id == id )
86
+ statement = select (self .user_table ).where (self .user_table .id == id )
87
87
return await self ._get_user (statement )
88
88
89
89
async def get_by_email (self , email : str ) -> Optional [UD ]:
90
- statement = select (self .user_model ).where (
91
- func .lower (self .user_model .email ) == func .lower (email )
90
+ statement = select (self .user_table ).where (
91
+ func .lower (self .user_table .email ) == func .lower (email )
92
92
)
93
93
return await self ._get_user (statement )
94
94
95
95
async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UD ]:
96
- if self .oauth_account_model is not None :
96
+ if self .oauth_account_table is not None :
97
97
statement = (
98
- select (self .user_model )
99
- .join (self .oauth_account_model )
100
- .where (self .oauth_account_model .oauth_name == oauth )
101
- .where (self .oauth_account_model .account_id == account_id )
98
+ select (self .user_table )
99
+ .join (self .oauth_account_table )
100
+ .where (self .oauth_account_table .oauth_name == oauth )
101
+ .where (self .oauth_account_table .account_id == account_id )
102
102
)
103
103
return await self ._get_user (statement )
104
104
105
105
async def create (self , user : UD ) -> UD :
106
- user_model = self .user_model (** user .dict (exclude = {"oauth_accounts" }))
107
- self .session .add (user_model )
106
+ user_table = self .user_table (** user .dict (exclude = {"oauth_accounts" }))
107
+ self .session .add (user_table )
108
108
109
- if self .oauth_account_model is not None :
109
+ if self .oauth_account_table is not None :
110
110
for oauth_account in user .oauth_accounts :
111
- oauth_account_model = self .oauth_account_model (
111
+ oauth_account_table = self .oauth_account_table (
112
112
** oauth_account .dict (), user_id = user .id
113
113
)
114
- self .session .add (oauth_account_model )
114
+ self .session .add (oauth_account_table )
115
115
116
116
await self .session .commit ()
117
117
return user
118
118
119
119
async def update (self , user : UD ) -> UD :
120
- user_model = await self .session .get (self .user_model , user .id )
120
+ user_table = await self .session .get (self .user_table , user .id )
121
121
for key , value in user .dict (exclude = {"oauth_accounts" }).items ():
122
- setattr (user_model , key , value )
123
- self .session .add (user_model )
122
+ setattr (user_table , key , value )
123
+ self .session .add (user_table )
124
124
125
- if self .oauth_account_model is not None :
125
+ if self .oauth_account_table is not None :
126
126
for oauth_account in user .oauth_accounts :
127
127
statement = update (
128
- self .oauth_account_model ,
129
- whereclause = self .oauth_account_model .id == oauth_account .id ,
128
+ self .oauth_account_table ,
129
+ whereclause = self .oauth_account_table .id == oauth_account .id ,
130
130
values = {** oauth_account .dict (), "user_id" : user .id },
131
131
)
132
132
await self .session .execute (statement )
@@ -136,11 +136,11 @@ async def update(self, user: UD) -> UD:
136
136
return user
137
137
138
138
async def delete (self , user : UD ) -> None :
139
- statement = delete (self .user_model , self .user_model .id == user .id )
139
+ statement = delete (self .user_table , self .user_table .id == user .id )
140
140
await self .session .execute (statement )
141
141
142
142
async def _get_user (self , statement : Select ) -> Optional [UD ]:
143
- if self .oauth_account_model is not None :
143
+ if self .oauth_account_table is not None :
144
144
statement = statement .options (joinedload ("oauth_accounts" ))
145
145
146
146
results = await self .session .execute (statement )
0 commit comments