Skip to content
This repository was archived by the owner on Jun 11, 2025. It is now read-only.

Commit f3a73ee

Browse files
committed
Implement access token strategy db adapter
1 parent d61ee21 commit f3a73ee

File tree

7 files changed

+270
-8
lines changed

7 files changed

+270
-8
lines changed

fastapi_users_db_sqlmodel/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414

1515
class SQLModelBaseUserDB(BaseUserDB, SQLModel):
16+
__tablename__ = "user"
17+
1618
id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True, nullable=False)
1719
email: EmailStr = Field(
1820
sa_column_kwargs={"unique": True, "index": True}, nullable=False
@@ -27,7 +29,10 @@ class Config:
2729

2830

2931
class SQLModelBaseOAuthAccount(BaseOAuthAccount, SQLModel):
32+
__tablename__ = "oauthaccount"
33+
3034
id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True)
35+
user_id: UUID4 = Field(foreign_key="user.id", nullable=False)
3136

3237
class Config:
3338
orm_mode = True
@@ -128,7 +133,7 @@ class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]):
128133
Database adapter for SQLModel working purely asynchronously.
129134
130135
:param user_db_model: SQLModel model of a DB representation of a user.
131-
:param engine: SQLAlchemy async engine.
136+
:param session: SQLAlchemy async session.
132137
"""
133138

134139
session: AsyncSession
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from datetime import datetime, timezone
2+
from typing import Generic, Optional, Type, TypeVar
3+
4+
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
5+
from fastapi_users.authentication.strategy.db.models import BaseAccessToken
6+
from pydantic import UUID4
7+
from sqlalchemy import Column, types
8+
from sqlalchemy.ext.asyncio import AsyncSession
9+
from sqlmodel import Field, Session, SQLModel, select
10+
11+
12+
def now_utc():
13+
return datetime.now(timezone.utc)
14+
15+
16+
class SQLModelBaseAccessToken(BaseAccessToken, SQLModel):
17+
__tablename__ = "accesstoken"
18+
19+
token: str = Field(
20+
sa_column=Column("token", types.String(length=43), primary_key=True)
21+
)
22+
created_at: datetime = Field(default_factory=now_utc, nullable=False)
23+
user_id: UUID4 = Field(foreign_key="user.id", nullable=False)
24+
25+
class Config:
26+
orm_mode = True
27+
28+
29+
A = TypeVar("A", bound=SQLModelBaseAccessToken)
30+
31+
32+
class SQLModelAccessTokenDatabase(Generic[A], AccessTokenDatabase[A]):
33+
"""
34+
Access token database adapter for SQLModel.
35+
36+
:param user_db_model: SQLModel model of a DB representation of an access token.
37+
:param session: SQLAlchemy session.
38+
"""
39+
40+
def __init__(self, access_token_model: Type[A], session: Session):
41+
self.access_token_model = access_token_model
42+
self.session = session
43+
44+
async def get_by_token(
45+
self, token: str, max_age: Optional[datetime] = None
46+
) -> Optional[A]:
47+
statement = select(self.access_token_model).where(
48+
self.access_token_model.token == token
49+
)
50+
if max_age is not None:
51+
statement = statement.where(self.access_token_model.created_at >= max_age)
52+
53+
results = self.session.exec(statement)
54+
return results.first()
55+
56+
async def create(self, access_token: A) -> A:
57+
self.session.add(access_token)
58+
self.session.commit()
59+
self.session.refresh(access_token)
60+
return access_token
61+
62+
async def update(self, access_token: A) -> A:
63+
self.session.add(access_token)
64+
self.session.commit()
65+
self.session.refresh(access_token)
66+
return access_token
67+
68+
async def delete(self, access_token: A) -> None:
69+
self.session.delete(access_token)
70+
self.session.commit()
71+
72+
73+
class SQLModelAccessTokenDatabaseAsync(Generic[A], AccessTokenDatabase[A]):
74+
"""
75+
Access token database adapter for SQLModel working purely asynchronously.
76+
77+
:param user_db_model: SQLModel model of a DB representation of an access token.
78+
:param session: SQLAlchemy async session.
79+
"""
80+
81+
def __init__(self, access_token_model: Type[A], session: AsyncSession):
82+
self.access_token_model = access_token_model
83+
self.session = session
84+
85+
async def get_by_token(
86+
self, token: str, max_age: Optional[datetime] = None
87+
) -> Optional[A]:
88+
statement = select(self.access_token_model).where(
89+
self.access_token_model.token == token
90+
)
91+
if max_age is not None:
92+
statement = statement.where(self.access_token_model.created_at >= max_age)
93+
94+
results = await self.session.execute(statement)
95+
object = results.first()
96+
if object is None:
97+
return None
98+
return object[0]
99+
100+
async def create(self, access_token: A) -> A:
101+
self.session.add(access_token)
102+
await self.session.commit()
103+
await self.session.refresh(access_token)
104+
return access_token
105+
106+
async def update(self, access_token: A) -> A:
107+
self.session.add(access_token)
108+
await self.session.commit()
109+
await self.session.refresh(access_token)
110+
return access_token
111+
112+
async def delete(self, access_token: A) -> None:
113+
await self.session.delete(access_token)
114+
await self.session.commit()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
description-file = "README.md"
2323
requires-python = ">=3.7"
2424
requires = [
25-
"fastapi-users >= 7.0.0",
25+
"fastapi-users >= 9.1.0",
2626
"sqlmodel",
2727
]
2828

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
aiosqlite >= 0.17.0
2-
fastapi-users >= 8.1.1
2+
fastapi-users >= 9.1.0
33
sqlmodel

tests/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import uuid
23
from typing import List, Optional
34

45
import pytest
@@ -31,15 +32,15 @@ class UserOAuth(User):
3132

3233

3334
class UserDBOAuth(SQLModelBaseUserDB, table=True):
34-
__tablename__ = "user"
35+
__tablename__ = "user_oauth"
3536
oauth_accounts: List["OAuthAccount"] = Relationship(
3637
back_populates="user",
3738
sa_relationship_kwargs={"lazy": "joined", "cascade": "all, delete"},
3839
)
3940

4041

4142
class OAuthAccount(SQLModelBaseOAuthAccount, table=True):
42-
user_id: UUID4 = Field(foreign_key="user.id")
43+
user_id: UUID4 = Field(foreign_key="user_oauth.id")
4344
user: Optional[UserDBOAuth] = Relationship(back_populates="oauth_accounts")
4445

4546

@@ -53,6 +54,7 @@ def event_loop():
5354
@pytest.fixture
5455
def oauth_account1() -> OAuthAccount:
5556
return OAuthAccount(
57+
id=uuid.UUID("b9089e5d-2642-406d-a7c0-cbc641aca0ec"),
5658
oauth_name="service1",
5759
access_token="TOKEN",
5860
expires_at=1579000751,
@@ -64,6 +66,7 @@ def oauth_account1() -> OAuthAccount:
6466
@pytest.fixture
6567
def oauth_account2() -> OAuthAccount:
6668
return OAuthAccount(
69+
id=uuid.UUID("c9089e5d-2642-406d-a7c0-cbc641aca0ec"),
6770
oauth_name="service2",
6871
access_token="TOKEN",
6972
expires_at=1579000751,

tests/test_access_token.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import uuid
2+
from datetime import datetime, timedelta, timezone
3+
from typing import AsyncGenerator
4+
5+
import pytest
6+
from pydantic import UUID4
7+
from sqlalchemy import exc
8+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
9+
from sqlalchemy.orm import sessionmaker
10+
from sqlmodel import Session, SQLModel, create_engine
11+
12+
from fastapi_users_db_sqlmodel import SQLModelUserDatabase, SQLModelUserDatabaseAsync
13+
from fastapi_users_db_sqlmodel.access_token import (
14+
SQLModelAccessTokenDatabase,
15+
SQLModelAccessTokenDatabaseAsync,
16+
SQLModelBaseAccessToken,
17+
)
18+
from tests.conftest import UserDB
19+
20+
21+
class AccessToken(SQLModelBaseAccessToken, table=True):
22+
pass
23+
24+
25+
@pytest.fixture
26+
def user_id() -> UUID4:
27+
return uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec")
28+
29+
30+
async def init_sync_session(url: str) -> AsyncGenerator[Session, None]:
31+
engine = create_engine(url, connect_args={"check_same_thread": False})
32+
SQLModel.metadata.create_all(engine)
33+
with Session(engine) as session:
34+
yield session
35+
SQLModel.metadata.drop_all(engine)
36+
37+
38+
async def init_async_session(url: str) -> AsyncGenerator[AsyncSession, None]:
39+
engine = create_async_engine(url, connect_args={"check_same_thread": False})
40+
make_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
41+
async with engine.begin() as conn:
42+
await conn.run_sync(SQLModel.metadata.create_all)
43+
async with make_session() as session:
44+
yield session
45+
await conn.run_sync(SQLModel.metadata.drop_all)
46+
47+
48+
@pytest.fixture(
49+
params=[
50+
(
51+
init_sync_session,
52+
"sqlite:///./test-sqlmodel-access-token.db",
53+
SQLModelAccessTokenDatabase,
54+
SQLModelUserDatabase,
55+
),
56+
(
57+
init_async_session,
58+
"sqlite+aiosqlite:///./test-sqlmodel-access-token.db",
59+
SQLModelAccessTokenDatabaseAsync,
60+
SQLModelUserDatabaseAsync,
61+
),
62+
],
63+
ids=["sync", "async"],
64+
)
65+
async def sqlmodel_access_token_db(
66+
request, user_id: UUID4
67+
) -> AsyncGenerator[SQLModelAccessTokenDatabase, None]:
68+
create_session = request.param[0]
69+
database_url = request.param[1]
70+
access_token_database_class = request.param[2]
71+
user_database_class = request.param[3]
72+
async for session in create_session(database_url):
73+
user = UserDB(
74+
id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere"
75+
)
76+
user_db = user_database_class(UserDB, session)
77+
await user_db.create(user)
78+
yield access_token_database_class(AccessToken, session)
79+
80+
81+
@pytest.mark.asyncio
82+
@pytest.mark.db
83+
async def test_queries(
84+
sqlmodel_access_token_db: SQLModelAccessTokenDatabase[AccessToken],
85+
user_id: UUID4,
86+
):
87+
access_token = AccessToken(token="TOKEN", user_id=user_id)
88+
89+
# Create
90+
access_token_db = await sqlmodel_access_token_db.create(access_token)
91+
assert access_token_db.token == "TOKEN"
92+
assert access_token_db.user_id == user_id
93+
94+
# Update
95+
access_token_db.created_at = datetime.now(timezone.utc)
96+
await sqlmodel_access_token_db.update(access_token_db)
97+
98+
# Get by token
99+
access_token_by_token = await sqlmodel_access_token_db.get_by_token(
100+
access_token_db.token
101+
)
102+
assert access_token_by_token is not None
103+
104+
# Get by token expired
105+
access_token_by_token = await sqlmodel_access_token_db.get_by_token(
106+
access_token_db.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1)
107+
)
108+
assert access_token_by_token is None
109+
110+
# Get by token not expired
111+
access_token_by_token = await sqlmodel_access_token_db.get_by_token(
112+
access_token_db.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1)
113+
)
114+
assert access_token_by_token is not None
115+
116+
# Get by token unknown
117+
access_token_by_token = await sqlmodel_access_token_db.get_by_token(
118+
"NOT_EXISTING_TOKEN"
119+
)
120+
assert access_token_by_token is None
121+
122+
# Delete token
123+
await sqlmodel_access_token_db.delete(access_token_db)
124+
deleted_access_token = await sqlmodel_access_token_db.get_by_token(
125+
access_token_db.token
126+
)
127+
assert deleted_access_token is None
128+
129+
130+
@pytest.mark.asyncio
131+
@pytest.mark.db
132+
async def test_insert_existing_token(
133+
sqlmodel_access_token_db: SQLModelAccessTokenDatabase[AccessToken], user_id: UUID4
134+
):
135+
access_token = AccessToken(token="TOKEN", user_id=user_id)
136+
await sqlmodel_access_token_db.create(access_token)
137+
138+
with pytest.raises(exc.IntegrityError):
139+
await sqlmodel_access_token_db.create(
140+
AccessToken(token="TOKEN", user_id=user_id)
141+
)

tests/test_fastapi_users_db_sqlmodel.py renamed to tests/test_users.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import pytest
55
from sqlalchemy import exc
6-
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
6+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
77
from sqlalchemy.orm import sessionmaker
8-
from sqlmodel import SQLModel, create_engine, Session
8+
from sqlmodel import Session, SQLModel, create_engine
99

1010
from fastapi_users_db_sqlmodel import (
1111
NotSetOAuthAccountTableError,
@@ -14,7 +14,6 @@
1414
)
1515
from tests.conftest import OAuthAccount, UserDB, UserDBOAuth
1616

17-
1817
safe_uuid = uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec")
1918

2019

0 commit comments

Comments
 (0)