Skip to content

Commit 8a8600b

Browse files
committed
Make Beanie generic in terms of document ID
1 parent cd5dc19 commit 8a8600b

File tree

4 files changed

+21
-16
lines changed

4 files changed

+21
-16
lines changed

fastapi_users_db_beanie/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from beanie import Document, PydanticObjectId
55
from fastapi_users.db.base import BaseUserDatabase
6-
from fastapi_users.models import OAP
6+
from fastapi_users.models import ID, OAP
77
from pydantic import BaseModel, Field
88
from pymongo import IndexModel
99
from pymongo.collation import Collation
@@ -41,9 +41,7 @@ class BaseOAuthAccount(BaseModel):
4141
refresh_token: Optional[str] = None
4242

4343

44-
class BeanieUserDatabase(
45-
Generic[UP_BEANIE], BaseUserDatabase[UP_BEANIE, PydanticObjectId]
46-
):
44+
class BeanieUserDatabase(Generic[UP_BEANIE, ID], BaseUserDatabase[UP_BEANIE, ID]):
4745
"""
4846
Database adapter for Beanie.
4947
@@ -59,9 +57,9 @@ def __init__(
5957
self.user_model = user_model
6058
self.oauth_account_model = oauth_account_model
6159

62-
async def get(self, id: PydanticObjectId) -> Optional[UP_BEANIE]:
60+
async def get(self, id: ID) -> Optional[UP_BEANIE]:
6361
"""Get a single user by id."""
64-
return await self.user_model.get(id)
62+
return await self.user_model.get(id) # type: ignore
6563

6664
async def get_by_email(self, email: str) -> Optional[UP_BEANIE]:
6765
"""Get a single user by email."""

fastapi_users_db_beanie/access_token.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from datetime import datetime, timezone
22
from typing import Any, Dict, Generic, Optional, Type, TypeVar
33

4-
from beanie import Document, PydanticObjectId
4+
from beanie import Document
55
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
6+
from fastapi_users.models import ID
67
from pydantic import Field
78
from pymongo import IndexModel
89

910

10-
class BeanieBaseAccessToken(Document):
11+
class BeanieBaseAccessToken(Generic[ID], Document):
1112
token: str
12-
user_id: PydanticObjectId
13+
user_id: ID
1314
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
1415

1516
class Collection:

tests/test_access_token.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414

15-
class AccessToken(BeanieBaseAccessToken):
15+
class AccessToken(BeanieBaseAccessToken[PydanticObjectId]):
1616
pass
1717

1818

tests/test_fastapi_users_db_beanie.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pymongo.errors
44
import pytest
5-
from beanie import init_beanie
5+
from beanie import PydanticObjectId, init_beanie
66
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
77
from pydantic import Field
88

@@ -64,7 +64,8 @@ async def beanie_user_db_oauth(
6464

6565
@pytest.mark.asyncio
6666
async def test_queries(
67-
beanie_user_db: BeanieUserDatabase[User], oauth_account1: Dict[str, Any]
67+
beanie_user_db: BeanieUserDatabase[User, PydanticObjectId],
68+
oauth_account1: Dict[str, Any],
6869
):
6970
user_create = {
7071
"email": "lancelot@camelot.bt",
@@ -133,7 +134,10 @@ async def test_queries(
133134
],
134135
)
135136
async def test_email_query(
136-
beanie_user_db: BeanieUserDatabase[User], email: str, query: str, found: bool
137+
beanie_user_db: BeanieUserDatabase[User, PydanticObjectId],
138+
email: str,
139+
query: str,
140+
found: bool,
137141
):
138142
user_create = {
139143
"email": email,
@@ -151,7 +155,9 @@ async def test_email_query(
151155

152156

153157
@pytest.mark.asyncio
154-
async def test_insert_existing_email(beanie_user_db: BeanieUserDatabase[User]):
158+
async def test_insert_existing_email(
159+
beanie_user_db: BeanieUserDatabase[User, PydanticObjectId]
160+
):
155161
user_create = {
156162
"email": "lancelot@camelot.bt",
157163
"hashed_password": "guinevere",
@@ -164,7 +170,7 @@ async def test_insert_existing_email(beanie_user_db: BeanieUserDatabase[User]):
164170

165171
@pytest.mark.asyncio
166172
async def test_queries_custom_fields(
167-
beanie_user_db: BeanieUserDatabase[User],
173+
beanie_user_db: BeanieUserDatabase[User, PydanticObjectId],
168174
):
169175
"""It should output custom fields in query result."""
170176
user_create = {
@@ -183,7 +189,7 @@ async def test_queries_custom_fields(
183189

184190
@pytest.mark.asyncio
185191
async def test_queries_oauth(
186-
beanie_user_db_oauth: BeanieUserDatabase[UserOAuth],
192+
beanie_user_db_oauth: BeanieUserDatabase[UserOAuth, PydanticObjectId],
187193
oauth_account1: Dict[str, Any],
188194
oauth_account2: Dict[str, Any],
189195
):

0 commit comments

Comments
 (0)