diff --git a/social_django/models.py b/social_django/models.py index a5314cb3..abbd77ff 100644 --- a/social_django/models.py +++ b/social_django/models.py @@ -1,5 +1,7 @@ """Django ORM models for Social Auth""" +from typing import Union + from django.conf import settings from django.db import models from django.db.utils import IntegrityError @@ -54,7 +56,9 @@ class Meta: abstract = True @classmethod - def get_social_auth(cls, provider: str, uid: str): + def get_social_auth(cls, provider: str, uid: Union[str, int]): + if not isinstance(uid, str): + uid = str(uid) for social in cls.objects.select_related("user").filter( provider=provider, uid=uid ): diff --git a/tests/test_models.py b/tests/test_models.py index 3fc0df5e..24e7a83d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -144,7 +144,7 @@ def test_get_social_auth(self): self.assertEqual( UserSocialAuth.get_social_auth(provider=usa.provider, uid=usa.uid), usa ) - self.assertIsNone(UserSocialAuth.get_social_auth(provider="a", uid=1)) + self.assertIsNone(UserSocialAuth.get_social_auth(provider="a", uid="1")) # Mixin self.assertEqual( @@ -154,7 +154,7 @@ def test_get_social_auth(self): usa, ) self.assertIsNone( - super(AbstractUserSocialAuth, usa).get_social_auth(provider="a", uid=1) + super(AbstractUserSocialAuth, usa).get_social_auth(provider="a", uid="1") ) # Manager @@ -162,7 +162,30 @@ def test_get_social_auth(self): UserSocialAuth.objects.get_social_auth(provider=usa.provider, uid=usa.uid), usa, ) - self.assertIsNone(UserSocialAuth.objects.get_social_auth(provider="a", uid=1)) + self.assertIsNone(UserSocialAuth.objects.get_social_auth(provider="a", uid="1")) + + def test_get_social_auth_int_uid(self): + usa = self.usa + int_uid = int(usa.uid) + + # Model + self.assertEqual( + UserSocialAuth.get_social_auth(provider=usa.provider, uid=int_uid), usa + ) + + # Mixin + self.assertEqual( + super(AbstractUserSocialAuth, usa).get_social_auth( + provider=usa.provider, uid=usa.uid + ), + usa, + ) + + # Manager + self.assertEqual( + UserSocialAuth.get_social_auth(provider=usa.provider, uid=int_uid), + usa, + ) def test_get_social_auth_for_user(self): qs = UserSocialAuth.get_social_auth_for_user(