Skip to content

Commit d14de37

Browse files
committed
refactor auth module, support for renew auth cookie if token expired #947
1 parent 31d8f21 commit d14de37

File tree

5 files changed

+96
-63
lines changed

5 files changed

+96
-63
lines changed
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
from typing import Dict, Optional
3+
4+
5+
class AuthCookies:
6+
def __init__(self, cookies):
7+
# type: (Dict) -> None
8+
self._cookies = cookies
9+
10+
@property
11+
def is_valid(self) -> bool:
12+
"""Validates authorization cookies."""
13+
return bool(self._cookies) and (
14+
self.fed_auth is not None or self.spo_idcrl is not None
15+
)
16+
17+
@property
18+
def cookie_header(self):
19+
"""Converts stored cookies into an HTTP Cookie header string."""
20+
return "; ".join(f"{key}={val}" for key, val in self._cookies.items())
21+
22+
@property
23+
def fed_auth(self):
24+
# type: () -> Optional[str]
25+
"""Returns the Primary authentication token."""
26+
return self._cookies.get("FedAuth", None)
27+
28+
@property
29+
def spo_idcrl(self):
30+
# type: () -> Optional[str]
31+
"""Returns the secondary authentication token. (SharePoint Online Identity CRL)"""
32+
return self._cookies.get("SPOIDCRL", None)
33+
34+
@property
35+
def rt_fa(self):
36+
# type: () -> Optional[str]
37+
"""Returns the refresh token for Federated Authentication."""
38+
return self._cookies.get("rtFa", None)
39+
40+
def to_json(self):
41+
"""Serializes cookies to JSON format."""
42+
return json.dumps(self._cookies, indent=2)
43+
44+
@classmethod
45+
def from_json(cls, json_data):
46+
"""Deserializes cookies from JSON format."""
47+
cookies_dict = json.loads(json_data)
48+
return cls(cookies_dict)

office365/runtime/auth/providers/acs_token_provider.py

+4-16
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ def __init__(self, url, client_id, client_secret, environment="commercial"):
2626
self.SharePointPrincipal = "00000003-0000-0ff1-ce00-000000000000"
2727
self._client_id = client_id
2828
self._client_secret = client_secret
29-
self._cached_token = None
29+
self._cached_token = None # type: Optional[TokenResponse]
3030
self._environment = environment
3131

3232
def authenticate_request(self, request):
3333
# type: (RequestOptions) -> None
34-
self._ensure_app_only_access_token()
35-
request.set_header("Authorization", self._get_authorization_header())
34+
if self._cached_token is None:
35+
self._cached_token = self.get_app_only_access_token()
36+
request.set_header("Authorization", self._cached_token.authorization_header)
3637

3738
def get_app_only_access_token(self):
3839
"""Retrieves an app-only access token from ACS"""
@@ -48,11 +49,6 @@ def get_app_only_access_token(self):
4849
)
4950
raise ValueError(self.error)
5051

51-
def _ensure_app_only_access_token(self):
52-
if self._cached_token is None:
53-
self._cached_token = self.get_app_only_access_token()
54-
return self._cached_token and self._cached_token.is_valid
55-
5652
def _get_app_only_access_token(self, target_host, target_realm):
5753
"""
5854
Retrieves an app-only access token from ACS to call the specified principal
@@ -86,11 +82,6 @@ def _get_app_only_access_token(self, target_host, target_realm):
8682
def _get_realm_from_target_url(self):
8783
"""Get the realm for the URL"""
8884
response = requests.head(url=self.url, headers={"Authorization": "Bearer"})
89-
return self.process_realm_response(response)
90-
91-
@staticmethod
92-
def process_realm_response(response):
93-
# type: (requests.Response) -> Optional[str]
9485
header_key = "WWW-Authenticate"
9586
if header_key in response.headers:
9687
auth_values = response.headers[header_key].split(",")
@@ -116,6 +107,3 @@ def get_security_token_service_url(realm, environment):
116107
realm
117108
)
118109
)
119-
120-
def _get_authorization_header(self):
121-
return "Bearer {0}".format(self._cached_token.accessToken)

office365/runtime/auth/providers/saml_token_provider.py

+32-39
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import os
22
import uuid
3+
from datetime import datetime, timezone
4+
from typing import Optional
35
from xml.dom import minidom
46
from xml.etree import ElementTree
57

68
import requests
79
import requests.utils
810

911
import office365.logger
12+
from office365.runtime.auth.auth_cookies import AuthCookies
1013
from office365.runtime.auth.authentication_provider import AuthenticationProvider
1114
from office365.runtime.auth.sts_profile import STSProfile
1215
from office365.runtime.auth.user_realm_info import UserRealmInfo
@@ -20,23 +23,17 @@ def resolve_base_url(url):
2023
return parts[0] + "://" + host_name
2124

2225

23-
def xml_escape(s_val):
24-
s_val = s_val.replace("&", "&")
25-
s_val = s_val.replace("<", "&lt;")
26-
s_val = s_val.replace(">", "&gt;")
27-
s_val = s_val.replace('"', "&quot;")
28-
s_val = s_val.replace("'", "&apos;")
29-
return s_val
26+
def string_escape(value):
27+
value = value.replace("&", "&amp;")
28+
value = value.replace("<", "&lt;")
29+
value = value.replace(">", "&gt;")
30+
value = value.replace('"', "&quot;")
31+
value = value.replace("'", "&apos;")
32+
return value
3033

3134

32-
def is_valid_auth_cookies(values):
33-
"""
34-
Validates authorization cookies
35-
"""
36-
return any(values) and (
37-
values.get("FedAuth", None) is not None
38-
or values.get("SPOIDCRL", None) is not None
39-
)
35+
def datetime_escape(value):
36+
return value.isoformat("T")[:-9] + "Z"
4037

4138

4239
class SamlTokenProvider(AuthenticationProvider, office365.logger.LoggerContext):
@@ -60,7 +57,7 @@ def __init__(self, url, username, password, browser_mode, environment="commercia
6057
self.error = ""
6158
self._username = username
6259
self._password = password
63-
self._cached_auth_cookies = None
60+
self._cached_auth_cookies = None # type: Optional[AuthCookies]
6461
self.__ns_prefixes = {
6562
"S": "{http://www.w3.org/2003/05/soap-envelope}",
6663
"s": "{http://www.w3.org/2003/05/soap-envelope}",
@@ -82,24 +79,20 @@ def authenticate_request(self, request):
8279
Authenticate request handler
8380
"""
8481
logger = self.logger(self.authenticate_request.__name__)
85-
self.ensure_authentication_cookie()
86-
logger.debug_secrets(self._cached_auth_cookies)
87-
cookie_header_value = "; ".join(
88-
[
89-
"=".join([key, str(val)])
90-
for key, val in self._cached_auth_cookies.items()
91-
]
92-
)
93-
request.set_header("Cookie", cookie_header_value)
9482

95-
def ensure_authentication_cookie(self):
96-
if self._cached_auth_cookies is None:
83+
request_time = datetime.now(timezone.utc)
84+
if (
85+
self._cached_auth_cookies is None
86+
or request_time >= self._sts_profile.expires
87+
):
88+
self._sts_profile.reset()
9789
self._cached_auth_cookies = self.get_authentication_cookie()
98-
return True
90+
logger.debug_secrets(self._cached_auth_cookies)
91+
request.set_header("Cookie", self._cached_auth_cookies.cookie_header)
9992

10093
def get_authentication_cookie(self):
10194
"""Acquire authentication cookie"""
102-
logger = self.logger(self.ensure_authentication_cookie.__name__)
95+
logger = self.logger(self.get_authentication_cookie.__name__)
10396
logger.debug("get_authentication_cookie called")
10497

10598
try:
@@ -140,10 +133,10 @@ def _acquire_service_token_from_adfs(self, adfs_url):
140133
{
141134
"auth_url": adfs_url,
142135
"message_id": str(uuid.uuid4()),
143-
"username": xml_escape(self._username),
144-
"password": xml_escape(self._password),
145-
"created": self._sts_profile.created,
146-
"expires": self._sts_profile.expires,
136+
"username": string_escape(self._username),
137+
"password": string_escape(self._password),
138+
"created": datetime_escape(self._sts_profile.created),
139+
"expires": datetime_escape(self._sts_profile.expires),
147140
"issuer": self._sts_profile.tokenIssuer,
148141
},
149142
)
@@ -192,11 +185,11 @@ def _acquire_service_token(self):
192185
"SAML.xml",
193186
{
194187
"auth_url": self._sts_profile.authorityUrl,
195-
"username": xml_escape(self._username),
196-
"password": xml_escape(self._password),
188+
"username": string_escape(self._username),
189+
"password": string_escape(self._password),
197190
"message_id": str(uuid.uuid4()),
198-
"created": self._sts_profile.created,
199-
"expires": self._sts_profile.expires,
191+
"created": datetime_escape(self._sts_profile.created),
192+
"expires": datetime_escape(self._sts_profile.expires),
200193
"issuer": self._sts_profile.tokenIssuer,
201194
},
202195
)
@@ -297,9 +290,9 @@ def _get_authentication_cookie(self, security_token, federated=False):
297290
},
298291
)
299292
logger.debug_secrets("session.cookies: %s", session.cookies)
300-
cookies = requests.utils.dict_from_cookiejar(session.cookies)
293+
cookies = AuthCookies(requests.utils.dict_from_cookiejar(session.cookies))
301294
logger.debug_secrets("cookies: %s", cookies)
302-
if not is_valid_auth_cookies(cookies):
295+
if not cookies.is_valid:
303296
self.error = (
304297
"An error occurred while retrieving auth cookies from {0}".format(
305298
self._sts_profile.signin_page_url

office365/runtime/auth/sts_profile.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
class STSProfile(object):
77
def __init__(self, authority_url, environment):
8-
"""
9-
:type authority_url: str
10-
"""
8+
# type: (str, str) -> None
119
self.authorityUrl = authority_url
1210
if environment == "GCCH":
1311
self.serviceUrl = "https://login.microsoftonline.us"
@@ -16,13 +14,15 @@ def __init__(self, authority_url, environment):
1614
self.securityTokenServicePath = "extSTS.srf"
1715
self.userRealmServicePath = "GetUserRealm.srf"
1816
self.tokenIssuer = "urn:federation:MicrosoftOnline"
19-
now = datetime.now(tz=timezone.utc)
20-
self.created = now.astimezone(timezone.utc).isoformat("T")[:-9] + "Z"
21-
self.expires = (now + timedelta(minutes=10)).astimezone(timezone.utc).isoformat(
22-
"T"
23-
)[:-9] + "Z"
17+
self.created = datetime.now(tz=timezone.utc)
18+
self.expires = self.created + timedelta(minutes=30)
2419
self.signInPage = "_forms/default.aspx?wa=wsignin1.0"
2520

21+
def reset(self):
22+
"""Renew the expiration time."""
23+
self.created = datetime.now(tz=timezone.utc)
24+
self.expires = self.created + timedelta(minutes=30)
25+
2626
@property
2727
def tenant(self):
2828
return urlparse(self.authorityUrl).netloc

office365/runtime/auth/token_response.py

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ def __init__(self, access_token=None, token_type=None, **kwargs):
99
def is_valid(self):
1010
return self.accessToken is not None and self.tokenType == "Bearer"
1111

12+
@property
13+
def authorization_header(self):
14+
return "Bearer {0}".format(self.accessToken)
15+
1216
@staticmethod
1317
def from_json(value):
1418
error = value.get("error", None)

0 commit comments

Comments
 (0)