From e58b2e8d7b81219882fef246a2e778582f23ad3c Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 9 Apr 2025 12:19:11 -0700 Subject: [PATCH 1/4] draft Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 1 + .../vector_search_retriever_tool.py | 4 +- .../vector_search_retriever_tool.py | 21 ++- .../test_vector_search_retriever_tool.py | 67 +++++++++ .../utils/test_vector_search.py | 129 ++++++++++++++++++ 5 files changed, 213 insertions(+), 9 deletions(-) create mode 100644 tests/databricks_ai_bridge/test_vector_search_retriever_tool.py create mode 100644 tests/databricks_ai_bridge/utils/test_vector_search.py diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 465d89dd..300c08d9 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -57,6 +57,7 @@ def _validate_tool_inputs(self): self.resources = self._get_resources( self.index_name, (self.embedding.endpoint if isinstance(self.embedding, DatabricksEmbeddings) else None), + IndexDetails(dbvs.index) ) return self diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index cb2731b6..dceb33bb 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -138,9 +138,9 @@ def _validate_tool_inputs(self): self.workspace_client.serving_endpoints.get(self.embedding_model_name) else: WorkspaceClient().serving_endpoints.get(self.embedding_model_name) - self.resources = self._get_resources(self.index_name, self.embedding_model_name) + self.resources = self._get_resources(self.index_name, self.embedding_model_name, self._index_details) except ResourceDoesNotExist: - self.resources = self._get_resources(self.index_name, None) + self.resources = self._get_resources(self.index_name, None, self._index_details) return self diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index aedb55b8..9beb377f 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -89,13 +89,20 @@ def _get_default_tool_description(self, index_details: IndexDetails) -> str: ) return DEFAULT_TOOL_DESCRIPTION - def _get_resources(self, index_name: str, embedding_endpoint: str) -> List[Resource]: - return ([DatabricksVectorSearchIndex(index_name=index_name)] if index_name else []) + ( - [DatabricksServingEndpoint(endpoint_name=embedding_endpoint)] - if embedding_endpoint - else [] - ) - + def _get_resources(self, index_name: str, embedding_endpoint: str, index_details: IndexDetails) -> List[Resource]: + resources = [] + if index_name: + resources.append(DatabricksVectorSearchIndex(index_name=index_name)) + if embedding_endpoint: + resources.append(DatabricksServingEndpoint(endpoint_name=embedding_endpoint)) + if ( + index_details.is_databricks_managed_embeddings and + (managed_embedding := index_details.embedding_source_column.get("embedding_model_endpoint_name", None)) + ): + if managed_embedding != embedding_endpoint: + resources.append(DatabricksServingEndpoint(endpoint_name=managed_embedding)) + return resources + def _get_tool_name(self) -> str: tool_name = self.tool_name or self.index_name.replace(".", "__") diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py new file mode 100644 index 00000000..31b0b6af --- /dev/null +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -0,0 +1,67 @@ +import pytest +from unittest.mock import MagicMock +from databricks_ai_bridge.utils.vector_search import IndexDetails, VectorSearchRetrieverToolMixin +from mlflow.models.resources import DatabricksVectorSearchIndex, DatabricksServingEndpoint +from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 + ALL_INDEX_NAMES, + DELTA_SYNC_INDEX, + INPUT_TEXTS, + _get_index, + mock_vs_client, + mock_workspace_client, +) + +class DummyRetriever(VectorSearchRetrieverToolMixin): + pass + +@pytest.fixture +def mock_index_details(): + mock = MagicMock(spec=IndexDetails) + mock.is_databricks_managed_embeddings = False + mock.embedding_source_column = {} + return mock + +def test_get_resources_index_only(mock_index_details): + index_name = "catalog.schema.index" + tool = DummyRetriever(index_name=index_name) + resources = tool._get_resources(index_name, None, mock_index_details) + + assert resources == [DatabricksVectorSearchIndex(index_name)] + +def test_get_resources_with_embedding_endpoint(mock_index_details): + index_name = "catalog.schema.index" + tool = DummyRetriever(index_name=index_name) + resources = tool._get_resources(index_name, "embedding_endpoint", mock_index_details) + + assert resources == [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint") + ] + +def test_get_resources_with_managed_embeddings(): + index_name = "catalog.schema.index" + mock = MagicMock(spec=IndexDetails) + mock.is_databricks_managed_embeddings = True + mock.embedding_source_column = {"embedding_model_endpoint_name": "embedding_endpoint"} + + tool = DummyRetriever(index_name=index_name) + resources = tool._get_resources("catalog.schema.index", None, mock) + + assert resources == [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint") + ] + +def test_get_resources_with_duplicate_embedding_endpoints(): + index_name = "catalog.schema.index" + mock = MagicMock(spec=IndexDetails) + mock.is_databricks_managed_embeddings = True + mock.embedding_source_column = {"embedding_model_endpoint_name": "embedding_endpoint"} + + tool = DummyRetriever(index_name=index_name) + resources = tool._get_resources("catalog.schema.index", "embedding_endpoint", mock) + + assert resources == [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint") + ] diff --git a/tests/databricks_ai_bridge/utils/test_vector_search.py b/tests/databricks_ai_bridge/utils/test_vector_search.py new file mode 100644 index 00000000..706a98bc --- /dev/null +++ b/tests/databricks_ai_bridge/utils/test_vector_search.py @@ -0,0 +1,129 @@ +import pytest + +from databricks_ai_bridge.utils.vector_search import ( + IndexDetails, + RetrieverSchema, + parse_vector_search_response, + validate_and_get_return_columns, + validate_and_get_text_column, +) + +import pytest +from dataclasses import dataclass +from typing import Dict, List, Optional, Any + +# -- Mock setup -- + +@dataclass +class MockDocument: + page_content: str + metadata: Dict[str, Any] + +@dataclass +class IndexDetails: + primary_key: str + +# -- Fixtures -- + +@pytest.fixture +def search_resp(): + return { + "manifest": { + "columns": [ + {"name": "id"}, + {"name": "text"}, + {"name": "uri"}, + {"name": "chunk"}, + {"name": "extra"}, + {"name": "score"}, + ] + }, + "result": { + "data_array": [ + [1, "This is text A", "doc://a", "chunk-1", "x", 0.9], + [2, "This is text B", "doc://b", "chunk-2", "y", 0.8], + ] + } + } + +@pytest.fixture +def retriever_schema(): + return RetrieverSchema( + text_column="text", + doc_uri="uri", + chunk_id="chunk", + other_columns=["extra"] + ) + +@pytest.fixture +def index_details(): + return IndexDetails(primary_key="id") + + +def test_parses_basic_response(search_resp, index_details, retriever_schema): + results = parse_vector_search_response( + search_resp=search_resp, + index_details=index_details, + retriever_schema=retriever_schema, + document_class=MockDocument, + ) + + assert len(results) == 2 + + doc1, score1 = results[0] + assert doc1.page_content == "This is text A" + assert doc1.metadata["doc_uri"] == "doc://a" + assert doc1.metadata["chunk_id"] == "chunk-1" + assert doc1.metadata["extra"] == "x" + assert doc1.metadata["id"] == 1 + assert score1 == 0.9 + + doc2, score2 = results[1] + assert doc2.page_content == "This is text B" + assert doc2.metadata["doc_uri"] == "doc://b" + assert doc2.metadata["chunk_id"] == "chunk-2" + assert doc2.metadata["extra"] == "y" + assert doc2.metadata["id"] == 2 + assert score2 == 0.8 + + +def test_ignores_specified_columns(search_resp, index_details, retriever_schema): + results = parse_vector_search_response( + search_resp=search_resp, + index_details=index_details, + retriever_schema=retriever_schema, + ignore_cols=["extra"], + document_class=MockDocument, + ) + + doc, _ = results[0] + assert "extra" not in doc.metadata + + +def test_handles_empty_results(index_details, retriever_schema): + empty_resp = {"manifest": {"columns": []}, "result": {"data_array": []}} + results = parse_vector_search_response( + search_resp=empty_resp, + index_details=index_details, + retriever_schema=retriever_schema, + document_class=MockDocument, + ) + assert results == [] + + +def test_missing_optional_fields_handled_gracefully(search_resp): + retriever_schema = RetrieverSchema(text_column="text") # no doc_uri, chunk_id, or other_columns + index_details = IndexDetails(primary_key="id") + + results = parse_vector_search_response( + search_resp=search_resp, + index_details=index_details, + retriever_schema=retriever_schema, + document_class=MockDocument, + ) + + doc, _ = results[0] + assert doc.page_content == "This is text A" + assert doc.metadata["id"] == 1 + assert "doc_uri" not in doc.metadata + assert "chunk_id" not in doc.metadata From 2eccd08cf949469af3769e3c9afd2067c0728b19 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 9 Apr 2025 15:25:03 -0700 Subject: [PATCH 2/4] update tests Signed-off-by: Ann Zhang --- .../test_vector_search_retriever_tool.py | 114 ++++++++-------- .../utils/test_vector_search.py | 129 ------------------ 2 files changed, 56 insertions(+), 187 deletions(-) delete mode 100644 tests/databricks_ai_bridge/utils/test_vector_search.py diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py index 31b0b6af..792ec98f 100644 --- a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -1,67 +1,65 @@ import pytest from unittest.mock import MagicMock -from databricks_ai_bridge.utils.vector_search import IndexDetails, VectorSearchRetrieverToolMixin +from databricks_ai_bridge.utils.vector_search import IndexDetails +from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin from mlflow.models.resources import DatabricksVectorSearchIndex, DatabricksServingEndpoint -from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 - ALL_INDEX_NAMES, - DELTA_SYNC_INDEX, - INPUT_TEXTS, - _get_index, - mock_vs_client, - mock_workspace_client, -) -class DummyRetriever(VectorSearchRetrieverToolMixin): +class DummyVectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): pass -@pytest.fixture -def mock_index_details(): - mock = MagicMock(spec=IndexDetails) - mock.is_databricks_managed_embeddings = False - mock.embedding_source_column = {} - return mock - -def test_get_resources_index_only(mock_index_details): - index_name = "catalog.schema.index" - tool = DummyRetriever(index_name=index_name) - resources = tool._get_resources(index_name, None, mock_index_details) - - assert resources == [DatabricksVectorSearchIndex(index_name)] - -def test_get_resources_with_embedding_endpoint(mock_index_details): - index_name = "catalog.schema.index" - tool = DummyRetriever(index_name=index_name) - resources = tool._get_resources(index_name, "embedding_endpoint", mock_index_details) - - assert resources == [ - DatabricksVectorSearchIndex(index_name), - DatabricksServingEndpoint("embedding_endpoint") - ] +index_name = "catalog.schema.index" -def test_get_resources_with_managed_embeddings(): - index_name = "catalog.schema.index" +def make_mock_index_details(is_databricks_managed_embeddings=False, embedding_source_column=None): mock = MagicMock(spec=IndexDetails) - mock.is_databricks_managed_embeddings = True - mock.embedding_source_column = {"embedding_model_endpoint_name": "embedding_endpoint"} - - tool = DummyRetriever(index_name=index_name) - resources = tool._get_resources("catalog.schema.index", None, mock) - - assert resources == [ - DatabricksVectorSearchIndex(index_name), - DatabricksServingEndpoint("embedding_endpoint") - ] - -def test_get_resources_with_duplicate_embedding_endpoints(): - index_name = "catalog.schema.index" - mock = MagicMock(spec=IndexDetails) - mock.is_databricks_managed_embeddings = True - mock.embedding_source_column = {"embedding_model_endpoint_name": "embedding_endpoint"} - - tool = DummyRetriever(index_name=index_name) - resources = tool._get_resources("catalog.schema.index", "embedding_endpoint", mock) + mock.is_databricks_managed_embeddings = is_databricks_managed_embeddings + mock.embedding_source_column = embedding_source_column or {} + return mock - assert resources == [ - DatabricksVectorSearchIndex(index_name), - DatabricksServingEndpoint("embedding_endpoint") - ] +@pytest.mark.parametrize("embedding_endpoint,index_details,resources", [ + ( + None, + make_mock_index_details(False, {}), + [DatabricksVectorSearchIndex(index_name)] + ), + ( + "embedding_endpoint", + make_mock_index_details(False, {}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint") + ] + ), + ( + None, + make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint") + ] + ), # The following cases should not happen, but ensuring that they have reasonable behavior + ( + "embedding_endpoint", + make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint") + ] + ), + ( + "embedding_endpoint_1", + make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint_2"}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint_1"), + DatabricksServingEndpoint("embedding_endpoint_2") + ] + ), + ( + None, + make_mock_index_details(True, {}), + [DatabricksVectorSearchIndex(index_name)] + ) +]) +def test_get_resources(embedding_endpoint, index_details, resources): + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + assert tool._get_resources(index_name, embedding_endpoint, index_details) == resources diff --git a/tests/databricks_ai_bridge/utils/test_vector_search.py b/tests/databricks_ai_bridge/utils/test_vector_search.py deleted file mode 100644 index 706a98bc..00000000 --- a/tests/databricks_ai_bridge/utils/test_vector_search.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest - -from databricks_ai_bridge.utils.vector_search import ( - IndexDetails, - RetrieverSchema, - parse_vector_search_response, - validate_and_get_return_columns, - validate_and_get_text_column, -) - -import pytest -from dataclasses import dataclass -from typing import Dict, List, Optional, Any - -# -- Mock setup -- - -@dataclass -class MockDocument: - page_content: str - metadata: Dict[str, Any] - -@dataclass -class IndexDetails: - primary_key: str - -# -- Fixtures -- - -@pytest.fixture -def search_resp(): - return { - "manifest": { - "columns": [ - {"name": "id"}, - {"name": "text"}, - {"name": "uri"}, - {"name": "chunk"}, - {"name": "extra"}, - {"name": "score"}, - ] - }, - "result": { - "data_array": [ - [1, "This is text A", "doc://a", "chunk-1", "x", 0.9], - [2, "This is text B", "doc://b", "chunk-2", "y", 0.8], - ] - } - } - -@pytest.fixture -def retriever_schema(): - return RetrieverSchema( - text_column="text", - doc_uri="uri", - chunk_id="chunk", - other_columns=["extra"] - ) - -@pytest.fixture -def index_details(): - return IndexDetails(primary_key="id") - - -def test_parses_basic_response(search_resp, index_details, retriever_schema): - results = parse_vector_search_response( - search_resp=search_resp, - index_details=index_details, - retriever_schema=retriever_schema, - document_class=MockDocument, - ) - - assert len(results) == 2 - - doc1, score1 = results[0] - assert doc1.page_content == "This is text A" - assert doc1.metadata["doc_uri"] == "doc://a" - assert doc1.metadata["chunk_id"] == "chunk-1" - assert doc1.metadata["extra"] == "x" - assert doc1.metadata["id"] == 1 - assert score1 == 0.9 - - doc2, score2 = results[1] - assert doc2.page_content == "This is text B" - assert doc2.metadata["doc_uri"] == "doc://b" - assert doc2.metadata["chunk_id"] == "chunk-2" - assert doc2.metadata["extra"] == "y" - assert doc2.metadata["id"] == 2 - assert score2 == 0.8 - - -def test_ignores_specified_columns(search_resp, index_details, retriever_schema): - results = parse_vector_search_response( - search_resp=search_resp, - index_details=index_details, - retriever_schema=retriever_schema, - ignore_cols=["extra"], - document_class=MockDocument, - ) - - doc, _ = results[0] - assert "extra" not in doc.metadata - - -def test_handles_empty_results(index_details, retriever_schema): - empty_resp = {"manifest": {"columns": []}, "result": {"data_array": []}} - results = parse_vector_search_response( - search_resp=empty_resp, - index_details=index_details, - retriever_schema=retriever_schema, - document_class=MockDocument, - ) - assert results == [] - - -def test_missing_optional_fields_handled_gracefully(search_resp): - retriever_schema = RetrieverSchema(text_column="text") # no doc_uri, chunk_id, or other_columns - index_details = IndexDetails(primary_key="id") - - results = parse_vector_search_response( - search_resp=search_resp, - index_details=index_details, - retriever_schema=retriever_schema, - document_class=MockDocument, - ) - - doc, _ = results[0] - assert doc.page_content == "This is text A" - assert doc.metadata["id"] == 1 - assert "doc_uri" not in doc.metadata - assert "chunk_id" not in doc.metadata From 605ffddd8e2795f589bc8493d4723a8f8cc64578 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 9 Apr 2025 15:35:39 -0700 Subject: [PATCH 3/4] update tests and lint Signed-off-by: Ann Zhang --- .../vector_search_retriever_tool.py | 2 +- .../test_vector_search_retriever_tool.py | 15 ++- .../vector_search_retriever_tool.py | 4 +- .../test_vector_search_retriever_tool.py | 21 +++- .../test_utils/vector_search.py | 4 +- .../vector_search_retriever_tool.py | 13 ++- .../test_vector_search_retriever_tool.py | 97 ++++++++++--------- 7 files changed, 95 insertions(+), 61 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 300c08d9..49206c94 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -57,7 +57,7 @@ def _validate_tool_inputs(self): self.resources = self._get_resources( self.index_name, (self.embedding.endpoint if isinstance(self.embedding, DatabricksEmbeddings) else None), - IndexDetails(dbvs.index) + IndexDetails(dbvs.index), ) return self diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index 2bd4f4c6..f5cbf748 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -12,6 +12,7 @@ from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, DELTA_SYNC_INDEX, + DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, INPUT_TEXTS, _get_index, mock_vs_client, @@ -157,8 +158,18 @@ def test_vector_search_retriever_tool_resources( vector_search_tool = VectorSearchRetrieverTool( index_name=index_name, embedding=embeddings, text_column=text_column ) - expected_resources = [DatabricksVectorSearchIndex(index_name=index_name)] + ( - [DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)] if embeddings else [] + expected_resources = ( + ([DatabricksVectorSearchIndex(index_name=index_name)]) + + ([DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)] if embeddings else []) + + ( + [ + DatabricksServingEndpoint( + endpoint_name=DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME + ) + ] + if index_name == DELTA_SYNC_INDEX + else [] + ) ) assert [res.to_dict() for res in vector_search_tool.resources] == [ res.to_dict() for res in expected_resources diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index dceb33bb..d7bf049a 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -138,7 +138,9 @@ def _validate_tool_inputs(self): self.workspace_client.serving_endpoints.get(self.embedding_model_name) else: WorkspaceClient().serving_endpoints.get(self.embedding_model_name) - self.resources = self._get_resources(self.index_name, self.embedding_model_name, self._index_details) + self.resources = self._get_resources( + self.index_name, self.embedding_model_name, self._index_details + ) except ResourceDoesNotExist: self.resources = self._get_resources(self.index_name, None, self._index_details) diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 254626a9..9347e813 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -12,6 +12,7 @@ from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, DELTA_SYNC_INDEX, + DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, DIRECT_ACCESS_INDEX, INPUT_TEXTS, mock_vs_client, @@ -143,10 +144,22 @@ def test_vector_search_retriever_tool_init( ) assert isinstance(vector_search_tool, BaseModel) - expected_resources = [DatabricksVectorSearchIndex(index_name=index_name)] + ( - [DatabricksServingEndpoint(endpoint_name="text-embedding-3-small")] - if self_managed_embeddings_test.embedding_model_name - else [] + expected_resources = ( + ([DatabricksVectorSearchIndex(index_name=index_name)]) + + ( + [DatabricksServingEndpoint(endpoint_name="text-embedding-3-small")] + if self_managed_embeddings_test.embedding_model_name + else [] + ) + + ( + [ + DatabricksServingEndpoint( + endpoint_name=DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME + ) + ] + if index_name == DELTA_SYNC_INDEX + else [] + ) ) assert [res.to_dict() for res in vector_search_tool.resources] == [ res.to_dict() for res in expected_resources diff --git a/src/databricks_ai_bridge/test_utils/vector_search.py b/src/databricks_ai_bridge/test_utils/vector_search.py index 4e0ea96d..19e90505 100644 --- a/src/databricks_ai_bridge/test_utils/vector_search.py +++ b/src/databricks_ai_bridge/test_utils/vector_search.py @@ -56,6 +56,8 @@ def embed_documents(embedding_texts: List[str]) -> List[List[float]]: DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, } +DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME = "openai-text-embedding" + INDEX_DETAILS = { DELTA_SYNC_INDEX: { "name": DELTA_SYNC_INDEX, @@ -68,7 +70,7 @@ def embed_documents(embedding_texts: List[str]) -> List[List[float]]: "embedding_source_columns": [ { "name": "text", - "embedding_model_endpoint_name": "openai-text-embedding", + "embedding_model_endpoint_name": DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, } ], }, diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 9beb377f..a106e6da 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -89,20 +89,23 @@ def _get_default_tool_description(self, index_details: IndexDetails) -> str: ) return DEFAULT_TOOL_DESCRIPTION - def _get_resources(self, index_name: str, embedding_endpoint: str, index_details: IndexDetails) -> List[Resource]: + def _get_resources( + self, index_name: str, embedding_endpoint: str, index_details: IndexDetails + ) -> List[Resource]: resources = [] if index_name: resources.append(DatabricksVectorSearchIndex(index_name=index_name)) if embedding_endpoint: resources.append(DatabricksServingEndpoint(endpoint_name=embedding_endpoint)) - if ( - index_details.is_databricks_managed_embeddings and - (managed_embedding := index_details.embedding_source_column.get("embedding_model_endpoint_name", None)) + if index_details.is_databricks_managed_embeddings and ( + managed_embedding := index_details.embedding_source_column.get( + "embedding_model_endpoint_name", None + ) ): if managed_embedding != embedding_endpoint: resources.append(DatabricksServingEndpoint(endpoint_name=managed_embedding)) return resources - + def _get_tool_name(self) -> str: tool_name = self.tool_name or self.index_name.replace(".", "__") diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py index 792ec98f..b4dc66f7 100644 --- a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -1,65 +1,68 @@ -import pytest from unittest.mock import MagicMock + +import pytest +from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex + from databricks_ai_bridge.utils.vector_search import IndexDetails from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin -from mlflow.models.resources import DatabricksVectorSearchIndex, DatabricksServingEndpoint + class DummyVectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): pass + index_name = "catalog.schema.index" + def make_mock_index_details(is_databricks_managed_embeddings=False, embedding_source_column=None): mock = MagicMock(spec=IndexDetails) mock.is_databricks_managed_embeddings = is_databricks_managed_embeddings mock.embedding_source_column = embedding_source_column or {} return mock -@pytest.mark.parametrize("embedding_endpoint,index_details,resources", [ - ( - None, - make_mock_index_details(False, {}), - [DatabricksVectorSearchIndex(index_name)] - ), - ( - "embedding_endpoint", - make_mock_index_details(False, {}), - [ - DatabricksVectorSearchIndex(index_name), - DatabricksServingEndpoint("embedding_endpoint") - ] - ), - ( - None, - make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}), - [ - DatabricksVectorSearchIndex(index_name), - DatabricksServingEndpoint("embedding_endpoint") - ] - ), # The following cases should not happen, but ensuring that they have reasonable behavior - ( - "embedding_endpoint", - make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}), - [ - DatabricksVectorSearchIndex(index_name), - DatabricksServingEndpoint("embedding_endpoint") - ] - ), - ( - "embedding_endpoint_1", - make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint_2"}), - [ - DatabricksVectorSearchIndex(index_name), - DatabricksServingEndpoint("embedding_endpoint_1"), - DatabricksServingEndpoint("embedding_endpoint_2") - ] - ), - ( - None, - make_mock_index_details(True, {}), - [DatabricksVectorSearchIndex(index_name)] - ) -]) + +@pytest.mark.parametrize( + "embedding_endpoint,index_details,resources", + [ + (None, make_mock_index_details(False, {}), [DatabricksVectorSearchIndex(index_name)]), + ( + "embedding_endpoint", + make_mock_index_details(False, {}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint"), + ], + ), + ( + None, + make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint"), + ], + ), # The following cases should not happen, but ensuring that they have reasonable behavior + ( + "embedding_endpoint", + make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint"), + ], + ), + ( + "embedding_endpoint_1", + make_mock_index_details( + True, {"embedding_model_endpoint_name": "embedding_endpoint_2"} + ), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint_1"), + DatabricksServingEndpoint("embedding_endpoint_2"), + ], + ), + (None, make_mock_index_details(True, {}), [DatabricksVectorSearchIndex(index_name)]), + ], +) def test_get_resources(embedding_endpoint, index_details, resources): tool = DummyVectorSearchRetrieverTool(index_name=index_name) assert tool._get_resources(index_name, embedding_endpoint, index_details) == resources From edece1cec9303424e36b80bde76d04b3e7b89303 Mon Sep 17 00:00:00 2001 From: Ann Zhang Date: Wed, 9 Apr 2025 15:37:12 -0700 Subject: [PATCH 4/4] format again Signed-off-by: Ann Zhang --- .../tests/unit_tests/test_vector_search_retriever_tool.py | 2 +- .../tests/unit_tests/test_vector_search_retriever_tool.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index f5cbf748..1e0cfe13 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -159,7 +159,7 @@ def test_vector_search_retriever_tool_resources( index_name=index_name, embedding=embeddings, text_column=text_column ) expected_resources = ( - ([DatabricksVectorSearchIndex(index_name=index_name)]) + [DatabricksVectorSearchIndex(index_name=index_name)] + ([DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)] if embeddings else []) + ( [ diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 9347e813..a28aabe0 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -145,7 +145,7 @@ def test_vector_search_retriever_tool_init( assert isinstance(vector_search_tool, BaseModel) expected_resources = ( - ([DatabricksVectorSearchIndex(index_name=index_name)]) + [DatabricksVectorSearchIndex(index_name=index_name)] + ( [DatabricksServingEndpoint(endpoint_name="text-embedding-3-small")] if self_managed_embeddings_test.embedding_model_name