|
14 | 14 |
|
15 | 15 | from unittest.mock import MagicMock
|
16 | 16 |
|
| 17 | +from google.adk.agents.invocation_context import InvocationContext |
| 18 | +from google.adk.sessions.session import Session |
17 | 19 | from google.adk.tools.function_tool import FunctionTool
|
| 20 | +from google.adk.tools.tool_context import ToolContext |
18 | 21 | import pytest
|
19 | 22 |
|
20 | 23 |
|
@@ -294,3 +297,51 @@ async def async_func_with_optional_args(
|
294 | 297 | args = {"arg1": "test_value_1", "arg3": "test_value_3"}
|
295 | 298 | result = await tool.run_async(args=args, tool_context=MagicMock())
|
296 | 299 | assert result == "test_value_1,test_value_3"
|
| 300 | + |
| 301 | + |
| 302 | +@pytest.mark.asyncio |
| 303 | +async def test_run_async_with_unexpected_argument(): |
| 304 | + """Test that run_async filters out unexpected arguments.""" |
| 305 | + |
| 306 | + def sample_func(expected_arg: str): |
| 307 | + return {"received_arg": expected_arg} |
| 308 | + |
| 309 | + tool = FunctionTool(sample_func) |
| 310 | + mock_invocation_context = MagicMock(spec=InvocationContext) |
| 311 | + mock_invocation_context.session = MagicMock(spec=Session) |
| 312 | + # Add the missing state attribute to the session mock |
| 313 | + mock_invocation_context.session.state = MagicMock() |
| 314 | + tool_context_mock = ToolContext(invocation_context=mock_invocation_context) |
| 315 | + |
| 316 | + result = await tool.run_async( |
| 317 | + args={"expected_arg": "hello", "parameters": "should_be_filtered"}, |
| 318 | + tool_context=tool_context_mock, |
| 319 | + ) |
| 320 | + assert result == {"received_arg": "hello"} |
| 321 | + |
| 322 | + |
| 323 | +@pytest.mark.asyncio |
| 324 | +async def test_run_async_with_tool_context_and_unexpected_argument(): |
| 325 | + """Test that run_async handles tool_context and filters out unexpected arguments.""" |
| 326 | + |
| 327 | + def sample_func_with_context(expected_arg: str, tool_context: ToolContext): |
| 328 | + return {"received_arg": expected_arg, "context_present": bool(tool_context)} |
| 329 | + |
| 330 | + tool = FunctionTool(sample_func_with_context) |
| 331 | + mock_invocation_context = MagicMock(spec=InvocationContext) |
| 332 | + mock_invocation_context.session = MagicMock(spec=Session) |
| 333 | + # Add the missing state attribute to the session mock |
| 334 | + mock_invocation_context.session.state = MagicMock() |
| 335 | + mock_tool_context = ToolContext(invocation_context=mock_invocation_context) |
| 336 | + |
| 337 | + result = await tool.run_async( |
| 338 | + args={ |
| 339 | + "expected_arg": "world", |
| 340 | + "parameters": "should_also_be_filtered", |
| 341 | + }, |
| 342 | + tool_context=mock_tool_context, |
| 343 | + ) |
| 344 | + assert result == { |
| 345 | + "received_arg": "world", |
| 346 | + "context_present": True, |
| 347 | + } |
0 commit comments