Skip to content

Support passing arbitrary kwargs through VectorSearchRetrieverTool #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,16 @@ def _validate_tool_inputs(self):
return self

@vector_search_retriever_tool_trace
def _run(self, query: str, filters: Optional[Dict[str, Any]] = None) -> str:
def _run(self, query: str, filters: Optional[Dict[str, Any]] = None, **kwargs) -> str:
kwargs = {**kwargs, **(self.model_extra or {})}
combined_filters = {**(filters or {}), **(self.filters or {})}
return self._vector_store.similarity_search(
query, k=self.num_results, filter=combined_filters, query_type=self.query_type
# Ensure that we don't have duplicate keys
kwargs.update(
{
"query": query,
"k": self.num_results,
"filter": combined_filters,
"query_type": self.query_type,
}
)
return self._vector_store.similarity_search(**kwargs)
22 changes: 15 additions & 7 deletions integrations/langchain/src/databricks_langchain/vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import uuid
from functools import partial
Expand Down Expand Up @@ -458,14 +459,19 @@ def similarity_search_with_score(
query_text = None
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]

search_resp = self.index.similarity_search(
columns=self._columns,
query_text=query_text,
query_vector=query_vector,
filters=filter,
num_results=k,
query_type=query_type,
signature = inspect.signature(self.index.similarity_search)
kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters}
kwargs.update(
{
"columns": self._columns,
"query_text": query_text,
"query_vector": query_vector,
"filters": filter,
"num_results": k,
"query_type": query_type,
}
)
search_resp = self.index.similarity_search(**kwargs)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=Document
)
Expand Down Expand Up @@ -577,6 +583,7 @@ def similarity_search_by_vector_with_score(
filters=filter,
num_results=k,
query_type=query_type,
**kwargs,
)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=Document
Expand Down Expand Up @@ -696,6 +703,7 @@ def max_marginal_relevance_search_by_vector(
filters=filter,
num_results=fetch_k,
query_type=query_type,
**kwargs,
)

embeddings_result_index = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,21 @@ def init_vector_search_tool(
doc_uri: Optional[str] = None,
primary_key: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"doc_uri": doc_uri,
"primary_key": primary_key,
"filters": filters,
}
kwargs.update(
{
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"doc_uri": doc_uri,
"primary_key": primary_key,
"filters": filters,
}
)
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
Expand Down Expand Up @@ -96,7 +99,7 @@ def test_filters_are_passed_through() -> None:
{"query": "what cities are in Germany", "filters": {"country": "Germany"}}
)
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
"what cities are in Germany",
query="what cities are in Germany",
k=vector_search_tool.num_results,
filter={"country": "Germany"},
query_type=vector_search_tool.query_type,
Expand All @@ -111,7 +114,7 @@ def test_filters_are_combined() -> None:
{"query": "what cities are in Germany", "filters": {"country": "Germany"}}
)
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
"what cities are in Germany",
query="what cities are in Germany",
k=vector_search_tool.num_results,
filter={"city LIKE": "Berlin", "country": "Germany"},
query_type=vector_search_tool.query_type,
Expand Down Expand Up @@ -302,3 +305,20 @@ def test_vector_search_client_non_model_serving_environment():
workspace_client=w,
)
mockVSClient.assert_called_once_with(disable_notice=True)


def test_kwargs_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5)
vector_search_tool._vector_store.similarity_search = MagicMock()

vector_search_tool.invoke(
{"query": "what cities are in Germany", "extra_param": "something random"},
)
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
query="what cities are in Germany",
k=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
filter={},
score_threshold=0.5,
extra_param="something random",
)
30 changes: 19 additions & 11 deletions integrations/langchain/tests/unit_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,29 @@ def test_similarity_search_hybrid(index_name: str) -> None:
assert all(["id" in d.metadata for d in search_result])


def test_similarity_search_both_filter_and_filters_passed() -> None:
Copy link
Contributor Author

@annzhang-db annzhang-db May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleting this test since now we are passing kwargs through, so both filter and filters would show up. We are instead trusting that the arguments passed through are meaningful and correct.

vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
def test_similarity_search_passing_kwargs() -> None:
vectorsearch = init_vector_search(DELTA_SYNC_INDEX)
query = "foo"
filter = {"some filter": True}
filters = {"some other filter": False}
filters = {"some filter": True}
query_type = "ANN"

vectorsearch.similarity_search(query, filter=filter, filters=filters)
search_result = vectorsearch.similarity_search(
query,
k=5,
filter=filters,
query_type=query_type,
score_threshold=0.5,
num_results=10,
random_parameters="not included",
)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_vector=EMBEDDING_MODEL.embed_query(query),
# `filter` should prevail over `filters`
filters=filter,
num_results=4,
query_text=None,
query_type=None,
query_text=query,
query_vector=None,
filters=filters,
query_type=query_type,
num_results=5, # maintained
score_threshold=0.5, # passed
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, **data):

# Define the similarity search function
def similarity_search(
query: str, filters: Optional[Dict[str, Any]] = None
query: str, filters: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> List[Dict[str, Any]]:
def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[float]]]:
if self._index_details.is_databricks_managed_embeddings():
Expand Down Expand Up @@ -107,15 +107,25 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa
return text, vector

query_text, query_vector = get_query_text_vector(query)
filtered_model_extra = {
k: v
for k, v in (self.model_extra or {}).items()
if k != "requires_context" # don't include extra parameters set by FunctionTool
}
kwargs = {**kwargs, **filtered_model_extra}
combined_filters = {**(filters or {}), **(self.filters or {})}
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=combined_filters,
num_results=self.num_results,
query_type=self.query_type,
# Ensure that we don't have duplicate keys
kwargs.update(
{
"query_text": query_text,
"query_vector": query_vector,
"columns": self.columns,
"filters": combined_filters,
"num_results": self.num_results,
"query_type": self.query_type,
}
)
search_resp = self._index.similarity_search(**kwargs)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=dict
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,19 @@ def init_vector_search_tool(
embedding: Optional[BaseEmbedding] = None,
text_column: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"filters": filters,
}
kwargs.update(
{
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"filters": filters,
}
)
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
Expand Down Expand Up @@ -180,6 +183,23 @@ def test_vector_search_client_non_model_serving_environment():
mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None)


def test_kwargs_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5)
vector_search_tool._index.similarity_search = MagicMock()

vector_search_tool.call(query="what cities are in Germany", extra_param="something random")
vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="what cities are in Germany",
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
filters={},
score_threshold=0.5,
extra_param="something random",
)


def test_filters_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX)
vector_search_tool._index.similarity_search = MagicMock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def execute(
query: str,
filters: Optional[Dict[str, Any]] = None,
openai_client: OpenAI = None,
**kwargs: Any,
) -> List[Dict]:
"""
Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the
Expand Down Expand Up @@ -207,15 +208,19 @@ def execute(
f"Expected embedding dimension {index_embedding_dimension} but got {len(query_vector)}"
)

kwargs = {**kwargs, **(self.model_extra or {})}
combined_filters = {**(filters or {}), **(self.filters or {})}
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=combined_filters,
num_results=self.num_results,
query_type=self.query_type,
kwargs.update(
{
"query_text": query_text,
"query_vector": query_vector,
"columns": self.columns,
"filters": combined_filters,
"num_results": self.num_results,
"query_type": self.query_type,
}
)
search_resp = self._index.similarity_search(**kwargs)
docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response(
search_resp=search_resp,
retriever_schema=self._retriever_schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,19 @@ def init_vector_search_tool(
text_column: Optional[str] = None,
embedding_model_name: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"text_column": text_column,
"embedding_model_name": embedding_model_name,
"filters": filters,
}
kwargs.update(
{
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"text_column": text_column,
"embedding_model_name": embedding_model_name,
"filters": filters,
}
)
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
Expand Down Expand Up @@ -302,6 +305,23 @@ def test_vector_search_client_non_model_serving_environment():
mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None)


def test_kwargs_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5)
vector_search_tool._index.similarity_search = MagicMock()

vector_search_tool.execute(query="what cities are in Germany", extra_param="something random")
vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="what cities are in Germany",
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
filters={},
score_threshold=0.5,
extra_param="something random",
)


def test_filters_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX)
vector_search_tool._index.similarity_search = MagicMock()
Expand Down
4 changes: 2 additions & 2 deletions src/databricks_ai_bridge/test_utils/vector_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from typing import Generator, List, Optional
from unittest import mock
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, create_autospec, patch

import pytest
from databricks.vector_search.client import VectorSearchIndex # type: ignore
Expand Down Expand Up @@ -132,7 +132,7 @@ def _get_index(
endpoint_name: Optional[str] = None,
index_name: str = None, # type: ignore
) -> MagicMock:
index = MagicMock(spec=VectorSearchIndex)
index = create_autospec(VectorSearchIndex, instance=True)
if index_name not in INDEX_DETAILS:
index_name = DIRECT_ACCESS_INDEX
index.describe.return_value = INDEX_DETAILS[index_name]
Expand Down
Loading