Skip to content

[Feature] Fix Model Serving User Credentials threading scenarios #907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
* Update Jobs ListRuns API to support paginated responses ([#890](https://github.com/databricks/databricks-sdk-py/pull/890))
* Introduce automated tagging ([#888](https://github.com/databricks/databricks-sdk-py/pull/888))
* Update Jobs GetJob API to support paginated responses ([#869](https://github.com/databricks/databricks-sdk-py/pull/869)).
* Update On Behalf Of User Authentication in Multithreaded applications ([#907](https://github.com/databricks/databricks-sdk-py/pull/907))

### API Changes
4 changes: 2 additions & 2 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,8 @@ def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
return self.current_token

def _get_invokers_token(self):
current_thread = threading.current_thread()
thread_data = current_thread.__dict__
main_thread = threading.main_thread()
thread_data = main_thread.__dict__
invokers_token = None
if "invokers_token" in thread_data:
invokers_token = thread_data["invokers_token"]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_model_serving_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,26 @@ def test_agent_user_credentials(monkeypatch, mocker):
assert cfg.host == "x"
assert headers.get("Authorization") == f"Bearer {invokers_token_val}"

# Test invokers token in child thread

successful_authentication_event = threading.Event()

def authenticate():
try:
cfg = Config(credentials_strategy=ModelServingUserCredentials())
headers = cfg.authenticate()
assert cfg.host == "x"
assert headers.get("Authorization") == f"Bearer databricks_invokers_token_v2"
successful_authentication_event.set()
except Exception:
successful_authentication_event.clear()

thread = threading.Thread(target=authenticate)

thread.start()
thread.join()
assert successful_authentication_event.is_set()


# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):
Expand Down
Loading