1
- import io
2
1
import json
3
- from contextlib import redirect_stdout
4
-
5
- from databricks_openai import VectorSearchRetrieverTool
2
+ from pydantic import BaseModel
6
3
from databricks .sdk import WorkspaceClient
4
+ from databricks .vector_search .client import VectorSearchClient
5
+ from databricks .labs .mcp .servers .unity_catalog .tools .base_tool import BaseTool
6
+ from databricks .labs .mcp .servers .unity_catalog .cli import CliSettings
7
+ from mcp .types import TextContent , Tool as ToolSpec
7
8
8
- from mcp .types import Tool as ToolSpec , TextContent
9
+ # Constant storing vector index content vector column name
10
+ CONTENT_VECTOR_COLUMN_NAME = "__db_content_vector"
9
11
10
- from databricks .labs .mcp .servers .unity_catalog .tools .base_tool import BaseTool
12
+
13
+ class QueryInput (BaseModel ):
14
+ query : str
11
15
12
16
13
17
class VectorSearchTool (BaseTool ):
14
- def __init__ (self , tool_obj : VectorSearchRetrieverTool ):
15
- self .tool_obj = tool_obj
16
- tool_info = tool_obj .tool ["function" ]
17
- llm_friendly_tool_name = tool_info ["name" ]
18
+ def __init__ (self , endpoint_name : str , index_name : str , tool_name : str , columns : list [str ], num_results : int = 5 ):
19
+ self .endpoint_name = endpoint_name
20
+ self .index_name = index_name
21
+ self .tool_name = tool_name
22
+ self .columns = columns
23
+ self .num_results = num_results
24
+
18
25
tool_spec = ToolSpec (
19
- name = llm_friendly_tool_name ,
20
- description = tool_info [ "description" ] ,
21
- inputSchema = tool_info [ "parameters" ] ,
26
+ name = tool_name ,
27
+ description = f"Searches the vector index ` { index_name } `." ,
28
+ inputSchema = QueryInput . model_json_schema () ,
22
29
)
23
- super ().__init__ (tool_spec = tool_spec )
30
+ super ().__init__ (tool_spec )
24
31
25
32
def execute (self , ** kwargs ):
26
- """
27
- Executes the vector search tool with the provided arguments.
28
- """
29
- # Create a buffer to capture stdout from vector search client
30
- # print statements
31
- f = io .StringIO ()
32
- with redirect_stdout (f ):
33
- res = self .tool_obj .execute (** kwargs )
34
- return [
35
- TextContent (
36
- type = "text" ,
37
- text = json .dumps (vs_res ),
38
- )
39
- for vs_res in res
40
- ]
33
+ model = QueryInput .model_validate (kwargs )
34
+ vsc = VectorSearchClient (disable_notice = True )
35
+
36
+ index = vsc .get_index (index_name = self .index_name )
37
+
38
+ results = index .similarity_search (
39
+ query_text = model .query ,
40
+ columns = self .columns ,
41
+ num_results = self .num_results ,
42
+ )
43
+
44
+ docs = results .get ("result" , {}).get ("data_array" , [])
45
+
46
+ return [TextContent (type = "text" , text = json .dumps (docs , indent = 2 ))]
47
+
48
+
49
+ def get_table_columns (workspace_client : WorkspaceClient , full_table_name : str ) -> list [str ]:
50
+ table_info = workspace_client .tables .get (full_table_name )
51
+ return [
52
+ col .name
53
+ for col in table_info .columns
54
+ if col .name != CONTENT_VECTOR_COLUMN_NAME
55
+ ]
41
56
42
57
43
58
def _list_vector_search_tools (
44
- workspace_client : WorkspaceClient , catalog_name : str , schema_name : str
59
+ workspace_client : WorkspaceClient , catalog_name : str , schema_name : str , vector_search_num_results : int
45
60
) -> list [VectorSearchTool ]:
46
61
tools = []
47
62
for table in workspace_client .tables .list (
48
63
catalog_name = catalog_name , schema_name = schema_name
49
64
):
50
- # TODO: support filtering tables by securable kind (e.g. by making securable
51
- # kind accessible here)
52
65
if not table .properties or "model_endpoint_url" not in table .properties :
53
66
continue
54
- tool_obj = VectorSearchRetrieverTool (index_name = table .full_name )
55
- tools .append (VectorSearchTool (tool_obj ))
67
+
68
+ endpoint = table .properties ["model_endpoint_url" ]
69
+ index_name = table .full_name
70
+ tool_name = f"vector_search_{ table .name } "
71
+
72
+ columns = get_table_columns (workspace_client , index_name )
73
+
74
+ tools .append (VectorSearchTool (endpoint , index_name , tool_name , columns , vector_search_num_results ))
75
+
56
76
return tools
57
77
58
78
59
- def list_vector_search_tools (settings ) -> list [VectorSearchTool ]:
79
+ def list_vector_search_tools (settings : CliSettings ) -> list [VectorSearchTool ]:
60
80
workspace_client = WorkspaceClient ()
61
81
catalog_name , schema_name = settings .schema_full_name .split ("." )
62
- return _list_vector_search_tools (workspace_client , catalog_name , schema_name )
82
+ return _list_vector_search_tools (workspace_client , catalog_name , schema_name , settings . vector_search_num_results )
0 commit comments