From afcbc524b211ebfb989a5a9d253c6000b091a3a2 Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 27 May 2025 14:31:03 -0400 Subject: [PATCH 01/10] Updated the DatabricksRM class to use Databricks service principals. - Updated the base class to include the optional parameters - Updated logic on the _query_via_databricks_sdk to use the SP credentials if they exist otherwise will fallback to PAT. --- dspy/retrieve/databricks_rm.py | 38 ++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 0334154dd5..68bb132d2d 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -86,6 +86,8 @@ def __init__( databricks_index_name: str, databricks_endpoint: Optional[str] = None, databricks_token: Optional[str] = None, + databricks_client_id: Optional[str] = None, + databricks_client_secret: Optional[str] = None, columns: Optional[List[str]] = None, filters_json: Optional[str] = None, k: int = 3, @@ -105,6 +107,10 @@ def __init__( when querying the Vector Search Index. Defaults to the value of the ``DATABRICKS_TOKEN`` environment variable. If unspecified, the Databricks SDK is used to identify the token based on the current environment. + databricks_client_id (str): Databricks service principal id. If not specified, + the token is resolved from the current environment (DATABRICKS_CLIENT_ID). + databricks_client_secret (str): Databricks service principal secret. If not specified, + the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). columns (Optional[List[str]]): Extra column names to include in response, in addition to the document id and text columns specified by ``docs_id_column_name`` and ``text_column_name``. @@ -127,7 +133,13 @@ def __init__( self.databricks_endpoint = ( databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST") ) - if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: + self.databricks_client_id = databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID") + self.databricks_client_secret = ( + databricks_client_secret if databricks_client_secret is not None else os.environ.get("DATABRICKS_CLIENT_SECRET") + ) + + if not _databricks_sdk_installed and ((self.databricks_token, self.databricks_endpoint).count(None) > 0 and + (self.databricks_client_id, self.databricks_client_secret).count(None) > 0): raise ValueError( "To retrieve documents with Databricks Vector Search, you must install the" " databricks-sdk Python library, supply the databricks_token and" @@ -245,6 +257,8 @@ def forward( query_vector=query_vector, databricks_token=self.databricks_token, databricks_endpoint=self.databricks_endpoint, + databricks_client_id=self.databricks_client_id, + databricks_client_secret=self.databricks_client_secret, filters_json=filters_json or self.filters_json, ) else: @@ -315,6 +329,8 @@ def _query_via_databricks_sdk( query_vector: Optional[List[float]], databricks_token: Optional[str], databricks_endpoint: Optional[str], + databricks_client_id: Optional[str], + databricks_client_secret: Optional[str], filters_json: Optional[str], ) -> Dict[str, Any]: """ @@ -334,15 +350,33 @@ def _query_via_databricks_sdk( the token is resolved from the current environment. databricks_endpoint (str): Databricks index endpoint url. If not specified, the endpoint is resolved from the current environment. + databricks_client_id (str): Databricks service principal id. If not specified, + the token is resolved from the current environment (DATABRICKS_CLIENT_ID). + databricks_client_secret (str): Databricks service principal secret. If not specified, + the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). + Returns: Returns: Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. """ + from databricks.sdk import WorkspaceClient if (query_text, query_vector).count(None) != 1: raise ValueError("Exactly one of query_text or query_vector must be specified.") - databricks_client = WorkspaceClient(host=databricks_endpoint, token=databricks_token) + if databricks_client_secret and databricks_client_id: + # Use client ID and secret for authentication if they are provided + databricks_client = WorkspaceClient( + client_id=databricks_client_id, + client_secret=databricks_client_secret, + ) + else: + # Fallback for token-based authentication + databricks_client = WorkspaceClient( + host=databricks_endpoint, + token=databricks_token, + ) + return databricks_client.vector_search_indexes.query_index( index_name=index_name, query_type=query_type, From 784354abc9e89dc9ab9e95d912083a470bc0ca49 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 29 May 2025 11:26:39 -0400 Subject: [PATCH 02/10] Updated comments for the usage of service principals for the Databricks SDK Client --- dspy/retrieve/databricks_rm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 68bb132d2d..88728280da 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -137,14 +137,13 @@ def __init__( self.databricks_client_secret = ( databricks_client_secret if databricks_client_secret is not None else os.environ.get("DATABRICKS_CLIENT_SECRET") ) - - if not _databricks_sdk_installed and ((self.databricks_token, self.databricks_endpoint).count(None) > 0 and - (self.databricks_client_id, self.databricks_client_secret).count(None) > 0): + if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: raise ValueError( "To retrieve documents with Databricks Vector Search, you must install the" " databricks-sdk Python library, supply the databricks_token and" " databricks_endpoint parameters, or set the DATABRICKS_TOKEN and DATABRICKS_HOST" - " environment variables." + " environment variables. You may also supply a service principal the databricks_client_id and" + " databricks_client_secret parameters, or set the DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET" ) self.databricks_index_name = databricks_index_name self.columns = list({docs_id_column_name, text_column_name, *(columns or [])}) @@ -370,6 +369,7 @@ def _query_via_databricks_sdk( client_id=databricks_client_id, client_secret=databricks_client_secret, ) + else: # Fallback for token-based authentication databricks_client = WorkspaceClient( From f06c2edf9d1e871eb49dedaaae9b358163a7a504 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 29 May 2025 13:39:03 -0400 Subject: [PATCH 03/10] Updated print statements to acknowledge auth method. --- dspy/retrieve/databricks_rm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 88728280da..6437ab5600 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -369,6 +369,7 @@ def _query_via_databricks_sdk( client_id=databricks_client_id, client_secret=databricks_client_secret, ) + print("Creating Databricks workspace client using service principal authentication.") else: # Fallback for token-based authentication @@ -376,6 +377,7 @@ def _query_via_databricks_sdk( host=databricks_endpoint, token=databricks_token, ) + print("Creating Databricks workspace client using token authentication.") return databricks_client.vector_search_indexes.query_index( index_name=index_name, From 4b00022618225b916b01298f31ac2a1678f6cc3f Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 29 May 2025 13:39:03 -0400 Subject: [PATCH 04/10] Updated print statements to acknowledge auth method. --- dspy/retrieve/databricks_rm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 88728280da..1be09317d0 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -369,6 +369,7 @@ def _query_via_databricks_sdk( client_id=databricks_client_id, client_secret=databricks_client_secret, ) + print("Creating Databricks workspace client using service principal authentication.") else: # Fallback for token-based authentication @@ -376,6 +377,7 @@ def _query_via_databricks_sdk( host=databricks_endpoint, token=databricks_token, ) + print("Creating Databricks workspace client using token authentication.") return databricks_client.vector_search_indexes.query_index( index_name=index_name, From b47fe02e95552455b20b0b80fad28ef4a449c2e5 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 29 May 2025 16:51:35 -0700 Subject: [PATCH 05/10] format --- dspy/retrieve/databricks_rm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 1be09317d0..115aa79eb8 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -133,9 +133,13 @@ def __init__( self.databricks_endpoint = ( databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST") ) - self.databricks_client_id = databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID") + self.databricks_client_id = ( + databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID") + ) self.databricks_client_secret = ( - databricks_client_secret if databricks_client_secret is not None else os.environ.get("DATABRICKS_CLIENT_SECRET") + databricks_client_secret + if databricks_client_secret is not None + else os.environ.get("DATABRICKS_CLIENT_SECRET") ) if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: raise ValueError( From 9cf562a96a88c0054f64d071f4f8f308f82b84f3 Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 3 Jun 2025 11:26:02 -0400 Subject: [PATCH 06/10] Updated for Oauth support via REST API Updated print statements for which auth method is invoked. Updated docstrings. Created helper function: _get_oauth_token Updated logic for databricks token to use oauth token with SP when using the REST API with client secret and client id --- dspy/retrieve/databricks_rm.py | 89 ++++++++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 3 deletions(-) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 115aa79eb8..5e200a6de8 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -183,6 +183,7 @@ def _extract_doc_ids(self, item: Dict[str, Any]) -> str: if self.docs_id_column_name == "metadata": docs_dict = json.loads(item["metadata"]) return docs_dict["document_id"] + return item[self.docs_id_column_name] def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]: @@ -198,11 +199,13 @@ def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]: for k, v in item.items() if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name] } + if self.docs_id_column_name == "metadata": extra_columns = { **extra_columns, **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}}, } + return extra_columns def forward( @@ -251,6 +254,7 @@ def forward( raise ValueError("Query must be a string or a list of floats.") if _databricks_sdk_installed: + print("Using the Databricks SDK to query the Vector Search Index.") results = self._query_via_databricks_sdk( index_name=self.databricks_index_name, k=self.k, @@ -265,6 +269,7 @@ def forward( filters_json=filters_json or self.filters_json, ) else: + print("Using the REST API to query the Vector Search Index.") results = self._query_via_requests( index_name=self.databricks_index_name, k=self.k, @@ -313,6 +318,7 @@ def forward( ).to_dict() for doc in sorted_docs ] + else: # Returning the prediction return Prediction( @@ -351,8 +357,10 @@ def _query_via_databricks_sdk( filters_json (Optional[str]): JSON string representing additional query filters. databricks_token (str): Databricks authentication token. If not specified, the token is resolved from the current environment. - databricks_endpoint (str): Databricks index endpoint url. If not specified, - the endpoint is resolved from the current environment. + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. databricks_client_id (str): Databricks service principal id. If not specified, the token is resolved from the current environment (DATABRICKS_CLIENT_ID). databricks_client_secret (str): Databricks service principal secret. If not specified, @@ -400,6 +408,8 @@ def _query_via_requests( columns: List[str], databricks_token: str, databricks_endpoint: str, + databricks_client_id: Optional[str], + databricks_client_secret: Optional[str], query_type: str, query_text: Optional[str], query_vector: Optional[List[float]], @@ -413,7 +423,14 @@ def _query_via_requests( k (int): Number of relevant documents to retrieve. columns (List[str]): Column names to include in response. databricks_token (str): Databricks authentication token. - databricks_endpoint (str): Databricks index endpoint url. + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. + databricks_client_id (str): Databricks service principal id. If not specified, + the token is resolved from the current environment (DATABRICKS_CLIENT_ID). + databricks_client_secret (str): Databricks service principal secret. If not specified, + the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). query_text (Optional[str]): Text query for which to find relevant documents. Exactly one of query_text or query_vector must be specified. query_vector (Optional[List[float]]): Numeric query vector for which to find relevant @@ -423,30 +440,96 @@ def _query_via_requests( Returns: Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. """ + if (query_text, query_vector).count(None) != 1: raise ValueError("Exactly one of query_text or query_vector must be specified.") + if databricks_client_id and databricks_client_secret: + try: + databricks_token = _get_oauth_token( + index_name, databricks_endpoint, databricks_client_id, databricks_client_secret + ) + except Exception as e: + raise ValueError( + f"Failed to retrieve OAuth token. Please check your Databricks client ID and secret. Error: {e}" + ) + headers = { "Authorization": f"Bearer {databricks_token}", "Content-Type": "application/json", } + payload = { "columns": columns, "num_results": k, "query_type": query_type, } + if filters_json is not None: payload["filters_json"] = filters_json if query_text is not None: payload["query_text"] = query_text elif query_vector is not None: payload["query_vector"] = query_vector + response = requests.post( f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query", json=payload, headers=headers, ) + results = response.json() if "error_code" in results: raise Exception(f"ERROR: {results['error_code']} -- {results['message']}") return results + +def _get_oauth_token( + index_name: str, + databricks_endpoint: str, + databricks_client_id: str, + databricks_client_secret: str, +) -> str: + """ + Get OAuth token for Databricks service principal authentication. + + Args: + index_name (str): Name of the Databricks vector search index to query + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. + databricks_client_id (str): Databricks service principal id. If not specified, + the token is resolved from the current environment (DATABRICKS_CLIENT_ID). + databricks_client_secret (str): Databricks service principal secret. If not specified, + the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). + + Returns: + str: OAuth token. + """ + + authorization_details = { + "type": "unity_catalog_permission", + "securable_type": "table", + "securable_object_name": index_name, + "operation": "ReadVectorIndex" + } + + authorization_details_list = [authorization_details] + + token_url = f"{databricks_endpoint}/oidc/v1/token" + + data = { + 'grant_type': 'client_credentials', + 'scope': 'all-apis', + 'authorization_details': json.dumps(authorization_details_list) + } + + response = requests.post( + token_url, + auth=(databricks_client_id, databricks_client_secret), + data=data, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) + + response.raise_for_status() + return response.json()['access_token'] \ No newline at end of file From 34fa87cd98b6941e43fa55e15038f97f2bdbf1dc Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 3 Jun 2025 11:28:13 -0400 Subject: [PATCH 07/10] Updated for Oauth support via REST API Updated print statements for which auth method is invoked. Updated docstrings. Created helper function: _get_oauth_token Updated logic for databricks token to use oauth token with SP when using the REST API with client secret and client id --- dspy/retrieve/databricks_rm.py | 89 ++++++++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 3 deletions(-) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 115aa79eb8..5e200a6de8 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -183,6 +183,7 @@ def _extract_doc_ids(self, item: Dict[str, Any]) -> str: if self.docs_id_column_name == "metadata": docs_dict = json.loads(item["metadata"]) return docs_dict["document_id"] + return item[self.docs_id_column_name] def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]: @@ -198,11 +199,13 @@ def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]: for k, v in item.items() if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name] } + if self.docs_id_column_name == "metadata": extra_columns = { **extra_columns, **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}}, } + return extra_columns def forward( @@ -251,6 +254,7 @@ def forward( raise ValueError("Query must be a string or a list of floats.") if _databricks_sdk_installed: + print("Using the Databricks SDK to query the Vector Search Index.") results = self._query_via_databricks_sdk( index_name=self.databricks_index_name, k=self.k, @@ -265,6 +269,7 @@ def forward( filters_json=filters_json or self.filters_json, ) else: + print("Using the REST API to query the Vector Search Index.") results = self._query_via_requests( index_name=self.databricks_index_name, k=self.k, @@ -313,6 +318,7 @@ def forward( ).to_dict() for doc in sorted_docs ] + else: # Returning the prediction return Prediction( @@ -351,8 +357,10 @@ def _query_via_databricks_sdk( filters_json (Optional[str]): JSON string representing additional query filters. databricks_token (str): Databricks authentication token. If not specified, the token is resolved from the current environment. - databricks_endpoint (str): Databricks index endpoint url. If not specified, - the endpoint is resolved from the current environment. + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. databricks_client_id (str): Databricks service principal id. If not specified, the token is resolved from the current environment (DATABRICKS_CLIENT_ID). databricks_client_secret (str): Databricks service principal secret. If not specified, @@ -400,6 +408,8 @@ def _query_via_requests( columns: List[str], databricks_token: str, databricks_endpoint: str, + databricks_client_id: Optional[str], + databricks_client_secret: Optional[str], query_type: str, query_text: Optional[str], query_vector: Optional[List[float]], @@ -413,7 +423,14 @@ def _query_via_requests( k (int): Number of relevant documents to retrieve. columns (List[str]): Column names to include in response. databricks_token (str): Databricks authentication token. - databricks_endpoint (str): Databricks index endpoint url. + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. + databricks_client_id (str): Databricks service principal id. If not specified, + the token is resolved from the current environment (DATABRICKS_CLIENT_ID). + databricks_client_secret (str): Databricks service principal secret. If not specified, + the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). query_text (Optional[str]): Text query for which to find relevant documents. Exactly one of query_text or query_vector must be specified. query_vector (Optional[List[float]]): Numeric query vector for which to find relevant @@ -423,30 +440,96 @@ def _query_via_requests( Returns: Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. """ + if (query_text, query_vector).count(None) != 1: raise ValueError("Exactly one of query_text or query_vector must be specified.") + if databricks_client_id and databricks_client_secret: + try: + databricks_token = _get_oauth_token( + index_name, databricks_endpoint, databricks_client_id, databricks_client_secret + ) + except Exception as e: + raise ValueError( + f"Failed to retrieve OAuth token. Please check your Databricks client ID and secret. Error: {e}" + ) + headers = { "Authorization": f"Bearer {databricks_token}", "Content-Type": "application/json", } + payload = { "columns": columns, "num_results": k, "query_type": query_type, } + if filters_json is not None: payload["filters_json"] = filters_json if query_text is not None: payload["query_text"] = query_text elif query_vector is not None: payload["query_vector"] = query_vector + response = requests.post( f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query", json=payload, headers=headers, ) + results = response.json() if "error_code" in results: raise Exception(f"ERROR: {results['error_code']} -- {results['message']}") return results + +def _get_oauth_token( + index_name: str, + databricks_endpoint: str, + databricks_client_id: str, + databricks_client_secret: str, +) -> str: + """ + Get OAuth token for Databricks service principal authentication. + + Args: + index_name (str): Name of the Databricks vector search index to query + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. + databricks_client_id (str): Databricks service principal id. If not specified, + the token is resolved from the current environment (DATABRICKS_CLIENT_ID). + databricks_client_secret (str): Databricks service principal secret. If not specified, + the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). + + Returns: + str: OAuth token. + """ + + authorization_details = { + "type": "unity_catalog_permission", + "securable_type": "table", + "securable_object_name": index_name, + "operation": "ReadVectorIndex" + } + + authorization_details_list = [authorization_details] + + token_url = f"{databricks_endpoint}/oidc/v1/token" + + data = { + 'grant_type': 'client_credentials', + 'scope': 'all-apis', + 'authorization_details': json.dumps(authorization_details_list) + } + + response = requests.post( + token_url, + auth=(databricks_client_id, databricks_client_secret), + data=data, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) + + response.raise_for_status() + return response.json()['access_token'] \ No newline at end of file From aedce911ba15679f8eccbd0fe117c8447b832fb3 Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 3 Jun 2025 11:28:13 -0400 Subject: [PATCH 08/10] Updated for Oauth support via REST API Updated print statements for which auth method is invoked. Updated docstrings. Created helper function: _get_oauth_token Updated logic for databricks token to use oauth token with SP when using the REST API with client secret and client id --- dspy/retrieve/databricks_rm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 5e200a6de8..62dd788575 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -276,6 +276,8 @@ def forward( columns=self.columns, databricks_token=self.databricks_token, databricks_endpoint=self.databricks_endpoint, + databricks_client_id=self.databricks_client_id, + databricks_client_secret=self.databricks_client_secret, query_type=query_type, query_text=query_text, query_vector=query_vector, @@ -446,6 +448,7 @@ def _query_via_requests( if databricks_client_id and databricks_client_secret: try: + print("Retrieving OAuth token using service principal authentication.") databricks_token = _get_oauth_token( index_name, databricks_endpoint, databricks_client_id, databricks_client_secret ) From 594fb04d6de787902a0adc0394a65c6e7847599a Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 3 Jun 2025 11:28:13 -0400 Subject: [PATCH 09/10] Updated for Oauth support via REST API Updated print statements for which auth method is invoked. Updated docstrings. Created helper function: _get_oauth_token Updated logic for databricks token to use oauth token with SP when using the REST API with client secret and client id --- dspy/retrieve/databricks_rm.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 5e200a6de8..4db70fb132 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -9,8 +9,12 @@ import dspy from dspy.primitives.prediction import Prediction -_databricks_sdk_installed = find_spec("databricks.sdk") is not None +_databricks_sdk_installed = False +try: + _databricks_sdk_installed = find_spec("databricks.sdk") is not None +except ModuleNotFoundError: + _databricks_sdk_installed = False @dataclass class Document: @@ -137,11 +141,10 @@ def __init__( databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID") ) self.databricks_client_secret = ( - databricks_client_secret - if databricks_client_secret is not None - else os.environ.get("DATABRICKS_CLIENT_SECRET") + databricks_client_secret if databricks_client_secret is not None else os.environ.get("DATABRICKS_CLIENT_SECRET") ) - if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: + if not _databricks_sdk_installed and ((self.databricks_token, self.databricks_endpoint).count(None) > 0 + and (self.databricks_client_id, self.databricks_client_secret).count(None) > 0): raise ValueError( "To retrieve documents with Databricks Vector Search, you must install the" " databricks-sdk Python library, supply the databricks_token and" @@ -276,6 +279,8 @@ def forward( columns=self.columns, databricks_token=self.databricks_token, databricks_endpoint=self.databricks_endpoint, + databricks_client_id=self.databricks_client_id, + databricks_client_secret=self.databricks_client_secret, query_type=query_type, query_text=query_text, query_vector=query_vector, @@ -446,6 +451,7 @@ def _query_via_requests( if databricks_client_id and databricks_client_secret: try: + print("Retrieving OAuth token using service principal authentication.") databricks_token = _get_oauth_token( index_name, databricks_endpoint, databricks_client_id, databricks_client_secret ) From 0e5584127469a353ba346c9bd4b4ffdc69e39bb7 Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 3 Jun 2025 14:25:21 -0400 Subject: [PATCH 10/10] Updated for Databricks Oauth support via REST API Updated print statements for which auth method is invoked. Updated docstrings. Created helper function: _get_oauth_token Updated logic for databricks token to use oauth token with SP when using the REST API with client secret and client id --- dspy/retrieve/databricks_rm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 4db70fb132..d092a10b11 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -457,7 +457,10 @@ def _query_via_requests( ) except Exception as e: raise ValueError( - f"Failed to retrieve OAuth token. Please check your Databricks client ID and secret. Error: {e}" + f"Failed to retrieve OAuth token. Please check your Databricks client ID and secret. \n" + f"Error: {e} \n \n" + f"NOTE: If you are experiencing a 401 error, be sure to check the permissions on the index that " + f"you are trying to query. The service principal must have the select permission on the index." ) headers = {