From 9107e9c06e4a8e359cbe92626133afa04950b3b7 Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Mon, 24 Feb 2025 12:14:28 +0100 Subject: [PATCH 1/2] [Internal] Add DataPlane token source --- databricks/sdk/data_plane.py | 77 +++++++++++++++++++++++++++- tests/integration/test_dataplane.py | 15 ++++++ tests/test_data_plane.py | 78 +++++++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 tests/integration/test_dataplane.py diff --git a/databricks/sdk/data_plane.py b/databricks/sdk/data_plane.py index 5ad9b79ad..843128825 100644 --- a/databricks/sdk/data_plane.py +++ b/databricks/sdk/data_plane.py @@ -1,9 +1,81 @@ +from __future__ import annotations + import threading from dataclasses import dataclass -from typing import Callable, List +from typing import Callable, List, Optional +from urllib import parse +from databricks.sdk import oauth from databricks.sdk.oauth import Token +URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" +JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer" +OIDC_TOKEN_PATH = "/oidc/v1/token" + + +class DataPlaneTokenSource: + """ + EXPERIMENTAL Manages token sources for multiple DataPlane endpoints. + """ + + # TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled. + def __init__(self, + token_exchange_host: str, + cpts: Callable[[], Token], + disable_async: Optional[bool] = True): + self._cpts = cpts + self._token_exchange_host = token_exchange_host + self._token_sources = {} + self._disable_async = disable_async + self._lock = threading.Lock() + + def get_token(self, endpoint, auth_details): + key = f"{endpoint}:{auth_details}" + + # First, try to read without acquiring the lock to avoid contention. + # Reads are atomic, so this is safe. + token_source = self._token_sources.get(key) + if token_source: + return token_source.token() + + # If token_source is not found, acquire the lock and check again. + with self._lock: + # Another thread might have created it while we were waiting for the lock. + token_source = self._token_sources.get(key) + if not token_source: + token_source = DataPlaneEndpointTokenSource(self._token_exchange_host, self._cpts, + auth_details, self._disable_async) + self._token_sources[key] = token_source + + return token_source.token() + + +class DataPlaneEndpointTokenSource(oauth.Refreshable): + """ + EXPERIMENTAL A token source for a specific DataPlane endpoint. + """ + + def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str, + disable_async: bool): + super().__init__(disable_async=disable_async) + self._auth_details = auth_details + self._cpts = cpts + self._token_exchange_host = token_exchange_host + + def refresh(self) -> Token: + control_plane_token = self._cpts() + headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE} + params = parse.urlencode({ + "grant_type": JWT_BEARER_GRANT_TYPE, + "authorization_details": self._auth_details, + "assertion": control_plane_token.access_token + }) + return oauth.retrieve_token(client_id="", + client_secret="", + token_url=self._token_exchange_host + OIDC_TOKEN_PATH, + params=params, + headers=headers) + @dataclass class DataPlaneDetails: @@ -16,6 +88,9 @@ class DataPlaneDetails: """Token to query the DataPlane endpoint.""" +## Old implementation. #TODO: Remove after the new implementation is used + + class DataPlaneService: """Helper class to fetch and manage DataPlane details.""" from .service.serving import DataPlaneInfo diff --git a/tests/integration/test_dataplane.py b/tests/integration/test_dataplane.py new file mode 100644 index 000000000..0062a7ed0 --- /dev/null +++ b/tests/integration/test_dataplane.py @@ -0,0 +1,15 @@ +from databricks.sdk.data_plane import DataPlaneTokenSource + + +def test_data_plane_token_source(ucws, env_or_skip): + endpoint = env_or_skip("SERVING_ENDPOINT_NAME") + serving_endpoint = ucws.serving_endpoints.get(endpoint) + assert serving_endpoint.data_plane_info is not None + assert serving_endpoint.data_plane_info.query_info is not None + + info = serving_endpoint.data_plane_info.query_info + + ts = DataPlaneTokenSource(ucws.config.host, ucws._config.oauth_token) + dp_token = ts.token(info.endpoint_url, info.authorization_details) + + assert dp_token.valid diff --git a/tests/test_data_plane.py b/tests/test_data_plane.py index 1eac92382..29a33cb4b 100644 --- a/tests/test_data_plane.py +++ b/tests/test_data_plane.py @@ -1,9 +1,87 @@ from datetime import datetime, timedelta +from unittest.mock import patch +from urllib import parse +from databricks.sdk import data_plane, oauth from databricks.sdk.data_plane import DataPlaneService from databricks.sdk.oauth import Token from databricks.sdk.service.serving import DataPlaneInfo +cp_token = Token(access_token="control plane token", + token_type="type", + expiry=datetime.now() + timedelta(hours=1)) +dp_token = Token(access_token="data plane token", + token_type="type", + expiry=datetime.now() + timedelta(hours=1)) + + +def success_callable(token: oauth.Token): + + def success() -> oauth.Token: + return token + + return success + + +def test_endpoint_token_source_get_token(config): + token_source = data_plane.DataPlaneEndpointTokenSource(config.host, + success_callable(cp_token), + "authDetails", + disable_async=True) + + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: + token_source.token() + + retrieve_token.assert_called_once() + args, kwargs = retrieve_token.call_args + + assert kwargs["token_url"] == config.host + "/oidc/v1/token" + assert kwargs["params"] == parse.urlencode({ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "authorization_details": "authDetails", + "assertion": cp_token.access_token + }) + assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"} + + +def test_token_source_get_token_not_existing(config): + token_source = data_plane.DataPlaneTokenSource(config.host, + success_callable(cp_token), + disable_async=True) + + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: + result_token = token_source.token(endpoint="endpoint", auth_details="authDetails") + + retrieve_token.assert_called_once() + assert result_token.access_token == dp_token.access_token + assert "endpoint:authDetails" in token_source._token_sources + + +class MockEndpointTokenSource: + + def __init__(self, token: oauth.Token): + self._token = token + + def token(self): + return self._token + + +def test_token_source_get_token_existing(config): + another_token = Token(access_token="another token", + token_type="type", + expiry=datetime.now() + timedelta(hours=1)) + token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True) + token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token) + + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: + result_token = token_source.token(endpoint="endpoint", auth_details="authDetails") + + retrieve_token.assert_not_called() + assert result_token.access_token == another_token.access_token + + +## These tests are for the old implementation. #TODO: Remove after the new implementation is used + info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url") token = Token(access_token="token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) From 5530671f02c457648b0edc3f2e16ddfb7a180e2e Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Fri, 28 Feb 2025 15:08:15 +0100 Subject: [PATCH 2/2] black fmt --- databricks/sdk/data_plane.py | 39 ++++++++++--------- .../{test_dataplane.py => test_data_plane.py} | 0 tests/test_data_plane.py | 35 +++++++---------- 3 files changed, 34 insertions(+), 40 deletions(-) rename tests/integration/{test_dataplane.py => test_data_plane.py} (100%) diff --git a/databricks/sdk/data_plane.py b/databricks/sdk/data_plane.py index 97362cb5b..3c059ecf2 100644 --- a/databricks/sdk/data_plane.py +++ b/databricks/sdk/data_plane.py @@ -19,17 +19,14 @@ class DataPlaneTokenSource: """ # TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled. - def __init__(self, - token_exchange_host: str, - cpts: Callable[[], Token], - disable_async: Optional[bool] = True): + def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], disable_async: Optional[bool] = True): self._cpts = cpts self._token_exchange_host = token_exchange_host self._token_sources = {} self._disable_async = disable_async self._lock = threading.Lock() - def get_token(self, endpoint, auth_details): + def token(self, endpoint, auth_details): key = f"{endpoint}:{auth_details}" # First, try to read without acquiring the lock to avoid contention. @@ -43,8 +40,9 @@ def get_token(self, endpoint, auth_details): # Another thread might have created it while we were waiting for the lock. token_source = self._token_sources.get(key) if not token_source: - token_source = DataPlaneEndpointTokenSource(self._token_exchange_host, self._cpts, - auth_details, self._disable_async) + token_source = DataPlaneEndpointTokenSource( + self._token_exchange_host, self._cpts, auth_details, self._disable_async + ) self._token_sources[key] = token_source return token_source.token() @@ -55,8 +53,7 @@ class DataPlaneEndpointTokenSource(oauth.Refreshable): EXPERIMENTAL A token source for a specific DataPlane endpoint. """ - def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str, - disable_async: bool): + def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str, disable_async: bool): super().__init__(disable_async=disable_async) self._auth_details = auth_details self._cpts = cpts @@ -65,16 +62,20 @@ def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_det def refresh(self) -> Token: control_plane_token = self._cpts() headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE} - params = parse.urlencode({ - "grant_type": JWT_BEARER_GRANT_TYPE, - "authorization_details": self._auth_details, - "assertion": control_plane_token.access_token - }) - return oauth.retrieve_token(client_id="", - client_secret="", - token_url=self._token_exchange_host + OIDC_TOKEN_PATH, - params=params, - headers=headers) + params = parse.urlencode( + { + "grant_type": JWT_BEARER_GRANT_TYPE, + "authorization_details": self._auth_details, + "assertion": control_plane_token.access_token, + } + ) + return oauth.retrieve_token( + client_id="", + client_secret="", + token_url=self._token_exchange_host + OIDC_TOKEN_PATH, + params=params, + headers=headers, + ) @dataclass diff --git a/tests/integration/test_dataplane.py b/tests/integration/test_data_plane.py similarity index 100% rename from tests/integration/test_dataplane.py rename to tests/integration/test_data_plane.py diff --git a/tests/test_data_plane.py b/tests/test_data_plane.py index 2c7a86f58..54ace9ba7 100644 --- a/tests/test_data_plane.py +++ b/tests/test_data_plane.py @@ -7,12 +7,8 @@ from databricks.sdk.oauth import Token from databricks.sdk.service.serving import DataPlaneInfo -cp_token = Token(access_token="control plane token", - token_type="type", - expiry=datetime.now() + timedelta(hours=1)) -dp_token = Token(access_token="data plane token", - token_type="type", - expiry=datetime.now() + timedelta(hours=1)) +cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) +dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) def success_callable(token: oauth.Token): @@ -24,10 +20,9 @@ def success() -> oauth.Token: def test_endpoint_token_source_get_token(config): - token_source = data_plane.DataPlaneEndpointTokenSource(config.host, - success_callable(cp_token), - "authDetails", - disable_async=True) + token_source = data_plane.DataPlaneEndpointTokenSource( + config.host, success_callable(cp_token), "authDetails", disable_async=True + ) with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: token_source.token() @@ -36,18 +31,18 @@ def test_endpoint_token_source_get_token(config): args, kwargs = retrieve_token.call_args assert kwargs["token_url"] == config.host + "/oidc/v1/token" - assert kwargs["params"] == parse.urlencode({ - "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", - "authorization_details": "authDetails", - "assertion": cp_token.access_token - }) + assert kwargs["params"] == parse.urlencode( + { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "authorization_details": "authDetails", + "assertion": cp_token.access_token, + } + ) assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"} def test_token_source_get_token_not_existing(config): - token_source = data_plane.DataPlaneTokenSource(config.host, - success_callable(cp_token), - disable_async=True) + token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(cp_token), disable_async=True) with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: result_token = token_source.token(endpoint="endpoint", auth_details="authDetails") @@ -67,9 +62,7 @@ def token(self): def test_token_source_get_token_existing(config): - another_token = Token(access_token="another token", - token_type="type", - expiry=datetime.now() + timedelta(hours=1)) + another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True) token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token)