Skip to content

Support passing arbitrary kwargs through VectorSearchRetrieverTool #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ def _validate_tool_inputs(self):
return self

@vector_search_retriever_tool_trace
def _run(self, query: str) -> str:
def _run(self, query: str, **kwargs) -> str:
combined_kwargs = {**kwargs, **(self.model_extra or {})}
return self._vector_store.similarity_search(
query, k=self.num_results, filter=self.filters, query_type=self.query_type
query,
k=self.num_results,
filter=self.filters,
query_type=self.query_type,
**combined_kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def similarity_search_with_score(
filters=filter,
num_results=k,
query_type=query_type,
**kwargs,
)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=Document
Expand Down Expand Up @@ -577,6 +578,7 @@ def similarity_search_by_vector_with_score(
filters=filter,
num_results=k,
query_type=query_type,
**kwargs,
)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=Document
Expand Down Expand Up @@ -696,6 +698,7 @@ def max_marginal_relevance_search_by_vector(
filters=filter,
num_results=fetch_k,
query_type=query_type,
**kwargs,
)

embeddings_result_index = (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import os
import threading
from typing import Any, Dict, List, Optional
from unittest.mock import patch
from typing import Any, List, Optional
from unittest.mock import MagicMock, patch

import mlflow
import pytest
Expand Down Expand Up @@ -49,17 +49,20 @@ def init_vector_search_tool(
text_column: Optional[str] = None,
doc_uri: Optional[str] = None,
primary_key: Optional[str] = None,
**kwargs: Any,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"doc_uri": doc_uri,
"primary_key": primary_key,
}
kwargs.update(
{
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"doc_uri": doc_uri,
"primary_key": primary_key,
}
)
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
Expand Down Expand Up @@ -270,3 +273,20 @@ def test_vector_search_client_non_model_serving_environment():
workspace_client=w,
)
mockVSClient.assert_called_once_with(disable_notice=True)


def test_kwargs_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5)
vector_search_tool._vector_store.similarity_search = MagicMock()

vector_search_tool.invoke(
{"query": "what cities are in Germany", "extra_param": "something random"},
)
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
"what cities are in Germany",
k=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
filter=None,
score_threshold=0.5,
extra_param="something random",
)
18 changes: 0 additions & 18 deletions integrations/langchain/tests/unit_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,24 +304,6 @@ def test_similarity_search_hybrid(index_name: str) -> None:
assert all(["id" in d.metadata for d in search_result])


def test_similarity_search_both_filter_and_filters_passed() -> None:
Copy link
Contributor Author

@annzhang-db annzhang-db May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleting this test since now we are passing kwargs through, so both filter and filters would show up. We are instead trusting that the arguments passed through are meaningful and correct.

vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
query = "foo"
filter = {"some filter": True}
filters = {"some other filter": False}

vectorsearch.similarity_search(query, filter=filter, filters=filters)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_vector=EMBEDDING_MODEL.embed_query(query),
# `filter` should prevail over `filters`
filters=filter,
num_results=4,
query_text=None,
query_type=None,
)


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
@pytest.mark.parametrize(
"columns, expected_columns",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, **data):
)

# Define the similarity search function
def similarity_search(query: str) -> List[Dict[str, Any]]:
def similarity_search(query: str, **kwargs: Any) -> List[Dict[str, Any]]:
def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[float]]]:
if self._index_details.is_databricks_managed_embeddings():
if self.embedding:
Expand Down Expand Up @@ -105,13 +105,15 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa
return text, vector

query_text, query_vector = get_query_text_vector(query)
combined_kwargs = {**kwargs, **(self.model_extra or {})}
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=self.filters,
num_results=self.num_results,
query_type=self.query_type,
**combined_kwargs,
)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=dict
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import threading
from typing import Any, Dict, List, Optional
from unittest.mock import patch
from typing import Any, List, Optional
from unittest.mock import MagicMock, patch

import pytest
from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -52,15 +52,18 @@ def init_vector_search_tool(
tool_description: Optional[str] = None,
embedding: Optional[BaseEmbedding] = None,
text_column: Optional[str] = None,
**kwargs: Any,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
}
kwargs.update(
{
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
}
)
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
Expand Down Expand Up @@ -176,3 +179,21 @@ def test_vector_search_client_non_model_serving_environment():
workspace_client=w,
)
mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None)


def test_kwargs_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5)
vector_search_tool._index.similarity_search = MagicMock()

vector_search_tool.call(query="what cities are in Germany", extra_param="something random")
vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="what cities are in Germany",
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
filters=None,
score_threshold=0.5,
extra_param="something random",
requires_context=False,
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from databricks.vector_search.client import VectorSearchIndex
from databricks_ai_bridge.utils.vector_search import (
Expand Down Expand Up @@ -158,11 +158,7 @@ def _validate_tool_inputs(self):
return self

@vector_search_retriever_tool_trace
def execute(
self,
query: str,
openai_client: OpenAI = None,
) -> List[Dict]:
def execute(self, query: str, openai_client: OpenAI = None, **kwargs: Any) -> List[Dict]:
"""
Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the
self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages.
Expand Down Expand Up @@ -202,13 +198,16 @@ def execute(
f"Expected embedding dimension {index_embedding_dimension} but got {len(query_vector)}"
)

combined_kwargs = {**kwargs, **(self.model_extra or {})}

search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=self.filters,
num_results=self.num_results,
query_type=self.query_type,
**combined_kwargs,
)
docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response(
search_resp=search_resp,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import threading
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional
from unittest.mock import MagicMock, Mock, patch

import mlflow
Expand Down Expand Up @@ -89,15 +89,18 @@ def init_vector_search_tool(
tool_description: Optional[str] = None,
text_column: Optional[str] = None,
embedding_model_name: Optional[str] = None,
**kwargs: Any,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"text_column": text_column,
"embedding_model_name": embedding_model_name,
}
kwargs.update(
{
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"text_column": text_column,
"embedding_model_name": embedding_model_name,
}
)
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
Expand Down Expand Up @@ -298,3 +301,20 @@ def test_vector_search_client_non_model_serving_environment():
workspace_client=w,
)
mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None)


def test_kwargs_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5)
vector_search_tool._index.similarity_search = MagicMock()

vector_search_tool.execute(query="what cities are in Germany", extra_param="something random")
vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="what cities are in Germany",
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
filters=None,
score_threshold=0.5,
extra_param="something random",
)
3 changes: 2 additions & 1 deletion src/databricks_ai_bridge/vector_search_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def wrapper(self, *args, **kwargs):


class VectorSearchRetrieverToolInput(BaseModel):
model_config = ConfigDict(extra="allow")
query: str = Field(
description="The string used to query the index with and identify the most similar "
"vectors and return the associated documents."
Expand All @@ -50,7 +51,7 @@ class VectorSearchRetrieverToolMixin(BaseModel):
implementations should follow.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
index_name: str = Field(
..., description="The name of the index to use, format: 'catalog.schema.index'."
)
Expand Down
Loading