From 85508de8c016fa578c911861394585f81053c116 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Tue, 29 Apr 2025 23:18:25 -0700 Subject: [PATCH 1/9] draft Signed-off-by: Ann Zhang --- .../databricks_langchain/vector_search_retriever_tool.py | 5 +++-- .../langchain/src/databricks_langchain/vectorstores.py | 3 +++ .../databricks_llamaindex/vector_search_retriever_tool.py | 4 +++- .../src/databricks_openai/vector_search_retriever_tool.py | 6 +++++- .../tests/unit_tests/test_vector_search_retriever_tool.py | 4 ++++ src/databricks_ai_bridge/vector_search_retriever_tool.py | 2 +- 6 files changed, 19 insertions(+), 5 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index c8caec28..f5ff8a81 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -65,7 +65,8 @@ def _validate_tool_inputs(self): return self @vector_search_retriever_tool_trace - def _run(self, query: str) -> str: + def _run(self, query: str, **kwargs) -> str: + combined_kwargs = {**kwargs, **(self.model_extra or {})} return self._vector_store.similarity_search( - query, k=self.num_results, filter=self.filters, query_type=self.query_type + query, k=self.num_results, filter=self.filters, query_type=self.query_type, **combined_kwargs ) diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index be1012f7..ba02ffb5 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -465,6 +465,7 @@ def similarity_search_with_score( filters=filter, num_results=k, query_type=query_type, + **kwargs, ) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=Document @@ -577,6 +578,7 @@ def similarity_search_by_vector_with_score( filters=filter, num_results=k, query_type=query_type, + **kwargs, ) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=Document @@ -696,6 +698,7 @@ def max_marginal_relevance_search_by_vector( filters=filter, num_results=fetch_k, query_type=query_type, + **kwargs ) embeddings_result_index = ( diff --git a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py index 9179fdf2..c99aaa29 100644 --- a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py +++ b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py @@ -76,7 +76,7 @@ def __init__(self, **data): ) # Define the similarity search function - def similarity_search(query: str) -> List[Dict[str, Any]]: + def similarity_search(query: str, **kwargs: Any) -> List[Dict[str, Any]]: def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[float]]]: if self._index_details.is_databricks_managed_embeddings(): if self.embedding: @@ -105,6 +105,7 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa return text, vector query_text, query_vector = get_query_text_vector(query) + combined_kwargs = {**kwargs, **(self.model_extra or {})} search_resp = self._index.similarity_search( columns=self.columns, query_text=query_text, @@ -112,6 +113,7 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa filters=self.filters, num_results=self.num_results, query_type=self.query_type, + **combined_kwargs, ) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=dict diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 17a87c88..adefe5a5 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Any from databricks.vector_search.client import VectorSearchIndex from databricks_ai_bridge.utils.vector_search import ( @@ -162,6 +162,7 @@ def execute( self, query: str, openai_client: OpenAI = None, + **kwargs: Any ) -> List[Dict]: """ Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the @@ -202,6 +203,8 @@ def execute( f"Expected embedding dimension {index_embedding_dimension} but got {len(query_vector)}" ) + combined_kwargs = {**kwargs, **(self.model_extra or {})} + search_resp = self._index.similarity_search( columns=self.columns, query_text=query_text, @@ -209,6 +212,7 @@ def execute( filters=self.filters, num_results=self.num_results, query_type=self.query_type, + **combined_kwargs ) docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( search_resp=search_resp, diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index a28aabe0..35062823 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -298,3 +298,7 @@ def test_vector_search_client_non_model_serving_environment(): workspace_client=w, ) mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None) + + +def test_pass_in_kwargs(): + pass diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 7d944565..01787828 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -50,7 +50,7 @@ class VectorSearchRetrieverToolMixin(BaseModel): implementations should follow. """ - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") index_name: str = Field( ..., description="The name of the index to use, format: 'catalog.schema.index'." ) From 059469e00fe12df4e1c18be6533f95c22d729a99 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 30 Apr 2025 18:43:05 -0700 Subject: [PATCH 2/9] update Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 6 ++- .../src/databricks_langchain/vectorstores.py | 2 +- .../test_vector_search_retriever_tool.py | 44 ++++++++++++++----- .../tests/unit_tests/test_vectorstores.py | 18 -------- .../test_vector_search_retriever_tool.py | 41 ++++++++++++----- .../vector_search_retriever_tool.py | 11 ++--- .../test_vector_search_retriever_tool.py | 38 +++++++++++----- .../vector_search_retriever_tool.py | 1 + 8 files changed, 100 insertions(+), 61 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index f5ff8a81..db81184f 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -68,5 +68,9 @@ def _validate_tool_inputs(self): def _run(self, query: str, **kwargs) -> str: combined_kwargs = {**kwargs, **(self.model_extra or {})} return self._vector_store.similarity_search( - query, k=self.num_results, filter=self.filters, query_type=self.query_type, **combined_kwargs + query, + k=self.num_results, + filter=self.filters, + query_type=self.query_type, + **combined_kwargs, ) diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index ba02ffb5..100ac9db 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -698,7 +698,7 @@ def max_marginal_relevance_search_by_vector( filters=filter, num_results=fetch_k, query_type=query_type, - **kwargs + **kwargs, ) embeddings_result_index = ( diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index 7d11192e..016e5a00 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,8 +1,8 @@ import json import os import threading -from typing import Any, Dict, List, Optional -from unittest.mock import patch +from typing import Any, List, Optional +from unittest.mock import MagicMock, patch import mlflow import pytest @@ -49,17 +49,20 @@ def init_vector_search_tool( text_column: Optional[str] = None, doc_uri: Optional[str] = None, primary_key: Optional[str] = None, + **kwargs: Any, ) -> VectorSearchRetrieverTool: - kwargs: Dict[str, Any] = { - "index_name": index_name, - "columns": columns, - "tool_name": tool_name, - "tool_description": tool_description, - "embedding": embedding, - "text_column": text_column, - "doc_uri": doc_uri, - "primary_key": primary_key, - } + kwargs.update( + { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "embedding": embedding, + "text_column": text_column, + "doc_uri": doc_uri, + "primary_key": primary_key, + } + ) if index_name != DELTA_SYNC_INDEX: kwargs.update( { @@ -270,3 +273,20 @@ def test_vector_search_client_non_model_serving_environment(): workspace_client=w, ) mockVSClient.assert_called_once_with(disable_notice=True) + + +def test_kwargs_are_passed_through() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) + vector_search_tool._vector_store.similarity_search = MagicMock() + + vector_search_tool.invoke( + {"query": "what cities are in Germany", "extra_param": "something random"}, + ) + vector_search_tool._vector_store.similarity_search.assert_called_once_with( + "what cities are in Germany", + k=vector_search_tool.num_results, + query_type=vector_search_tool.query_type, + filter=None, + score_threshold=0.5, + extra_param="something random", + ) diff --git a/integrations/langchain/tests/unit_tests/test_vectorstores.py b/integrations/langchain/tests/unit_tests/test_vectorstores.py index b352b548..c728091f 100644 --- a/integrations/langchain/tests/unit_tests/test_vectorstores.py +++ b/integrations/langchain/tests/unit_tests/test_vectorstores.py @@ -304,24 +304,6 @@ def test_similarity_search_hybrid(index_name: str) -> None: assert all(["id" in d.metadata for d in search_result]) -def test_similarity_search_both_filter_and_filters_passed() -> None: - vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) - query = "foo" - filter = {"some filter": True} - filters = {"some other filter": False} - - vectorsearch.similarity_search(query, filter=filter, filters=filters) - vectorsearch.index.similarity_search.assert_called_once_with( - columns=["id", "text"], - query_vector=EMBEDDING_MODEL.embed_query(query), - # `filter` should prevail over `filters` - filters=filter, - num_results=4, - query_text=None, - query_type=None, - ) - - @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) @pytest.mark.parametrize( "columns, expected_columns", diff --git a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py index 6378d607..e35ddfd7 100644 --- a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,7 +1,7 @@ import os import threading -from typing import Any, Dict, List, Optional -from unittest.mock import patch +from typing import Any, List, Optional +from unittest.mock import MagicMock, patch import pytest from databricks.sdk import WorkspaceClient @@ -52,15 +52,18 @@ def init_vector_search_tool( tool_description: Optional[str] = None, embedding: Optional[BaseEmbedding] = None, text_column: Optional[str] = None, + **kwargs: Any, ) -> VectorSearchRetrieverTool: - kwargs: Dict[str, Any] = { - "index_name": index_name, - "columns": columns, - "tool_name": tool_name, - "tool_description": tool_description, - "embedding": embedding, - "text_column": text_column, - } + kwargs.update( + { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "embedding": embedding, + "text_column": text_column, + } + ) if index_name != DELTA_SYNC_INDEX: kwargs.update( { @@ -176,3 +179,21 @@ def test_vector_search_client_non_model_serving_environment(): workspace_client=w, ) mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None) + + +def test_kwargs_are_passed_through() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) + vector_search_tool._index.similarity_search = MagicMock() + + vector_search_tool.call(query="what cities are in Germany", extra_param="something random") + vector_search_tool._index.similarity_search.assert_called_once_with( + columns=vector_search_tool.columns, + query_text="what cities are in Germany", + num_results=vector_search_tool.num_results, + query_type=vector_search_tool.query_type, + query_vector=None, + filters=None, + score_threshold=0.5, + extra_param="something random", + requires_context=False, + ) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index adefe5a5..bccfe21d 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Tuple, Any +from typing import Any, Dict, List, Optional, Tuple from databricks.vector_search.client import VectorSearchIndex from databricks_ai_bridge.utils.vector_search import ( @@ -158,12 +158,7 @@ def _validate_tool_inputs(self): return self @vector_search_retriever_tool_trace - def execute( - self, - query: str, - openai_client: OpenAI = None, - **kwargs: Any - ) -> List[Dict]: + def execute(self, query: str, openai_client: OpenAI = None, **kwargs: Any) -> List[Dict]: """ Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages. @@ -212,7 +207,7 @@ def execute( filters=self.filters, num_results=self.num_results, query_type=self.query_type, - **combined_kwargs + **combined_kwargs, ) docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( search_resp=search_resp, diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 35062823..ddb834a6 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,7 +1,7 @@ import json import os import threading -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from unittest.mock import MagicMock, Mock, patch import mlflow @@ -89,15 +89,18 @@ def init_vector_search_tool( tool_description: Optional[str] = None, text_column: Optional[str] = None, embedding_model_name: Optional[str] = None, + **kwargs: Any, ) -> VectorSearchRetrieverTool: - kwargs: Dict[str, Any] = { - "index_name": index_name, - "columns": columns, - "tool_name": tool_name, - "tool_description": tool_description, - "text_column": text_column, - "embedding_model_name": embedding_model_name, - } + kwargs.update( + { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "text_column": text_column, + "embedding_model_name": embedding_model_name, + } + ) if index_name != DELTA_SYNC_INDEX: kwargs.update( { @@ -300,5 +303,18 @@ def test_vector_search_client_non_model_serving_environment(): mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None) -def test_pass_in_kwargs(): - pass +def test_kwargs_are_passed_through() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) + vector_search_tool._index.similarity_search = MagicMock() + + vector_search_tool.execute(query="what cities are in Germany", extra_param="something random") + vector_search_tool._index.similarity_search.assert_called_once_with( + columns=vector_search_tool.columns, + query_text="what cities are in Germany", + num_results=vector_search_tool.num_results, + query_type=vector_search_tool.query_type, + query_vector=None, + filters=None, + score_threshold=0.5, + extra_param="something random", + ) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 01787828..bc3ec9a9 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -37,6 +37,7 @@ def wrapper(self, *args, **kwargs): class VectorSearchRetrieverToolInput(BaseModel): + model_config = ConfigDict(extra="allow") query: str = Field( description="The string used to query the index with and identify the most similar " "vectors and return the associated documents." From 8091c1da503feb46f216f310298f0eb8ee9efac9 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 30 Apr 2025 21:44:07 -0700 Subject: [PATCH 3/9] update Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 17 ++++++------ .../src/databricks_langchain/vectorstores.py | 21 ++++++++------- .../tests/unit_tests/test_vectorstores.py | 26 +++++++++++++++++++ .../test_utils/vector_search.py | 4 +-- 4 files changed, 49 insertions(+), 19 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index db81184f..b7441d79 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -66,11 +66,12 @@ def _validate_tool_inputs(self): @vector_search_retriever_tool_trace def _run(self, query: str, **kwargs) -> str: - combined_kwargs = {**kwargs, **(self.model_extra or {})} - return self._vector_store.similarity_search( - query, - k=self.num_results, - filter=self.filters, - query_type=self.query_type, - **combined_kwargs, - ) + kwargs = {**kwargs, **(self.model_extra or {})} + # Ensure that we don't have duplicate keys + kwargs.update({ + "query": query, + "k": self.num_results, + "filter": self.filters, + "query_type": self.query_type + }) + return self._vector_store.similarity_search(**kwargs) diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index 100ac9db..092d41ff 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import inspect import logging import uuid from functools import partial @@ -458,15 +459,17 @@ def similarity_search_with_score( query_text = None query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr] - search_resp = self.index.similarity_search( - columns=self._columns, - query_text=query_text, - query_vector=query_vector, - filters=filter, - num_results=k, - query_type=query_type, - **kwargs, - ) + signature = inspect.signature(self.index.similarity_search) + kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} + kwargs.update({ + "columns": self._columns, + "query_text": query_text, + "query_vector": query_vector, + "filters": filter, + "num_results": k, + "query_type": query_type, + }) + search_resp = self.index.similarity_search(**kwargs) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=Document ) diff --git a/integrations/langchain/tests/unit_tests/test_vectorstores.py b/integrations/langchain/tests/unit_tests/test_vectorstores.py index c728091f..d92e9cb6 100644 --- a/integrations/langchain/tests/unit_tests/test_vectorstores.py +++ b/integrations/langchain/tests/unit_tests/test_vectorstores.py @@ -304,6 +304,32 @@ def test_similarity_search_hybrid(index_name: str) -> None: assert all(["id" in d.metadata for d in search_result]) +def test_similarity_search_passing_kwargs() -> None: + vectorsearch = init_vector_search(DELTA_SYNC_INDEX) + query = "foo" + filters = {"some filter": True} + query_type="ANN" + + search_result = vectorsearch.similarity_search( + query, + k=5, + filter=filters, + query_type=query_type, + score_threshold=0.5, + num_results=10, + random_parameters="not included" + ) + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_text=query, + query_vector=None, + filters=filters, + query_type=query_type, + num_results=5, # maintained + score_threshold=0.5 # passed + ) + + @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) @pytest.mark.parametrize( "columns, expected_columns", diff --git a/src/databricks_ai_bridge/test_utils/vector_search.py b/src/databricks_ai_bridge/test_utils/vector_search.py index a5e18291..3eb440c2 100644 --- a/src/databricks_ai_bridge/test_utils/vector_search.py +++ b/src/databricks_ai_bridge/test_utils/vector_search.py @@ -1,7 +1,7 @@ import uuid from typing import Generator, List, Optional from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import create_autospec, MagicMock, patch import pytest from databricks.vector_search.client import VectorSearchIndex # type: ignore @@ -132,7 +132,7 @@ def _get_index( endpoint_name: Optional[str] = None, index_name: str = None, # type: ignore ) -> MagicMock: - index = MagicMock(spec=VectorSearchIndex) + index = create_autospec(VectorSearchIndex, instance=True) if index_name not in INDEX_DETAILS: index_name = DIRECT_ACCESS_INDEX index.describe.return_value = INDEX_DETAILS[index_name] From fb90218ae49c14fe18d34a9e078184a62218ef92 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Tue, 6 May 2025 11:24:31 -0700 Subject: [PATCH 4/9] update Signed-off-by: Ann Zhang --- .../test_vector_search_retriever_tool.py | 8 +++---- .../vector_search_retriever_tool.py | 22 ++++++++++--------- .../test_vector_search_retriever_tool.py | 6 +++-- .../test_vector_search_retriever_tool.py | 4 ++-- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index 81a8f710..e76a668d 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -99,7 +99,7 @@ def test_filters_are_passed_through() -> None: {"query": "what cities are in Germany", "filters": {"country": "Germany"}} ) vector_search_tool._vector_store.similarity_search.assert_called_once_with( - "what cities are in Germany", + query="what cities are in Germany", k=vector_search_tool.num_results, filter={"country": "Germany"}, query_type=vector_search_tool.query_type, @@ -114,7 +114,7 @@ def test_filters_are_combined() -> None: {"query": "what cities are in Germany", "filters": {"country": "Germany"}} ) vector_search_tool._vector_store.similarity_search.assert_called_once_with( - "what cities are in Germany", + query="what cities are in Germany", k=vector_search_tool.num_results, filter={"city LIKE": "Berlin", "country": "Germany"}, query_type=vector_search_tool.query_type, @@ -315,10 +315,10 @@ def test_kwargs_are_passed_through() -> None: {"query": "what cities are in Germany", "extra_param": "something random"}, ) vector_search_tool._vector_store.similarity_search.assert_called_once_with( - "what cities are in Germany", + query="what cities are in Germany", k=vector_search_tool.num_results, query_type=vector_search_tool.query_type, - filter=None, + filter={}, score_threshold=0.5, extra_param="something random", ) diff --git a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py index 8812517d..f9f57626 100644 --- a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py +++ b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py @@ -107,17 +107,19 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa return text, vector query_text, query_vector = get_query_text_vector(query) - combined_kwargs = {**kwargs, **(self.model_extra or {})} + kwargs = {**kwargs, **(self.model_extra or {})} combined_filters = {**(filters or {}), **(self.filters or {})} - search_resp = self._index.similarity_search( - columns=self.columns, - query_text=query_text, - query_vector=query_vector, - filters=combined_filters, - num_results=self.num_results, - query_type=self.query_type, - **combined_kwargs, - ) + # Ensure that we don't have duplicate keys + kwargs.update({ + "query_text": query_text, + "query_vector": query_vector, + "columns": self.columns, + "filters": combined_filters, + "num_results": self.num_results, + "query_type": self.query_type + }) + print(kwargs) + search_resp = self._index.similarity_search(**kwargs) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=dict ) diff --git a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py index 3c7f09f4..592c4cc6 100644 --- a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py @@ -194,13 +194,13 @@ def test_kwargs_are_passed_through() -> None: num_results=vector_search_tool.num_results, query_type=vector_search_tool.query_type, query_vector=None, - filters=None, + filters={}, score_threshold=0.5, extra_param="something random", requires_context=False, ) - + def test_filters_are_passed_through() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) vector_search_tool._index.similarity_search = MagicMock() @@ -213,6 +213,7 @@ def test_filters_are_passed_through() -> None: num_results=vector_search_tool.num_results, query_type=vector_search_tool.query_type, query_vector=None, + requires_context=False, ) @@ -228,4 +229,5 @@ def test_filters_are_combined() -> None: num_results=vector_search_tool.num_results, query_type=vector_search_tool.query_type, query_vector=None, + requires_context=False, ) diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index e3c63da6..98053b4b 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,7 +1,7 @@ import json import os import threading -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from unittest.mock import MagicMock, Mock, patch import mlflow @@ -316,7 +316,7 @@ def test_kwargs_are_passed_through() -> None: num_results=vector_search_tool.num_results, query_type=vector_search_tool.query_type, query_vector=None, - filters=None, + filters={}, score_threshold=0.5, extra_param="something random", ) From 928067f6cc13a772442510b5329e5ce205e287a8 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Tue, 6 May 2025 11:35:33 -0700 Subject: [PATCH 5/9] remove requires_context Signed-off-by: Ann Zhang --- .../databricks_llamaindex/vector_search_retriever_tool.py | 7 +++++-- .../tests/unit_tests/test_vector_search_retriever_tool.py | 3 --- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py index f9f57626..0530407b 100644 --- a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py +++ b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py @@ -107,7 +107,11 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa return text, vector query_text, query_vector = get_query_text_vector(query) - kwargs = {**kwargs, **(self.model_extra or {})} + filtered_model_extra = { + k: v for k, v in (self.model_extra or {}).items() + if k != "requires_context" # don't include extra parameters set by FunctionTool + } + kwargs = {**kwargs, **filtered_model_extra} combined_filters = {**(filters or {}), **(self.filters or {})} # Ensure that we don't have duplicate keys kwargs.update({ @@ -118,7 +122,6 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa "num_results": self.num_results, "query_type": self.query_type }) - print(kwargs) search_resp = self._index.similarity_search(**kwargs) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=dict diff --git a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py index 592c4cc6..b36d24f9 100644 --- a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py @@ -197,7 +197,6 @@ def test_kwargs_are_passed_through() -> None: filters={}, score_threshold=0.5, extra_param="something random", - requires_context=False, ) @@ -213,7 +212,6 @@ def test_filters_are_passed_through() -> None: num_results=vector_search_tool.num_results, query_type=vector_search_tool.query_type, query_vector=None, - requires_context=False, ) @@ -229,5 +227,4 @@ def test_filters_are_combined() -> None: num_results=vector_search_tool.num_results, query_type=vector_search_tool.query_type, query_vector=None, - requires_context=False, ) From 36a1dbb40755a58747074e2720d1438f2f8e36ec Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Tue, 6 May 2025 11:35:49 -0700 Subject: [PATCH 6/9] ruff Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 14 ++++++----- .../src/databricks_langchain/vectorstores.py | 18 ++++++++------- .../tests/unit_tests/test_vectorstores.py | 8 +++---- .../vector_search_retriever_tool.py | 23 +++++++++++-------- .../vector_search_retriever_tool.py | 2 +- .../test_utils/vector_search.py | 2 +- 6 files changed, 37 insertions(+), 30 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 34232505..9e593aa5 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -69,10 +69,12 @@ def _run(self, query: str, filters: Optional[Dict[str, Any]] = None, **kwargs) - kwargs = {**kwargs, **(self.model_extra or {})} combined_filters = {**(filters or {}), **(self.filters or {})} # Ensure that we don't have duplicate keys - kwargs.update({ - "query": query, - "k": self.num_results, - "filter": combined_filters, - "query_type": self.query_type - }) + kwargs.update( + { + "query": query, + "k": self.num_results, + "filter": combined_filters, + "query_type": self.query_type, + } + ) return self._vector_store.similarity_search(**kwargs) diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index 092d41ff..c0d35081 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -461,14 +461,16 @@ def similarity_search_with_score( signature = inspect.signature(self.index.similarity_search) kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} - kwargs.update({ - "columns": self._columns, - "query_text": query_text, - "query_vector": query_vector, - "filters": filter, - "num_results": k, - "query_type": query_type, - }) + kwargs.update( + { + "columns": self._columns, + "query_text": query_text, + "query_vector": query_vector, + "filters": filter, + "num_results": k, + "query_type": query_type, + } + ) search_resp = self.index.similarity_search(**kwargs) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=Document diff --git a/integrations/langchain/tests/unit_tests/test_vectorstores.py b/integrations/langchain/tests/unit_tests/test_vectorstores.py index d92e9cb6..89db8c91 100644 --- a/integrations/langchain/tests/unit_tests/test_vectorstores.py +++ b/integrations/langchain/tests/unit_tests/test_vectorstores.py @@ -308,7 +308,7 @@ def test_similarity_search_passing_kwargs() -> None: vectorsearch = init_vector_search(DELTA_SYNC_INDEX) query = "foo" filters = {"some filter": True} - query_type="ANN" + query_type = "ANN" search_result = vectorsearch.similarity_search( query, @@ -317,7 +317,7 @@ def test_similarity_search_passing_kwargs() -> None: query_type=query_type, score_threshold=0.5, num_results=10, - random_parameters="not included" + random_parameters="not included", ) vectorsearch.index.similarity_search.assert_called_once_with( columns=["id", "text"], @@ -325,8 +325,8 @@ def test_similarity_search_passing_kwargs() -> None: query_vector=None, filters=filters, query_type=query_type, - num_results=5, # maintained - score_threshold=0.5 # passed + num_results=5, # maintained + score_threshold=0.5, # passed ) diff --git a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py index 0530407b..3efebe41 100644 --- a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py +++ b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py @@ -108,20 +108,23 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa query_text, query_vector = get_query_text_vector(query) filtered_model_extra = { - k: v for k, v in (self.model_extra or {}).items() - if k != "requires_context" # don't include extra parameters set by FunctionTool + k: v + for k, v in (self.model_extra or {}).items() + if k != "requires_context" # don't include extra parameters set by FunctionTool } kwargs = {**kwargs, **filtered_model_extra} combined_filters = {**(filters or {}), **(self.filters or {})} # Ensure that we don't have duplicate keys - kwargs.update({ - "query_text": query_text, - "query_vector": query_vector, - "columns": self.columns, - "filters": combined_filters, - "num_results": self.num_results, - "query_type": self.query_type - }) + kwargs.update( + { + "query_text": query_text, + "query_vector": query_vector, + "columns": self.columns, + "filters": combined_filters, + "num_results": self.num_results, + "query_type": self.query_type, + } + ) search_resp = self._index.similarity_search(**kwargs) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=dict diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 999cb6f2..4dc028f2 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -167,7 +167,7 @@ def execute( query: str, filters: Optional[Dict[str, Any]] = None, openai_client: OpenAI = None, - **kwargs: Any + **kwargs: Any, ) -> List[Dict]: """ Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the diff --git a/src/databricks_ai_bridge/test_utils/vector_search.py b/src/databricks_ai_bridge/test_utils/vector_search.py index 0126a501..72aa7c29 100644 --- a/src/databricks_ai_bridge/test_utils/vector_search.py +++ b/src/databricks_ai_bridge/test_utils/vector_search.py @@ -1,7 +1,7 @@ import uuid from typing import Generator, List, Optional from unittest import mock -from unittest.mock import create_autospec, MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch import pytest from databricks.vector_search.client import VectorSearchIndex # type: ignore From e11106dd79c022957d6477010143ccbce3e3e897 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Tue, 6 May 2025 14:23:51 -0700 Subject: [PATCH 7/9] fix openai Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 4dc028f2..36cf1e6b 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -208,17 +208,19 @@ def execute( f"Expected embedding dimension {index_embedding_dimension} but got {len(query_vector)}" ) - combined_kwargs = {**kwargs, **(self.model_extra or {})} + kwargs = {**kwargs, **(self.model_extra or {})} combined_filters = {**(filters or {}), **(self.filters or {})} - search_resp = self._index.similarity_search( - columns=self.columns, - query_text=query_text, - query_vector=query_vector, - filters=combined_filters, - num_results=self.num_results, - query_type=self.query_type, - **combined_kwargs, + kwargs.update( + { + "query_text": query_text, + "query_vector": query_vector, + "columns": self.columns, + "filters": combined_filters, + "num_results": self.num_results, + "query_type": self.query_type, + } ) + search_resp = self._index.similarity_search(**kwargs) docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( search_resp=search_resp, retriever_schema=self._retriever_schema, From 756c02be59ceb58e8356bc48caa2324c22802214 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Tue, 6 May 2025 16:31:34 -0700 Subject: [PATCH 8/9] update Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 12 ++++++------ .../vector_search_retriever_tool.py | 6 +++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py index 3efebe41..b84085b1 100644 --- a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py +++ b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Dict, List, Optional, Tuple from databricks_ai_bridge.utils.vector_search import ( @@ -107,13 +108,12 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa return text, vector query_text, query_vector = get_query_text_vector(query) - filtered_model_extra = { - k: v - for k, v in (self.model_extra or {}).items() - if k != "requires_context" # don't include extra parameters set by FunctionTool - } - kwargs = {**kwargs, **filtered_model_extra} combined_filters = {**(filters or {}), **(self.filters or {})} + + signature = inspect.signature(self.index.similarity_search) + kwargs = {**kwargs, **(self.model_extra or {})} + kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} + # Ensure that we don't have duplicate keys kwargs.update( { diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 36cf1e6b..2c012797 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,3 +1,4 @@ +import inspect import logging from typing import Any, Dict, List, Optional, Tuple @@ -208,8 +209,11 @@ def execute( f"Expected embedding dimension {index_embedding_dimension} but got {len(query_vector)}" ) - kwargs = {**kwargs, **(self.model_extra or {})} combined_filters = {**(filters or {}), **(self.filters or {})} + + signature = inspect.signature(self.index.similarity_search) + kwargs = {**kwargs, **(self.model_extra or {})} + kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} kwargs.update( { "query_text": query_text, From 873e5a5cc62ac9ac7e4f8ab68de2a8f70ed42ad6 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Tue, 6 May 2025 16:47:04 -0700 Subject: [PATCH 9/9] fix Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 6 +++--- .../unit_tests/test_vector_search_retriever_tool.py | 12 ++++++++---- .../vector_search_retriever_tool.py | 4 ++-- .../unit_tests/test_vector_search_retriever_tool.py | 12 ++++++++---- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py index b84085b1..c7ae3243 100644 --- a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py +++ b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py @@ -109,11 +109,11 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa query_text, query_vector = get_query_text_vector(query) combined_filters = {**(filters or {}), **(self.filters or {})} - - signature = inspect.signature(self.index.similarity_search) + + signature = inspect.signature(self._index.similarity_search) kwargs = {**kwargs, **(self.model_extra or {})} kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} - + # Ensure that we don't have duplicate keys kwargs.update( { diff --git a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py index b36d24f9..8b2eb4d9 100644 --- a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,11 +1,12 @@ import os import threading from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch import pytest from databricks.sdk import WorkspaceClient from databricks.sdk.credentials_provider import ModelServingUserCredentials +from databricks.vector_search.client import VectorSearchIndex from databricks.vector_search.utils import CredentialStrategy from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, @@ -185,9 +186,12 @@ def test_vector_search_client_non_model_serving_environment(): def test_kwargs_are_passed_through() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) - vector_search_tool._index.similarity_search = MagicMock() + vector_search_tool._index = create_autospec(VectorSearchIndex, instance=True) - vector_search_tool.call(query="what cities are in Germany", extra_param="something random") + # extra_param is ignored because it isn't part of the signature for similarity_search + vector_search_tool.call( + query="what cities are in Germany", debug_level=2, extra_param="something random" + ) vector_search_tool._index.similarity_search.assert_called_once_with( columns=vector_search_tool.columns, query_text="what cities are in Germany", @@ -196,7 +200,7 @@ def test_kwargs_are_passed_through() -> None: query_vector=None, filters={}, score_threshold=0.5, - extra_param="something random", + debug_level=2, ) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 2c012797..10571270 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -210,8 +210,8 @@ def execute( ) combined_filters = {**(filters or {}), **(self.filters or {})} - - signature = inspect.signature(self.index.similarity_search) + + signature = inspect.signature(self._index.similarity_search) kwargs = {**kwargs, **(self.model_extra or {})} kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} kwargs.update( diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 98053b4b..b53b588a 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -2,12 +2,13 @@ import os import threading from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, create_autospec, patch import mlflow import pytest from databricks.sdk import WorkspaceClient from databricks.sdk.credentials_provider import ModelServingUserCredentials +from databricks.vector_search.client import VectorSearchIndex from databricks.vector_search.utils import CredentialStrategy from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, @@ -307,9 +308,12 @@ def test_vector_search_client_non_model_serving_environment(): def test_kwargs_are_passed_through() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) - vector_search_tool._index.similarity_search = MagicMock() + vector_search_tool._index = create_autospec(VectorSearchIndex, instance=True) - vector_search_tool.execute(query="what cities are in Germany", extra_param="something random") + # extra_param is ignored because it isn't part of the signature for similarity_search + vector_search_tool.execute( + query="what cities are in Germany", debug_level=2, extra_param="something random" + ) vector_search_tool._index.similarity_search.assert_called_once_with( columns=vector_search_tool.columns, query_text="what cities are in Germany", @@ -318,7 +322,7 @@ def test_kwargs_are_passed_through() -> None: query_vector=None, filters={}, score_threshold=0.5, - extra_param="something random", + debug_level=2, )