diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 26ba70c3d..9c12df319 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -13,5 +13,6 @@ * Update Jobs ListRuns API to support paginated responses ([#890](https://github.com/databricks/databricks-sdk-py/pull/890)) * Introduce automated tagging ([#888](https://github.com/databricks/databricks-sdk-py/pull/888)) * Update Jobs GetJob API to support paginated responses ([#869](https://github.com/databricks/databricks-sdk-py/pull/869)). +* Update On Behalf Of User Authentication in Multithreaded applications ([#907](https://github.com/databricks/databricks-sdk-py/pull/907)) ### API Changes diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 8fb1b45c2..86acac86c 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -853,8 +853,8 @@ def _get_model_dependency_oauth_token(self, should_retry=True) -> str: return self.current_token def _get_invokers_token(self): - current_thread = threading.current_thread() - thread_data = current_thread.__dict__ + main_thread = threading.main_thread() + thread_data = main_thread.__dict__ invokers_token = None if "invokers_token" in thread_data: invokers_token = thread_data["invokers_token"] diff --git a/tests/test_model_serving_auth.py b/tests/test_model_serving_auth.py index ba9319f8a..3c3ddfa99 100644 --- a/tests/test_model_serving_auth.py +++ b/tests/test_model_serving_auth.py @@ -198,6 +198,26 @@ def test_agent_user_credentials(monkeypatch, mocker): assert cfg.host == "x" assert headers.get("Authorization") == f"Bearer {invokers_token_val}" + # Test invokers token in child thread + + successful_authentication_event = threading.Event() + + def authenticate(): + try: + cfg = Config(credentials_strategy=ModelServingUserCredentials()) + headers = cfg.authenticate() + assert cfg.host == "x" + assert headers.get("Authorization") == f"Bearer databricks_invokers_token_v2" + successful_authentication_event.set() + except Exception: + successful_authentication_event.clear() + + thread = threading.Thread(target=authenticate) + + thread.start() + thread.join() + assert successful_authentication_event.is_set() + # If this credential strategy is being used in a non model serving environments then use default credential strategy instead def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):