Skip to content

Commit fab7f93

Browse files
committed
Make BeanieBaseUser and BeanieBaseAccessToken pure mixins to avoid duplicate collections
1 parent 206a1e5 commit fab7f93

File tree

4 files changed

+36
-24
lines changed

4 files changed

+36
-24
lines changed

fastapi_users_db_beanie/__init__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""FastAPI Users database adapter for Beanie."""
2-
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type, TypeVar
2+
from typing import Any, Dict, Generic, Optional, Type, TypeVar
33

44
import bson.errors
55
from beanie import Document, PydanticObjectId
@@ -13,9 +13,7 @@
1313
__version__ = "1.1.4"
1414

1515

16-
class BeanieBaseUser(Generic[ID], Document):
17-
if TYPE_CHECKING:
18-
id: ID # type: ignore # pragma: no cover
16+
class BeanieBaseUser(BaseModel):
1917
email: str
2018
hashed_password: str
2119
is_active: bool = True
@@ -32,7 +30,11 @@ class Settings:
3230
]
3331

3432

35-
UP_BEANIE = TypeVar("UP_BEANIE", bound=BeanieBaseUser)
33+
class BeanieBaseUserDocument(BeanieBaseUser, Document): # type: ignore
34+
pass
35+
36+
37+
UP_BEANIE = TypeVar("UP_BEANIE", bound=BeanieBaseUserDocument)
3638

3739

3840
class BaseOAuthAccount(BaseModel):
@@ -45,7 +47,9 @@ class BaseOAuthAccount(BaseModel):
4547
refresh_token: Optional[str] = None
4648

4749

48-
class BeanieUserDatabase(Generic[UP_BEANIE, ID], BaseUserDatabase[UP_BEANIE, ID]):
50+
class BeanieUserDatabase(
51+
Generic[UP_BEANIE], BaseUserDatabase[UP_BEANIE, PydanticObjectId]
52+
):
4953
"""
5054
Database adapter for Beanie.
5155

fastapi_users_db_beanie/access_token.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
from datetime import datetime, timezone
2-
from typing import Any, Dict, Generic, Optional, Type, TypeVar
2+
from typing import (
3+
Any,
4+
Dict,
5+
Generic,
6+
Optional,
7+
Type,
8+
TypeVar,
9+
)
310

4-
from beanie import Document
11+
from beanie import Document, PydanticObjectId
512
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
6-
from fastapi_users.models import ID
7-
from pydantic import Field
13+
from pydantic import BaseModel, Field
814
from pymongo import IndexModel
915

1016

11-
class BeanieBaseAccessToken(Generic[ID], Document):
17+
class BeanieBaseAccessToken(BaseModel):
1218
token: str
13-
user_id: ID
19+
user_id: PydanticObjectId
1420
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
1521

1622
class Settings:
1723
indexes = [IndexModel("token", unique=True)]
1824

1925

20-
AP_BEANIE = TypeVar("AP_BEANIE", bound=BeanieBaseAccessToken)
26+
class BeanieBaseAccessTokenDocument(BeanieBaseAccessToken, Document): # type: ignore
27+
pass
28+
29+
30+
AP_BEANIE = TypeVar("AP_BEANIE", bound=BeanieBaseAccessTokenDocument)
2131

2232

2333
class BeanieAccessTokenDatabase(Generic[AP_BEANIE], AccessTokenDatabase[AP_BEANIE]):

tests/test_access_token.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pymongo.errors
55
import pytest
6-
from beanie import PydanticObjectId, init_beanie
6+
from beanie import Document, PydanticObjectId, init_beanie
77
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
88

99
from fastapi_users_db_beanie.access_token import (
@@ -12,7 +12,7 @@
1212
)
1313

1414

15-
class AccessToken(BeanieBaseAccessToken[PydanticObjectId]):
15+
class AccessToken(BeanieBaseAccessToken, Document):
1616
pass
1717

1818

tests/test_fastapi_users_db_beanie.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pymongo.errors
44
import pytest
5-
from beanie import PydanticObjectId, init_beanie
5+
from beanie import Document, PydanticObjectId, init_beanie
66
from fastapi_users import InvalidID
77
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
88
from pydantic import Field
@@ -15,7 +15,7 @@
1515
)
1616

1717

18-
class User(BeanieBaseUser[PydanticObjectId]):
18+
class User(Document, BeanieBaseUser):
1919
first_name: Optional[str] = None
2020

2121

@@ -70,7 +70,7 @@ async def beanie_user_db_oauth(
7070

7171
@pytest.mark.asyncio
7272
async def test_queries(
73-
beanie_user_db: BeanieUserDatabase[User, PydanticObjectId],
73+
beanie_user_db: BeanieUserDatabase[User],
7474
oauth_account1: Dict[str, Any],
7575
):
7676
user_create = {
@@ -140,7 +140,7 @@ async def test_queries(
140140
],
141141
)
142142
async def test_email_query(
143-
beanie_user_db: BeanieUserDatabase[User, PydanticObjectId],
143+
beanie_user_db: BeanieUserDatabase[User],
144144
email: str,
145145
query: str,
146146
found: bool,
@@ -161,9 +161,7 @@ async def test_email_query(
161161

162162

163163
@pytest.mark.asyncio
164-
async def test_insert_existing_email(
165-
beanie_user_db: BeanieUserDatabase[User, PydanticObjectId]
166-
):
164+
async def test_insert_existing_email(beanie_user_db: BeanieUserDatabase[User]):
167165
user_create = {
168166
"email": "lancelot@camelot.bt",
169167
"hashed_password": "guinevere",
@@ -176,7 +174,7 @@ async def test_insert_existing_email(
176174

177175
@pytest.mark.asyncio
178176
async def test_queries_custom_fields(
179-
beanie_user_db: BeanieUserDatabase[User, PydanticObjectId],
177+
beanie_user_db: BeanieUserDatabase[User],
180178
):
181179
"""It should output custom fields in query result."""
182180
user_create = {
@@ -195,7 +193,7 @@ async def test_queries_custom_fields(
195193

196194
@pytest.mark.asyncio
197195
async def test_queries_oauth(
198-
beanie_user_db_oauth: BeanieUserDatabase[UserOAuth, PydanticObjectId],
196+
beanie_user_db_oauth: BeanieUserDatabase[UserOAuth],
199197
oauth_account1: Dict[str, Any],
200198
oauth_account2: Dict[str, Any],
201199
):

0 commit comments

Comments
 (0)