Skip to content

Commit 25a1a33

Browse files
✨ fix vector search tool
✨ fix vector search tool
2 parents bf69923 + 16ab9b8 commit 25a1a33

File tree

3 files changed

+105
-70
lines changed

3 files changed

+105
-70
lines changed

src/databricks/labs/mcp/servers/unity_catalog/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ class CliSettings(BaseSettings):
3232
validation_alias=AliasChoices("g", "genie_space_ids"),
3333
)
3434

35+
vector_search_num_results: int = Field(
36+
default=5,
37+
description="Number of results to return from vector search queries",
38+
validation_alias=AliasChoices("vn", "vector_search_num_results", "vector_num_results"),
39+
)
40+
3541
def get_catalog_name(self):
3642
return self.schema_full_name.split(".")[0] if self.schema_full_name else None
3743

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,82 @@
1-
import io
21
import json
3-
from contextlib import redirect_stdout
4-
5-
from databricks_openai import VectorSearchRetrieverTool
2+
from pydantic import BaseModel
63
from databricks.sdk import WorkspaceClient
4+
from databricks.vector_search.client import VectorSearchClient
5+
from databricks.labs.mcp.servers.unity_catalog.tools.base_tool import BaseTool
6+
from databricks.labs.mcp.servers.unity_catalog.cli import CliSettings
7+
from mcp.types import TextContent, Tool as ToolSpec
78

8-
from mcp.types import Tool as ToolSpec, TextContent
9+
# Constant storing vector index content vector column name
10+
CONTENT_VECTOR_COLUMN_NAME = "__db_content_vector"
911

10-
from databricks.labs.mcp.servers.unity_catalog.tools.base_tool import BaseTool
12+
13+
class QueryInput(BaseModel):
14+
query: str
1115

1216

1317
class VectorSearchTool(BaseTool):
14-
def __init__(self, tool_obj: VectorSearchRetrieverTool):
15-
self.tool_obj = tool_obj
16-
tool_info = tool_obj.tool["function"]
17-
llm_friendly_tool_name = tool_info["name"]
18+
def __init__(self, endpoint_name: str, index_name: str, tool_name: str, columns: list[str], num_results: int = 5):
19+
self.endpoint_name = endpoint_name
20+
self.index_name = index_name
21+
self.tool_name = tool_name
22+
self.columns = columns
23+
self.num_results = num_results
24+
1825
tool_spec = ToolSpec(
19-
name=llm_friendly_tool_name,
20-
description=tool_info["description"],
21-
inputSchema=tool_info["parameters"],
26+
name=tool_name,
27+
description=f"Searches the vector index `{index_name}`.",
28+
inputSchema=QueryInput.model_json_schema(),
2229
)
23-
super().__init__(tool_spec=tool_spec)
30+
super().__init__(tool_spec)
2431

2532
def execute(self, **kwargs):
26-
"""
27-
Executes the vector search tool with the provided arguments.
28-
"""
29-
# Create a buffer to capture stdout from vector search client
30-
# print statements
31-
f = io.StringIO()
32-
with redirect_stdout(f):
33-
res = self.tool_obj.execute(**kwargs)
34-
return [
35-
TextContent(
36-
type="text",
37-
text=json.dumps(vs_res),
38-
)
39-
for vs_res in res
40-
]
33+
model = QueryInput.model_validate(kwargs)
34+
vsc = VectorSearchClient(disable_notice=True)
35+
36+
index = vsc.get_index(index_name=self.index_name)
37+
38+
results = index.similarity_search(
39+
query_text=model.query,
40+
columns=self.columns,
41+
num_results=self.num_results,
42+
)
43+
44+
docs = results.get("result", {}).get("data_array", [])
45+
46+
return [TextContent(type="text", text=json.dumps(docs, indent=2))]
47+
48+
49+
def get_table_columns(workspace_client: WorkspaceClient, full_table_name: str) -> list[str]:
50+
table_info = workspace_client.tables.get(full_table_name)
51+
return [
52+
col.name
53+
for col in table_info.columns
54+
if col.name != CONTENT_VECTOR_COLUMN_NAME
55+
]
4156

4257

4358
def _list_vector_search_tools(
44-
workspace_client: WorkspaceClient, catalog_name: str, schema_name: str
59+
workspace_client: WorkspaceClient, catalog_name: str, schema_name: str, vector_search_num_results: int
4560
) -> list[VectorSearchTool]:
4661
tools = []
4762
for table in workspace_client.tables.list(
4863
catalog_name=catalog_name, schema_name=schema_name
4964
):
50-
# TODO: support filtering tables by securable kind (e.g. by making securable
51-
# kind accessible here)
5265
if not table.properties or "model_endpoint_url" not in table.properties:
5366
continue
54-
tool_obj = VectorSearchRetrieverTool(index_name=table.full_name)
55-
tools.append(VectorSearchTool(tool_obj))
67+
68+
endpoint = table.properties["model_endpoint_url"]
69+
index_name = table.full_name
70+
tool_name = f"vector_search_{table.name}"
71+
72+
columns = get_table_columns(workspace_client, index_name)
73+
74+
tools.append(VectorSearchTool(endpoint, index_name, tool_name, columns, vector_search_num_results))
75+
5676
return tools
5777

5878

59-
def list_vector_search_tools(settings) -> list[VectorSearchTool]:
79+
def list_vector_search_tools(settings: CliSettings) -> list[VectorSearchTool]:
6080
workspace_client = WorkspaceClient()
6181
catalog_name, schema_name = settings.schema_full_name.split(".")
62-
return _list_vector_search_tools(workspace_client, catalog_name, schema_name)
82+
return _list_vector_search_tools(workspace_client, catalog_name, schema_name, settings.vector_search_num_results)

tests/test_vector_search.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
class DummyTable:
1010
def __init__(self, full_name, properties):
1111
self.full_name = full_name
12+
self.name = full_name.split(".")[-1]
1213
self.properties = properties
1314

1415

@@ -21,6 +22,17 @@ def list(self, catalog_name=None, schema_name=None):
2122
DummyTable(full_name="cat.sch.tbl2", properties={}),
2223
]
2324

25+
def get(self, full_table_name):
26+
# Mock get_table_columns behavior
27+
class DummyColumn:
28+
def __init__(self, name):
29+
self.name = name
30+
31+
class DummyTableInfo:
32+
columns = [DummyColumn("col1"), DummyColumn("col2"), DummyColumn("__db_content_vector")]
33+
34+
return DummyTableInfo()
35+
2436

2537
class DummyWorkspaceClient:
2638
def __init__(self):
@@ -35,47 +47,44 @@ class DummySettings:
3547
"databricks.labs.mcp.servers.unity_catalog.tools.vector_search.WorkspaceClient",
3648
new=DummyWorkspaceClient,
3749
)
38-
@mock.patch(
39-
"databricks.labs.mcp.servers.unity_catalog.tools.vector_search.VectorSearchRetrieverTool"
40-
)
41-
def test_list_vector_search_tools_filters_and_returns_expected(
42-
MockVectorSearchRetrieverTool,
43-
):
44-
MockVectorSearchRetrieverTool.side_effect = lambda index_name: mock.Mock(
45-
tool={"function": {"name": index_name, "description": "", "parameters": {}}},
46-
index_name=index_name,
47-
)
50+
def test_list_vector_search_tools_filters_and_returns_expected():
4851
settings = DummySettings()
4952
tools = list_vector_search_tools(settings)
5053
assert len(tools) == 1
5154
tool = tools[0]
5255
assert isinstance(tool, VectorSearchTool)
53-
assert tool.tool_obj.index_name == "cat.sch.tbl1"
56+
assert tool.index_name == "cat.sch.tbl1"
57+
assert tool.columns == ["col1", "col2"] # filtered out "__db_content_vector"
5458

5559

5660
def test_internal_list_vector_search_tools_direct():
57-
with mock.patch(
58-
"databricks.labs.mcp.servers.unity_catalog.tools.vector_search.VectorSearchRetrieverTool"
59-
) as MockVectorSearchRetrieverTool:
60-
MockVectorSearchRetrieverTool.side_effect = lambda index_name: mock.Mock(
61-
tool={
62-
"function": {"name": index_name, "description": "", "parameters": {}}
63-
},
64-
index_name=index_name,
65-
)
66-
client = DummyWorkspaceClient()
67-
tools = _list_vector_search_tools(client, "cat", "sch")
68-
assert len(tools) == 1
69-
assert tools[0].tool_obj.index_name == "cat.sch.tbl1"
70-
71-
72-
def test_vector_search_tool_execute():
73-
tool_obj = mock.Mock()
74-
tool_obj.tool = {
75-
"function": {"name": "vs_tool", "description": "", "parameters": {}}
61+
client = DummyWorkspaceClient()
62+
tools = _list_vector_search_tools(client, "cat", "sch")
63+
assert len(tools) == 1
64+
assert isinstance(tools[0], VectorSearchTool)
65+
assert tools[0].index_name == "cat.sch.tbl1"
66+
assert tools[0].columns == ["col1", "col2"]
67+
68+
69+
@mock.patch("databricks.labs.mcp.servers.unity_catalog.tools.vector_search.VectorSearchClient")
70+
def test_vector_search_tool_execute(MockVectorSearchClient):
71+
mock_index = mock.Mock()
72+
mock_index.similarity_search.return_value = {
73+
"result": {"data_array": [{"id": 1, "score": 0.9}]}
7674
}
77-
tool_obj.execute.return_value = [{"foo": "bar"}]
78-
tool = VectorSearchTool(tool_obj)
79-
result = tool.execute(query="test")
75+
76+
# Make get_index return our mock_index
77+
MockVectorSearchClient.return_value.get_index.return_value = mock_index
78+
79+
tool = VectorSearchTool(
80+
endpoint_name="endpoint1",
81+
index_name="cat.sch.tbl1",
82+
tool_name="vector_search_test",
83+
columns=["col1", "col2"],
84+
)
85+
86+
result = tool.execute(query="test query")
87+
8088
assert isinstance(result, list)
81-
assert result[0].text == '{"foo": "bar"}'
89+
assert result[0].text.strip().startswith("[") # It should be JSON string
90+
assert "score" in result[0].text

0 commit comments

Comments
 (0)