From dbdd053f302787d94e8cb57b63e25ac87e98a023 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Thu, 24 Apr 2025 15:10:31 -0700 Subject: [PATCH 01/13] draft Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 5 ++- .../vector_search_retriever_tool.py | 43 ++++++++++++++++++- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index c8caec28..7246c3a1 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -65,7 +65,8 @@ def _validate_tool_inputs(self): return self @vector_search_retriever_tool_trace - def _run(self, query: str) -> str: + def _run(self, query: str, **kwargs) -> str: + filter = kwargs.get("filters", self.filters) return self._vector_store.similarity_search( - query, k=self.num_results, filter=self.filters, query_type=self.query_type + query, k=self.num_results, filter=filter, query_type=self.query_type ) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 7d944565..943637f6 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -87,14 +87,53 @@ def validate_tool_name(cls, tool_name): raise ValueError("tool_name must match the pattern '^[a-zA-Z0-9_-]{1,64}$'") return tool_name + def _describe_columns(self, columns: List[Tuple[str, str]]) -> str: + description_lines = [f" - {name}: {col_type}" for name, col_type in columns] + description = "\n".join(description_lines) + + return ( + "This vector search index includes the following columns:\n\n" + f"{description}\n\n" + "You can refine vector search results by passing a `filters` dictionary when invoking the tool. " + "Supported operators include:\n\n" + "Equality: {\"column\": value}, {\"column\": [value1, value2]}\n\n" + "Inequality: {\"column NOT\": value}\n\n" + "Comparisons: {\"column <\": value}, {\"column >=\": value}, etc.\n\n" + "Pattern match: {\"column LIKE\": \"word\"} (matches full tokens)\n\n" + "OR: {\"column OR column2\": [value1, value2]}" + ) + + def _get_index_columns(self): + try: + from databricks.sdk import WorkspaceClient + import json + + if self.workspace_client: + table_info = self.workspace_client.tables.get(full_name = self.index_name) + else: + table_info = WorkspaceClient().tables.get(full_name = self.index_name) + + columns = [] + + for column_info in table_info.columns: + column_name = column_info.name + column_type = json.loads(column_info.type_json).get("type", None) + columns.append((column_name, column_type)) + + return columns + except: + pass + def _get_default_tool_description(self, index_details: IndexDetails) -> str: + if index_details.is_delta_sync_index(): source_table = index_details.index_spec.get("source_table", "") - return ( + description = ( DEFAULT_TOOL_DESCRIPTION + f" The queried index uses the source table {source_table}" ) - return DEFAULT_TOOL_DESCRIPTION + description = DEFAULT_TOOL_DESCRIPTION + return description + self._describe_columns(self._get_index_columns()) def _get_resources( self, index_name: str, embedding_endpoint: str, index_details: IndexDetails From b20afae76c8f01f23c7f8a2da91be72f2a583c2f Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Thu, 24 Apr 2025 15:24:35 -0700 Subject: [PATCH 02/13] update Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 943637f6..4f89fe14 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -87,20 +87,19 @@ def validate_tool_name(cls, tool_name): raise ValueError("tool_name must match the pattern '^[a-zA-Z0-9_-]{1,64}$'") return tool_name - def _describe_columns(self, columns: List[Tuple[str, str]]) -> str: - description_lines = [f" - {name}: {col_type}" for name, col_type in columns] - description = "\n".join(description_lines) + def _describe_columns(self, columns) -> str: + description = "\n".join(f" - {name}: {col_type}" for name, col_type in columns) return ( "This vector search index includes the following columns:\n\n" f"{description}\n\n" - "You can refine vector search results by passing a `filters` dictionary when invoking the tool. " - "Supported operators include:\n\n" - "Equality: {\"column\": value}, {\"column\": [value1, value2]}\n\n" - "Inequality: {\"column NOT\": value}\n\n" - "Comparisons: {\"column <\": value}, {\"column >=\": value}, etc.\n\n" - "Pattern match: {\"column LIKE\": \"word\"} (matches full tokens)\n\n" - "OR: {\"column OR column2\": [value1, value2]}" + "You can refine results by passing a `filters` dictionary when calling this tool. " + "Supported filters include:\n\n" + "Equality: {\"column\": value} or {\"column\": [value1, value2]}\n" + "Inequality: {\"column NOT\": value}\n" + "Comparisons: {\"column <\": value}, {\"column >=\": value}, etc.\n" + "Pattern match: {\"column LIKE\": \"word\"} (matches full tokens)\n" + "OR condition: {\"column1 OR column2\": [value1, value2]}" ) def _get_index_columns(self): From 99c5cd73236ddda4e92b1483ba7c1fa0a81a4a31 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Thu, 24 Apr 2025 16:05:41 -0700 Subject: [PATCH 03/13] update Signed-off-by: Ann Zhang --- .../databricks_langchain/vector_search_retriever_tool.py | 7 +++---- src/databricks_ai_bridge/vector_search_retriever_tool.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 7246c3a1..8a051099 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Optional, Type, Dict, Any from databricks_ai_bridge.utils.vector_search import IndexDetails from databricks_ai_bridge.vector_search_retriever_tool import ( @@ -65,8 +65,7 @@ def _validate_tool_inputs(self): return self @vector_search_retriever_tool_trace - def _run(self, query: str, **kwargs) -> str: - filter = kwargs.get("filters", self.filters) + def _run(self, query: str, filters: Optional[Dict[str, Any]] = None) -> str: return self._vector_store.similarity_search( - query, k=self.num_results, filter=filter, query_type=self.query_type + query, k=self.num_results, filter=filters or self.filters, query_type=self.query_type ) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 4f89fe14..bfb01d87 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -121,6 +121,7 @@ def _get_index_columns(self): return columns except: + # log a warning pass def _get_default_tool_description(self, index_details: IndexDetails) -> str: From 62bc934e569c222211fc5e34003894d92abf020c Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Thu, 24 Apr 2025 16:18:06 -0700 Subject: [PATCH 04/13] update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 5 +- .../vector_search_retriever_tool.py | 52 ++++++++++--------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 8a051099..ee9452af 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -65,7 +65,8 @@ def _validate_tool_inputs(self): return self @vector_search_retriever_tool_trace - def _run(self, query: str, filters: Optional[Dict[str, Any]] = None) -> str: + def _run(self, query: str, filters: Dict[str, Any] = None) -> str: + combined_filters = {**(filters or {}), **(self.filters or {})} return self._vector_store.similarity_search( - query, k=self.num_results, filter=filters or self.filters, query_type=self.query_type + query, k=self.num_results, filter=combined_filters, query_type=self.query_type ) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index bfb01d87..024bc5f7 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -41,6 +41,18 @@ class VectorSearchRetrieverToolInput(BaseModel): description="The string used to query the index with and identify the most similar " "vectors and return the associated documents." ) + filters: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Optional filters to refine vector search results. " + "Supports:\n\n" + "- Equality: {\"column\": value} or {\"column\": [value1, value2]}\n" + "- Inequality: {\"column NOT\": value}\n" + "- Comparisons: {\"column <\": value}, {\"column >=\": value}\n" + "- Pattern match: {\"column LIKE\": \"word\"} (matches full tokens)\n" + "- OR logic: {\"column OR column2\": [value1, value2]}" + ) + ) class VectorSearchRetrieverToolMixin(BaseModel): @@ -87,22 +99,7 @@ def validate_tool_name(cls, tool_name): raise ValueError("tool_name must match the pattern '^[a-zA-Z0-9_-]{1,64}$'") return tool_name - def _describe_columns(self, columns) -> str: - description = "\n".join(f" - {name}: {col_type}" for name, col_type in columns) - - return ( - "This vector search index includes the following columns:\n\n" - f"{description}\n\n" - "You can refine results by passing a `filters` dictionary when calling this tool. " - "Supported filters include:\n\n" - "Equality: {\"column\": value} or {\"column\": [value1, value2]}\n" - "Inequality: {\"column NOT\": value}\n" - "Comparisons: {\"column <\": value}, {\"column >=\": value}, etc.\n" - "Pattern match: {\"column LIKE\": \"word\"} (matches full tokens)\n" - "OR condition: {\"column1 OR column2\": [value1, value2]}" - ) - - def _get_index_columns(self): + def _describe_columns(self) -> str: try: from databricks.sdk import WorkspaceClient import json @@ -115,25 +112,30 @@ def _get_index_columns(self): columns = [] for column_info in table_info.columns: - column_name = column_info.name - column_type = json.loads(column_info.type_json).get("type", None) - columns.append((column_name, column_type)) + name, comment = column_info.name, column_info.comment + if comment == None: + comment = "No description provided" + col_type = json.loads(column_info.type_json).get("type", None) + if not name.startswith("__"): + columns.append((name, col_type, comment)) - return columns + return ( + "The vector search index includes the following columns:\n\n" + + "\n".join(f"{name} ({col_type}): {comment}" for name, col_type, comment in columns) + ) except: - # log a warning - pass + _logger.warning("Unable to retrieve column information automatically. Please manually specify column names, types, and descriptions in the tool description to help LLMs apply filters correctly.") def _get_default_tool_description(self, index_details: IndexDetails) -> str: - if index_details.is_delta_sync_index(): source_table = index_details.index_spec.get("source_table", "") description = ( DEFAULT_TOOL_DESCRIPTION + f" The queried index uses the source table {source_table}" ) - description = DEFAULT_TOOL_DESCRIPTION - return description + self._describe_columns(self._get_index_columns()) + else: + description = DEFAULT_TOOL_DESCRIPTION + return description + self._describe_columns() def _get_resources( self, index_name: str, embedding_endpoint: str, index_details: IndexDetails From 194d9632504a34108283ebbabe11dcedaf5f80cf Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Thu, 24 Apr 2025 17:32:02 -0700 Subject: [PATCH 05/13] update description Signed-off-by: Ann Zhang update default description Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang --- .../test_utils/vector_search.py | 23 ++++++++++++++ .../vector_search_retriever_tool.py | 30 +++++++++---------- .../test_vector_search_retriever_tool.py | 15 ++++++++-- 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/databricks_ai_bridge/test_utils/vector_search.py b/src/databricks_ai_bridge/test_utils/vector_search.py index a5e18291..7ccd039e 100644 --- a/src/databricks_ai_bridge/test_utils/vector_search.py +++ b/src/databricks_ai_bridge/test_utils/vector_search.py @@ -155,8 +155,31 @@ def _get_serving_endpoint(full_name: str) -> MagicMock: endpoint.name = full_name return endpoint + def _construct_column(name, col_type, comment): + mock_column = MagicMock() + mock_column.name = name + mock_column.type_name = MagicMock() + mock_column.type_name.name = col_type + mock_column.comment = comment + return mock_column + + def _get_table(full_name: str) -> MagicMock: + columns = [ + ("city_id", "INT", None), + ("city", "STRING", "Name of the city"), + ("country", "STRING", "Name of the country"), + ("description", "STRING", "Detailed description of the city"), + ("__db_description_vector", "ARRAY", None), + ] + return MagicMock( + full_name=full_name, + columns=[_construct_column(*col) for col in columns] + ) + mock_client = MagicMock() mock_client.serving_endpoints.get.side_effect = _get_serving_endpoint + mock_client.tables.get.side_effect = _get_table + with patch( "databricks.sdk.WorkspaceClient", return_value=mock_client, diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 024bc5f7..f5d69f8d 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -44,13 +44,13 @@ class VectorSearchRetrieverToolInput(BaseModel): filters: Optional[Dict[str, Any]] = Field( default=None, description=( - "Optional filters to refine vector search results. " - "Supports:\n\n" - "- Equality: {\"column\": value} or {\"column\": [value1, value2]}\n" - "- Inequality: {\"column NOT\": value}\n" - "- Comparisons: {\"column <\": value}, {\"column >=\": value}\n" - "- Pattern match: {\"column LIKE\": \"word\"} (matches full tokens)\n" - "- OR logic: {\"column OR column2\": [value1, value2]}" + "Optional filters to refine vector search results. Supports the following operators:\n\n" + "- Inclusion: {\"column\": value} or {\"column\": [value1, value2]} (matches if the column equals any of the provided values)\n" + "- Exclusion: {\"column NOT\": value}\n" + "- Comparisons: {\"column <\": value}, {\"column >=\": value}, etc.\n" + "- Pattern match: {\"column LIKE\": \"word\"} (matches full tokens separated by whitespace)\n" + "- OR logic: {\"column1 OR column2\": [value1, value2]} " + "(matches if column1 equals value1 or column2 equals value2; matches are position-specific)" ) ) @@ -102,7 +102,6 @@ def validate_tool_name(cls, tool_name): def _describe_columns(self) -> str: try: from databricks.sdk import WorkspaceClient - import json if self.workspace_client: table_info = self.workspace_client.tables.get(full_name = self.index_name) @@ -112,15 +111,14 @@ def _describe_columns(self) -> str: columns = [] for column_info in table_info.columns: - name, comment = column_info.name, column_info.comment - if comment == None: - comment = "No description provided" - col_type = json.loads(column_info.type_json).get("type", None) + name = column_info.name + comment = column_info.comment or "No description provided" + col_type = column_info.type_name.name if not name.startswith("__"): columns.append((name, col_type, comment)) - + return ( - "The vector search index includes the following columns:\n\n" + + "The vector search index includes the following columns:\n" + "\n".join(f"{name} ({col_type}): {comment}" for name, col_type, comment in columns) ) except: @@ -131,11 +129,11 @@ def _get_default_tool_description(self, index_details: IndexDetails) -> str: source_table = index_details.index_spec.get("source_table", "") description = ( DEFAULT_TOOL_DESCRIPTION - + f" The queried index uses the source table {source_table}" + + f" The queried index uses the source table {source_table}." ) else: description = DEFAULT_TOOL_DESCRIPTION - return description + self._describe_columns() + return f"{description}\n\n{self._describe_columns}()" def _get_resources( self, index_name: str, embedding_endpoint: str, index_details: IndexDetails diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py index b4dc66f7..2e3b6996 100644 --- a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -1,11 +1,11 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex from databricks_ai_bridge.utils.vector_search import IndexDetails from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin - +from databricks_ai_bridge.test_utils.vector_search import mock_workspace_client class DummyVectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): pass @@ -66,3 +66,14 @@ def make_mock_index_details(is_databricks_managed_embeddings=False, embedding_so def test_get_resources(embedding_endpoint, index_details, resources): tool = DummyVectorSearchRetrieverTool(index_name=index_name) assert tool._get_resources(index_name, embedding_endpoint, index_details) == resources + + +def test_describe_columns(): + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + assert tool._describe_columns() == ( + "The vector search index includes the following columns:\n" + "city_id (INT): No description provided\n" + "city (STRING): Name of the city\n" + "country (STRING): Name of the country\n" + "description (STRING): Detailed description of the city" + ) From e079e76dc7f9325e51bafa713ad7ad8819d4b404 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Thu, 24 Apr 2025 18:35:01 -0700 Subject: [PATCH 06/13] unit tests Signed-off-by: Ann Zhang --- .../test_vector_search_retriever_tool.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index 7d11192e..cfa7460e 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -2,7 +2,7 @@ import os import threading from typing import Any, Dict, List, Optional -from unittest.mock import patch +from unittest.mock import MagicMock, patch import mlflow import pytest @@ -49,6 +49,7 @@ def init_vector_search_tool( text_column: Optional[str] = None, doc_uri: Optional[str] = None, primary_key: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, ) -> VectorSearchRetrieverTool: kwargs: Dict[str, Any] = { "index_name": index_name, @@ -59,6 +60,7 @@ def init_vector_search_tool( "text_column": text_column, "doc_uri": doc_uri, "primary_key": primary_key, + "filters": filters, } if index_name != DELTA_SYNC_INDEX: kwargs.update( @@ -86,6 +88,38 @@ def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: assert isinstance(response, AIMessage) +def test_filters_are_passed_through() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + vector_search_tool._vector_store.similarity_search = MagicMock() + + vector_search_tool.invoke({ + "query": "what cities are in Germany", + "filters": {"country": "Germany"} + }) + vector_search_tool._vector_store.similarity_search.assert_called_once_with( + "what cities are in Germany", + k = vector_search_tool.num_results, + filter = {"country": "Germany"}, + query_type = vector_search_tool.query_type, + ) + + +def test_filters_are_combined() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"}) + vector_search_tool._vector_store.similarity_search = MagicMock() + + vector_search_tool.invoke({ + "query": "what cities are in Germany", + "filters": {"country": "Germany"} + }) + vector_search_tool._vector_store.similarity_search.assert_called_once_with( + "what cities are in Germany", + k = vector_search_tool.num_results, + filter = {"city LIKE": "Berlin", "country": "Germany"}, + query_type = vector_search_tool.query_type, + ) + + @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @pytest.mark.parametrize("columns", [None, ["id", "text"]]) @pytest.mark.parametrize("tool_name", [None, "test_tool"]) From c7295e83ce57b80aa9e9ec199f20ba3c38eece5a Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Thu, 24 Apr 2025 18:42:03 -0700 Subject: [PATCH 07/13] openai + llamaindex Signed-off-by: Ann Zhang keep import Signed-off-by: Ann Zhang ruff Signed-off-by: Ann Zhang add dependency Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 4 +-- .../test_vector_search_retriever_tool.py | 30 +++++++++---------- .../vector_search_retriever_tool.py | 7 +++-- .../vector_search_retriever_tool.py | 8 +++-- pyproject.toml | 1 + .../test_utils/vector_search.py | 5 +--- .../vector_search_retriever_tool.py | 29 +++++++++--------- .../test_vector_search_retriever_tool.py | 5 ++-- 8 files changed, 46 insertions(+), 43 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index ee9452af..0c1dcef4 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Optional, Type, Dict, Any +from typing import Any, Dict, Optional, Type from databricks_ai_bridge.utils.vector_search import IndexDetails from databricks_ai_bridge.vector_search_retriever_tool import ( @@ -65,7 +65,7 @@ def _validate_tool_inputs(self): return self @vector_search_retriever_tool_trace - def _run(self, query: str, filters: Dict[str, Any] = None) -> str: + def _run(self, query: str, filters: Optional[Dict[str, Any]] = None) -> str: combined_filters = {**(filters or {}), **(self.filters or {})} return self._vector_store.similarity_search( query, k=self.num_results, filter=combined_filters, query_type=self.query_type diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index cfa7460e..e4ce414e 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -91,32 +91,30 @@ def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: def test_filters_are_passed_through() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) vector_search_tool._vector_store.similarity_search = MagicMock() - - vector_search_tool.invoke({ - "query": "what cities are in Germany", - "filters": {"country": "Germany"} - }) + + vector_search_tool.invoke( + {"query": "what cities are in Germany", "filters": {"country": "Germany"}} + ) vector_search_tool._vector_store.similarity_search.assert_called_once_with( "what cities are in Germany", - k = vector_search_tool.num_results, - filter = {"country": "Germany"}, - query_type = vector_search_tool.query_type, + k=vector_search_tool.num_results, + filter={"country": "Germany"}, + query_type=vector_search_tool.query_type, ) def test_filters_are_combined() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"}) vector_search_tool._vector_store.similarity_search = MagicMock() - - vector_search_tool.invoke({ - "query": "what cities are in Germany", - "filters": {"country": "Germany"} - }) + + vector_search_tool.invoke( + {"query": "what cities are in Germany", "filters": {"country": "Germany"}} + ) vector_search_tool._vector_store.similarity_search.assert_called_once_with( "what cities are in Germany", - k = vector_search_tool.num_results, - filter = {"city LIKE": "Berlin", "country": "Germany"}, - query_type = vector_search_tool.query_type, + k=vector_search_tool.num_results, + filter={"city LIKE": "Berlin", "country": "Germany"}, + query_type=vector_search_tool.query_type, ) diff --git a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py index 9179fdf2..06c4dfea 100644 --- a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py +++ b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py @@ -76,7 +76,9 @@ def __init__(self, **data): ) # Define the similarity search function - def similarity_search(query: str) -> List[Dict[str, Any]]: + def similarity_search( + query: str, filters: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[float]]]: if self._index_details.is_databricks_managed_embeddings(): if self.embedding: @@ -105,11 +107,12 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa return text, vector query_text, query_vector = get_query_text_vector(query) + combined_filters = {**(filters or {}), **(self.filters or {})} search_resp = self._index.similarity_search( columns=self.columns, query_text=query_text, query_vector=query_vector, - filters=self.filters, + filters=combined_filters, num_results=self.num_results, query_type=self.query_type, ) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 17a87c88..85c762de 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from databricks.vector_search.client import VectorSearchIndex from databricks_ai_bridge.utils.vector_search import ( @@ -52,7 +52,7 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): tool_call = first_response.choices[0].message.tool_calls[0] args = json.loads(tool_call.function.arguments) result = dbvs_tool.execute( - query=args["query"] + query=args["query"], filters=args.get("filters", None) ) # For self-managed embeddings, optionally pass in openai_client=client Step 3: Supply model with results – so it can incorporate them into its final response. @@ -161,6 +161,7 @@ def _validate_tool_inputs(self): def execute( self, query: str, + filters: Optional[Dict[str, Any]] = None, openai_client: OpenAI = None, ) -> List[Dict]: """ @@ -202,11 +203,12 @@ def execute( f"Expected embedding dimension {index_embedding_dimension} but got {len(query_vector)}" ) + combined_filters = {**(filters or {}), **(self.filters or {})} search_resp = self._index.similarity_search( columns=self.columns, query_text=query_text, query_vector=query_vector, - filters=self.filters, + filters=combined_filters, num_results=self.num_results, query_type=self.query_type, ) diff --git a/pyproject.toml b/pyproject.toml index f68b70ce..3871b6d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dev = [ "hatch", "pytest", "ruff==0.6.4", + "databricks-sdk>=0.49.0", ] [tool.ruff] diff --git a/src/databricks_ai_bridge/test_utils/vector_search.py b/src/databricks_ai_bridge/test_utils/vector_search.py index 7ccd039e..586b6341 100644 --- a/src/databricks_ai_bridge/test_utils/vector_search.py +++ b/src/databricks_ai_bridge/test_utils/vector_search.py @@ -171,10 +171,7 @@ def _get_table(full_name: str) -> MagicMock: ("description", "STRING", "Detailed description of the city"), ("__db_description_vector", "ARRAY", None), ] - return MagicMock( - full_name=full_name, - columns=[_construct_column(*col) for col in columns] - ) + return MagicMock(full_name=full_name, columns=[_construct_column(*col) for col in columns]) mock_client = MagicMock() mock_client.serving_endpoints.get.side_effect = _get_serving_endpoint diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index f5d69f8d..0c255120 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -45,13 +45,13 @@ class VectorSearchRetrieverToolInput(BaseModel): default=None, description=( "Optional filters to refine vector search results. Supports the following operators:\n\n" - "- Inclusion: {\"column\": value} or {\"column\": [value1, value2]} (matches if the column equals any of the provided values)\n" - "- Exclusion: {\"column NOT\": value}\n" - "- Comparisons: {\"column <\": value}, {\"column >=\": value}, etc.\n" - "- Pattern match: {\"column LIKE\": \"word\"} (matches full tokens separated by whitespace)\n" - "- OR logic: {\"column1 OR column2\": [value1, value2]} " + '- Inclusion: {"column": value} or {"column": [value1, value2]} (matches if the column equals any of the provided values)\n' + '- Exclusion: {"column NOT": value}\n' + '- Comparisons: {"column <": value}, {"column >=": value}, etc.\n' + '- Pattern match: {"column LIKE": "word"} (matches full tokens separated by whitespace)\n' + '- OR logic: {"column1 OR column2": [value1, value2]} ' "(matches if column1 equals value1 or column2 equals value2; matches are position-specific)" - ) + ), ) @@ -104,9 +104,9 @@ def _describe_columns(self) -> str: from databricks.sdk import WorkspaceClient if self.workspace_client: - table_info = self.workspace_client.tables.get(full_name = self.index_name) + table_info = self.workspace_client.tables.get(full_name=self.index_name) else: - table_info = WorkspaceClient().tables.get(full_name = self.index_name) + table_info = WorkspaceClient().tables.get(full_name=self.index_name) columns = [] @@ -116,13 +116,14 @@ def _describe_columns(self) -> str: col_type = column_info.type_name.name if not name.startswith("__"): columns.append((name, col_type, comment)) - - return ( - "The vector search index includes the following columns:\n" + - "\n".join(f"{name} ({col_type}): {comment}" for name, col_type, comment in columns) + + return "The vector search index includes the following columns:\n" + "\n".join( + f"{name} ({col_type}): {comment}" for name, col_type, comment in columns + ) + except Exception: + _logger.warning( + "Unable to retrieve column information automatically. Please manually specify column names, types, and descriptions in the tool description to help LLMs apply filters correctly." ) - except: - _logger.warning("Unable to retrieve column information automatically. Please manually specify column names, types, and descriptions in the tool description to help LLMs apply filters correctly.") def _get_default_tool_description(self, index_details: IndexDetails) -> str: if index_details.is_delta_sync_index(): diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py index 2e3b6996..dbf72a45 100644 --- a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -1,11 +1,12 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex +from databricks_ai_bridge.test_utils.vector_search import mock_workspace_client # noqa: F401 from databricks_ai_bridge.utils.vector_search import IndexDetails from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin -from databricks_ai_bridge.test_utils.vector_search import mock_workspace_client + class DummyVectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): pass From 7e24711ba553f87c8b8d213a4345b79e7ed14036 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Thu, 24 Apr 2025 20:50:55 -0700 Subject: [PATCH 08/13] fix Signed-off-by: Ann Zhang fix Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang make required Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang update Signed-off-by: Ann Zhang --- .../test_vector_search_retriever_tool.py | 30 +++++++++++++++++++ .../vector_search_retriever_tool.py | 7 ++++- pyproject.toml | 2 +- .../vector_search_retriever_tool.py | 6 ++-- 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py index 6378d607..d40f6d99 100644 --- a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py @@ -176,3 +176,33 @@ def test_vector_search_client_non_model_serving_environment(): workspace_client=w, ) mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None) + + +def test_filters_are_passed_through() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + vector_search_tool._vector_store.similarity_search = MagicMock() + + vector_search_tool.invoke( + {"query": "what cities are in Germany", "filters": {"country": "Germany"}} + ) + vector_search_tool._vector_store.similarity_search.assert_called_once_with( + "what cities are in Germany", + k=vector_search_tool.num_results, + filter={"country": "Germany"}, + query_type=vector_search_tool.query_type, + ) + + +def test_filters_are_combined() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"}) + vector_search_tool._vector_store.similarity_search = MagicMock() + + vector_search_tool.invoke( + {"query": "what cities are in Germany", "filters": {"country": "Germany"}} + ) + vector_search_tool._vector_store.similarity_search.assert_called_once_with( + "what cities are in Germany", + k=vector_search_tool.num_results, + filter={"city LIKE": "Berlin", "country": "Germany"}, + query_type=vector_search_tool.query_type, + ) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 85c762de..1e88a33e 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -52,7 +52,8 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): tool_call = first_response.choices[0].message.tool_calls[0] args = json.loads(tool_call.function.arguments) result = dbvs_tool.execute( - query=args["query"], filters=args.get("filters", None) + query=args["query"], + filters=args.get("filters", None) ) # For self-managed embeddings, optionally pass in openai_client=client Step 3: Supply model with results – so it can incorporate them into its final response. @@ -141,6 +142,10 @@ def _validate_tool_inputs(self): description=self.tool_description or self._get_default_tool_description(self._index_details), ) + # We need to remove strict: True from the tool in order to support arbitrary filters + if "function" in self.tool and "strict" in self.tool["function"]: + del self.tool["function"]["strict"] + try: from databricks.sdk import WorkspaceClient from databricks.sdk.errors.platform import ResourceDoesNotExist diff --git a/pyproject.toml b/pyproject.toml index 3871b6d8..be66632a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dev = [ "hatch", "pytest", "ruff==0.6.4", - "databricks-sdk>=0.49.0", + "databricks-vectorsearch>=0.50", ] [tool.ruff] diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 0c255120..9a3962f4 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -41,7 +41,7 @@ class VectorSearchRetrieverToolInput(BaseModel): description="The string used to query the index with and identify the most similar " "vectors and return the associated documents." ) - filters: Optional[Dict[str, Any]] = Field( + filters: Dict[str, Any] = Field( default=None, description=( "Optional filters to refine vector search results. Supports the following operators:\n\n" @@ -51,7 +51,7 @@ class VectorSearchRetrieverToolInput(BaseModel): '- Pattern match: {"column LIKE": "word"} (matches full tokens separated by whitespace)\n' '- OR logic: {"column1 OR column2": [value1, value2]} ' "(matches if column1 equals value1 or column2 equals value2; matches are position-specific)" - ), + ) ) @@ -134,7 +134,7 @@ def _get_default_tool_description(self, index_details: IndexDetails) -> str: ) else: description = DEFAULT_TOOL_DESCRIPTION - return f"{description}\n\n{self._describe_columns}()" + return f"{description}\n\n{self._describe_columns()}" def _get_resources( self, index_name: str, embedding_endpoint: str, index_details: IndexDetails From e45ab3fba0bc57a897d9f60b35e7c31bd86e49fc Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Tue, 29 Apr 2025 16:06:00 -0700 Subject: [PATCH 09/13] llamaindex tests Signed-off-by: Ann Zhang update openai tests Signed-off-by: Ann Zhang format Signed-off-by: Ann Zhang else case Signed-off-by: Ann Zhang lint Signed-off-by: Ann Zhang improve unit tests Signed-off-by: Ann Zhang --- .../test_vector_search_retriever_tool.py | 40 ++++++++++--------- .../vector_search_retriever_tool.py | 5 +-- .../test_vector_search_retriever_tool.py | 36 +++++++++++++++++ .../vector_search_retriever_tool.py | 9 ++++- 4 files changed, 66 insertions(+), 24 deletions(-) diff --git a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py index d40f6d99..9968049e 100644 --- a/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/llamaindex/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,7 +1,7 @@ import os import threading from typing import Any, Dict, List, Optional -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from databricks.sdk import WorkspaceClient @@ -52,6 +52,7 @@ def init_vector_search_tool( tool_description: Optional[str] = None, embedding: Optional[BaseEmbedding] = None, text_column: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, ) -> VectorSearchRetrieverTool: kwargs: Dict[str, Any] = { "index_name": index_name, @@ -60,6 +61,7 @@ def init_vector_search_tool( "tool_description": tool_description, "embedding": embedding, "text_column": text_column, + "filters": filters, } if index_name != DELTA_SYNC_INDEX: kwargs.update( @@ -180,29 +182,29 @@ def test_vector_search_client_non_model_serving_environment(): def test_filters_are_passed_through() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - vector_search_tool._vector_store.similarity_search = MagicMock() - - vector_search_tool.invoke( - {"query": "what cities are in Germany", "filters": {"country": "Germany"}} - ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - "what cities are in Germany", - k=vector_search_tool.num_results, - filter={"country": "Germany"}, + vector_search_tool._index.similarity_search = MagicMock() + + vector_search_tool.call(query="what cities are in Germany", filters={"country": "Germany"}) + vector_search_tool._index.similarity_search.assert_called_once_with( + columns=vector_search_tool.columns, + query_text="what cities are in Germany", + filters={"country": "Germany"}, + num_results=vector_search_tool.num_results, query_type=vector_search_tool.query_type, + query_vector=None, ) def test_filters_are_combined() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"}) - vector_search_tool._vector_store.similarity_search = MagicMock() - - vector_search_tool.invoke( - {"query": "what cities are in Germany", "filters": {"country": "Germany"}} - ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - "what cities are in Germany", - k=vector_search_tool.num_results, - filter={"city LIKE": "Berlin", "country": "Germany"}, + vector_search_tool._index.similarity_search = MagicMock() + + vector_search_tool.call(query="what cities are in Germany", filters={"country": "Germany"}) + vector_search_tool._index.similarity_search.assert_called_once_with( + columns=vector_search_tool.columns, + query_text="what cities are in Germany", + filters={"city LIKE": "Berlin", "country": "Germany"}, + num_results=vector_search_tool.num_results, query_type=vector_search_tool.query_type, + query_vector=None, ) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 1e88a33e..ec942914 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -52,8 +52,7 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): tool_call = first_response.choices[0].message.tool_calls[0] args = json.loads(tool_call.function.arguments) result = dbvs_tool.execute( - query=args["query"], - filters=args.get("filters", None) + query=args["query"], filters=args.get("filters", None) ) # For self-managed embeddings, optionally pass in openai_client=client Step 3: Supply model with results – so it can incorporate them into its final response. @@ -145,7 +144,7 @@ def _validate_tool_inputs(self): # We need to remove strict: True from the tool in order to support arbitrary filters if "function" in self.tool and "strict" in self.tool["function"]: del self.tool["function"]["strict"] - + try: from databricks.sdk import WorkspaceClient from databricks.sdk.errors.platform import ResourceDoesNotExist diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index a28aabe0..76d6970a 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -89,6 +89,7 @@ def init_vector_search_tool( tool_description: Optional[str] = None, text_column: Optional[str] = None, embedding_model_name: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, ) -> VectorSearchRetrieverTool: kwargs: Dict[str, Any] = { "index_name": index_name, @@ -97,6 +98,7 @@ def init_vector_search_tool( "tool_description": tool_description, "text_column": text_column, "embedding_model_name": embedding_model_name, + "filters": filters, } if index_name != DELTA_SYNC_INDEX: kwargs.update( @@ -298,3 +300,37 @@ def test_vector_search_client_non_model_serving_environment(): workspace_client=w, ) mockVSClient.assert_called_once_with(disable_notice=True, credential_strategy=None) + + +def test_filters_are_passed_through() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + vector_search_tool._index.similarity_search = MagicMock() + + vector_search_tool.execute( + {"query": "what cities are in Germany"}, filters={"country": "Germany"} + ) + vector_search_tool._index.similarity_search.assert_called_once_with( + columns=vector_search_tool.columns, + query_text={"query": "what cities are in Germany"}, + filters={"country": "Germany"}, + num_results=vector_search_tool.num_results, + query_type=vector_search_tool.query_type, + query_vector=None, + ) + + +def test_filters_are_combined() -> None: + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"}) + vector_search_tool._index.similarity_search = MagicMock() + + vector_search_tool.execute( + {"query": "what cities are in Germany"}, filters={"country": "Germany"} + ) + vector_search_tool._index.similarity_search.assert_called_once_with( + columns=vector_search_tool.columns, + query_text={"query": "what cities are in Germany"}, + filters={"city LIKE": "Berlin", "country": "Germany"}, + num_results=vector_search_tool.num_results, + query_type=vector_search_tool.query_type, + query_vector=None, + ) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 9a3962f4..d0bf5b3c 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -51,7 +51,7 @@ class VectorSearchRetrieverToolInput(BaseModel): '- Pattern match: {"column LIKE": "word"} (matches full tokens separated by whitespace)\n' '- OR logic: {"column1 OR column2": [value1, value2]} ' "(matches if column1 equals value1 or column2 equals value2; matches are position-specific)" - ) + ), ) @@ -134,7 +134,12 @@ def _get_default_tool_description(self, index_details: IndexDetails) -> str: ) else: description = DEFAULT_TOOL_DESCRIPTION - return f"{description}\n\n{self._describe_columns()}" + + column_description = self._describe_columns() + if column_description: + return f"{description}\n\n{column_description}" + else: + return description def _get_resources( self, index_name: str, embedding_endpoint: str, index_details: IndexDetails From 1f20b2e51cbb3a5449e805133459fdba076a3124 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 30 Apr 2025 18:44:06 -0700 Subject: [PATCH 10/13] update Signed-off-by: Ann Zhang --- .../tests/unit_tests/test_vector_search_retriever_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 76d6970a..9a63033f 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -324,11 +324,11 @@ def test_filters_are_combined() -> None: vector_search_tool._index.similarity_search = MagicMock() vector_search_tool.execute( - {"query": "what cities are in Germany"}, filters={"country": "Germany"} + query="what cities are in Germany", filters={"country": "Germany"} ) vector_search_tool._index.similarity_search.assert_called_once_with( columns=vector_search_tool.columns, - query_text={"query": "what cities are in Germany"}, + query_text="what cities are in Germany", filters={"city LIKE": "Berlin", "country": "Germany"}, num_results=vector_search_tool.num_results, query_type=vector_search_tool.query_type, From 02671d22fcaa7fdc6c9141ac1ab05efae01e840d Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 30 Apr 2025 19:17:24 -0700 Subject: [PATCH 11/13] update Signed-off-by: Ann Zhang --- integrations/llamaindex/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/llamaindex/pyproject.toml b/integrations/llamaindex/pyproject.toml index c88ad7d9..578407d4 100644 --- a/integrations/llamaindex/pyproject.toml +++ b/integrations/llamaindex/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.9" dependencies = [ "databricks-vectorsearch>=0.40", "databricks-ai-bridge>=0.1.0", - "llama-index>=0.11.0", + "llama-index[core]>=0.11.0", ] [project.optional-dependencies] From 896e03c2955cddae8d2660c84dc3f07bd6787ca9 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 30 Apr 2025 20:25:48 -0700 Subject: [PATCH 12/13] update Signed-off-by: Ann Zhang --- integrations/llamaindex/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/llamaindex/pyproject.toml b/integrations/llamaindex/pyproject.toml index 578407d4..c88ad7d9 100644 --- a/integrations/llamaindex/pyproject.toml +++ b/integrations/llamaindex/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.9" dependencies = [ "databricks-vectorsearch>=0.40", "databricks-ai-bridge>=0.1.0", - "llama-index[core]>=0.11.0", + "llama-index>=0.11.0", ] [project.optional-dependencies] From 12f5754c3ffacd8192bc1bc48085e1072c774222 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 30 Apr 2025 20:26:14 -0700 Subject: [PATCH 13/13] update Signed-off-by: Ann Zhang --- .../tests/unit_tests/test_vector_search_retriever_tool.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 9a63033f..5b158654 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -323,9 +323,7 @@ def test_filters_are_combined() -> None: vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"}) vector_search_tool._index.similarity_search = MagicMock() - vector_search_tool.execute( - query="what cities are in Germany", filters={"country": "Germany"} - ) + vector_search_tool.execute(query="what cities are in Germany", filters={"country": "Germany"}) vector_search_tool._index.similarity_search.assert_called_once_with( columns=vector_search_tool.columns, query_text="what cities are in Germany",