diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 0c1dcef..9e593aa 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -65,8 +65,16 @@ def _validate_tool_inputs(self): return self @vector_search_retriever_tool_trace - def _run(self, query: str, filters: Optional[Dict[str, Any]] = None) -> str: + def _run(self, query: str, filters: Optional[Dict[str, Any]] = None, **kwargs) -> str: + kwargs = {**kwargs, **(self.model_extra or {})} combined_filters = {**(filters or {}), **(self.filters or {})} - return self._vector_store.similarity_search( - query, k=self.num_results, filter=combined_filters, query_type=self.query_type + # Ensure that we don't have duplicate keys + kwargs.update( + { + "query": query, + "k": self.num_results, + "filter": combined_filters, + "query_type": self.query_type, + } ) + return self._vector_store.similarity_search(**kwargs) diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index be1012f..c0d3508 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import inspect import logging import uuid from functools import partial @@ -458,14 +459,19 @@ def similarity_search_with_score( query_text = None query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr] - search_resp = self.index.similarity_search( - columns=self._columns, - query_text=query_text, - query_vector=query_vector, - filters=filter, - num_results=k, - query_type=query_type, + signature = inspect.signature(self.index.similarity_search) + kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} + kwargs.update( + { + "columns": self._columns, + "query_text": query_text, + "query_vector": query_vector, + "filters": filter, + "num_results": k, + "query_type": query_type, + } ) + search_resp = self.index.similarity_search(**kwargs) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=Document ) @@ -577,6 +583,7 @@ def similarity_search_by_vector_with_score( filters=filter, num_results=k, query_type=query_type, + **kwargs, ) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=Document @@ -696,6 +703,7 @@ def max_marginal_relevance_search_by_vector( filters=filter, num_results=fetch_k, query_type=query_type, + **kwargs, ) embeddings_result_index = ( diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index e4ce414..e76a668 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -50,18 +50,21 @@ def init_vector_search_tool( doc_uri: Optional[str] = None, primary_key: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> VectorSearchRetrieverTool: - kwargs: Dict[str, Any] = { - "index_name": index_name, - "columns": columns, - "tool_name": tool_name, - "tool_description": tool_description, - "embedding": embedding, - "text_column": text_column, - "doc_uri": doc_uri, - "primary_key": primary_key, - "filters": filters, - } + kwargs.update( + { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "embedding": embedding, + "text_column": text_column, + "doc_uri": doc_uri, + "primary_key": primary_key, + "filters": filters, + } + ) if index_name != DELTA_SYNC_INDEX: kwargs.update( { @@ -96,7 +99,7 @@ def test_filters_are_passed_through() -> None: {"query": "what cities are in Germany", "filters": {"country": "Germany"}} ) vector_search_tool._vector_store.similarity_search.assert_called_once_with( - "what cities are in Germany", + query="what cities are in Germany", k=vector_search_tool.num_results, filter={"country": "Germany"}, query_type=vector_search_tool.query_type, @@ -111,7 +114,7 @@ def test_filters_are_combined() -> None: {"query": "what cities are in Germany", "filters": {"country": "Germany"}} ) vector_search_tool._vector_store.similarity_search.assert_called_once_with( - "what cities are in Germany", + query="what cities are in Germany", k=vector_search_tool.num_results, filter={"city LIKE": "Berlin", "country": "Germany"}, query_type=vector_search_tool.query_type, @@ -302,3 +305,20 @@ def test_vector_search_client_non_model_serving_environment(): workspace_client=w, ) mockVSClient.assert_called_once_with(disable_notice=True) + + +def test_kwargs_are_passed_through() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) + vector_search_tool._vector_store.similarity_search = MagicMock() + + vector_search_tool.invoke( + {"query": "what cities are in Germany", "extra_param": "something random"}, + ) + vector_search_tool._vector_store.similarity_search.assert_called_once_with( + query="what cities are in Germany", + k=vector_search_tool.num_results, + query_type=vector_search_tool.query_type, + filter={}, + score_threshold=0.5, + extra_param="something random", + ) diff --git a/integrations/langchain/tests/unit_tests/test_vectorstores.py b/integrations/langchain/tests/unit_tests/test_vectorstores.py index b352b54..89db8c9 100644 --- a/integrations/langchain/tests/unit_tests/test_vectorstores.py +++ b/integrations/langchain/tests/unit_tests/test_vectorstores.py @@ -304,21 +304,29 @@ def test_similarity_search_hybrid(index_name: str) -> None: assert all(["id" in d.metadata for d in search_result]) -def test_similarity_search_both_filter_and_filters_passed() -> None: - vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) +def test_similarity_search_passing_kwargs() -> None: + vectorsearch = init_vector_search(DELTA_SYNC_INDEX) query = "foo" - filter = {"some filter": True} - filters = {"some other filter": False} + filters = {"some filter": True} + query_type = "ANN" - vectorsearch.similarity_search(query, filter=filter, filters=filters) + search_result = vectorsearch.similarity_search( + query, + k=5, + filter=filters, + query_type=query_type, + score_threshold=0.5, + num_results=10, + random_parameters="not included", + ) vectorsearch.index.similarity_search.assert_called_once_with( columns=["id", "text"], - query_vector=EMBEDDING_MODEL.embed_query(query), - # `filter` should prevail over `filters` - filters=filter, - num_results=4, - query_text=None, - query_type=None, + query_text=query, + query_vector=None, + filters=filters, + query_type=query_type, + num_results=5, # maintained + score_threshold=0.5, # passed ) diff --git a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py index 06c4dfe..c7ae324 100644 --- a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py +++ b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Dict, List, Optional, Tuple from databricks_ai_bridge.utils.vector_search import ( @@ -77,7 +78,7 @@ def __init__(self, **data): # Define the similarity search function def similarity_search( - query: str, filters: Optional[Dict[str, Any]] = None + query: str, filters: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> List[Dict[str, Any]]: def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[float]]]: if self._index_details.is_databricks_managed_embeddings(): @@ -108,14 +109,23 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa query_text, query_vector = get_query_text_vector(query) combined_filters = {**(filters or {}), **(self.filters or {})} - search_resp = self._index.similarity_search( - columns=self.columns, - query_text=query_text, - query_vector=query_vector, - filters=combined_filters, - num_results=self.num_results, - query_type=self.query_type, + + signature = inspect.signature(self._index.similarity_search) + kwargs = {**kwargs, **(self.model_extra or {})} + kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} + + # Ensure that we don't have duplicate keys + kwargs.update( + { + "query_text": query_text, + "query_vector": query_vector, + "columns": self.columns, + "filters": combined_filters, + "num_results": self.num_results, + "query_type": self.query_type, + } ) + search_resp = self._index.similarity_search(**kwargs) return parse_vector_search_response( search_resp, self._retriever_schema, document_class=dict ) diff --git a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py index 9968049..8b2eb4d 100644 --- a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,11 +1,12 @@ import os import threading from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch import pytest from databricks.sdk import WorkspaceClient from databricks.sdk.credentials_provider import ModelServingUserCredentials +from databricks.vector_search.client import VectorSearchIndex from databricks.vector_search.utils import CredentialStrategy from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, @@ -53,16 +54,19 @@ def init_vector_search_tool( embedding: Optional[BaseEmbedding] = None, text_column: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> VectorSearchRetrieverTool: - kwargs: Dict[str, Any] = { - "index_name": index_name, - "columns": columns, - "tool_name": tool_name, - "tool_description": tool_description, - "embedding": embedding, - "text_column": text_column, - "filters": filters, - } + kwargs.update( + { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "embedding": embedding, + "text_column": text_column, + "filters": filters, + } + ) if index_name != DELTA_SYNC_INDEX: kwargs.update( { @@ -180,6 +184,26 @@ def test_vector_search_client_non_model_serving_environment(): mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None) +def test_kwargs_are_passed_through() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) + vector_search_tool._index = create_autospec(VectorSearchIndex, instance=True) + + # extra_param is ignored because it isn't part of the signature for similarity_search + vector_search_tool.call( + query="what cities are in Germany", debug_level=2, extra_param="something random" + ) + vector_search_tool._index.similarity_search.assert_called_once_with( + columns=vector_search_tool.columns, + query_text="what cities are in Germany", + num_results=vector_search_tool.num_results, + query_type=vector_search_tool.query_type, + query_vector=None, + filters={}, + score_threshold=0.5, + debug_level=2, + ) + + def test_filters_are_passed_through() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) vector_search_tool._index.similarity_search = MagicMock() diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index ec94291..1057127 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,3 +1,4 @@ +import inspect import logging from typing import Any, Dict, List, Optional, Tuple @@ -167,6 +168,7 @@ def execute( query: str, filters: Optional[Dict[str, Any]] = None, openai_client: OpenAI = None, + **kwargs: Any, ) -> List[Dict]: """ Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the @@ -208,14 +210,21 @@ def execute( ) combined_filters = {**(filters or {}), **(self.filters or {})} - search_resp = self._index.similarity_search( - columns=self.columns, - query_text=query_text, - query_vector=query_vector, - filters=combined_filters, - num_results=self.num_results, - query_type=self.query_type, + + signature = inspect.signature(self._index.similarity_search) + kwargs = {**kwargs, **(self.model_extra or {})} + kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} + kwargs.update( + { + "query_text": query_text, + "query_vector": query_vector, + "columns": self.columns, + "filters": combined_filters, + "num_results": self.num_results, + "query_type": self.query_type, + } ) + search_resp = self._index.similarity_search(**kwargs) docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( search_resp=search_resp, retriever_schema=self._retriever_schema, diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 5b15865..b53b588 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -2,12 +2,13 @@ import os import threading from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, create_autospec, patch import mlflow import pytest from databricks.sdk import WorkspaceClient from databricks.sdk.credentials_provider import ModelServingUserCredentials +from databricks.vector_search.client import VectorSearchIndex from databricks.vector_search.utils import CredentialStrategy from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, @@ -90,16 +91,19 @@ def init_vector_search_tool( text_column: Optional[str] = None, embedding_model_name: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> VectorSearchRetrieverTool: - kwargs: Dict[str, Any] = { - "index_name": index_name, - "columns": columns, - "tool_name": tool_name, - "tool_description": tool_description, - "text_column": text_column, - "embedding_model_name": embedding_model_name, - "filters": filters, - } + kwargs.update( + { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "text_column": text_column, + "embedding_model_name": embedding_model_name, + "filters": filters, + } + ) if index_name != DELTA_SYNC_INDEX: kwargs.update( { @@ -302,6 +306,26 @@ def test_vector_search_client_non_model_serving_environment(): mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None) +def test_kwargs_are_passed_through() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) + vector_search_tool._index = create_autospec(VectorSearchIndex, instance=True) + + # extra_param is ignored because it isn't part of the signature for similarity_search + vector_search_tool.execute( + query="what cities are in Germany", debug_level=2, extra_param="something random" + ) + vector_search_tool._index.similarity_search.assert_called_once_with( + columns=vector_search_tool.columns, + query_text="what cities are in Germany", + num_results=vector_search_tool.num_results, + query_type=vector_search_tool.query_type, + query_vector=None, + filters={}, + score_threshold=0.5, + debug_level=2, + ) + + def test_filters_are_passed_through() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) vector_search_tool._index.similarity_search = MagicMock() diff --git a/src/databricks_ai_bridge/test_utils/vector_search.py b/src/databricks_ai_bridge/test_utils/vector_search.py index 586b634..72aa7c2 100644 --- a/src/databricks_ai_bridge/test_utils/vector_search.py +++ b/src/databricks_ai_bridge/test_utils/vector_search.py @@ -1,7 +1,7 @@ import uuid from typing import Generator, List, Optional from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch import pytest from databricks.vector_search.client import VectorSearchIndex # type: ignore @@ -132,7 +132,7 @@ def _get_index( endpoint_name: Optional[str] = None, index_name: str = None, # type: ignore ) -> MagicMock: - index = MagicMock(spec=VectorSearchIndex) + index = create_autospec(VectorSearchIndex, instance=True) if index_name not in INDEX_DETAILS: index_name = DIRECT_ACCESS_INDEX index.describe.return_value = INDEX_DETAILS[index_name] diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index d0bf5b3..a289ef6 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -37,6 +37,7 @@ def wrapper(self, *args, **kwargs): class VectorSearchRetrieverToolInput(BaseModel): + model_config = ConfigDict(extra="allow") query: str = Field( description="The string used to query the index with and identify the most similar " "vectors and return the associated documents." @@ -62,7 +63,7 @@ class VectorSearchRetrieverToolMixin(BaseModel): implementations should follow. """ - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") index_name: str = Field( ..., description="The name of the index to use, format: 'catalog.schema.index'." )