Skip to content

Commit 9bd539e

Browse files
Merge pull request #1656 from google:fix-transfer-to-agent-parameters-issue
PiperOrigin-RevId: 778334126
2 parents 3d2f13c + 0959b06 commit 9bd539e

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

src/google/adk/tools/function_tool.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import inspect
1618
from typing import Any
1719
from typing import Callable
@@ -79,9 +81,13 @@ async def run_async(
7981
) -> Any:
8082
args_to_call = args.copy()
8183
signature = inspect.signature(self.func)
82-
if 'tool_context' in signature.parameters:
84+
valid_params = {param for param in signature.parameters}
85+
if 'tool_context' in valid_params:
8386
args_to_call['tool_context'] = tool_context
8487

88+
# Filter args_to_call to only include valid parameters for the function
89+
args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params}
90+
8591
# Before invoking the function, we check for if the list of args passed in
8692
# has all the mandatory arguments or not.
8793
# If the check fails, then we don't invoke the tool and let the Agent know

tests/unittests/tools/test_function_tool.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
from unittest.mock import MagicMock
1616

17+
from google.adk.agents.invocation_context import InvocationContext
18+
from google.adk.sessions.session import Session
1719
from google.adk.tools.function_tool import FunctionTool
20+
from google.adk.tools.tool_context import ToolContext
1821
import pytest
1922

2023

@@ -294,3 +297,51 @@ async def async_func_with_optional_args(
294297
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
295298
result = await tool.run_async(args=args, tool_context=MagicMock())
296299
assert result == "test_value_1,test_value_3"
300+
301+
302+
@pytest.mark.asyncio
303+
async def test_run_async_with_unexpected_argument():
304+
"""Test that run_async filters out unexpected arguments."""
305+
306+
def sample_func(expected_arg: str):
307+
return {"received_arg": expected_arg}
308+
309+
tool = FunctionTool(sample_func)
310+
mock_invocation_context = MagicMock(spec=InvocationContext)
311+
mock_invocation_context.session = MagicMock(spec=Session)
312+
# Add the missing state attribute to the session mock
313+
mock_invocation_context.session.state = MagicMock()
314+
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)
315+
316+
result = await tool.run_async(
317+
args={"expected_arg": "hello", "parameters": "should_be_filtered"},
318+
tool_context=tool_context_mock,
319+
)
320+
assert result == {"received_arg": "hello"}
321+
322+
323+
@pytest.mark.asyncio
324+
async def test_run_async_with_tool_context_and_unexpected_argument():
325+
"""Test that run_async handles tool_context and filters out unexpected arguments."""
326+
327+
def sample_func_with_context(expected_arg: str, tool_context: ToolContext):
328+
return {"received_arg": expected_arg, "context_present": bool(tool_context)}
329+
330+
tool = FunctionTool(sample_func_with_context)
331+
mock_invocation_context = MagicMock(spec=InvocationContext)
332+
mock_invocation_context.session = MagicMock(spec=Session)
333+
# Add the missing state attribute to the session mock
334+
mock_invocation_context.session.state = MagicMock()
335+
mock_tool_context = ToolContext(invocation_context=mock_invocation_context)
336+
337+
result = await tool.run_async(
338+
args={
339+
"expected_arg": "world",
340+
"parameters": "should_also_be_filtered",
341+
},
342+
tool_context=mock_tool_context,
343+
)
344+
assert result == {
345+
"received_arg": "world",
346+
"context_present": True,
347+
}

0 commit comments

Comments
 (0)