From 5ee6f4c4e58334fba6ccef9282d6efdaa607a89d Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 15 Jul 2025 23:28:02 -0700 Subject: [PATCH] chore: Support computer use models PiperOrigin-RevId: 783621903 --- src/google/adk/models/google_llm.py | 15 +- src/google/adk/tools/computer_use/computer.py | 261 +++++++++++ .../tools/computer_use/computer_use_tool.py | 177 ++++++++ .../computer_use/computer_use_toolset.py | 89 ++++ tests/unittests/models/test_google_llm.py | 220 +++------- .../unittests/tools/computer_use/__init__.py | 13 + .../unittests/tools/computer_use/conftest.py | 37 ++ .../tools/computer_use/test_computer.py | 321 ++++++++++++++ .../computer_use/test_computer_use_tool.py | 404 ++++++++++++++++++ .../computer_use/test_computer_use_toolset.py | 342 +++++++++++++++ 10 files changed, 1703 insertions(+), 176 deletions(-) create mode 100644 src/google/adk/tools/computer_use/computer.py create mode 100644 src/google/adk/tools/computer_use/computer_use_tool.py create mode 100644 src/google/adk/tools/computer_use/computer_use_toolset.py create mode 100644 tests/unittests/tools/computer_use/__init__.py create mode 100644 tests/unittests/tools/computer_use/conftest.py create mode 100644 tests/unittests/tools/computer_use/test_computer.py create mode 100644 tests/unittests/tools/computer_use/test_computer_use_tool.py create mode 100644 tests/unittests/tools/computer_use/test_computer_use_toolset.py diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 48547c7a6..15c99ef1c 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -27,7 +27,6 @@ from google.genai import Client from google.genai import types -from google.genai.types import FinishReason from typing_extensions import override from .. import version @@ -72,6 +71,8 @@ def supported_models() -> list[str]: r'projects\/.+\/locations\/.+\/endpoints\/.+', # vertex gemini long name r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+', + 'computer-use-exp', + r'projects\/.+\/locations\/.+\/publishers\/google\/models\/computer-use-exp', ] async def generate_content_async( @@ -162,12 +163,8 @@ async def generate_content_async( parts.append(types.Part.from_text(text=text)) yield LlmResponse( content=types.ModelContent(parts=parts), - error_code=None - if response.candidates[0].finish_reason == FinishReason.STOP - else response.candidates[0].finish_reason, - error_message=None - if response.candidates[0].finish_reason == FinishReason.STOP - else response.candidates[0].finish_message, + error_code=response.candidates[0].finish_reason, + error_message=response.candidates[0].finish_message, usage_metadata=usage_metadata, ) @@ -282,6 +279,10 @@ def _preprocess_request(self, llm_request: LlmRequest) -> None: _remove_display_name_if_present(part.inline_data) _remove_display_name_if_present(part.file_data) + # computer use model doesn't support system instruction + if llm_request.model.endswith('computer-use-exp'): + llm_request.config.system_instruction = None + def _build_function_declaration_log( func_decl: types.FunctionDeclaration, diff --git a/src/google/adk/tools/computer_use/computer.py b/src/google/adk/tools/computer_use/computer.py new file mode 100644 index 000000000..aab3f59d8 --- /dev/null +++ b/src/google/adk/tools/computer_use/computer.py @@ -0,0 +1,261 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +from typing import Literal + +from google.genai.types import Environment +import pydantic + +from ...utils.feature_decorator import experimental + + +@experimental +class EnvironmentState(pydantic.BaseModel): + """Represents the current state of the computer environment. + + Attributes: + screenshot: The screenshot in PNG format as bytes. + url: The current URL of the webpage being displayed. + """ + + screenshot: bytes = pydantic.Field( + ..., description="Screenshot in PNG format" + ) + url: str = pydantic.Field(..., description="Current webpage URL") + + @pydantic.field_validator("url") + @classmethod + def validate_url(cls, v: str) -> str: + """Validate that URL is not empty.""" + if not v.strip(): + raise ValueError("URL cannot be empty") + return v + + +@experimental +class Computer(abc.ABC): + """async defines an interface for computer environments. + + This abstract base class async defines the standard interface for controlling + computer environments, including web browsers and other interactive systems. + """ + + @abc.abstractmethod + async def screen_size(self) -> tuple[int, int]: + """Returns the screen size of the environment. + + Returns: + A tuple of (width, height) in pixels. + """ + + @abc.abstractmethod + async def open_web_browser(self) -> EnvironmentState: + """Opens the web browser. + + Returns: + The current state after opening the browser. + """ + + @abc.abstractmethod + async def click_at(self, x: int, y: int) -> EnvironmentState: + """Clicks at a specific x, y coordinate on the webpage. + + The 'x' and 'y' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to click at. + y: The y-coordinate to click at. + + Returns: + The current state after clicking. + """ + + @abc.abstractmethod + async def hover_at(self, x: int, y: int) -> EnvironmentState: + """Hovers at a specific x, y coordinate on the webpage. + + May be used to explore sub-menus that appear on hover. + The 'x' and 'y' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to hover at. + y: The y-coordinate to hover at. + + Returns: + The current state after hovering. + """ + + @abc.abstractmethod + async def type_text_at( + self, + x: int, + y: int, + text: str, + press_enter: bool = True, + clear_before_typing: bool = True, + ) -> EnvironmentState: + """Types text at a specific x, y coordinate. + + The system automatically presses ENTER after typing. To disable this, set `press_enter` to False. + The system automatically clears any existing content before typing the specified `text`. To disable this, set `clear_before_typing` to False. + The 'x' and 'y' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to type at. + y: The y-coordinate to type at. + text: The text to type. + press_enter: Whether to press ENTER after typing. + clear_before_typing: Whether to clear existing content before typing. + + Returns: + The current state after typing. + """ + + @abc.abstractmethod + async def scroll_document( + self, direction: Literal["up", "down", "left", "right"] + ) -> EnvironmentState: + """Scrolls the entire webpage "up", "down", "left" or "right" based on direction. + + Args: + direction: The direction to scroll. + + Returns: + The current state after scrolling. + """ + + @abc.abstractmethod + async def scroll_at( + self, + x: int, + y: int, + direction: Literal["up", "down", "left", "right"], + magnitude: int, + ) -> EnvironmentState: + """Scrolls up, down, right, or left at a x, y coordinate by magnitude. + + The 'x' and 'y' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to scroll at. + y: The y-coordinate to scroll at. + direction: The direction to scroll. + magnitude: The amount to scroll. + + Returns: + The current state after scrolling. + """ + + @abc.abstractmethod + async def wait_5_seconds(self) -> EnvironmentState: + """Waits for 5 seconds to allow unfinished webpage processes to complete. + + Returns: + The current state after waiting. + """ + + @abc.abstractmethod + async def go_back(self) -> EnvironmentState: + """Navigates back to the previous webpage in the browser history. + + Returns: + The current state after navigating back. + """ + + @abc.abstractmethod + async def go_forward(self) -> EnvironmentState: + """Navigates forward to the next webpage in the browser history. + + Returns: + The current state after navigating forward. + """ + + @abc.abstractmethod + async def search(self) -> EnvironmentState: + """Directly jumps to a search engine home page. + + Used when you need to start with a search. For example, this is used when + the current website doesn't have the information needed or because a new + task is being started. + + Returns: + The current state after navigating to search. + """ + + @abc.abstractmethod + async def navigate(self, url: str) -> EnvironmentState: + """Navigates directly to a specified URL. + + Args: + url: The URL to navigate to. + + Returns: + The current state after navigation. + """ + + @abc.abstractmethod + async def key_combination(self, keys: list[str]) -> EnvironmentState: + """Presses keyboard keys and combinations, such as "control+c" or "enter". + + Args: + keys: List of keys to press in combination. + + Returns: + The current state after key press. + """ + + @abc.abstractmethod + async def drag_and_drop( + self, x: int, y: int, destination_x: int, destination_y: int + ) -> EnvironmentState: + """Drag and drop an element from a x, y coordinate to a destination destination_y, destination_x coordinate. + + The 'x', 'y', 'destination_y' and 'destination_x' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to start dragging from. + y: The y-coordinate to start dragging from. + destination_x: The x-coordinate to drop at. + destination_y: The y-coordinate to drop at. + + Returns: + The current state after drag and drop. + """ + + @abc.abstractmethod + async def current_state(self) -> EnvironmentState: + """Returns the current state of the current webpage. + + Returns: + The current environment state. + """ + + async def initialize(self) -> None: + """Initialize the computer.""" + pass + + async def close(self) -> None: + """Cleanup resource of the computer.""" + pass + + async def environment(self) -> Environment: + """Returns the environment of the computer. + + Returns: + The environment type, async defaults to ENVIRONMENT_BROWSER. + """ + return Environment.ENVIRONMENT_BROWSER diff --git a/src/google/adk/tools/computer_use/computer_use_tool.py b/src/google/adk/tools/computer_use/computer_use_tool.py new file mode 100644 index 000000000..8511698cc --- /dev/null +++ b/src/google/adk/tools/computer_use/computer_use_tool.py @@ -0,0 +1,177 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import base64 +import logging +from typing import Any +from typing import Callable + +from google.genai import types +from typing_extensions import override + +from ...models.llm_request import LlmRequest +from ...utils.feature_decorator import experimental +from ..function_tool import FunctionTool +from ..tool_context import ToolContext +from .computer import EnvironmentState + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class ComputerUseTool(FunctionTool): + """A tool that wraps computer control functions for use with LLMs. + + This tool automatically normalizes coordinates from a 1000x1000 coordinate system + to the actual screen size. + """ + + def __init__( + self, + *, + func: Callable[..., Any], + screen_size: tuple[int, int], + environment: types.Environment = types.Environment.ENVIRONMENT_BROWSER, + ): + """Initialize the ComputerUseTool. + + Args: + func: The computer control function to wrap. + screen_size: The actual screen size as (width, height). + environment: The environment type for the tool. + """ + super().__init__(func=func) + self._screen_size = screen_size + self._environment = environment + + # Validate screen size + if not isinstance(screen_size, tuple) or len(screen_size) != 2: + raise ValueError("screen_size must be a tuple of (width, height)") + if screen_size[0] <= 0 or screen_size[1] <= 0: + raise ValueError("screen_size dimensions must be positive") + + def _normalize_x(self, x: int) -> int: + """Normalize x coordinate from 1000-based to actual screen width.""" + if not isinstance(x, (int, float)): + raise ValueError(f"x coordinate must be numeric, got {type(x)}") + + normalized = int(x / 1000 * self._screen_size[0]) + # Clamp to screen bounds + return max(0, min(normalized, self._screen_size[0] - 1)) + + def _normalize_y(self, y: int) -> int: + """Normalize y coordinate from 1000-based to actual screen height.""" + if not isinstance(y, (int, float)): + raise ValueError(f"y coordinate must be numeric, got {type(y)}") + + normalized = int(y / 1000 * self._screen_size[1]) + # Clamp to screen bounds + return max(0, min(normalized, self._screen_size[1] - 1)) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Run the computer control function with normalized coordinates.""" + + try: + # Normalize coordinates if present + if "x" in args: + original_x = args["x"] + args["x"] = self._normalize_x(args["x"]) + logger.debug("Normalized x: %s -> %s", original_x, args["x"]) + + if "y" in args: + original_y = args["y"] + args["y"] = self._normalize_y(args["y"]) + logger.debug("Normalized y: %s -> %s", original_y, args["y"]) + + # Handle destination coordinates for drag and drop + if "destination_x" in args: + original_dest_x = args["destination_x"] + args["destination_x"] = self._normalize_x(args["destination_x"]) + logger.debug( + "Normalized destination_x: %s -> %s", + original_dest_x, + args["destination_x"], + ) + + if "destination_y" in args: + original_dest_y = args["destination_y"] + args["destination_y"] = self._normalize_y(args["destination_y"]) + logger.debug( + "Normalized destination_y: %s -> %s", + original_dest_y, + args["destination_y"], + ) + + # Execute the actual computer control function + result = await super().run_async(args=args, tool_context=tool_context) + + # Process the result if it's an EnvironmentState + if isinstance(result, EnvironmentState): + return { + "image": { + "mimetype": "image/png", + "data": base64.b64encode(result.screenshot).decode("utf-8"), + }, + "url": result.url, + } + + return result + + except Exception as e: + logger.error("Error in ComputerUseTool.run_async: %s", e) + raise + + @override + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + """Add computer use configuration to the LLM request.""" + + try: + + # Add this tool to the tools dictionary + llm_request.tools_dict[self.name] = self + + # Initialize config if needed + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + # Check if computer use is already configured + for tool in llm_request.config.tools: + if ( + isinstance(tool, (types.Tool, types.ToolDict)) + and hasattr(tool, "computer_use") + and tool.computer_use + ): + logger.debug("Computer use already configured in LLM request") + return + + # Add computer use tool configuration + llm_request.config.tools.append( + types.Tool( + computer_use=types.ToolComputerUse(environment=self._environment) + ) + ) + logger.debug( + "Added computer use tool with environment: %s", self._environment + ) + + except Exception as e: + logger.error("Error in ComputerUseTool.process_llm_request: %s", e) + raise diff --git a/src/google/adk/tools/computer_use/computer_use_toolset.py b/src/google/adk/tools/computer_use/computer_use_toolset.py new file mode 100644 index 000000000..916efa3d5 --- /dev/null +++ b/src/google/adk/tools/computer_use/computer_use_toolset.py @@ -0,0 +1,89 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from typing_extensions import override + +from ...agents.readonly_context import ReadonlyContext +from ...utils.feature_decorator import experimental +from ..base_toolset import BaseToolset +from .computer import Computer +from .computer_use_tool import ComputerUseTool + + +@experimental +class ComputerUseToolset(BaseToolset): + + def __init__( + self, + *, + computer: Computer, + ): + super().__init__() + self._computer = computer + self._initialized = False + + async def _ensure_initialized(self) -> None: + if not self._initialized: + await self._computer.initialize() + self._initialized = True + + @override + async def get_tools( + self, + readonly_context: Optional[ReadonlyContext] = None, + ) -> list[ComputerUseTool]: + await self._ensure_initialized() + # Get screen size and environment for tool configuration + screen_size = await self._computer.screen_size() + environment = await self._computer.environment() + + # Get all methods defined in Computer abstract base class, excluding specified methods + excluded_methods = {'screen_size', 'environment', 'close'} + computer_methods = [] + + # Get all methods defined in the Computer ABC interface + for method_name in dir(Computer): + # Skip private methods (starting with underscore) + if method_name.startswith('_'): + continue + + # Skip excluded methods + if method_name in excluded_methods: + continue + + # Check if it's a method defined in Computer class + attr = getattr(Computer, method_name, None) + if attr is not None and callable(attr): + # Get the corresponding method from the concrete instance + instance_method = getattr(self._computer, method_name) + computer_methods.append(instance_method) + + # Create ComputerUseTool instances for each method + + return [ + ComputerUseTool( + func=method, + screen_size=screen_size, + environment=environment, + ) + for method in computer_methods + ] + + @override + async def close(self) -> None: + await self._computer.close() diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index fa01daddd..f2b723c9c 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -28,7 +28,6 @@ from google.genai import types from google.genai import version as genai_version from google.genai.types import Content -from google.genai.types import FinishReason from google.genai.types import Part import pytest @@ -75,13 +74,18 @@ def mock_os_environ(): def test_supported_models(): models = Gemini.supported_models() - assert len(models) == 3 + assert len(models) == 5 assert models[0] == r"gemini-.*" assert models[1] == r"projects\/.+\/locations\/.+\/endpoints\/.+" assert ( models[2] == r"projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+" ) + assert models[3] == "computer-use-exp" + assert ( + models[4] + == r"projects\/.+\/locations\/.+\/publishers\/google\/models\/computer-use-exp" + ) def test_client_version_header(): @@ -672,6 +676,48 @@ def test_preprocess_request_handles_backend_specific_fields( assert llm_request_with_files.config.labels == expected_labels +@pytest.mark.parametrize( + "model_name, expected_system_instruction", + [ + ("gemini-1.5-flash", "You are a helpful assistant"), + ("computer-use-exp", None), + ( + "projects/my-project/locations/us-central1/publishers/google/models/computer-use-exp", + None, + ), + ( + "projects/my-project/locations/us-central1/publishers/google/models/gemini-1.5-flash", + "You are a helpful assistant", + ), + ], +) +def test_preprocess_request_computer_use_model_removes_system_instruction( + gemini_llm: Gemini, + model_name: str, + expected_system_instruction: str, +): + """ + Tests that _preprocess_request correctly removes system instruction for computer-use models. + + - For computer-use-exp models (both short and long form), system instruction should be removed + - For other models, system instruction should be preserved + """ + # Arrange: Create a request with system instruction + llm_request = LlmRequest( + model=model_name, + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant", + ), + ) + + # Act: Run the preprocessing method + gemini_llm._preprocess_request(llm_request) + + # Assert: Check if system instruction was correctly processed + assert llm_request.config.system_instruction == expected_system_instruction + + @pytest.mark.asyncio async def test_generate_content_async_stream_aggregated_content_regardless_of_finish_reason(): """Test that aggregated content is generated regardless of finish_reason.""" @@ -755,7 +801,6 @@ async def mock_coro(): # Final response should have aggregated content with error info final_response = responses[2] assert final_response.content.parts[0].text == "Hello world" - # After the code changes, error_code and error_message are set for non-STOP finish reasons assert final_response.error_code == finish_reason assert final_response.error_message == f"Finished with {finish_reason}" @@ -837,163 +882,6 @@ async def mock_coro(): assert final_response.content.parts[0].text == "Think1" assert final_response.content.parts[0].thought is True assert final_response.content.parts[1].text == "Answer" - # After the code changes, error_code and error_message are set for non-STOP finish reasons - assert final_response.error_code == types.FinishReason.MAX_TOKENS - assert final_response.error_message == "Maximum tokens reached" - - -@pytest.mark.asyncio -async def test_generate_content_async_stream_error_info_none_for_stop_finish_reason(): - """Test that error_code and error_message are None when finish_reason is STOP.""" - gemini_llm = Gemini(model="gemini-1.5-flash") - llm_request = LlmRequest( - model="gemini-1.5-flash", - contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], - config=types.GenerateContentConfig( - temperature=0.1, - response_modalities=[types.Modality.TEXT], - system_instruction="You are a helpful assistant", - ), - ) - - with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - - mock_responses = [ - types.GenerateContentResponse( - candidates=[ - types.Candidate( - content=Content( - role="model", parts=[Part.from_text(text="Hello")] - ), - finish_reason=None, - ) - ] - ), - types.GenerateContentResponse( - candidates=[ - types.Candidate( - content=Content( - role="model", parts=[Part.from_text(text=" world")] - ), - finish_reason=types.FinishReason.STOP, - finish_message="Successfully completed", - ) - ] - ), - ] - - async def mock_coro(): - return MockAsyncIterator(mock_responses) - - mock_client.aio.models.generate_content_stream.return_value = mock_coro() - - responses = [ - resp - async for resp in gemini_llm.generate_content_async( - llm_request, stream=True - ) - ] - - # Should have 3 responses: 2 partial and 1 final aggregated - assert len(responses) == 3 - assert responses[0].partial is True - assert responses[1].partial is True - - # Final response should have aggregated content with error info None for STOP finish reason - final_response = responses[2] - assert final_response.content.parts[0].text == "Hello world" - assert final_response.error_code is None - assert final_response.error_message is None - - -@pytest.mark.asyncio -async def test_generate_content_async_stream_error_info_set_for_non_stop_finish_reason(): - """Test that error_code and error_message are set for non-STOP finish reasons.""" - gemini_llm = Gemini(model="gemini-1.5-flash") - llm_request = LlmRequest( - model="gemini-1.5-flash", - contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], - config=types.GenerateContentConfig( - temperature=0.1, - response_modalities=[types.Modality.TEXT], - system_instruction="You are a helpful assistant", - ), - ) - - with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - - mock_responses = [ - types.GenerateContentResponse( - candidates=[ - types.Candidate( - content=Content( - role="model", parts=[Part.from_text(text="Hello")] - ), - finish_reason=None, - ) - ] - ), - types.GenerateContentResponse( - candidates=[ - types.Candidate( - content=Content( - role="model", parts=[Part.from_text(text=" world")] - ), - finish_reason=types.FinishReason.MAX_TOKENS, - finish_message="Maximum tokens reached", - ) - ] - ), - ] - - async def mock_coro(): - return MockAsyncIterator(mock_responses) - - mock_client.aio.models.generate_content_stream.return_value = mock_coro() - - responses = [ - resp - async for resp in gemini_llm.generate_content_async( - llm_request, stream=True - ) - ] - - # Should have 3 responses: 2 partial and 1 final aggregated - assert len(responses) == 3 - assert responses[0].partial is True - assert responses[1].partial is True - - # Final response should have aggregated content with error info set for non-STOP finish reason - final_response = responses[2] - assert final_response.content.parts[0].text == "Hello world" assert final_response.error_code == types.FinishReason.MAX_TOKENS assert final_response.error_message == "Maximum tokens reached" @@ -1182,9 +1070,7 @@ async def mock_coro(): # Final aggregated text with error info assert responses[4].content.parts[0].text == " second text" - assert ( - responses[4].error_code is None - ) # STOP finish reason should have None error_code + assert responses[4].error_code == types.FinishReason.STOP @pytest.mark.asyncio @@ -1381,9 +1267,7 @@ async def mock_coro(): # Final aggregated response should have both thought and text final_response = responses[-1] - assert ( - final_response.error_code is None - ) # STOP finish reason should have None error_code + assert final_response.error_code == types.FinishReason.STOP assert len(final_response.content.parts) == 2 # thought part + text part assert final_response.content.parts[0].thought is True assert "More thinking..." in final_response.content.parts[0].text @@ -1519,9 +1403,7 @@ async def mock_coro(): # Final aggregation should contain "Second chunk" and have error info final_aggregation = aggregated_text_responses[-1] assert final_aggregation.content.parts[0].text == "Second chunk" - assert ( - final_aggregation.error_code is None - ) # STOP finish reason should have None error_code + assert final_aggregation.error_code == types.FinishReason.STOP # Verify the function call is preserved between aggregations function_call_responses = [ diff --git a/tests/unittests/tools/computer_use/__init__.py b/tests/unittests/tools/computer_use/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/tools/computer_use/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/computer_use/conftest.py b/tests/unittests/tools/computer_use/conftest.py new file mode 100644 index 000000000..6299112e9 --- /dev/null +++ b/tests/unittests/tools/computer_use/conftest.py @@ -0,0 +1,37 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class MockEnvironment: + ENVIRONMENT_BROWSER = "ENVIRONMENT_BROWSER" + ENVIRONMENT_DESKTOP = "ENVIRONMENT_DESKTOP" + + +class MockToolComputerUse: + + def __init__(self, environment=None): + self.environment = environment + + +# Patch google.genai.types before imports +try: + from google.genai import types + + # Add missing types if they don't exist + if not hasattr(types, "Environment"): + types.Environment = MockEnvironment + if not hasattr(types, "ToolComputerUse"): + types.ToolComputerUse = MockToolComputerUse +except ImportError: + pass diff --git a/tests/unittests/tools/computer_use/test_computer.py b/tests/unittests/tools/computer_use/test_computer.py new file mode 100644 index 000000000..65d88067a --- /dev/null +++ b/tests/unittests/tools/computer_use/test_computer.py @@ -0,0 +1,321 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from google.adk.tools.computer_use.computer import Computer +from google.adk.tools.computer_use.computer import EnvironmentState +import pydantic +import pytest + +# Mock Environment enum since it may not be available in the current google.genai version +Environment = MagicMock() +Environment.ENVIRONMENT_BROWSER = "ENVIRONMENT_BROWSER" +Environment.ENVIRONMENT_DESKTOP = "ENVIRONMENT_DESKTOP" + + +class TestEnvironmentState: + """Test cases for EnvironmentState model.""" + + def test_valid_environment_state(self): + """Test creating a valid EnvironmentState.""" + screenshot_data = b"fake_png_data" + url = "https://example.com" + + state = EnvironmentState(screenshot=screenshot_data, url=url) + + assert state.screenshot == screenshot_data + assert state.url == url + + def test_empty_url_raises_validation_error(self): + """Test that empty URL raises validation error.""" + screenshot_data = b"fake_png_data" + + with pytest.raises(pydantic.ValidationError, match="URL cannot be empty"): + EnvironmentState(screenshot=screenshot_data, url="") + + def test_whitespace_only_url_raises_validation_error(self): + """Test that whitespace-only URL raises validation error.""" + screenshot_data = b"fake_png_data" + + with pytest.raises(pydantic.ValidationError, match="URL cannot be empty"): + EnvironmentState(screenshot=screenshot_data, url=" ") + + def test_valid_url_with_spaces_is_accepted(self): + """Test that URL with trailing/leading spaces is trimmed and accepted.""" + screenshot_data = b"fake_png_data" + url = " https://example.com " + + state = EnvironmentState(screenshot=screenshot_data, url=url) + assert state.url == url # pydantic validation doesn't auto-strip + + def test_missing_required_fields_raise_validation_error(self): + """Test that missing required fields raise validation errors.""" + with pytest.raises(pydantic.ValidationError): + EnvironmentState() + + with pytest.raises(pydantic.ValidationError): + EnvironmentState(screenshot=b"data") + + with pytest.raises(pydantic.ValidationError): + EnvironmentState(url="https://example.com") + + def test_environment_state_serialization(self): + """Test that EnvironmentState can be serialized and deserialized.""" + screenshot_data = b"fake_png_data" + url = "https://example.com" + + original_state = EnvironmentState(screenshot=screenshot_data, url=url) + + # Test dict conversion + state_dict = original_state.model_dump() + assert state_dict["screenshot"] == screenshot_data + assert state_dict["url"] == url + + # Test reconstruction from dict + reconstructed_state = EnvironmentState(**state_dict) + assert reconstructed_state.screenshot == original_state.screenshot + assert reconstructed_state.url == original_state.url + + +class MockComputer(Computer): + """Mock implementation of Computer for testing.""" + + def __init__(self): + self.screen_width = 1920 + self.screen_height = 1080 + self.current_url = "https://example.com" + self.screenshot_data = b"mock_screenshot_data" + + async def screen_size(self) -> tuple[int, int]: + return (self.screen_width, self.screen_height) + + async def open_web_browser(self) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def click_at(self, x: int, y: int) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def hover_at(self, x: int, y: int) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def type_text_at( + self, + x: int, + y: int, + text: str, + press_enter: bool = True, + clear_before_typing: bool = True, + ) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def scroll_document(self, direction: str) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def scroll_at( + self, x: int, y: int, direction: str, magnitude: int + ) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def wait_5_seconds(self) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def go_back(self) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def go_forward(self) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def search(self) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def navigate(self, url: str) -> EnvironmentState: + self.current_url = url + return EnvironmentState(screenshot=self.screenshot_data, url=url) + + async def key_combination(self, keys: list[str]) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def drag_and_drop( + self, x: int, y: int, destination_x: int, destination_y: int + ) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + async def current_state(self) -> EnvironmentState: + return EnvironmentState( + screenshot=self.screenshot_data, url=self.current_url + ) + + +class TestComputer: + """Test cases for Computer abstract class.""" + + @pytest.fixture + def mock_computer(self): + """Fixture providing a mock computer instance.""" + return MockComputer() + + @pytest.mark.asyncio + async def test_initialize_default_implementation(self, mock_computer): + """Test that default initialize method works.""" + # Should not raise any exception + await mock_computer.initialize() + + @pytest.mark.asyncio + async def test_close_default_implementation(self, mock_computer): + """Test that default close method works.""" + # Should not raise any exception + await mock_computer.close() + + @pytest.mark.asyncio + async def test_environment_default_implementation(self, mock_computer): + """Test that default environment method returns ENVIRONMENT_BROWSER.""" + environment = await mock_computer.environment() + assert environment == Environment.ENVIRONMENT_BROWSER + + @pytest.mark.asyncio + async def test_screen_size(self, mock_computer): + """Test screen_size method.""" + size = await mock_computer.screen_size() + assert size == (1920, 1080) + assert isinstance(size, tuple) + assert len(size) == 2 + assert isinstance(size[0], int) + assert isinstance(size[1], int) + + @pytest.mark.asyncio + async def test_open_web_browser(self, mock_computer): + """Test open_web_browser method.""" + state = await mock_computer.open_web_browser() + assert isinstance(state, EnvironmentState) + assert state.screenshot == b"mock_screenshot_data" + assert state.url == "https://example.com" + + @pytest.mark.asyncio + async def test_click_at(self, mock_computer): + """Test click_at method.""" + state = await mock_computer.click_at(100, 200) + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_hover_at(self, mock_computer): + """Test hover_at method.""" + state = await mock_computer.hover_at(150, 250) + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_type_text_at(self, mock_computer): + """Test type_text_at method.""" + state = await mock_computer.type_text_at(100, 200, "test text") + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_type_text_at_with_options(self, mock_computer): + """Test type_text_at method with optional parameters.""" + state = await mock_computer.type_text_at( + 100, 200, "test text", press_enter=False, clear_before_typing=False + ) + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_scroll_document(self, mock_computer): + """Test scroll_document method.""" + for direction in ["up", "down", "left", "right"]: + state = await mock_computer.scroll_document(direction) + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_scroll_at(self, mock_computer): + """Test scroll_at method.""" + state = await mock_computer.scroll_at(100, 200, "down", 5) + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_wait_5_seconds(self, mock_computer): + """Test wait_5_seconds method.""" + state = await mock_computer.wait_5_seconds() + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_go_back(self, mock_computer): + """Test go_back method.""" + state = await mock_computer.go_back() + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_go_forward(self, mock_computer): + """Test go_forward method.""" + state = await mock_computer.go_forward() + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_search(self, mock_computer): + """Test search method.""" + state = await mock_computer.search() + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_navigate(self, mock_computer): + """Test navigate method.""" + new_url = "https://google.com" + state = await mock_computer.navigate(new_url) + assert isinstance(state, EnvironmentState) + assert state.url == new_url + + @pytest.mark.asyncio + async def test_key_combination(self, mock_computer): + """Test key_combination method.""" + state = await mock_computer.key_combination(["ctrl", "c"]) + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_drag_and_drop(self, mock_computer): + """Test drag_and_drop method.""" + state = await mock_computer.drag_and_drop(100, 200, 300, 400) + assert isinstance(state, EnvironmentState) + + @pytest.mark.asyncio + async def test_current_state(self, mock_computer): + """Test current_state method.""" + state = await mock_computer.current_state() + assert isinstance(state, EnvironmentState) + + def test_computer_is_abstract(self): + """Test that Computer cannot be instantiated directly.""" + with pytest.raises(TypeError): + Computer() diff --git a/tests/unittests/tools/computer_use/test_computer_use_tool.py b/tests/unittests/tools/computer_use/test_computer_use_tool.py new file mode 100644 index 000000000..fa76b0fc0 --- /dev/null +++ b/tests/unittests/tools/computer_use/test_computer_use_tool.py @@ -0,0 +1,404 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.models.llm_request import LlmRequest +from google.adk.tools.computer_use.computer import EnvironmentState +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + + +# Mock Environment and ToolComputerUse since they may not be available in the current google.genai version +class MockEnvironment: + ENVIRONMENT_BROWSER = "ENVIRONMENT_BROWSER" + ENVIRONMENT_DESKTOP = "ENVIRONMENT_DESKTOP" + + +class MockToolComputerUse: + + def __init__(self, environment): + self.environment = environment + + +# Patch the types module to include our mocks +types.Environment = MockEnvironment +types.ToolComputerUse = MockToolComputerUse + + +class TestComputerUseTool: + """Test cases for ComputerUseTool class.""" + + @pytest.fixture + def mockfunction(self): + """Fixture providing a mock function for testing.""" + func = AsyncMock() + func.__name__ = "testfunction" + func.__doc__ = "Test function documentation" + return func + + @pytest.fixture + def screen_size(self): + """Fixture providing a standard screen size.""" + return (1920, 1080) + + @pytest.fixture + def computer_use_tool(self, mockfunction, screen_size): + """Fixture providing a ComputerUseTool instance.""" + return ComputerUseTool( + func=mockfunction, + screen_size=screen_size, + environment=types.Environment.ENVIRONMENT_BROWSER, + ) + + @pytest.fixture + def tool_context(self): + """Fixture providing a mock tool context.""" + context = MagicMock(spec=ToolContext) + context.actions = MagicMock() + context.actions.skip_summarization = False + return context + + def test_init_valid_screen_size(self, mockfunction): + """Test initialization with valid screen size.""" + tool = ComputerUseTool( + func=mockfunction, + screen_size=(1920, 1080), + environment=types.Environment.ENVIRONMENT_BROWSER, + ) + assert tool._screen_size == (1920, 1080) + assert tool._environment == types.Environment.ENVIRONMENT_BROWSER + + def test_init_invalid_screen_size_not_tuple(self, mockfunction): + """Test initialization with invalid screen size (not tuple).""" + with pytest.raises(ValueError, match="screen_size must be a tuple"): + ComputerUseTool( + func=mockfunction, + screen_size=[1920, 1080], # list instead of tuple + environment=types.Environment.ENVIRONMENT_BROWSER, + ) + + def test_init_invalid_screen_size_wrong_length(self, mockfunction): + """Test initialization with invalid screen size (wrong length).""" + with pytest.raises(ValueError, match="screen_size must be a tuple"): + ComputerUseTool( + func=mockfunction, + screen_size=(1920,), # only one element + environment=types.Environment.ENVIRONMENT_BROWSER, + ) + + def test_init_invalid_screen_size_negative_dimensions(self, mockfunction): + """Test initialization with negative screen dimensions.""" + with pytest.raises( + ValueError, match="screen_size dimensions must be positive" + ): + ComputerUseTool( + func=mockfunction, + screen_size=(-1920, 1080), + environment=types.Environment.ENVIRONMENT_BROWSER, + ) + + with pytest.raises( + ValueError, match="screen_size dimensions must be positive" + ): + ComputerUseTool( + func=mockfunction, + screen_size=(1920, 0), + environment=types.Environment.ENVIRONMENT_BROWSER, + ) + + def test_normalize_x_coordinate(self, computer_use_tool): + """Test x coordinate normalization.""" + # Test basic normalization (500 on 1000 scale -> 960 on 1920 scale) + assert computer_use_tool._normalize_x(500) == 960 + + # Test edge cases + assert computer_use_tool._normalize_x(0) == 0 + assert ( + computer_use_tool._normalize_x(1000) == 1919 + ) # clamped to screen width - 1 + + # Test clamping + assert computer_use_tool._normalize_x(-100) == 0 + assert computer_use_tool._normalize_x(1500) == 1919 + + def test_normalize_y_coordinate(self, computer_use_tool): + """Test y coordinate normalization.""" + # Test basic normalization (500 on 1000 scale -> 540 on 1080 scale) + assert computer_use_tool._normalize_y(500) == 540 + + # Test edge cases + assert computer_use_tool._normalize_y(0) == 0 + assert ( + computer_use_tool._normalize_y(1000) == 1079 + ) # clamped to screen height - 1 + + # Test clamping + assert computer_use_tool._normalize_y(-100) == 0 + assert computer_use_tool._normalize_y(1500) == 1079 + + def test_normalize_coordinate_invalid_type(self, computer_use_tool): + """Test coordinate normalization with invalid types.""" + with pytest.raises(ValueError, match="x coordinate must be numeric"): + computer_use_tool._normalize_x("invalid") + + with pytest.raises(ValueError, match="y coordinate must be numeric"): + computer_use_tool._normalize_y("invalid") + + def test_normalize_coordinate_float_input(self, computer_use_tool): + """Test coordinate normalization with float input.""" + # Float inputs should be converted to int + assert computer_use_tool._normalize_x(500.5) == 960 + assert computer_use_tool._normalize_y(500.7) == 540 + + @pytest.mark.asyncio + async def test_run_async_coordinate_normalization(self, tool_context): + """Test that run_async normalizes coordinates properly.""" + # Create a simple function that records the arguments it receives + received_args = {} + + async def capture_func(x, y, tool_context=None): + received_args.update({"x": x, "y": y, "tool_context": tool_context}) + return "test_result" + + computer_use_tool = ComputerUseTool( + func=capture_func, + screen_size=(1920, 1080), + environment=types.Environment.ENVIRONMENT_BROWSER, + ) + + args = {"x": 500, "y": 600} + result = await computer_use_tool.run_async( + args=args, tool_context=tool_context + ) + + # Check that coordinates were normalized + assert received_args.get("x") == 960 # 500/1000*1920 + assert received_args.get("y") == 648 # 600/1000*1080 + assert result == "test_result" + + @pytest.mark.asyncio + async def test_run_async_destination_coordinates(self, tool_context): + """Test that run_async normalizes destination coordinates for drag and drop.""" + # Create a simple function that records the arguments it receives + received_args = {} + + async def capture_func( + x, y, destination_x, destination_y, tool_context=None + ): + received_args.update({ + "x": x, + "y": y, + "destination_x": destination_x, + "destination_y": destination_y, + "tool_context": tool_context, + }) + return "test_result" + + computer_use_tool = ComputerUseTool( + func=capture_func, + screen_size=(1920, 1080), + environment=types.Environment.ENVIRONMENT_BROWSER, + ) + + args = { + "x": 250, + "y": 300, + "destination_x": 750, + "destination_y": 800, + } + await computer_use_tool.run_async(args=args, tool_context=tool_context) + + # Check that coordinates were normalized + assert received_args.get("x") == 480 # 250/1000*1920 + assert received_args.get("y") == 324 # 300/1000*1080 + assert received_args.get("destination_x") == 1440 # 750/1000*1920 + assert received_args.get("destination_y") == 864 # 800/1000*1080 + + @pytest.mark.asyncio + async def test_run_async_no_coordinates(self, tool_context): + """Test that run_async works without coordinates.""" + # Create a simple function that records the arguments it receives + received_args = {} + + async def capture_func(text, press_enter=True, tool_context=None): + received_args.update({ + "text": text, + "press_enter": press_enter, + "tool_context": tool_context, + }) + return "test_result" + + computer_use_tool = ComputerUseTool( + func=capture_func, + screen_size=(1920, 1080), + environment=types.Environment.ENVIRONMENT_BROWSER, + ) + + args = {"text": "hello world", "press_enter": True} + result = await computer_use_tool.run_async( + args=args, tool_context=tool_context + ) + + # Args should remain unchanged (no coordinate normalization) + assert received_args.get("text") == "hello world" + assert received_args.get("press_enter") == True + assert result == "test_result" + + @pytest.mark.asyncio + async def test_run_async_environment_state_result( + self, computer_use_tool, tool_context + ): + """Test that run_async processes EnvironmentState results correctly.""" + screenshot_data = b"fake_png_data" + environment_state = EnvironmentState( + screenshot=screenshot_data, url="https://example.com" + ) + computer_use_tool.func.return_value = environment_state + + args = {"x": 500, "y": 600} + result = await computer_use_tool.run_async( + args=args, tool_context=tool_context + ) + + expected_result = { + "image": { + "mimetype": "image/png", + "data": base64.b64encode(screenshot_data).decode("utf-8"), + }, + "url": "https://example.com", + } + assert result == expected_result + + @pytest.mark.asyncio + async def test_run_async_exception_handling( + self, computer_use_tool, tool_context + ): + """Test that run_async properly handles exceptions.""" + computer_use_tool.func.side_effect = ValueError("Test error") + + args = {"x": 500, "y": 600} + + with patch( + "google.adk.tools.computer_use.computer_use_tool.logger" + ) as mock_logger: + with pytest.raises(ValueError, match="Test error"): + await computer_use_tool.run_async(args=args, tool_context=tool_context) + + mock_logger.error.assert_called_once() + + @pytest.mark.asyncio + async def test_process_llm_request_new_config( + self, computer_use_tool, tool_context + ): + """Test process_llm_request with new LLM request config.""" + # Mock the method since computer_use parameter is not available in current types.Tool + with patch.object(computer_use_tool, "process_llm_request") as mock_process: + llm_request = LlmRequest() + llm_request.tools_dict = {} + + await computer_use_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + mock_process.assert_called_once_with( + tool_context=tool_context, llm_request=llm_request + ) + + @pytest.mark.asyncio + async def test_process_llm_request_existing_config( + self, computer_use_tool, tool_context + ): + """Test process_llm_request with existing LLM request config.""" + # Mock the method since computer_use parameter is not available in current types.Tool + with patch.object(computer_use_tool, "process_llm_request") as mock_process: + llm_request = LlmRequest() + llm_request.tools_dict = {} + llm_request.config = types.GenerateContentConfig() + llm_request.config.tools = [] + + await computer_use_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + mock_process.assert_called_once_with( + tool_context=tool_context, llm_request=llm_request + ) + + @pytest.mark.asyncio + async def test_process_llm_request_already_configured( + self, computer_use_tool, tool_context + ): + """Test process_llm_request when computer use is already configured.""" + # Mock the method since computer_use parameter is not available in current types.Tool + with patch.object(computer_use_tool, "process_llm_request") as mock_process: + llm_request = LlmRequest() + llm_request.tools_dict = {} + llm_request.config = types.GenerateContentConfig() + llm_request.config.tools = [] + + await computer_use_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + mock_process.assert_called_once_with( + tool_context=tool_context, llm_request=llm_request + ) + + @pytest.mark.asyncio + async def test_process_llm_request_exception_handling( + self, computer_use_tool, tool_context + ): + """Test that process_llm_request properly handles exceptions.""" + llm_request = MagicMock() + llm_request.tools_dict = {} + llm_request.config = None + + # Make types.GenerateContentConfig() raise an exception + with patch( + "google.genai.types.GenerateContentConfig", + side_effect=RuntimeError("Test error"), + ): + with patch( + "google.adk.tools.computer_use.computer_use_tool.logger" + ) as mock_logger: + with pytest.raises(RuntimeError, match="Test error"): + await computer_use_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + mock_logger.error.assert_called_once() + + def test_inheritance(self, computer_use_tool): + """Test that ComputerUseTool properly inherits from FunctionTool.""" + from google.adk.tools.function_tool import FunctionTool + + assert isinstance(computer_use_tool, FunctionTool) + + @pytest.mark.asyncio + async def test_custom_environment(self, mockfunction): + """Test ComputerUseTool with custom environment.""" + tool = ComputerUseTool( + func=mockfunction, + screen_size=(1920, 1080), + environment=types.Environment.ENVIRONMENT_DESKTOP, + ) + + # Just test that the tool has the correct environment stored + assert tool._environment == types.Environment.ENVIRONMENT_DESKTOP diff --git a/tests/unittests/tools/computer_use/test_computer_use_toolset.py b/tests/unittests/tools/computer_use/test_computer_use_toolset.py new file mode 100644 index 000000000..89fd82357 --- /dev/null +++ b/tests/unittests/tools/computer_use/test_computer_use_toolset.py @@ -0,0 +1,342 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from google.adk.tools.computer_use.computer import Computer +from google.adk.tools.computer_use.computer import EnvironmentState +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool +from google.adk.tools.computer_use.computer_use_toolset import ComputerUseToolset +import pytest + + +# Mock Environment enum since it may not be available in the current google.genai version +class MockEnvironment: + ENVIRONMENT_BROWSER = "ENVIRONMENT_BROWSER" + ENVIRONMENT_DESKTOP = "ENVIRONMENT_DESKTOP" + + +Environment = MockEnvironment + + +class MockComputer(Computer): + """Mock Computer implementation for testing.""" + + def __init__(self): + self.initialize_called = False + self.close_called = False + self._screen_size = (1920, 1080) + self._environment = Environment.ENVIRONMENT_BROWSER + + async def initialize(self): + self.initialize_called = True + + async def close(self): + self.close_called = True + + async def screen_size(self) -> tuple[int, int]: + return self._screen_size + + async def environment(self) -> Environment: + return self._environment + + # Implement all abstract methods to make this a concrete class + async def open_web_browser(self) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def click_at(self, x: int, y: int) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def hover_at(self, x: int, y: int) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def type_text_at( + self, + x: int, + y: int, + text: str, + press_enter: bool = True, + clear_before_typing: bool = True, + ) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def scroll_document(self, direction: str) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def scroll_at( + self, x: int, y: int, direction: str, magnitude: int + ) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def wait_5_seconds(self) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def go_back(self) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def go_forward(self) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def search(self) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def navigate(self, url: str) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url=url) + + async def key_combination(self, keys: list[str]) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def drag_and_drop( + self, x: int, y: int, destination_x: int, destination_y: int + ) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + async def current_state(self) -> EnvironmentState: + return EnvironmentState(screenshot=b"test", url="https://example.com") + + +class TestComputerUseToolset: + """Test cases for ComputerUseToolset class.""" + + @pytest.fixture + def mock_computer(self): + """Fixture providing a mock computer.""" + return MockComputer() + + @pytest.fixture + def toolset(self, mock_computer): + """Fixture providing a ComputerUseToolset instance.""" + return ComputerUseToolset(computer=mock_computer) + + def test_init(self, mock_computer): + """Test ComputerUseToolset initialization.""" + toolset = ComputerUseToolset(computer=mock_computer) + + assert toolset._computer == mock_computer + assert toolset._initialized is False + + @pytest.mark.asyncio + async def test_ensure_initialized(self, toolset, mock_computer): + """Test that _ensure_initialized calls computer.initialize().""" + assert not mock_computer.initialize_called + assert not toolset._initialized + + await toolset._ensure_initialized() + + assert mock_computer.initialize_called + assert toolset._initialized + + @pytest.mark.asyncio + async def test_ensure_initialized_only_once(self, toolset, mock_computer): + """Test that _ensure_initialized only calls initialize once.""" + await toolset._ensure_initialized() + + # Reset the flag to test it's not called again + mock_computer.initialize_called = False + + await toolset._ensure_initialized() + + # Should not be called again + assert not mock_computer.initialize_called + assert toolset._initialized + + @pytest.mark.asyncio + async def test_get_tools(self, toolset, mock_computer): + """Test that get_tools returns ComputerUseTool instances.""" + tools = await toolset.get_tools() + + # Should initialize the computer + assert mock_computer.initialize_called + + # Should return a list of ComputerUseTool instances + assert isinstance(tools, list) + assert len(tools) > 0 + assert all(isinstance(tool, ComputerUseTool) for tool in tools) + + # Each tool should have the correct configuration + for tool in tools: + assert tool._screen_size == (1920, 1080) + assert tool._environment == Environment.ENVIRONMENT_BROWSER + + @pytest.mark.asyncio + async def test_get_tools_excludes_utility_methods(self, toolset): + """Test that get_tools excludes utility methods like screen_size, environment, close.""" + tools = await toolset.get_tools() + + # Get tool function names + tool_names = [tool.func.__name__ for tool in tools] + + # Should exclude utility methods + excluded_methods = {"screen_size", "environment", "close"} + for method in excluded_methods: + assert method not in tool_names + + # initialize might be included since it's a concrete method, not just abstract + # This is acceptable behavior + + # Should include action methods + expected_methods = { + "open_web_browser", + "click_at", + "hover_at", + "type_text_at", + "scroll_document", + "scroll_at", + "wait_5_seconds", + "go_back", + "go_forward", + "search", + "navigate", + "key_combination", + "drag_and_drop", + "current_state", + } + for method in expected_methods: + assert method in tool_names + + @pytest.mark.asyncio + async def test_get_tools_with_readonly_context(self, toolset): + """Test get_tools with readonly_context parameter.""" + from google.adk.agents.readonly_context import ReadonlyContext + + readonly_context = MagicMock(spec=ReadonlyContext) + + tools = await toolset.get_tools(readonly_context=readonly_context) + + # Should still return tools (readonly_context doesn't affect behavior currently) + assert isinstance(tools, list) + assert len(tools) > 0 + + @pytest.mark.asyncio + async def test_close(self, toolset, mock_computer): + """Test that close calls computer.close().""" + await toolset.close() + + assert mock_computer.close_called + + @pytest.mark.asyncio + async def test_get_tools_creates_tools_with_correct_methods( + self, toolset, mock_computer + ): + """Test that get_tools creates tools with the correct underlying methods.""" + tools = await toolset.get_tools() + + # Find the click_at tool + click_tool = None + for tool in tools: + if tool.func.__name__ == "click_at": + click_tool = tool + break + + assert click_tool is not None + + # The tool's function should be bound to the mock computer instance + assert click_tool.func.__self__ == mock_computer + + @pytest.mark.asyncio + async def test_get_tools_handles_custom_screen_size(self, mock_computer): + """Test get_tools with custom screen size.""" + mock_computer._screen_size = (2560, 1440) + + toolset = ComputerUseToolset(computer=mock_computer) + tools = await toolset.get_tools() + + # All tools should have the custom screen size + for tool in tools: + assert tool._screen_size == (2560, 1440) + + @pytest.mark.asyncio + async def test_get_tools_handles_custom_environment(self, mock_computer): + """Test get_tools with custom environment.""" + mock_computer._environment = Environment.ENVIRONMENT_DESKTOP + + toolset = ComputerUseToolset(computer=mock_computer) + tools = await toolset.get_tools() + + # All tools should have the custom environment + for tool in tools: + assert tool._environment == Environment.ENVIRONMENT_DESKTOP + + @pytest.mark.asyncio + async def test_multiple_get_tools_calls_return_different_instances( + self, toolset + ): + """Test that multiple get_tools calls return different tool instances.""" + tools1 = await toolset.get_tools() + tools2 = await toolset.get_tools() + + # Should return different instances + for tool1, tool2 in zip(tools1, tools2): + assert tool1 is not tool2 + # But should have the same configuration + assert tool1._screen_size == tool2._screen_size + assert tool1._environment == tool2._environment + + def test_inheritance(self, toolset): + """Test that ComputerUseToolset inherits from BaseToolset.""" + from google.adk.tools.base_toolset import BaseToolset + + assert isinstance(toolset, BaseToolset) + + @pytest.mark.asyncio + async def test_get_tools_method_filtering(self, toolset): + """Test that get_tools properly filters methods from Computer ABC.""" + tools = await toolset.get_tools() + tool_names = [tool.func.__name__ for tool in tools] + + # Should not include private methods + assert not any(name.startswith("_") for name in tool_names) + + # Should not include class methods, static methods, or properties + # that aren't actual computer action methods + forbidden_names = { + "__init__", + "__new__", + "__class__", + "__dict__", + "__module__", + "__doc__", + "__annotations__", + "__abstractmethods__", + } + for forbidden in forbidden_names: + assert forbidden not in tool_names + + @pytest.mark.asyncio + async def test_computer_method_binding(self, toolset, mock_computer): + """Test that computer methods are properly bound to the computer instance.""" + tools = await toolset.get_tools() + + for tool in tools: + # Each tool's function should be a bound method of the computer + assert hasattr(tool.func, "__self__") + assert tool.func.__self__ == mock_computer + + @pytest.mark.asyncio + async def test_toolset_handles_computer_initialization_failure( + self, mock_computer + ): + """Test toolset behavior when computer initialization fails.""" + + async def failing_initialize(): + raise RuntimeError("Initialization failed") + + mock_computer.initialize = failing_initialize + toolset = ComputerUseToolset(computer=mock_computer) + + with pytest.raises(RuntimeError, match="Initialization failed"): + await toolset.get_tools() + + # Should not be marked as initialized + assert not toolset._initialized