Skip to content

[Internal] Implement async token refresh #893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 123 additions & 10 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import urllib.parse
import webbrowser
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -187,21 +189,132 @@ def retrieve_token(client_id,
raise NotImplementedError(f"Not supported yet: {e}")


class _TokenState(Enum):
"""
Represents the state of a token. Each token can be in one of
the following three states:
- FRESH: The token is valid.
- STALE: The token is valid but will expire soon.
- EXPIRED: The token has expired and cannot be used.
"""
FRESH = 1 # The token is valid.
STALE = 2 # The token is valid but will expire soon.
EXPIRED = 3 # The token has expired and cannot be used.


class Refreshable(TokenSource):
"""A token source that supports refreshing expired tokens."""

_EXECUTOR = None
_EXECUTOR_LOCK = threading.Lock()
_DEFAULT_STALE_DURATION = timedelta(minutes=3)

@classmethod
def _get_executor(cls):
"""Lazy initialization of the ThreadPoolExecutor."""
if cls._EXECUTOR is None:
with cls._EXECUTOR_LOCK:
if cls._EXECUTOR is None:
# This thread pool has multiple workers because it is shared by all instances of Refreshable.
cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
return cls._EXECUTOR

def __init__(self, token=None):
self._lock = threading.Lock() # to guard _token
def __init__(self,
token: Token = None,
disable_async: bool = True,
stale_duration: timedelta = _DEFAULT_STALE_DURATION):
# Config properties
self._stale_duration = stale_duration
self._disable_async = disable_async
# Lock
self._lock = threading.Lock()
# Non Thread safe properties. They should be accessed only when protected by the lock above.
self._token = token
self._is_refreshing = False
self._refresh_err = False

# This is the main entry point for the Token. Do not access the token
# using any of the internal functions.
def token(self) -> Token:
self._lock.acquire()
try:
if self._token and self._token.valid:
return self._token
self._token = self.refresh()
"""Returns a valid token, blocking if async refresh is disabled."""
with self._lock:
if self._disable_async:
return self._blocking_token()
return self._async_token()

def _async_token(self) -> Token:
"""
Returns a token.
If the token is stale, triggers an asynchronous refresh.
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
"""
state = self._token_state()
token = self._token

if state == _TokenState.FRESH:
return token
if state == _TokenState.STALE:
self._trigger_async_refresh()
return token
return self._blocking_token()

def _token_state(self) -> _TokenState:
"""Returns the current state of the token."""
if not self._token or not self._token.valid:
return _TokenState.EXPIRED
if not self._token.expiry:
return _TokenState.FRESH

lifespan = self._token.expiry - datetime.now()
if lifespan < timedelta(seconds=0):
return _TokenState.EXPIRED
if lifespan < self._stale_duration:
return _TokenState.STALE
return _TokenState.FRESH

def _blocking_token(self) -> Token:
"""Returns a token, blocking if necessary to refresh it."""
state = self._token_state()
# This is important to recover from potential previous failed attempts
# to refresh the token asynchronously.
self._refresh_err = False
self._is_refreshing = False

# It's possible that the token got refreshed (either by a _blocking_refresh or
# an _async_refresh call) while this particular call was waiting to acquire
# the lock. This check avoids refreshing the token again in such cases.
if state != _TokenState.EXPIRED:
return self._token
finally:
self._lock.release()

self._token = self.refresh()
return self._token

def _trigger_async_refresh(self):
"""Starts an asynchronous refresh if none is in progress."""

def _refresh_internal():
new_token: Token = None
try:
new_token = self.refresh()
except Exception as e:
# This happens on a thread, so we don't want to propagate the error.
# Instead, if there is no new_token for any reason, we will disable async refresh below
# But we will do it inside the lock.
logger.warning(f'Tried to refresh token asynchronously, but failed: {e}')

with self._lock:
if new_token is not None:
self._token = new_token
else:
self._refresh_err = True
self._is_refreshing = False

# The token may have been refreshed by another thread.
if self._token_state() == _TokenState.FRESH:
return
if not self._is_refreshing and not self._refresh_err:
self._is_refreshing = True
Refreshable._get_executor().submit(_refresh_internal)

@abstractmethod
def refresh(self) -> Token:
Expand Down Expand Up @@ -295,7 +408,7 @@ def __init__(self,
super().__init__(token)

def as_dict(self) -> dict:
return {'token': self._token.as_dict()}
return {'token': self.token().as_dict()}

@staticmethod
def from_dict(raw: dict,
Expand Down
216 changes: 216 additions & 0 deletions tests/test_refreshable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import time
from datetime import datetime, timedelta
from time import sleep
from typing import Callable

from databricks.sdk.oauth import Refreshable, Token


class _MockRefreshable(Refreshable):

def __init__(self,
disable_async,
token=None,
stale_duration=timedelta(seconds=60),
refresh_effect: Callable[[], Token] = None):
super().__init__(token, disable_async, stale_duration)
self._refresh_effect = refresh_effect
self._refresh_count = 0

def refresh(self) -> Token:
if self._refresh_effect:
self._token = self._refresh_effect()
self._refresh_count += 1
return self._token


def fail() -> Token:
raise Exception("Simulated token refresh failure")


def static_token(token: Token, wait: int = 0) -> Callable[[], Token]:

def f() -> Token:
time.sleep(wait)
return token

return f


def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[], None]):
"""
Create a refresh function that blocks until unblock is called.

Param:
token: the token that will be returned

Returns:
A tuple containing the refresh function and the unblock function.

"""
blocking = True

def refresh():
while blocking:
sleep(0.1)
return token

def unblock():
nonlocal blocking
blocking = False

return refresh, unblock


def test_disable_async_stale_does_not_refresh():
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), )
r = _MockRefreshable(token=stale_token, disable_async=True, refresh_effect=fail)
result = r.token()
assert r._refresh_count == 0
assert result == stale_token


def test_disable_async_no_token_does_refresh():
token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), )
r = _MockRefreshable(token=None, disable_async=True, refresh_effect=static_token(token))
result = r.token()
assert r._refresh_count == 1
assert result == token


def test_disable_async_no_expiration_does_not_refresh():
non_expiring_token = Token(access_token="access_token", )
r = _MockRefreshable(token=non_expiring_token, disable_async=True, refresh_effect=fail)
result = r.token()
assert r._refresh_count == 0
assert result == non_expiring_token


def test_disable_async_fresh_does_not_refresh():
# Create a token that is already stale. If async is disabled, the token should not be refreshed.
token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
r = _MockRefreshable(token=token, disable_async=True, refresh_effect=fail)
result = r.token()
assert r._refresh_count == 0
assert result == token


def test_disable_async_expired_does_refresh():
expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), )
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
# Add one second to the refresh time to ensure that the call is blocking.
# If the call is not blocking, the wait time will ensure that the
# old token is returned.
r = _MockRefreshable(token=expired_token,
disable_async=True,
refresh_effect=static_token(new_token, wait=1))
result = r.token()
assert r._refresh_count == 1
assert result == new_token


def test_expired_does_refresh():
expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), )
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
# Add one second to the refresh time to ensure that the call is blocking.
# If the call is not blocking, the wait time will ensure that the
# old token is returned.
r = _MockRefreshable(token=expired_token,
disable_async=False,
refresh_effect=static_token(new_token, wait=1))
result = r.token()
assert r._refresh_count == 1
assert result == new_token


def test_stale_does_refresh_async():
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), )
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
# Add one second to the refresh to avoid race conditions.
# Without it, the new token may be returned in some cases.
refresh, unblock = blocking_refresh(new_token)
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh)
result = r.token()
# NOTE: Do not check for refresh count here, since the
assert result == stale_token
assert r._refresh_count == 0
# Unblock the refresh and wait
unblock()
time.sleep(2)
# Call again and check that you get the new token
result = r.token()
assert result == new_token
# Ensure that all calls have completed
time.sleep(0.1)
assert r._refresh_count == 1


def test_no_token_does_refresh():
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
# Add one second to the refresh time to ensure that the call is blocking.
# If the call is not blocking, the wait time will ensure that the
# token is not returned.
r = _MockRefreshable(token=None, disable_async=False, refresh_effect=static_token(new_token, wait=1))
result = r.token()
assert r._refresh_count == 1
assert result == new_token


def test_fresh_does_not_refresh():
fresh_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
r = _MockRefreshable(token=fresh_token, disable_async=False, refresh_effect=fail)
result = r.token()
assert r._refresh_count == 0
assert result == fresh_token


def test_multiple_calls_dont_start_many_threads():
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), )
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
refresh, unblock = blocking_refresh(new_token)
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh)
# Call twice. The second call should not start a new thread.
result = r.token()
assert result == stale_token
result = r.token()
assert result == stale_token
unblock()
# Wait for the refresh to complete
time.sleep(1)
result = r.token()
# Check that only one refresh was called
assert r._refresh_count == 1
assert result == new_token


def test_async_failure_disables_async():
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), )
new_token = Token(access_token="new_token", expiry=datetime.now() + timedelta(seconds=300), )
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=fail)
# The call should fail and disable async refresh,
# but the exception will be catch inside the tread.
result = r.token()
assert result == stale_token
# Give time to the async refresh to fail
time.sleep(1)
assert r._refresh_err
# Now, the refresh should be blocking.
# Blocking refresh only happens for expired, not stale.
# Therefore, the next call should return the stale token.
r._refresh_effect = static_token(new_token, wait=1)
result = r.token()
assert result == stale_token
# Wait to be sure no async thread was started
time.sleep(1)
assert r._refresh_count == 0

# Inject an expired token.
expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), )
r._token = expired_token

# This should be blocking and return the new token.
result = r.token()
assert r._refresh_count == 1
assert result == new_token
# The refresh error should be cleared.
assert not r._refresh_err
Loading