Skip to content

feat: add agent.clone() method #1747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Awaitable
from typing import Callable
from typing import final
from typing import Mapping
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
Expand Down Expand Up @@ -121,6 +122,52 @@ class BaseAgent(BaseModel):
response and appended to event history as agent response.
"""

def clone(self, update: Mapping[str, Any] | None = None) -> BaseAgent:
"""Creates a copy of this BaseAgent instance.

Args:
update: Optional mapping of new values for the fields of the cloned agent.
For example: {"name": "cloned_agent"}

Returns:
A new BaseAgent instance with identical configuration as the original
agent except for the fields specified in the update.
"""
if update is not None and 'parent_agent' in update:
raise ValueError(
'Cannot update `parent_agent` field in clone. Parent agent is set'
' only when the parent agent is instantiated with the sub-agents.'
)

# Only allow updating fields that are defined in the agent class.
allowed_fields = set(self.__class__.model_fields)
if update is not None:
invalid_fields = set(update) - allowed_fields
if invalid_fields:
raise ValueError(
f'Cannot update non-existent fields in {self.__class__.__name__}:'
f' {invalid_fields}'
)

cloned_agent = self.model_copy(update=update)

if update is None or 'sub_agents' not in update:
# If `sub_agents` is not provided in the update, need to recursively clone
# the sub-agents to avoid sharing the sub-agents with the original agent.
cloned_agent.sub_agents = []
for sub_agent in self.sub_agents:
cloned_sub_agent = sub_agent.clone()
cloned_sub_agent.parent_agent = cloned_agent
cloned_agent.sub_agents.append(cloned_sub_agent)
else:
for sub_agent in cloned_agent.sub_agents:
sub_agent.parent_agent = cloned_agent

# Remove the parent agent from the cloned agent to avoid sharing the parent
# agent with the cloned agent.
cloned_agent.parent_agent = None
return cloned_agent

@final
async def run_async(
self,
Expand Down
18 changes: 18 additions & 0 deletions src/google/adk/agents/langgraph_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any
from typing import AsyncGenerator
from typing import cast
from typing import Mapping
from typing import Union

from google.genai import types
Expand Down Expand Up @@ -59,6 +64,19 @@ class LangGraphAgent(BaseAgent):

instruction: str = ''

@override
def clone(self, update: Mapping[str, Any] | None = None) -> LangGraphAgent:
"""Creates a copy of this LangGraphAgent instance.

Args:
update: Optional mapping of new values for the fields of the cloned agent.

Returns:
A new LangGraphAgent instance with identical configuration as the original
agent except for the fields specified in the update.
"""
return cast(LangGraphAgent, super().clone(update))

@override
async def _run_async_impl(
self,
Expand Down
15 changes: 15 additions & 0 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from typing import AsyncGenerator
from typing import Awaitable
from typing import Callable
from typing import cast
from typing import Literal
from typing import Mapping
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -268,6 +270,19 @@ class LlmAgent(BaseAgent):
"""
# Callbacks - End

@override
def clone(self, update: Mapping[str, Any] | None = None) -> LlmAgent:
"""Creates a copy of this LlmAgent instance.

Args:
update: Optional mapping of new values for the fields of the cloned agent.

Returns:
A new LlmAgent instance with identical configuration as the original
agent except for the fields specified in the update.
"""
return cast(LlmAgent, super().clone(update))

@override
async def _run_async_impl(
self, ctx: InvocationContext
Expand Down
16 changes: 16 additions & 0 deletions src/google/adk/agents/loop_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

from __future__ import annotations

from typing import Any
from typing import AsyncGenerator
from typing import cast
from typing import Mapping
from typing import Optional

from typing_extensions import override
Expand All @@ -40,6 +43,19 @@ class LoopAgent(BaseAgent):
escalates.
"""

@override
def clone(self, update: Mapping[str, Any] | None = None) -> LoopAgent:
"""Creates a copy of this LoopAgent instance.

Args:
update: Optional mapping of new values for the fields of the cloned agent.

Returns:
A new LoopAgent instance with identical configuration as the original
agent except for the fields specified in the update.
"""
return cast(LoopAgent, super().clone(update))

@override
async def _run_async_impl(
self, ctx: InvocationContext
Expand Down
16 changes: 16 additions & 0 deletions src/google/adk/agents/parallel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from __future__ import annotations

import asyncio
from typing import Any
from typing import AsyncGenerator
from typing import cast
from typing import Mapping

from typing_extensions import override

Expand Down Expand Up @@ -92,6 +95,19 @@ class ParallelAgent(BaseAgent):
- Generating multiple responses for review by a subsequent evaluation agent.
"""

@override
def clone(self, update: Mapping[str, Any] | None = None) -> ParallelAgent:
"""Creates a copy of this ParallelAgent instance.

Args:
update: Optional mapping of new values for the fields of the cloned agent.

Returns:
A new ParallelAgent instance with identical configuration as the original
agent except for the fields specified in the update.
"""
return cast(ParallelAgent, super().clone(update))

@override
async def _run_async_impl(
self, ctx: InvocationContext
Expand Down
16 changes: 16 additions & 0 deletions src/google/adk/agents/remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from pathlib import Path
from typing import Any
from typing import AsyncGenerator
from typing import cast
from typing import Mapping
from typing import Optional
from typing import Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -48,6 +50,7 @@

from google.genai import types as genai_types
import httpx
from typing_extensions import override

from ..a2a.converters.event_converter import convert_a2a_message_to_event
from ..a2a.converters.event_converter import convert_a2a_task_to_event
Expand Down Expand Up @@ -151,6 +154,19 @@ def __init__(
f"got {type(agent_card)}"
)

@override
def clone(self, update: Mapping[str, Any] | None = None) -> "RemoteA2aAgent":
"""Creates a copy of this RemoteA2aAgent instance.

Args:
update: Optional mapping of new values for the fields of the cloned agent.

Returns:
A new RemoteA2aAgent instance with identical configuration as the original
agent except for the fields specified in the update.
"""
return cast(RemoteA2aAgent, super().clone(update))

async def _ensure_httpx_client(self) -> httpx.AsyncClient:
"""Ensure HTTP client is available and properly configured."""
if not self._httpx_client:
Expand Down
16 changes: 16 additions & 0 deletions src/google/adk/agents/sequential_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

from __future__ import annotations

from typing import Any
from typing import AsyncGenerator
from typing import cast
from typing import Mapping

from typing_extensions import override

Expand All @@ -29,6 +32,19 @@
class SequentialAgent(BaseAgent):
"""A shell agent that runs its sub-agents in sequence."""

@override
def clone(self, update: Mapping[str, Any] | None = None) -> SequentialAgent:
"""Creates a copy of this SequentialAgent instance.

Args:
update: Optional mapping of new values for the fields of the cloned agent.

Returns:
A new SequentialAgent instance with identical configuration as the original
agent except for the fields specified in the update.
"""
return cast(SequentialAgent, super().clone(update))

@override
async def _run_async_impl(
self, ctx: InvocationContext
Expand Down
Loading