diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index f75645d2..e1836b21 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -84,6 +84,7 @@ ProviderExchangeFiltersAPI, ProviderExchangesAPI, ProviderFilesAPI, ProviderListingsAPI, ProviderPersonalizationRequestsAPI, ProviderProviderAnalyticsDashboardsAPI, ProviderProvidersAPI) +from databricks.sdk.service.mcp import MCP from databricks.sdk.service.ml import (ExperimentsAPI, ForecastingAPI, ModelRegistryAPI) from databricks.sdk.service.oauth2 import (AccountFederationPolicyAPI, @@ -337,6 +338,7 @@ def __init__( self._workspace_bindings = pkg_catalog.WorkspaceBindingsAPI(self._api_client) self._workspace_conf = pkg_settings.WorkspaceConfAPI(self._api_client) self._forecasting = pkg_ml.ForecastingAPI(self._api_client) + self._mcp = MCP(self.config) @property def config(self) -> client.Config: @@ -860,6 +862,10 @@ def forecasting(self) -> pkg_ml.ForecastingAPI: """The Forecasting API allows you to create and get serverless forecasting experiments.""" return self._forecasting + @property + def mcp(self) -> MCP: + return self._mcp + def get_workspace_id(self) -> int: """Get the workspace ID of the workspace that this client is connected to.""" response = self._api_client.do("GET", "/api/2.0/preview/scim/v2/Me", response_headers=["X-Databricks-Org-Id"]) diff --git a/databricks/sdk/service/mcp.py b/databricks/sdk/service/mcp.py new file mode 100644 index 00000000..d906e348 --- /dev/null +++ b/databricks/sdk/service/mcp.py @@ -0,0 +1,27 @@ +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.shared.auth import OAuthToken + + +class DatabricksTokenStorage(TokenStorage): + def __init__(self, config): + self.config = config + + async def get_tokens(self) -> OAuthToken | None: + headers = self.config.authenticate() + token = headers["Authorization"].split("Bearer ")[1] + return OAuthToken(access_token=token, expires_in=60) + + +class MCP: + def __init__(self, config): + self._config = config + self.databricks_token_storage = DatabricksTokenStorage(config) + + def get_oauth_provider(self): + return OAuthClientProvider( + server_url="", + client_metadata=None, + storage=self.databricks_token_storage, + redirect_handler=None, + callback_handler=None, + ) diff --git a/pyproject.toml b/pyproject.toml index 60c33f0e..5e90e6c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ dev = [ 'langchain-openai; python_version > "3.7"', "httpx", "build", # some integration tests depend on the databricks-sdk-py wheel + "mcp>=1.9.1", + "pytest-asyncio" ] notebook = [ "ipython>=8,<10", @@ -63,6 +65,9 @@ openai = [ 'langchain-openai; python_version > "3.7"', "httpx", ] +mcp = [ + "mcp>=1.9.1" +] [tool.setuptools.dynamic] version = { attr = "databricks.sdk.version.__version__" } diff --git a/tests/test_mcp.py b/tests/test_mcp.py new file mode 100644 index 00000000..9178ef2a --- /dev/null +++ b/tests/test_mcp.py @@ -0,0 +1,20 @@ +import time + +import httpx +import pytest + + +@pytest.mark.asyncio +async def test_mcp_oauth_provider(monkeypatch): + monkeypatch.setattr(time, "time", lambda: 100) + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv("DATABRICKS_HOST", "test_host") + monkeypatch.setenv("DATABRICKS_TOKEN", "test_token") + + w = WorkspaceClient() + mcp_oauth_provider = w.mcp.get_oauth_provider() + + request = httpx.Request("GET", "https://example.com") + response = await anext(mcp_oauth_provider.async_auth_flow(request)) + assert response.headers["Authorization"] == "Bearer test_token"