Skip to content

Commit d9b4c1c

Browse files
committed
feat: implement session restore when cookies are not reliable
When cookies are created with SameSite policy, they won't be available during the authentication flow which uses POST such as OpenID or SAML. This adds support in Strategy to get session ID and restore it later in the login flow. See python-social-auth/social-app-django#481
1 parent 8f2669c commit d9b4c1c

File tree

6 files changed

+51
-11
lines changed

6 files changed

+51
-11
lines changed

social_core/backends/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def continue_pipeline(self, partial):
224224
self, pipeline_index=partial.next_step, *partial.args, **partial.kwargs
225225
)
226226

227-
def auth_extra_arguments(self):
227+
def auth_extra_arguments(self) -> dict[str, str]:
228228
"""Return extra arguments needed on auth process. The defaults can be
229229
overridden by GET parameters."""
230230
extra_arguments = self.setting("AUTH_EXTRA_ARGUMENTS", {}).copy()

social_core/backends/oauth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ def oauth_authorization_request(self, token):
280280
token = parse_qs(token)
281281
params = self.auth_extra_arguments() or {}
282282
params.update(self.get_scope_argument())
283-
params[self.OAUTH_TOKEN_PARAMETER_NAME] = token.get(
284-
self.OAUTH_TOKEN_PARAMETER_NAME
283+
params[self.OAUTH_TOKEN_PARAMETER_NAME] = cast(
284+
"str", token.get(self.OAUTH_TOKEN_PARAMETER_NAME)
285285
)
286286
state = self.get_or_create_state()
287287
params[self.REDIRECT_URI_PARAMETER_NAME] = self.get_redirect_uri(state)

social_core/backends/open_id.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,25 @@ def extra_data(self, user, uid, response, details=None, *args, **kwargs):
144144
values.update(from_details)
145145
return values
146146

147+
def get_return_to(self) -> str:
148+
params: dict[str, str] = {}
149+
if session_id := self.strategy.get_session_id():
150+
params[self.strategy.SESSION_SAVE_KEY] = session_id
151+
152+
return url_add_parameters(self.strategy.absolute_uri(self.redirect_uri), params)
153+
147154
def auth_url(self):
148155
"""Return auth URL returned by service"""
149156
openid_request = self.setup_request(self.auth_extra_arguments())
150157
# Construct completion URL, including page we should redirect to
151-
return_to = self.strategy.absolute_uri(self.redirect_uri)
152-
return openid_request.redirectURL(self.trust_root(), return_to)
158+
return openid_request.redirectURL(self.trust_root(), self.get_return_to())
153159

154160
def auth_html(self):
155161
"""Return auth HTML returned by service"""
156162
openid_request = self.setup_request(self.auth_extra_arguments())
157-
return_to = self.strategy.absolute_uri(self.redirect_uri)
158163
form_tag = {"id": "openid_message"}
159164
return openid_request.htmlMarkup(
160-
self.trust_root(), return_to, form_tag_attrs=form_tag
165+
self.trust_root(), self.get_return_to(), form_tag_attrs=form_tag
161166
)
162167

163168
def trust_root(self):
@@ -167,7 +172,7 @@ def trust_root(self):
167172
def continue_pipeline(self, partial):
168173
"""Continue previous halted pipeline"""
169174
response = self.consumer().complete(
170-
dict(self.data.items()), self.strategy.absolute_uri(self.redirect_uri)
175+
dict(self.data.items()), self.get_return_to()
171176
)
172177
return self.strategy.authenticate(
173178
self,
@@ -180,9 +185,11 @@ def continue_pipeline(self, partial):
180185
def auth_complete(self, *args, **kwargs):
181186
"""Complete auth process"""
182187
response = self.consumer().complete(
183-
dict(self.data.items()), self.strategy.absolute_uri(self.redirect_uri)
188+
dict(self.data.items()), self.get_return_to()
184189
)
185190
self.process_error(response)
191+
if session_id := self.data.get(self.strategy.SESSION_SAVE_KEY):
192+
self.strategy.restore_session(session_id, kwargs)
186193
return self.strategy.authenticate(self, response=response, *args, **kwargs)
187194

188195
def process_error(self, data):

social_core/backends/saml.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ def auth_url(self):
286286
"idp": idp_name,
287287
"next": self.data.get("next"),
288288
}
289+
if session_id := self.strategy.get_session_id():
290+
relay_state[self.strategy.SESSION_SAVE_KEY] = session_id
289291
return auth.login(return_to=json.dumps(relay_state))
290292

291293
def get_user_details(self, response):
@@ -326,7 +328,9 @@ def auth_complete(self, *args, **kwargs):
326328
idp_name = relay_state_str
327329
else:
328330
idp_name = relay_state["idp"]
329-
if next_url := relay_state.get("next"):
331+
if session_id := relay_state.get(self.strategy.SESSION_SAVE_KEY):
332+
self.strategy.restore_session(session_id, kwargs)
333+
elif next_url := relay_state.get("next"):
330334
# The do_complete action expects the "next" URL to be in session state or the request params.
331335
self.strategy.session_set(kwargs.get("redirect_name", "next"), next_url)
332336

social_core/exceptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,22 @@ class SocialAuthBaseException(ValueError):
1010
"""Base class for pipeline exceptions."""
1111

1212

13+
class StrategyMissingFeatureError(SocialAuthBaseException):
14+
"""Strategy does not support this."""
15+
16+
def __init__(self, strategy_name: str, feature_name: str):
17+
self.strategy_name = strategy_name
18+
self.feature_name = feature_name
19+
super().__init__()
20+
21+
def __str__(self):
22+
return f"Strategy {self.strategy_name} does not support {self.feature_name}"
23+
24+
1325
class WrongBackend(SocialAuthBaseException):
1426
def __init__(self, backend_name: str):
1527
self.backend_name = backend_name
28+
super().__init__()
1629

1730
def __str__(self):
1831
return f'Incorrect authentication service "{self.backend_name}"'

social_core/strategy.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
import secrets
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any
55

66
from .backends.utils import get_backend
7+
from .exceptions import StrategyMissingFeatureError
78
from .pipeline import DEFAULT_AUTH_PIPELINE, DEFAULT_DISCONNECT_PIPELINE
89
from .pipeline.utils import partial_load, partial_prepare, partial_store
910
from .store import OpenIdSessionWrapper, OpenIdStore
@@ -35,6 +36,7 @@ def render_string(self, html, context):
3536
class BaseStrategy:
3637
ALLOWED_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
3738
DEFAULT_TEMPLATE_STRATEGY = BaseTemplateStrategy
39+
SESSION_SAVE_KEY = "psa_session_id"
3840

3941
def __init__(self, storage=None, tpl=None):
4042
self.storage = storage
@@ -61,6 +63,20 @@ def session_setdefault(self, name, value):
6163
self.session_set(name, value)
6264
return self.session_get(name)
6365

66+
def get_session_id(self) -> str | None:
67+
"""
68+
Return session ID to be used by restore_session.
69+
"""
70+
return None
71+
72+
def restore_session(self, session_id: str, kwargs: dict[str, Any]) -> None:
73+
"""
74+
Restores session and updates kwargs to match it.
75+
76+
This is only called if get_session_id returns a value.
77+
"""
78+
raise StrategyMissingFeatureError(self.__class__.__name__, "session restore")
79+
6480
def openid_session_dict(self, name):
6581
# Many frameworks are switching the session serialization from Pickle
6682
# to JSON to avoid code execution risks. Flask did this from Flask

0 commit comments

Comments
 (0)