Skip to content

Commit ce50e5f

Browse files
committed
Add "required" parameter to ToolParameter.model_dump_tool
1 parent 61a0e46 commit ce50e5f

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

src/neo4j_graphrag/tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ class ToolParameter(BaseModel):
2424

2525
def model_dump_tool(self) -> Dict[str, Any]:
2626
"""Convert the parameter to a dictionary format for tool usage."""
27-
result = {"type": self.type, "description": self.description}
27+
result: Dict[str, Any] = {"type": self.type, "description": self.description}
28+
if self.required:
29+
result["required"] = True
2830
return result
2931

3032
@classmethod

tests/unit/tool/test_tool.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def test_string_parameter() -> None:
2121
d = param.model_dump_tool()
2222
assert d["type"] == ParameterType.STRING
2323
assert d["enum"] == ["a", "b"]
24+
assert d["required"] is True
2425

2526

2627
def test_integer_parameter() -> None:
@@ -139,6 +140,39 @@ def test_from_dict() -> None:
139140
ToolParameter.from_dict({"description": "no type"})
140141

141142

143+
def test_required_parameter() -> None:
144+
# Test that required=True is included in model_dump_tool output for different parameter types
145+
string_param = StringParameter(description="Required string", required=True)
146+
assert string_param.model_dump_tool()["required"] is True
147+
148+
integer_param = IntegerParameter(description="Required integer", required=True)
149+
assert integer_param.model_dump_tool()["required"] is True
150+
151+
number_param = NumberParameter(description="Required number", required=True)
152+
assert number_param.model_dump_tool()["required"] is True
153+
154+
boolean_param = BooleanParameter(description="Required boolean", required=True)
155+
assert boolean_param.model_dump_tool()["required"] is True
156+
157+
array_param = ArrayParameter(
158+
description="Required array",
159+
items=StringParameter(description="item"),
160+
required=True,
161+
)
162+
assert array_param.model_dump_tool()["required"] is True
163+
164+
object_param = ObjectParameter(
165+
description="Required object",
166+
properties={"prop": StringParameter(description="property")},
167+
required=True,
168+
)
169+
assert object_param.model_dump_tool()["required"] is True
170+
171+
# Test that required=False doesn't include the required field
172+
optional_param = StringParameter(description="Optional string", required=False)
173+
assert "required" not in optional_param.model_dump_tool()
174+
175+
142176
def test_tool_class() -> None:
143177
def dummy_func(query: str, **kwargs: Any) -> dict[str, Any]:
144178
return kwargs

0 commit comments

Comments
 (0)