Skip to content

Support passing arbitrary kwargs through VectorSearchRetrieverTool #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,16 @@ def _validate_tool_inputs(self):
return self

@vector_search_retriever_tool_trace
def _run(self, query: str, filters: Optional[Dict[str, Any]] = None) -> str:
def _run(self, query: str, filters: Optional[Dict[str, Any]] = None, **kwargs) -> str:
kwargs = {**kwargs, **(self.model_extra or {})}
combined_filters = {**(filters or {}), **(self.filters or {})}
return self._vector_store.similarity_search(
query, k=self.num_results, filter=combined_filters, query_type=self.query_type
# Ensure that we don't have duplicate keys
kwargs.update(
{
"query": query,
"k": self.num_results,
"filter": combined_filters,
"query_type": self.query_type,
}
)
return self._vector_store.similarity_search(**kwargs)
22 changes: 15 additions & 7 deletions integrations/langchain/src/databricks_langchain/vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import uuid
from functools import partial
Expand Down Expand Up @@ -458,14 +459,19 @@ def similarity_search_with_score(
query_text = None
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]

search_resp = self.index.similarity_search(
columns=self._columns,
query_text=query_text,
query_vector=query_vector,
filters=filter,
num_results=k,
query_type=query_type,
signature = inspect.signature(self.index.similarity_search)
kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters}
kwargs.update(
{
"columns": self._columns,
"query_text": query_text,
"query_vector": query_vector,
"filters": filter,
"num_results": k,
"query_type": query_type,
}
)
search_resp = self.index.similarity_search(**kwargs)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=Document
)
Expand Down Expand Up @@ -577,6 +583,7 @@ def similarity_search_by_vector_with_score(
filters=filter,
num_results=k,
query_type=query_type,
**kwargs,
)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=Document
Expand Down Expand Up @@ -696,6 +703,7 @@ def max_marginal_relevance_search_by_vector(
filters=filter,
num_results=fetch_k,
query_type=query_type,
**kwargs,
)

embeddings_result_index = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,21 @@ def init_vector_search_tool(
doc_uri: Optional[str] = None,
primary_key: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"doc_uri": doc_uri,
"primary_key": primary_key,
"filters": filters,
}
kwargs.update(
{
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"doc_uri": doc_uri,
"primary_key": primary_key,
"filters": filters,
}
)
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
Expand Down Expand Up @@ -96,7 +99,7 @@ def test_filters_are_passed_through() -> None:
{"query": "what cities are in Germany", "filters": {"country": "Germany"}}
)
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
"what cities are in Germany",
query="what cities are in Germany",
k=vector_search_tool.num_results,
filter={"country": "Germany"},
query_type=vector_search_tool.query_type,
Expand All @@ -111,7 +114,7 @@ def test_filters_are_combined() -> None:
{"query": "what cities are in Germany", "filters": {"country": "Germany"}}
)
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
"what cities are in Germany",
query="what cities are in Germany",
k=vector_search_tool.num_results,
filter={"city LIKE": "Berlin", "country": "Germany"},
query_type=vector_search_tool.query_type,
Expand Down Expand Up @@ -302,3 +305,20 @@ def test_vector_search_client_non_model_serving_environment():
workspace_client=w,
)
mockVSClient.assert_called_once_with(disable_notice=True)


def test_kwargs_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5)
vector_search_tool._vector_store.similarity_search = MagicMock()

vector_search_tool.invoke(
{"query": "what cities are in Germany", "extra_param": "something random"},
)
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
query="what cities are in Germany",
k=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
filter={},
score_threshold=0.5,
extra_param="something random",
)
30 changes: 19 additions & 11 deletions integrations/langchain/tests/unit_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,29 @@ def test_similarity_search_hybrid(index_name: str) -> None:
assert all(["id" in d.metadata for d in search_result])


def test_similarity_search_both_filter_and_filters_passed() -> None:
Copy link
Contributor Author

@annzhang-db annzhang-db May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleting this test since now we are passing kwargs through, so both filter and filters would show up. We are instead trusting that the arguments passed through are meaningful and correct.

vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
def test_similarity_search_passing_kwargs() -> None:
vectorsearch = init_vector_search(DELTA_SYNC_INDEX)
query = "foo"
filter = {"some filter": True}
filters = {"some other filter": False}
filters = {"some filter": True}
query_type = "ANN"

vectorsearch.similarity_search(query, filter=filter, filters=filters)
search_result = vectorsearch.similarity_search(
query,
k=5,
filter=filters,
query_type=query_type,
score_threshold=0.5,
num_results=10,
random_parameters="not included",
)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_vector=EMBEDDING_MODEL.embed_query(query),
# `filter` should prevail over `filters`
filters=filter,
num_results=4,
query_text=None,
query_type=None,
query_text=query,
query_vector=None,
filters=filters,
query_type=query_type,
num_results=5, # maintained
score_threshold=0.5, # passed
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from typing import Any, Dict, List, Optional, Tuple

from databricks_ai_bridge.utils.vector_search import (
Expand Down Expand Up @@ -77,7 +78,7 @@ def __init__(self, **data):

# Define the similarity search function
def similarity_search(
query: str, filters: Optional[Dict[str, Any]] = None
query: str, filters: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> List[Dict[str, Any]]:
def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[float]]]:
if self._index_details.is_databricks_managed_embeddings():
Expand Down Expand Up @@ -108,14 +109,23 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa

query_text, query_vector = get_query_text_vector(query)
combined_filters = {**(filters or {}), **(self.filters or {})}
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=combined_filters,
num_results=self.num_results,
query_type=self.query_type,

signature = inspect.signature(self._index.similarity_search)
kwargs = {**kwargs, **(self.model_extra or {})}
kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters}

# Ensure that we don't have duplicate keys
kwargs.update(
{
"query_text": query_text,
"query_vector": query_vector,
"columns": self.columns,
"filters": combined_filters,
"num_results": self.num_results,
"query_type": self.query_type,
}
)
search_resp = self._index.similarity_search(**kwargs)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=dict
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import threading
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, create_autospec, patch

import pytest
from databricks.sdk import WorkspaceClient
from databricks.sdk.credentials_provider import ModelServingUserCredentials
from databricks.vector_search.client import VectorSearchIndex
from databricks.vector_search.utils import CredentialStrategy
from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401
ALL_INDEX_NAMES,
Expand Down Expand Up @@ -53,16 +54,19 @@ def init_vector_search_tool(
embedding: Optional[BaseEmbedding] = None,
text_column: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"filters": filters,
}
kwargs.update(
{
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"filters": filters,
}
)
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
Expand Down Expand Up @@ -180,6 +184,26 @@ def test_vector_search_client_non_model_serving_environment():
mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None)


def test_kwargs_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5)
vector_search_tool._index = create_autospec(VectorSearchIndex, instance=True)

# extra_param is ignored because it isn't part of the signature for similarity_search
vector_search_tool.call(
query="what cities are in Germany", debug_level=2, extra_param="something random"
)
vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="what cities are in Germany",
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
filters={},
score_threshold=0.5,
debug_level=2,
)


def test_filters_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX)
vector_search_tool._index.similarity_search = MagicMock()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -167,6 +168,7 @@ def execute(
query: str,
filters: Optional[Dict[str, Any]] = None,
openai_client: OpenAI = None,
**kwargs: Any,
) -> List[Dict]:
"""
Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the
Expand Down Expand Up @@ -208,14 +210,21 @@ def execute(
)

combined_filters = {**(filters or {}), **(self.filters or {})}
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=combined_filters,
num_results=self.num_results,
query_type=self.query_type,

signature = inspect.signature(self._index.similarity_search)
kwargs = {**kwargs, **(self.model_extra or {})}
kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters}
kwargs.update(
{
"query_text": query_text,
"query_vector": query_vector,
"columns": self.columns,
"filters": combined_filters,
"num_results": self.num_results,
"query_type": self.query_type,
}
)
search_resp = self._index.similarity_search(**kwargs)
docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response(
search_resp=search_resp,
retriever_schema=self._retriever_schema,
Expand Down
Loading