Skip to content

[Internal] Add DataPlane token source #897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion databricks/sdk/data_plane.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,82 @@
from __future__ import annotations

import threading
from dataclasses import dataclass
from typing import Callable, List
from typing import Callable, List, Optional
from urllib import parse

from databricks.sdk import oauth
from databricks.sdk.oauth import Token

URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
OIDC_TOKEN_PATH = "/oidc/v1/token"


class DataPlaneTokenSource:
"""
EXPERIMENTAL Manages token sources for multiple DataPlane endpoints.
"""

# TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], disable_async: Optional[bool] = True):
self._cpts = cpts
self._token_exchange_host = token_exchange_host
self._token_sources = {}
self._disable_async = disable_async
self._lock = threading.Lock()

def token(self, endpoint, auth_details):
key = f"{endpoint}:{auth_details}"

# First, try to read without acquiring the lock to avoid contention.
# Reads are atomic, so this is safe.
token_source = self._token_sources.get(key)
if token_source:
return token_source.token()

# If token_source is not found, acquire the lock and check again.
with self._lock:
# Another thread might have created it while we were waiting for the lock.
token_source = self._token_sources.get(key)
if not token_source:
token_source = DataPlaneEndpointTokenSource(
self._token_exchange_host, self._cpts, auth_details, self._disable_async
)
self._token_sources[key] = token_source

return token_source.token()


class DataPlaneEndpointTokenSource(oauth.Refreshable):
"""
EXPERIMENTAL A token source for a specific DataPlane endpoint.
"""

def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str, disable_async: bool):
super().__init__(disable_async=disable_async)
self._auth_details = auth_details
self._cpts = cpts
self._token_exchange_host = token_exchange_host

def refresh(self) -> Token:
control_plane_token = self._cpts()
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
params = parse.urlencode(
{
"grant_type": JWT_BEARER_GRANT_TYPE,
"authorization_details": self._auth_details,
"assertion": control_plane_token.access_token,
}
)
return oauth.retrieve_token(
client_id="",
client_secret="",
token_url=self._token_exchange_host + OIDC_TOKEN_PATH,
params=params,
headers=headers,
)


@dataclass
class DataPlaneDetails:
Expand All @@ -17,6 +90,9 @@ class DataPlaneDetails:
"""Token to query the DataPlane endpoint."""


## Old implementation. #TODO: Remove after the new implementation is used


class DataPlaneService:
"""Helper class to fetch and manage DataPlane details."""

Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_data_plane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from databricks.sdk.data_plane import DataPlaneTokenSource


def test_data_plane_token_source(ucws, env_or_skip):
endpoint = env_or_skip("SERVING_ENDPOINT_NAME")
serving_endpoint = ucws.serving_endpoints.get(endpoint)
assert serving_endpoint.data_plane_info is not None
assert serving_endpoint.data_plane_info.query_info is not None

info = serving_endpoint.data_plane_info.query_info

ts = DataPlaneTokenSource(ucws.config.host, ucws._config.oauth_token)
dp_token = ts.token(info.endpoint_url, info.authorization_details)

assert dp_token.valid
71 changes: 71 additions & 0 deletions tests/test_data_plane.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,80 @@
from datetime import datetime, timedelta
from unittest.mock import patch
from urllib import parse

from databricks.sdk import data_plane, oauth
from databricks.sdk.data_plane import DataPlaneService
from databricks.sdk.oauth import Token
from databricks.sdk.service.serving import DataPlaneInfo

cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))


def success_callable(token: oauth.Token):

def success() -> oauth.Token:
return token

return success


def test_endpoint_token_source_get_token(config):
token_source = data_plane.DataPlaneEndpointTokenSource(
config.host, success_callable(cp_token), "authDetails", disable_async=True
)

with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
token_source.token()

retrieve_token.assert_called_once()
args, kwargs = retrieve_token.call_args

assert kwargs["token_url"] == config.host + "/oidc/v1/token"
assert kwargs["params"] == parse.urlencode(
{
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"authorization_details": "authDetails",
"assertion": cp_token.access_token,
}
)
assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"}


def test_token_source_get_token_not_existing(config):
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(cp_token), disable_async=True)

with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")

retrieve_token.assert_called_once()
assert result_token.access_token == dp_token.access_token
assert "endpoint:authDetails" in token_source._token_sources


class MockEndpointTokenSource:

def __init__(self, token: oauth.Token):
self._token = token

def token(self):
return self._token


def test_token_source_get_token_existing(config):
another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True)
token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token)

with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")

retrieve_token.assert_not_called()
assert result_token.access_token == another_token.access_token


## These tests are for the old implementation. #TODO: Remove after the new implementation is used

info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url")

token = Token(
Expand Down
Loading