diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 2f512118..b1d4b5a1 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -942,6 +942,14 @@ def _get_model_dependency_oauth_token(self, should_retry=True) -> str: ) from e return self.current_token + def _get_invokers_token_from_greenlet(self): + # Attempt to retrieve 'invokers_token' from greenlet local + from greenlet import greenlet, getcurrent + greenlet = getcurrent() + if hasattr(greenlet, 'invokers_token'): + return greenlet.invokers_token + raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving") + def _get_invokers_token(self): main_thread = threading.main_thread() thread_data = main_thread.__dict__ @@ -950,7 +958,8 @@ def _get_invokers_token(self): invokers_token = thread_data["invokers_token"] if invokers_token is None: - raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving") + # This is likely async server code, so we should check greenlet local + return self._get_invokers_token_from_greenlet() return invokers_token diff --git a/pyproject.toml b/pyproject.toml index 60c33f0e..d2d5b1e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ dependencies = [ "requests>=2.28.1,<3", "google-auth~=2.0", + "greenlet>=3.2.0", ] [project.urls] @@ -82,4 +83,4 @@ include = ["."] exclude = ["**/node_modules", "**/__pycache__"] reportMissingImports = true reportMissingTypeStubs = false -pythonVersion = "3.7" \ No newline at end of file +pythonVersion = "3.7" diff --git a/tests/test_model_serving_auth.py b/tests/test_model_serving_auth.py index 3c3ddfa9..6cdc47e7 100644 --- a/tests/test_model_serving_auth.py +++ b/tests/test_model_serving_auth.py @@ -5,6 +5,7 @@ from databricks.sdk.core import Config from databricks.sdk.credentials_provider import ModelServingUserCredentials +from greenlet import greenlet, getcurrent from .conftest import raises @@ -217,7 +218,41 @@ def authenticate(): thread.start() thread.join() assert successful_authentication_event.is_set() + del current_thread.__dict__["invokers_token"] # Clean up invokers token +def test_agent_user_credentials_via_greenlet(monkeypatch, mocker): + # Guarantee that the tests defaults to env variables rather than config file. + # + # TODO: this is hacky and we should find a better way to tell the config + # that it should not read from the config file. + monkeypatch.setenv("DATABRICKS_CONFIG_FILE", "x") + + monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") + monkeypatch.setenv("DB_MODEL_SERVING_HOST_URL", "x") + monkeypatch.setattr( + "databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", + "tests/testdata/model-serving-test-token", + ) + + invokers_token_val = "databricks_invokers_token" + greenlet_local = getcurrent() + setattr(greenlet_local, "invokers_token", invokers_token_val) + + cfg = Config(credentials_strategy=ModelServingUserCredentials()) + assert cfg.auth_type == "model_serving_user_credentials" + + headers = cfg.authenticate() + + assert cfg.host == "x" + assert headers.get("Authorization") == f"Bearer {invokers_token_val}" + + # Test updates of invokers token + invokers_token_val = "databricks_invokers_token_v2" + setattr(greenlet_local, "invokers_token", invokers_token_val) + + headers = cfg.authenticate() + assert cfg.host == "x" + assert headers.get("Authorization") == f"Bearer {invokers_token_val}" # If this credential strategy is being used in a non model serving environments then use default credential strategy instead def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):