Skip to content

Commit 9d8bc72

Browse files
committed
Fix merge and tests
1 parent 13a3cc9 commit 9d8bc72

File tree

2 files changed

+42
-33
lines changed

2 files changed

+42
-33
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Any
2+
3+
# from neo4j_graphrag.retrievers.base import Retriever
4+
from neo4j_graphrag.types import RawSearchResult, RetrieverResult, RetrieverResultItem
5+
6+
7+
class MCPServerInterface:
8+
def __init__(self, *args, **kwargs):
9+
pass
10+
11+
def get_tools(self):
12+
return []
13+
14+
def execute_tool(self, tool) -> Any:
15+
return ""
16+
17+
18+
class MCPRetriever:
19+
def __init__(self, server: MCPServerInterface) -> None:
20+
super().__init__()
21+
self.server = server
22+
self.tools = server.get_tools()
23+
24+
def search(self, query_text: str) -> RetrieverResult:
25+
"""Reimplement the search method because we can't inherit from
26+
the Retriever interface (no need for neo4j.driver here).
27+
28+
1. Call llm with a list of tools
29+
2. Call MCP server for specific tool and LLM-generated arguments
30+
3. Return all results as context in RetrieverResult
31+
"""
32+
raw_result = RawSearchResult(records=[])
33+
search_items = [RetrieverResultItem(content=str(record)) for record in raw_result.records]
34+
metadata = raw_result.metadata or {}
35+
metadata["__retriever"] = self.__class__.__name__
36+
metadata["__tool_results"] = {}
37+
return RetrieverResult(
38+
items=search_items,
39+
metadata=metadata,
40+
)

src/neo4j_graphrag/tool.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def model_dump_tool(self) -> Dict[str, Any]:
180180
if self.required_properties:
181181
result["required"] = self.required_properties
182182

183-
# if not self.additional_properties:
184-
# result["additionalProperties"] = False
183+
if not self.additional_properties:
184+
result["additionalProperties"] = False
185185

186186
return result
187187

@@ -204,37 +204,6 @@ def validate_properties(self) -> "ObjectParameter":
204204
self.properties = validated_properties
205205
return self
206206

207-
class ArrayParameter(ToolParameter):
208-
"""Array parameter for tools."""
209-
210-
def __init__(
211-
self,
212-
description: str,
213-
items: ToolParameter,
214-
required: bool = False,
215-
min_items: Optional[int] = None,
216-
max_items: Optional[int] = None,
217-
):
218-
super().__init__(description, required)
219-
self.items = items
220-
self.min_items = min_items
221-
self.max_items = max_items
222-
223-
def to_dict(self) -> Dict[str, Any]:
224-
result: Dict[str, Any] = {
225-
"type": ParameterType.ARRAY,
226-
"description": self.description,
227-
"items": self.items.to_dict(),
228-
}
229-
230-
if self.min_items is not None:
231-
result["minItems"] = self.min_items
232-
233-
if self.max_items is not None:
234-
result["maxItems"] = self.max_items
235-
236-
return result
237-
238207

239208
class Tool(ABC):
240209
"""Abstract base class defining the interface for all tools in the neo4j-graphrag library."""

0 commit comments

Comments
 (0)