Skip to content

Commit 8b5e0d8

Browse files
committed
Add AccessToken adapter
1 parent 41fffc7 commit 8b5e0d8

File tree

4 files changed

+170
-6
lines changed

4 files changed

+170
-6
lines changed

fastapi_users_db_beanie/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pymongo import IndexModel
88
from pymongo.collation import Collation
99

10-
1110
__version__ = "0.0.0"
1211

1312

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from datetime import datetime, timezone
2+
from typing import Any, Dict, Generic, Optional, Type, TypeVar
3+
4+
from beanie import Document, PydanticObjectId
5+
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
6+
from pydantic import Field
7+
from pymongo import IndexModel
8+
9+
10+
class BeanieBaseAccessToken(Document):
11+
token: str
12+
user_id: PydanticObjectId
13+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
14+
15+
class Collection:
16+
indexes = [IndexModel("token", unique=True)]
17+
18+
19+
AP_BEANIE = TypeVar("AP_BEANIE", bound=BeanieBaseAccessToken)
20+
21+
22+
class BeanieAccessTokenDatabase(Generic[AP_BEANIE], AccessTokenDatabase[AP_BEANIE]):
23+
"""
24+
Access token database adapter for Beanie.
25+
26+
:param access_token_model: Beanie access token model.
27+
"""
28+
29+
def __init__(self, access_token_model: Type[AP_BEANIE]):
30+
self.access_token_model = access_token_model
31+
32+
async def get_by_token(
33+
self, token: str, max_age: Optional[datetime] = None
34+
) -> Optional[AP_BEANIE]:
35+
query: Dict[str, Any] = {"token": token}
36+
if max_age is not None:
37+
query["created_at"] = {"$gte": max_age}
38+
return await self.access_token_model.find_one(query)
39+
40+
async def create(self, create_dict: Dict[str, Any]) -> AP_BEANIE:
41+
access_token = self.access_token_model(**create_dict)
42+
return await access_token.save()
43+
44+
async def update(
45+
self, access_token: AP_BEANIE, update_dict: Dict[str, Any]
46+
) -> AP_BEANIE:
47+
for key, value in update_dict.items():
48+
setattr(access_token, key, value)
49+
return await access_token.save()
50+
51+
async def delete(self, access_token: AP_BEANIE) -> None:
52+
await access_token.delete()

tests/test_access_token.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from datetime import datetime, timedelta, timezone
2+
from typing import AsyncGenerator
3+
4+
import pymongo.errors
5+
import pytest
6+
from beanie import PydanticObjectId, init_beanie
7+
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
8+
9+
from fastapi_users_db_beanie.access_token import (
10+
BeanieAccessTokenDatabase,
11+
BeanieBaseAccessToken,
12+
)
13+
14+
15+
class AccessToken(BeanieBaseAccessToken):
16+
pass
17+
18+
19+
@pytest.fixture(scope="module")
20+
async def mongodb_client():
21+
client = AsyncIOMotorClient(
22+
"mongodb://localhost:27017",
23+
serverSelectionTimeoutMS=10000,
24+
uuidRepresentation="standard",
25+
)
26+
27+
try:
28+
await client.server_info()
29+
yield client
30+
client.close()
31+
except pymongo.errors.ServerSelectionTimeoutError:
32+
pytest.skip("MongoDB not available", allow_module_level=True)
33+
return
34+
35+
36+
@pytest.fixture
37+
async def beanie_access_token_db(
38+
mongodb_client: AsyncIOMotorClient,
39+
) -> AsyncGenerator[BeanieAccessTokenDatabase, None]:
40+
database: AsyncIOMotorDatabase = mongodb_client["test_database"]
41+
await init_beanie(database=database, document_models=[AccessToken])
42+
43+
yield BeanieAccessTokenDatabase(AccessToken)
44+
45+
await mongodb_client.drop_database("test_database")
46+
47+
48+
@pytest.fixture
49+
def user_id() -> PydanticObjectId:
50+
return PydanticObjectId()
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_queries(
55+
beanie_access_token_db: BeanieAccessTokenDatabase[AccessToken],
56+
user_id: PydanticObjectId,
57+
):
58+
access_token_create = {"token": "TOKEN", "user_id": user_id}
59+
60+
# Create
61+
access_token = await beanie_access_token_db.create(access_token_create)
62+
assert access_token.token == "TOKEN"
63+
assert access_token.user_id == user_id
64+
65+
# Update
66+
update_dict = {"created_at": datetime.now(timezone.utc)}
67+
updated_access_token = await beanie_access_token_db.update(
68+
access_token, update_dict
69+
)
70+
assert updated_access_token.created_at == update_dict["created_at"]
71+
72+
# Get by token
73+
access_token_by_token = await beanie_access_token_db.get_by_token(
74+
access_token.token
75+
)
76+
assert access_token_by_token is not None
77+
78+
# Get by token expired
79+
access_token_by_token = await beanie_access_token_db.get_by_token(
80+
access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1)
81+
)
82+
assert access_token_by_token is None
83+
84+
# Get by token not expired
85+
access_token_by_token = await beanie_access_token_db.get_by_token(
86+
access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1)
87+
)
88+
assert access_token_by_token is not None
89+
90+
# Get by token unknown
91+
access_token_by_token = await beanie_access_token_db.get_by_token(
92+
"NOT_EXISTING_TOKEN"
93+
)
94+
assert access_token_by_token is None
95+
96+
# Delete token
97+
await beanie_access_token_db.delete(access_token)
98+
deleted_access_token = await beanie_access_token_db.get_by_token(access_token.token)
99+
assert deleted_access_token is None
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_insert_existing_token(
104+
beanie_access_token_db: BeanieAccessTokenDatabase[AccessToken],
105+
user_id: PydanticObjectId,
106+
):
107+
access_token_create = {"token": "TOKEN", "user_id": user_id}
108+
await beanie_access_token_db.create(access_token_create)
109+
110+
with pytest.raises(pymongo.errors.DuplicateKeyError):
111+
await beanie_access_token_db.create(access_token_create)

tests/test_fastapi_users_db_beanie.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import Any, AsyncGenerator, Dict, List, Optional
22

3-
from beanie import init_beanie
4-
from pydantic import Field
5-
import pytest
63
import pymongo.errors
4+
import pytest
5+
from beanie import init_beanie
76
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
7+
from pydantic import Field
88

9-
from fastapi_users_db_beanie import BaseOAuthAccount, BeanieUserDatabase, BeanieBaseUser
9+
from fastapi_users_db_beanie import BaseOAuthAccount, BeanieBaseUser, BeanieUserDatabase
1010

1111

1212
class User(BeanieBaseUser):
@@ -63,7 +63,9 @@ async def beanie_user_db_oauth(
6363

6464

6565
@pytest.mark.asyncio
66-
async def test_queries(beanie_user_db: BeanieUserDatabase[User], oauth_account1: Dict[str, Any]):
66+
async def test_queries(
67+
beanie_user_db: BeanieUserDatabase[User], oauth_account1: Dict[str, Any]
68+
):
6769
user_create = {
6870
"email": "lancelot@camelot.bt",
6971
"hashed_password": "guinevere",

0 commit comments

Comments
 (0)