Skip to content

Commit 872ef17

Browse files
✏️ fix linter issues
✏️ fix linter issues
2 parents 25a1a33 + d7ecda9 commit 872ef17

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

src/databricks/labs/mcp/servers/unity_catalog/cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ class CliSettings(BaseSettings):
3535
vector_search_num_results: int = Field(
3636
default=5,
3737
description="Number of results to return from vector search queries",
38-
validation_alias=AliasChoices("vn", "vector_search_num_results", "vector_num_results"),
38+
validation_alias=AliasChoices(
39+
"vn", "vector_search_num_results", "vector_num_results"
40+
),
3941
)
4042

4143
def get_catalog_name(self):

src/databricks/labs/mcp/servers/unity_catalog/tools/vector_search.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ class QueryInput(BaseModel):
1515

1616

1717
class VectorSearchTool(BaseTool):
18-
def __init__(self, endpoint_name: str, index_name: str, tool_name: str, columns: list[str], num_results: int = 5):
18+
def __init__(
19+
self,
20+
endpoint_name: str,
21+
index_name: str,
22+
tool_name: str,
23+
columns: list[str],
24+
num_results: int = 5,
25+
):
1926
self.endpoint_name = endpoint_name
2027
self.index_name = index_name
2128
self.tool_name = tool_name
@@ -46,17 +53,20 @@ def execute(self, **kwargs):
4653
return [TextContent(type="text", text=json.dumps(docs, indent=2))]
4754

4855

49-
def get_table_columns(workspace_client: WorkspaceClient, full_table_name: str) -> list[str]:
56+
def get_table_columns(
57+
workspace_client: WorkspaceClient, full_table_name: str
58+
) -> list[str]:
5059
table_info = workspace_client.tables.get(full_table_name)
5160
return [
52-
col.name
53-
for col in table_info.columns
54-
if col.name != CONTENT_VECTOR_COLUMN_NAME
61+
col.name for col in table_info.columns if col.name != CONTENT_VECTOR_COLUMN_NAME
5562
]
5663

5764

5865
def _list_vector_search_tools(
59-
workspace_client: WorkspaceClient, catalog_name: str, schema_name: str, vector_search_num_results: int
66+
workspace_client: WorkspaceClient,
67+
catalog_name: str,
68+
schema_name: str,
69+
vector_search_num_results: int,
6070
) -> list[VectorSearchTool]:
6171
tools = []
6272
for table in workspace_client.tables.list(
@@ -71,12 +81,18 @@ def _list_vector_search_tools(
7181

7282
columns = get_table_columns(workspace_client, index_name)
7383

74-
tools.append(VectorSearchTool(endpoint, index_name, tool_name, columns, vector_search_num_results))
84+
tools.append(
85+
VectorSearchTool(
86+
endpoint, index_name, tool_name, columns, vector_search_num_results
87+
)
88+
)
7589

7690
return tools
7791

7892

7993
def list_vector_search_tools(settings: CliSettings) -> list[VectorSearchTool]:
8094
workspace_client = WorkspaceClient()
8195
catalog_name, schema_name = settings.schema_full_name.split(".")
82-
return _list_vector_search_tools(workspace_client, catalog_name, schema_name, settings.vector_search_num_results)
96+
return _list_vector_search_tools(
97+
workspace_client, catalog_name, schema_name, settings.vector_search_num_results
98+
)

tests/test_vector_search.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ def __init__(self, name):
2929
self.name = name
3030

3131
class DummyTableInfo:
32-
columns = [DummyColumn("col1"), DummyColumn("col2"), DummyColumn("__db_content_vector")]
32+
columns = [
33+
DummyColumn("col1"),
34+
DummyColumn("col2"),
35+
DummyColumn("__db_content_vector"),
36+
]
3337

3438
return DummyTableInfo()
3539

@@ -41,6 +45,7 @@ def __init__(self):
4145

4246
class DummySettings:
4347
schema_full_name = "cat.sch"
48+
vector_search_num_results = 5
4449

4550

4651
@mock.patch(
@@ -59,14 +64,16 @@ def test_list_vector_search_tools_filters_and_returns_expected():
5964

6065
def test_internal_list_vector_search_tools_direct():
6166
client = DummyWorkspaceClient()
62-
tools = _list_vector_search_tools(client, "cat", "sch")
67+
tools = _list_vector_search_tools(client, "cat", "sch", vector_search_num_results=5)
6368
assert len(tools) == 1
6469
assert isinstance(tools[0], VectorSearchTool)
6570
assert tools[0].index_name == "cat.sch.tbl1"
6671
assert tools[0].columns == ["col1", "col2"]
6772

6873

69-
@mock.patch("databricks.labs.mcp.servers.unity_catalog.tools.vector_search.VectorSearchClient")
74+
@mock.patch(
75+
"databricks.labs.mcp.servers.unity_catalog.tools.vector_search.VectorSearchClient"
76+
)
7077
def test_vector_search_tool_execute(MockVectorSearchClient):
7178
mock_index = mock.Mock()
7279
mock_index.similarity_search.return_value = {
@@ -87,4 +94,4 @@ def test_vector_search_tool_execute(MockVectorSearchClient):
8794

8895
assert isinstance(result, list)
8996
assert result[0].text.strip().startswith("[") # It should be JSON string
90-
assert "score" in result[0].text
97+
assert "score" in result[0].text

0 commit comments

Comments
 (0)