Skip to content

Allow filters to be defined when calling VS tool #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Type
from typing import Any, Dict, Optional, Type

from databricks_ai_bridge.utils.vector_search import IndexDetails
from databricks_ai_bridge.vector_search_retriever_tool import (
Expand Down Expand Up @@ -65,7 +65,8 @@ def _validate_tool_inputs(self):
return self

@vector_search_retriever_tool_trace
def _run(self, query: str) -> str:
def _run(self, query: str, filters: Optional[Dict[str, Any]] = None) -> str:
combined_filters = {**(filters or {}), **(self.filters or {})}
return self._vector_store.similarity_search(
query, k=self.num_results, filter=self.filters, query_type=self.query_type
query, k=self.num_results, filter=combined_filters, query_type=self.query_type
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import threading
from typing import Any, Dict, List, Optional
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import mlflow
import pytest
Expand Down Expand Up @@ -49,6 +49,7 @@ def init_vector_search_tool(
text_column: Optional[str] = None,
doc_uri: Optional[str] = None,
primary_key: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
Expand All @@ -59,6 +60,7 @@ def init_vector_search_tool(
"text_column": text_column,
"doc_uri": doc_uri,
"primary_key": primary_key,
"filters": filters,
}
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
Expand Down Expand Up @@ -86,6 +88,36 @@ def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None:
assert isinstance(response, AIMessage)


def test_filters_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX)
vector_search_tool._vector_store.similarity_search = MagicMock()

vector_search_tool.invoke(
{"query": "what cities are in Germany", "filters": {"country": "Germany"}}
)
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
"what cities are in Germany",
k=vector_search_tool.num_results,
filter={"country": "Germany"},
query_type=vector_search_tool.query_type,
)


def test_filters_are_combined() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"})
vector_search_tool._vector_store.similarity_search = MagicMock()

vector_search_tool.invoke(
{"query": "what cities are in Germany", "filters": {"country": "Germany"}}
)
vector_search_tool._vector_store.similarity_search.assert_called_once_with(
"what cities are in Germany",
k=vector_search_tool.num_results,
filter={"city LIKE": "Berlin", "country": "Germany"},
query_type=vector_search_tool.query_type,
)


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
@pytest.mark.parametrize("columns", [None, ["id", "text"]])
@pytest.mark.parametrize("tool_name", [None, "test_tool"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(self, **data):
)

# Define the similarity search function
def similarity_search(query: str) -> List[Dict[str, Any]]:
def similarity_search(
query: str, filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[float]]]:
if self._index_details.is_databricks_managed_embeddings():
if self.embedding:
Expand Down Expand Up @@ -105,11 +107,12 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa
return text, vector

query_text, query_vector = get_query_text_vector(query)
combined_filters = {**(filters or {}), **(self.filters or {})}
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=self.filters,
filters=combined_filters,
num_results=self.num_results,
query_type=self.query_type,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import threading
from typing import Any, Dict, List, Optional
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -52,6 +52,7 @@ def init_vector_search_tool(
tool_description: Optional[str] = None,
embedding: Optional[BaseEmbedding] = None,
text_column: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
Expand All @@ -60,6 +61,7 @@ def init_vector_search_tool(
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
"filters": filters,
}
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
Expand Down Expand Up @@ -176,3 +178,33 @@ def test_vector_search_client_non_model_serving_environment():
workspace_client=w,
)
mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None)


def test_filters_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX)
vector_search_tool._index.similarity_search = MagicMock()

vector_search_tool.call(query="what cities are in Germany", filters={"country": "Germany"})
vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="what cities are in Germany",
filters={"country": "Germany"},
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
)


def test_filters_are_combined() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"})
vector_search_tool._index.similarity_search = MagicMock()

vector_search_tool.call(query="what cities are in Germany", filters={"country": "Germany"})
vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="what cities are in Germany",
filters={"city LIKE": "Berlin", "country": "Germany"},
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from databricks.vector_search.client import VectorSearchIndex
from databricks_ai_bridge.utils.vector_search import (
Expand Down Expand Up @@ -52,7 +52,7 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin):
tool_call = first_response.choices[0].message.tool_calls[0]
args = json.loads(tool_call.function.arguments)
result = dbvs_tool.execute(
query=args["query"]
query=args["query"], filters=args.get("filters", None)
) # For self-managed embeddings, optionally pass in openai_client=client

Step 3: Supply model with results – so it can incorporate them into its final response.
Expand Down Expand Up @@ -141,6 +141,10 @@ def _validate_tool_inputs(self):
description=self.tool_description
or self._get_default_tool_description(self._index_details),
)
# We need to remove strict: True from the tool in order to support arbitrary filters
if "function" in self.tool and "strict" in self.tool["function"]:
del self.tool["function"]["strict"]

try:
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors.platform import ResourceDoesNotExist
Expand All @@ -161,6 +165,7 @@ def _validate_tool_inputs(self):
def execute(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
openai_client: OpenAI = None,
) -> List[Dict]:
"""
Expand Down Expand Up @@ -202,11 +207,12 @@ def execute(
f"Expected embedding dimension {index_embedding_dimension} but got {len(query_vector)}"
)

combined_filters = {**(filters or {}), **(self.filters or {})}
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=self.filters,
filters=combined_filters,
num_results=self.num_results,
query_type=self.query_type,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def init_vector_search_tool(
tool_description: Optional[str] = None,
text_column: Optional[str] = None,
embedding_model_name: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
Expand All @@ -97,6 +98,7 @@ def init_vector_search_tool(
"tool_description": tool_description,
"text_column": text_column,
"embedding_model_name": embedding_model_name,
"filters": filters,
}
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
Expand Down Expand Up @@ -298,3 +300,35 @@ def test_vector_search_client_non_model_serving_environment():
workspace_client=w,
)
mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None)


def test_filters_are_passed_through() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX)
vector_search_tool._index.similarity_search = MagicMock()

vector_search_tool.execute(
{"query": "what cities are in Germany"}, filters={"country": "Germany"}
)
vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text={"query": "what cities are in Germany"},
filters={"country": "Germany"},
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
)


def test_filters_are_combined() -> None:
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"})
vector_search_tool._index.similarity_search = MagicMock()

vector_search_tool.execute(query="what cities are in Germany", filters={"country": "Germany"})
vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="what cities are in Germany",
filters={"city LIKE": "Berlin", "country": "Germany"},
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dev = [
"hatch",
"pytest",
"ruff==0.6.4",
"databricks-vectorsearch>=0.50",
]

[tool.ruff]
Expand Down
20 changes: 20 additions & 0 deletions src/databricks_ai_bridge/test_utils/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,28 @@ def _get_serving_endpoint(full_name: str) -> MagicMock:
endpoint.name = full_name
return endpoint

def _construct_column(name, col_type, comment):
mock_column = MagicMock()
mock_column.name = name
mock_column.type_name = MagicMock()
mock_column.type_name.name = col_type
mock_column.comment = comment
return mock_column

def _get_table(full_name: str) -> MagicMock:
columns = [
("city_id", "INT", None),
("city", "STRING", "Name of the city"),
("country", "STRING", "Name of the country"),
("description", "STRING", "Detailed description of the city"),
("__db_description_vector", "ARRAY", None),
]
return MagicMock(full_name=full_name, columns=[_construct_column(*col) for col in columns])

mock_client = MagicMock()
mock_client.serving_endpoints.get.side_effect = _get_serving_endpoint
mock_client.tables.get.side_effect = _get_table

with patch(
"databricks.sdk.WorkspaceClient",
return_value=mock_client,
Expand Down
51 changes: 48 additions & 3 deletions src/databricks_ai_bridge/vector_search_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ class VectorSearchRetrieverToolInput(BaseModel):
description="The string used to query the index with and identify the most similar "
"vectors and return the associated documents."
)
filters: Dict[str, Any] = Field(
default=None,
description=(
"Optional filters to refine vector search results. Supports the following operators:\n\n"
'- Inclusion: {"column": value} or {"column": [value1, value2]} (matches if the column equals any of the provided values)\n'
'- Exclusion: {"column NOT": value}\n'
'- Comparisons: {"column <": value}, {"column >=": value}, etc.\n'
'- Pattern match: {"column LIKE": "word"} (matches full tokens separated by whitespace)\n'
'- OR logic: {"column1 OR column2": [value1, value2]} '
"(matches if column1 equals value1 or column2 equals value2; matches are position-specific)"
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm didnt realize llm is getting passed this. So a link wont be useful

)


class VectorSearchRetrieverToolMixin(BaseModel):
Expand Down Expand Up @@ -87,14 +99,47 @@ def validate_tool_name(cls, tool_name):
raise ValueError("tool_name must match the pattern '^[a-zA-Z0-9_-]{1,64}$'")
return tool_name

def _describe_columns(self) -> str:
try:
from databricks.sdk import WorkspaceClient

if self.workspace_client:
table_info = self.workspace_client.tables.get(full_name=self.index_name)
else:
table_info = WorkspaceClient().tables.get(full_name=self.index_name)

columns = []

for column_info in table_info.columns:
name = column_info.name
comment = column_info.comment or "No description provided"
col_type = column_info.type_name.name
if not name.startswith("__"):
columns.append((name, col_type, comment))

return "The vector search index includes the following columns:\n" + "\n".join(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty cool

f"{name} ({col_type}): {comment}" for name, col_type, comment in columns
)
except Exception:
_logger.warning(
"Unable to retrieve column information automatically. Please manually specify column names, types, and descriptions in the tool description to help LLMs apply filters correctly."
)

def _get_default_tool_description(self, index_details: IndexDetails) -> str:
if index_details.is_delta_sync_index():
source_table = index_details.index_spec.get("source_table", "")
return (
description = (
DEFAULT_TOOL_DESCRIPTION
+ f" The queried index uses the source table {source_table}"
+ f" The queried index uses the source table {source_table}."
)
return DEFAULT_TOOL_DESCRIPTION
else:
description = DEFAULT_TOOL_DESCRIPTION

column_description = self._describe_columns()
if column_description:
return f"{description}\n\n{column_description}"
else:
return description

def _get_resources(
self, index_name: str, embedding_endpoint: str, index_details: IndexDetails
Expand Down
12 changes: 12 additions & 0 deletions tests/databricks_ai_bridge/test_vector_search_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex

from databricks_ai_bridge.test_utils.vector_search import mock_workspace_client # noqa: F401
from databricks_ai_bridge.utils.vector_search import IndexDetails
from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin

Expand Down Expand Up @@ -66,3 +67,14 @@ def make_mock_index_details(is_databricks_managed_embeddings=False, embedding_so
def test_get_resources(embedding_endpoint, index_details, resources):
tool = DummyVectorSearchRetrieverTool(index_name=index_name)
assert tool._get_resources(index_name, embedding_endpoint, index_details) == resources


def test_describe_columns():
tool = DummyVectorSearchRetrieverTool(index_name=index_name)
assert tool._describe_columns() == (
"The vector search index includes the following columns:\n"
"city_id (INT): No description provided\n"
"city (STRING): Name of the city\n"
"country (STRING): Name of the country\n"
"description (STRING): Detailed description of the city"
)