diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 465d89dd..49206c94 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -57,6 +57,7 @@ def _validate_tool_inputs(self): self.resources = self._get_resources( self.index_name, (self.embedding.endpoint if isinstance(self.embedding, DatabricksEmbeddings) else None), + IndexDetails(dbvs.index), ) return self diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index 2bd4f4c6..1e0cfe13 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -12,6 +12,7 @@ from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, DELTA_SYNC_INDEX, + DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, INPUT_TEXTS, _get_index, mock_vs_client, @@ -157,8 +158,18 @@ def test_vector_search_retriever_tool_resources( vector_search_tool = VectorSearchRetrieverTool( index_name=index_name, embedding=embeddings, text_column=text_column ) - expected_resources = [DatabricksVectorSearchIndex(index_name=index_name)] + ( - [DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)] if embeddings else [] + expected_resources = ( + [DatabricksVectorSearchIndex(index_name=index_name)] + + ([DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)] if embeddings else []) + + ( + [ + DatabricksServingEndpoint( + endpoint_name=DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME + ) + ] + if index_name == DELTA_SYNC_INDEX + else [] + ) ) assert [res.to_dict() for res in vector_search_tool.resources] == [ res.to_dict() for res in expected_resources diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index cb2731b6..d7bf049a 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -138,9 +138,11 @@ def _validate_tool_inputs(self): self.workspace_client.serving_endpoints.get(self.embedding_model_name) else: WorkspaceClient().serving_endpoints.get(self.embedding_model_name) - self.resources = self._get_resources(self.index_name, self.embedding_model_name) + self.resources = self._get_resources( + self.index_name, self.embedding_model_name, self._index_details + ) except ResourceDoesNotExist: - self.resources = self._get_resources(self.index_name, None) + self.resources = self._get_resources(self.index_name, None, self._index_details) return self diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 254626a9..a28aabe0 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -12,6 +12,7 @@ from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, DELTA_SYNC_INDEX, + DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, DIRECT_ACCESS_INDEX, INPUT_TEXTS, mock_vs_client, @@ -143,10 +144,22 @@ def test_vector_search_retriever_tool_init( ) assert isinstance(vector_search_tool, BaseModel) - expected_resources = [DatabricksVectorSearchIndex(index_name=index_name)] + ( - [DatabricksServingEndpoint(endpoint_name="text-embedding-3-small")] - if self_managed_embeddings_test.embedding_model_name - else [] + expected_resources = ( + [DatabricksVectorSearchIndex(index_name=index_name)] + + ( + [DatabricksServingEndpoint(endpoint_name="text-embedding-3-small")] + if self_managed_embeddings_test.embedding_model_name + else [] + ) + + ( + [ + DatabricksServingEndpoint( + endpoint_name=DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME + ) + ] + if index_name == DELTA_SYNC_INDEX + else [] + ) ) assert [res.to_dict() for res in vector_search_tool.resources] == [ res.to_dict() for res in expected_resources diff --git a/src/databricks_ai_bridge/test_utils/vector_search.py b/src/databricks_ai_bridge/test_utils/vector_search.py index 4e0ea96d..19e90505 100644 --- a/src/databricks_ai_bridge/test_utils/vector_search.py +++ b/src/databricks_ai_bridge/test_utils/vector_search.py @@ -56,6 +56,8 @@ def embed_documents(embedding_texts: List[str]) -> List[List[float]]: DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, } +DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME = "openai-text-embedding" + INDEX_DETAILS = { DELTA_SYNC_INDEX: { "name": DELTA_SYNC_INDEX, @@ -68,7 +70,7 @@ def embed_documents(embedding_texts: List[str]) -> List[List[float]]: "embedding_source_columns": [ { "name": "text", - "embedding_model_endpoint_name": "openai-text-embedding", + "embedding_model_endpoint_name": DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, } ], }, diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index aedb55b8..a106e6da 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -89,12 +89,22 @@ def _get_default_tool_description(self, index_details: IndexDetails) -> str: ) return DEFAULT_TOOL_DESCRIPTION - def _get_resources(self, index_name: str, embedding_endpoint: str) -> List[Resource]: - return ([DatabricksVectorSearchIndex(index_name=index_name)] if index_name else []) + ( - [DatabricksServingEndpoint(endpoint_name=embedding_endpoint)] - if embedding_endpoint - else [] - ) + def _get_resources( + self, index_name: str, embedding_endpoint: str, index_details: IndexDetails + ) -> List[Resource]: + resources = [] + if index_name: + resources.append(DatabricksVectorSearchIndex(index_name=index_name)) + if embedding_endpoint: + resources.append(DatabricksServingEndpoint(endpoint_name=embedding_endpoint)) + if index_details.is_databricks_managed_embeddings and ( + managed_embedding := index_details.embedding_source_column.get( + "embedding_model_endpoint_name", None + ) + ): + if managed_embedding != embedding_endpoint: + resources.append(DatabricksServingEndpoint(endpoint_name=managed_embedding)) + return resources def _get_tool_name(self) -> str: tool_name = self.tool_name or self.index_name.replace(".", "__") diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py new file mode 100644 index 00000000..b4dc66f7 --- /dev/null +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -0,0 +1,68 @@ +from unittest.mock import MagicMock + +import pytest +from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex + +from databricks_ai_bridge.utils.vector_search import IndexDetails +from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin + + +class DummyVectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): + pass + + +index_name = "catalog.schema.index" + + +def make_mock_index_details(is_databricks_managed_embeddings=False, embedding_source_column=None): + mock = MagicMock(spec=IndexDetails) + mock.is_databricks_managed_embeddings = is_databricks_managed_embeddings + mock.embedding_source_column = embedding_source_column or {} + return mock + + +@pytest.mark.parametrize( + "embedding_endpoint,index_details,resources", + [ + (None, make_mock_index_details(False, {}), [DatabricksVectorSearchIndex(index_name)]), + ( + "embedding_endpoint", + make_mock_index_details(False, {}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint"), + ], + ), + ( + None, + make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint"), + ], + ), # The following cases should not happen, but ensuring that they have reasonable behavior + ( + "embedding_endpoint", + make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint"), + ], + ), + ( + "embedding_endpoint_1", + make_mock_index_details( + True, {"embedding_model_endpoint_name": "embedding_endpoint_2"} + ), + [ + DatabricksVectorSearchIndex(index_name), + DatabricksServingEndpoint("embedding_endpoint_1"), + DatabricksServingEndpoint("embedding_endpoint_2"), + ], + ), + (None, make_mock_index_details(True, {}), [DatabricksVectorSearchIndex(index_name)]), + ], +) +def test_get_resources(embedding_endpoint, index_details, resources): + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + assert tool._get_resources(index_name, embedding_endpoint, index_details) == resources