From eeaf6219f8345251520680cdb2c8b311f13e2a8e Mon Sep 17 00:00:00 2001 From: Liang Wu Date: Tue, 1 Jul 2025 18:40:04 -0700 Subject: [PATCH] feat: add agent.clone() method This would allow users to easily make a copy of the agents they built without having to add too much boilerplates. This promotes code reuse, modularity and testability of agents. PiperOrigin-RevId: 778276937 --- src/google/adk/agents/base_agent.py | 47 +++ src/google/adk/agents/langgraph_agent.py | 18 + src/google/adk/agents/llm_agent.py | 15 + src/google/adk/agents/loop_agent.py | 16 + src/google/adk/agents/parallel_agent.py | 16 + src/google/adk/agents/remote_a2a_agent.py | 16 + src/google/adk/agents/sequential_agent.py | 16 + tests/unittests/agents/test_agent_clone.py | 371 +++++++++++++++++++++ 8 files changed, 515 insertions(+) create mode 100644 tests/unittests/agents/test_agent_clone.py diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index bdc10ac3a..b348ebc4d 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -20,6 +20,7 @@ from typing import Awaitable from typing import Callable from typing import final +from typing import Mapping from typing import Optional from typing import TYPE_CHECKING from typing import Union @@ -121,6 +122,52 @@ class BaseAgent(BaseModel): response and appended to event history as agent response. """ + def clone(self, update: Mapping[str, Any] | None = None) -> BaseAgent: + """Creates a copy of this BaseAgent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + For example: {"name": "cloned_agent"} + + Returns: + A new BaseAgent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + if update is not None and 'parent_agent' in update: + raise ValueError( + 'Cannot update `parent_agent` field in clone. Parent agent is set' + ' only when the parent agent is instantiated with the sub-agents.' + ) + + # Only allow updating fields that are defined in the agent class. + allowed_fields = set(self.__class__.model_fields) + if update is not None: + invalid_fields = set(update) - allowed_fields + if invalid_fields: + raise ValueError( + f'Cannot update non-existent fields in {self.__class__.__name__}:' + f' {invalid_fields}' + ) + + cloned_agent = self.model_copy(update=update) + + if update is None or 'sub_agents' not in update: + # If `sub_agents` is not provided in the update, need to recursively clone + # the sub-agents to avoid sharing the sub-agents with the original agent. + cloned_agent.sub_agents = [] + for sub_agent in self.sub_agents: + cloned_sub_agent = sub_agent.clone() + cloned_sub_agent.parent_agent = cloned_agent + cloned_agent.sub_agents.append(cloned_sub_agent) + else: + for sub_agent in cloned_agent.sub_agents: + sub_agent.parent_agent = cloned_agent + + # Remove the parent agent from the cloned agent to avoid sharing the parent + # agent with the cloned agent. + cloned_agent.parent_agent = None + return cloned_agent + @final async def run_async( self, diff --git a/src/google/adk/agents/langgraph_agent.py b/src/google/adk/agents/langgraph_agent.py index f07b203fa..ce481ba6f 100644 --- a/src/google/adk/agents/langgraph_agent.py +++ b/src/google/adk/agents/langgraph_agent.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import Any from typing import AsyncGenerator +from typing import cast +from typing import Mapping from typing import Union from google.genai import types @@ -59,6 +64,19 @@ class LangGraphAgent(BaseAgent): instruction: str = '' + @override + def clone(self, update: Mapping[str, Any] | None = None) -> LangGraphAgent: + """Creates a copy of this LangGraphAgent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + + Returns: + A new LangGraphAgent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + return cast(LangGraphAgent, super().clone(update)) + @override async def _run_async_impl( self, diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index a5c859e26..4241703ae 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -20,7 +20,9 @@ from typing import AsyncGenerator from typing import Awaitable from typing import Callable +from typing import cast from typing import Literal +from typing import Mapping from typing import Optional from typing import Union @@ -268,6 +270,19 @@ class LlmAgent(BaseAgent): """ # Callbacks - End + @override + def clone(self, update: Mapping[str, Any] | None = None) -> LlmAgent: + """Creates a copy of this LlmAgent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + + Returns: + A new LlmAgent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + return cast(LlmAgent, super().clone(update)) + @override async def _run_async_impl( self, ctx: InvocationContext diff --git a/src/google/adk/agents/loop_agent.py b/src/google/adk/agents/loop_agent.py index 219e0c22f..2efa7c665 100644 --- a/src/google/adk/agents/loop_agent.py +++ b/src/google/adk/agents/loop_agent.py @@ -16,7 +16,10 @@ from __future__ import annotations +from typing import Any from typing import AsyncGenerator +from typing import cast +from typing import Mapping from typing import Optional from typing_extensions import override @@ -40,6 +43,19 @@ class LoopAgent(BaseAgent): escalates. """ + @override + def clone(self, update: Mapping[str, Any] | None = None) -> LoopAgent: + """Creates a copy of this LoopAgent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + + Returns: + A new LoopAgent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + return cast(LoopAgent, super().clone(update)) + @override async def _run_async_impl( self, ctx: InvocationContext diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index 427128cec..10810b715 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -17,7 +17,10 @@ from __future__ import annotations import asyncio +from typing import Any from typing import AsyncGenerator +from typing import cast +from typing import Mapping from typing_extensions import override @@ -92,6 +95,19 @@ class ParallelAgent(BaseAgent): - Generating multiple responses for review by a subsequent evaluation agent. """ + @override + def clone(self, update: Mapping[str, Any] | None = None) -> ParallelAgent: + """Creates a copy of this ParallelAgent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + + Returns: + A new ParallelAgent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + return cast(ParallelAgent, super().clone(update)) + @override async def _run_async_impl( self, ctx: InvocationContext diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index b9f765576..133d9a9ac 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -19,6 +19,8 @@ from pathlib import Path from typing import Any from typing import AsyncGenerator +from typing import cast +from typing import Mapping from typing import Optional from typing import Union from urllib.parse import urlparse @@ -48,6 +50,7 @@ from google.genai import types as genai_types import httpx +from typing_extensions import override from ..a2a.converters.event_converter import convert_a2a_message_to_event from ..a2a.converters.event_converter import convert_a2a_task_to_event @@ -151,6 +154,19 @@ def __init__( f"got {type(agent_card)}" ) + @override + def clone(self, update: Mapping[str, Any] | None = None) -> "RemoteA2aAgent": + """Creates a copy of this RemoteA2aAgent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + + Returns: + A new RemoteA2aAgent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + return cast(RemoteA2aAgent, super().clone(update)) + async def _ensure_httpx_client(self) -> httpx.AsyncClient: """Ensure HTTP client is available and properly configured.""" if not self._httpx_client: diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index 845dd5ac1..ef081044f 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -16,7 +16,10 @@ from __future__ import annotations +from typing import Any from typing import AsyncGenerator +from typing import cast +from typing import Mapping from typing_extensions import override @@ -29,6 +32,19 @@ class SequentialAgent(BaseAgent): """A shell agent that runs its sub-agents in sequence.""" + @override + def clone(self, update: Mapping[str, Any] | None = None) -> SequentialAgent: + """Creates a copy of this SequentialAgent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + + Returns: + A new SequentialAgent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + return cast(SequentialAgent, super().clone(update)) + @override async def _run_async_impl( self, ctx: InvocationContext diff --git a/tests/unittests/agents/test_agent_clone.py b/tests/unittests/agents/test_agent_clone.py new file mode 100644 index 000000000..5e6f9f683 --- /dev/null +++ b/tests/unittests/agents/test_agent_clone.py @@ -0,0 +1,371 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testings for the clone functionality of agents.""" + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.loop_agent import LoopAgent +from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.agents.sequential_agent import SequentialAgent +import pytest + + +def test_llm_agent_clone(): + """Test cloning an LLM agent.""" + # Create an LLM agent + original = LlmAgent( + name="llm_agent", + description="An LLM agent", + instruction="You are a helpful assistant.", + ) + + # Clone it with name update + cloned = original.clone(update={"name": "cloned_llm_agent"}) + + # Verify the clone + assert cloned.name == "cloned_llm_agent" + assert cloned.description == "An LLM agent" + assert cloned.instruction == "You are a helpful assistant." + assert cloned.parent_agent is None + assert len(cloned.sub_agents) == 0 + assert isinstance(cloned, LlmAgent) + + # Verify the original is unchanged + assert original.name == "llm_agent" + assert original.instruction == "You are a helpful assistant." + + +def test_agent_with_sub_agents(): + """Test cloning an agent that has sub-agents.""" + # Create sub-agents + sub_agent1 = LlmAgent(name="sub_agent1", description="First sub-agent") + sub_agent2 = LlmAgent(name="sub_agent2", description="Second sub-agent") + + # Create a parent agent with sub-agents + original = SequentialAgent( + name="parent_agent", + description="Parent agent with sub-agents", + sub_agents=[sub_agent1, sub_agent2], + ) + + # Clone it with name update + cloned = original.clone(update={"name": "cloned_parent"}) + + # Verify the clone has sub-agents (deep copy behavior) + assert cloned.name == "cloned_parent" + assert cloned.description == "Parent agent with sub-agents" + assert cloned.parent_agent is None + assert len(cloned.sub_agents) == 2 + + # Sub-agents should be cloned with their original names + assert cloned.sub_agents[0].name == "sub_agent1" + assert cloned.sub_agents[1].name == "sub_agent2" + + # Sub-agents should have the cloned agent as their parent + assert cloned.sub_agents[0].parent_agent == cloned + assert cloned.sub_agents[1].parent_agent == cloned + + # Sub-agents should be different objects from the original + assert cloned.sub_agents[0] is not original.sub_agents[0] + assert cloned.sub_agents[1] is not original.sub_agents[1] + + # Verify the original still has sub-agents + assert original.name == "parent_agent" + assert len(original.sub_agents) == 2 + assert original.sub_agents[0].name == "sub_agent1" + assert original.sub_agents[1].name == "sub_agent2" + assert original.sub_agents[0].parent_agent == original + assert original.sub_agents[1].parent_agent == original + + +def test_three_level_nested_agent(): + """Test cloning a three-level nested agent to verify recursive cloning logic.""" + # Create third-level agents (leaf nodes) + leaf_agent1 = LlmAgent(name="leaf1", description="First leaf agent") + leaf_agent2 = LlmAgent(name="leaf2", description="Second leaf agent") + + # Create second-level agents + middle_agent1 = SequentialAgent( + name="middle1", description="First middle agent", sub_agents=[leaf_agent1] + ) + middle_agent2 = ParallelAgent( + name="middle2", + description="Second middle agent", + sub_agents=[leaf_agent2], + ) + + # Create top-level agent + root_agent = LoopAgent( + name="root_agent", + description="Root agent with three levels", + max_iterations=5, + sub_agents=[middle_agent1, middle_agent2], + ) + + # Clone the root agent + cloned_root = root_agent.clone(update={"name": "cloned_root"}) + + # Verify root level + assert cloned_root.name == "cloned_root" + assert cloned_root.description == "Root agent with three levels" + assert cloned_root.max_iterations == 5 + assert cloned_root.parent_agent is None + assert len(cloned_root.sub_agents) == 2 + assert isinstance(cloned_root, LoopAgent) + + # Verify middle level + cloned_middle1 = cloned_root.sub_agents[0] + cloned_middle2 = cloned_root.sub_agents[1] + + assert cloned_middle1.name == "middle1" + assert cloned_middle1.description == "First middle agent" + assert cloned_middle1.parent_agent == cloned_root + assert len(cloned_middle1.sub_agents) == 1 + assert isinstance(cloned_middle1, SequentialAgent) + + assert cloned_middle2.name == "middle2" + assert cloned_middle2.description == "Second middle agent" + assert cloned_middle2.parent_agent == cloned_root + assert len(cloned_middle2.sub_agents) == 1 + assert isinstance(cloned_middle2, ParallelAgent) + + # Verify leaf level + cloned_leaf1 = cloned_middle1.sub_agents[0] + cloned_leaf2 = cloned_middle2.sub_agents[0] + + assert cloned_leaf1.name == "leaf1" + assert cloned_leaf1.description == "First leaf agent" + assert cloned_leaf1.parent_agent == cloned_middle1 + assert len(cloned_leaf1.sub_agents) == 0 + assert isinstance(cloned_leaf1, LlmAgent) + + assert cloned_leaf2.name == "leaf2" + assert cloned_leaf2.description == "Second leaf agent" + assert cloned_leaf2.parent_agent == cloned_middle2 + assert len(cloned_leaf2.sub_agents) == 0 + assert isinstance(cloned_leaf2, LlmAgent) + + # Verify all objects are different from originals + assert cloned_root is not root_agent + assert cloned_middle1 is not middle_agent1 + assert cloned_middle2 is not middle_agent2 + assert cloned_leaf1 is not leaf_agent1 + assert cloned_leaf2 is not leaf_agent2 + + # Verify original structure is unchanged + assert root_agent.name == "root_agent" + assert root_agent.sub_agents[0].name == "middle1" + assert root_agent.sub_agents[1].name == "middle2" + assert root_agent.sub_agents[0].sub_agents[0].name == "leaf1" + assert root_agent.sub_agents[1].sub_agents[0].name == "leaf2" + + +def test_multiple_clones(): + """Test creating multiple clones with automatic naming.""" + # Create multiple agents and clone each one + original = LlmAgent( + name="original_agent", description="Agent for multiple cloning" + ) + + # Test multiple clones from the same original + clone1 = original.clone(update={"name": "clone1"}) + clone2 = original.clone(update={"name": "clone2"}) + + assert clone1.name == "clone1" + assert clone2.name == "clone2" + assert clone1 is not clone2 + + +def test_clone_with_complex_configuration(): + """Test cloning an agent with complex configuration.""" + # Create an LLM agent with various configurations + original = LlmAgent( + name="complex_agent", + description="A complex agent with many settings", + instruction="You are a specialized assistant.", + global_instruction="Always be helpful and accurate.", + disallow_transfer_to_parent=True, + disallow_transfer_to_peers=True, + include_contents="none", + ) + + # Clone it with name update + cloned = original.clone(update={"name": "complex_clone"}) + + # Verify all configurations are preserved + assert cloned.name == "complex_clone" + assert cloned.description == "A complex agent with many settings" + assert cloned.instruction == "You are a specialized assistant." + assert cloned.global_instruction == "Always be helpful and accurate." + assert cloned.disallow_transfer_to_parent is True + assert cloned.disallow_transfer_to_peers is True + assert cloned.include_contents == "none" + + # Verify parent and sub-agents are set + assert cloned.parent_agent is None + assert len(cloned.sub_agents) == 0 + + +def test_clone_without_updates(): + """Test cloning without providing updates (should use original values).""" + original = LlmAgent(name="test_agent", description="Test agent") + + cloned = original.clone() + + assert cloned.name == "test_agent" + assert cloned.description == "Test agent" + + +def test_clone_with_multiple_updates(): + """Test cloning with multiple field updates.""" + original = LlmAgent( + name="original_agent", + description="Original description", + instruction="Original instruction", + ) + + cloned = original.clone( + update={ + "name": "updated_agent", + "description": "Updated description", + "instruction": "Updated instruction", + } + ) + + assert cloned.name == "updated_agent" + assert cloned.description == "Updated description" + assert cloned.instruction == "Updated instruction" + + +def test_clone_with_sub_agents_deep_copy(): + """Test cloning with deep copy of sub-agents.""" + # Create an agent with sub-agents + sub_agent = LlmAgent(name="sub_agent", description="Sub agent") + original = LlmAgent( + name="root_agent", + description="Root agent", + sub_agents=[sub_agent], + ) + + # Clone with deep copy + cloned = original.clone(update={"name": "cloned_root_agent"}) + assert cloned.name == "cloned_root_agent" + assert cloned.sub_agents[0].name == "sub_agent" + assert cloned.sub_agents[0].parent_agent == cloned + assert cloned.sub_agents[0] is not original.sub_agents[0] + + +def test_clone_invalid_field(): + """Test that cloning with invalid fields raises an error.""" + original = LlmAgent(name="test_agent", description="Test agent") + + with pytest.raises(ValueError, match="Cannot update non-existent fields"): + original.clone(update={"invalid_field": "value"}) + + +def test_clone_parent_agent_field(): + """Test that cloning with parent_agent field raises an error.""" + original = LlmAgent(name="test_agent", description="Test agent") + + with pytest.raises( + ValueError, match="Cannot update `parent_agent` field in clone" + ): + original.clone(update={"parent_agent": None}) + + +def test_clone_preserves_agent_type(): + """Test that cloning preserves the specific agent type.""" + # Test LlmAgent + llm_original = LlmAgent(name="llm_test") + llm_cloned = llm_original.clone() + assert isinstance(llm_cloned, LlmAgent) + + # Test SequentialAgent + seq_original = SequentialAgent(name="seq_test") + seq_cloned = seq_original.clone() + assert isinstance(seq_cloned, SequentialAgent) + + # Test ParallelAgent + par_original = ParallelAgent(name="par_test") + par_cloned = par_original.clone() + assert isinstance(par_cloned, ParallelAgent) + + # Test LoopAgent + loop_original = LoopAgent(name="loop_test") + loop_cloned = loop_original.clone() + assert isinstance(loop_cloned, LoopAgent) + + +def test_clone_with_none_update(): + """Test cloning with explicit None update parameter.""" + original = LlmAgent(name="test_agent", description="Test agent") + + cloned = original.clone(update=None) + + assert cloned.name == "test_agent" + assert cloned.description == "Test agent" + assert cloned is not original + + +def test_clone_with_empty_update(): + """Test cloning with empty update dictionary.""" + original = LlmAgent(name="test_agent", description="Test agent") + + cloned = original.clone(update={}) + + assert cloned.name == "test_agent" + assert cloned.description == "Test agent" + assert cloned is not original + + +def test_clone_with_sub_agents_update(): + """Test cloning with sub_agents provided in update.""" + # Create original sub-agents + original_sub1 = LlmAgent(name="original_sub1", description="Original sub 1") + original_sub2 = LlmAgent(name="original_sub2", description="Original sub 2") + + # Create new sub-agents for the update + new_sub1 = LlmAgent(name="new_sub1", description="New sub 1") + new_sub2 = LlmAgent(name="new_sub2", description="New sub 2") + + # Create original agent with sub-agents + original = SequentialAgent( + name="original_agent", + description="Original agent", + sub_agents=[original_sub1, original_sub2], + ) + + # Clone with sub_agents update + cloned = original.clone( + update={"name": "cloned_agent", "sub_agents": [new_sub1, new_sub2]} + ) + + # Verify the clone uses the new sub-agents + assert cloned.name == "cloned_agent" + assert len(cloned.sub_agents) == 2 + assert cloned.sub_agents[0].name == "new_sub1" + assert cloned.sub_agents[1].name == "new_sub2" + assert cloned.sub_agents[0].parent_agent == cloned + assert cloned.sub_agents[1].parent_agent == cloned + + # Verify original is unchanged + assert original.name == "original_agent" + assert len(original.sub_agents) == 2 + assert original.sub_agents[0].name == "original_sub1" + assert original.sub_agents[1].name == "original_sub2" + + +if __name__ == "__main__": + # Run a specific test for debugging + test_three_level_nested_agent()