Skip to content

Add embedding model as resource #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _validate_tool_inputs(self):
self.resources = self._get_resources(
self.index_name,
(self.embedding.endpoint if isinstance(self.embedding, DatabricksEmbeddings) else None),
IndexDetails(dbvs.index),
)

return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401
ALL_INDEX_NAMES,
DELTA_SYNC_INDEX,
DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME,
INPUT_TEXTS,
_get_index,
mock_vs_client,
Expand Down Expand Up @@ -157,8 +158,18 @@ def test_vector_search_retriever_tool_resources(
vector_search_tool = VectorSearchRetrieverTool(
index_name=index_name, embedding=embeddings, text_column=text_column
)
expected_resources = [DatabricksVectorSearchIndex(index_name=index_name)] + (
[DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)] if embeddings else []
expected_resources = (
[DatabricksVectorSearchIndex(index_name=index_name)]
+ ([DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)] if embeddings else [])
+ (
[
DatabricksServingEndpoint(
endpoint_name=DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME
)
]
if index_name == DELTA_SYNC_INDEX
else []
)
)
assert [res.to_dict() for res in vector_search_tool.resources] == [
res.to_dict() for res in expected_resources
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,11 @@ def _validate_tool_inputs(self):
self.workspace_client.serving_endpoints.get(self.embedding_model_name)
else:
WorkspaceClient().serving_endpoints.get(self.embedding_model_name)
self.resources = self._get_resources(self.index_name, self.embedding_model_name)
self.resources = self._get_resources(
self.index_name, self.embedding_model_name, self._index_details
)
except ResourceDoesNotExist:
self.resources = self._get_resources(self.index_name, None)
self.resources = self._get_resources(self.index_name, None, self._index_details)

return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401
ALL_INDEX_NAMES,
DELTA_SYNC_INDEX,
DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME,
DIRECT_ACCESS_INDEX,
INPUT_TEXTS,
mock_vs_client,
Expand Down Expand Up @@ -143,10 +144,22 @@ def test_vector_search_retriever_tool_init(
)
assert isinstance(vector_search_tool, BaseModel)

expected_resources = [DatabricksVectorSearchIndex(index_name=index_name)] + (
[DatabricksServingEndpoint(endpoint_name="text-embedding-3-small")]
if self_managed_embeddings_test.embedding_model_name
else []
expected_resources = (
[DatabricksVectorSearchIndex(index_name=index_name)]
+ (
[DatabricksServingEndpoint(endpoint_name="text-embedding-3-small")]
if self_managed_embeddings_test.embedding_model_name
else []
)
+ (
[
DatabricksServingEndpoint(
endpoint_name=DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME
)
]
if index_name == DELTA_SYNC_INDEX
else []
)
)
assert [res.to_dict() for res in vector_search_tool.resources] == [
res.to_dict() for res in expected_resources
Expand Down
4 changes: 3 additions & 1 deletion src/databricks_ai_bridge/test_utils/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def embed_documents(embedding_texts: List[str]) -> List[List[float]]:
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
}

DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME = "openai-text-embedding"

INDEX_DETAILS = {
DELTA_SYNC_INDEX: {
"name": DELTA_SYNC_INDEX,
Expand All @@ -68,7 +70,7 @@ def embed_documents(embedding_texts: List[str]) -> List[List[float]]:
"embedding_source_columns": [
{
"name": "text",
"embedding_model_endpoint_name": "openai-text-embedding",
"embedding_model_endpoint_name": DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME,
}
],
},
Expand Down
22 changes: 16 additions & 6 deletions src/databricks_ai_bridge/vector_search_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,22 @@ def _get_default_tool_description(self, index_details: IndexDetails) -> str:
)
return DEFAULT_TOOL_DESCRIPTION

def _get_resources(self, index_name: str, embedding_endpoint: str) -> List[Resource]:
return ([DatabricksVectorSearchIndex(index_name=index_name)] if index_name else []) + (
[DatabricksServingEndpoint(endpoint_name=embedding_endpoint)]
if embedding_endpoint
else []
)
def _get_resources(
self, index_name: str, embedding_endpoint: str, index_details: IndexDetails
) -> List[Resource]:
resources = []
if index_name:
resources.append(DatabricksVectorSearchIndex(index_name=index_name))
if embedding_endpoint:
resources.append(DatabricksServingEndpoint(endpoint_name=embedding_endpoint))
if index_details.is_databricks_managed_embeddings and (
managed_embedding := index_details.embedding_source_column.get(
"embedding_model_endpoint_name", None
)
):
if managed_embedding != embedding_endpoint:
resources.append(DatabricksServingEndpoint(endpoint_name=managed_embedding))
return resources

def _get_tool_name(self) -> str:
tool_name = self.tool_name or self.index_name.replace(".", "__")
Expand Down
68 changes: 68 additions & 0 deletions tests/databricks_ai_bridge/test_vector_search_retriever_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from unittest.mock import MagicMock

import pytest
from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex

from databricks_ai_bridge.utils.vector_search import IndexDetails
from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin


class DummyVectorSearchRetrieverTool(VectorSearchRetrieverToolMixin):
pass


index_name = "catalog.schema.index"


def make_mock_index_details(is_databricks_managed_embeddings=False, embedding_source_column=None):
mock = MagicMock(spec=IndexDetails)
mock.is_databricks_managed_embeddings = is_databricks_managed_embeddings
mock.embedding_source_column = embedding_source_column or {}
return mock


@pytest.mark.parametrize(
"embedding_endpoint,index_details,resources",
[
(None, make_mock_index_details(False, {}), [DatabricksVectorSearchIndex(index_name)]),
(
"embedding_endpoint",
make_mock_index_details(False, {}),
[
DatabricksVectorSearchIndex(index_name),
DatabricksServingEndpoint("embedding_endpoint"),
],
),
(
None,
make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}),
[
DatabricksVectorSearchIndex(index_name),
DatabricksServingEndpoint("embedding_endpoint"),
],
), # The following cases should not happen, but ensuring that they have reasonable behavior
(
"embedding_endpoint",
make_mock_index_details(True, {"embedding_model_endpoint_name": "embedding_endpoint"}),
[
DatabricksVectorSearchIndex(index_name),
DatabricksServingEndpoint("embedding_endpoint"),
],
),
(
"embedding_endpoint_1",
make_mock_index_details(
True, {"embedding_model_endpoint_name": "embedding_endpoint_2"}
),
[
DatabricksVectorSearchIndex(index_name),
DatabricksServingEndpoint("embedding_endpoint_1"),
DatabricksServingEndpoint("embedding_endpoint_2"),
],
),
(None, make_mock_index_details(True, {}), [DatabricksVectorSearchIndex(index_name)]),
],
)
def test_get_resources(embedding_endpoint, index_details, resources):
tool = DummyVectorSearchRetrieverTool(index_name=index_name)
assert tool._get_resources(index_name, embedding_endpoint, index_details) == resources