From aff1d01dd1b0116f50878d8df518684bf4f33c02 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Thu, 27 Mar 2025 11:29:09 -0400 Subject: [PATCH 01/17] Add ability to register Agent instances --- .../agbench/src/agbench/linter/__init__.py | 2 +- .../agbench/src/agbench/linter/_base.py | 5 +- .../agbench/src/agbench/linter/cli.py | 8 ++- .../src/agbench/linter/coders/oai_coder.py | 8 +-- .../autogen-core/src/autogen_core/_agent.py | 4 ++ .../src/autogen_core/_agent_runtime.py | 6 ++ .../src/autogen_core/_base_agent.py | 68 ++++++++++++++++--- .../_single_threaded_agent_runtime.py | 32 ++++++++- .../autogen-core/tests/test_runtime.py | 64 +++++++++++++++++ .../src/autogen_ext/models/openai/__init__.py | 2 +- .../models/openai/_openai_client.py | 3 +- .../runtimes/grpc/_worker_runtime.py | 7 ++ 12 files changed, 183 insertions(+), 26 deletions(-) diff --git a/python/packages/agbench/src/agbench/linter/__init__.py b/python/packages/agbench/src/agbench/linter/__init__.py index 797b7f272a5b..a104962445f6 100644 --- a/python/packages/agbench/src/agbench/linter/__init__.py +++ b/python/packages/agbench/src/agbench/linter/__init__.py @@ -1,4 +1,4 @@ # __init__.py -from ._base import Code, Document, CodedDocument, BaseQualitativeCoder +from ._base import BaseQualitativeCoder, Code, CodedDocument, Document __all__ = ["Code", "Document", "CodedDocument", "BaseQualitativeCoder"] diff --git a/python/packages/agbench/src/agbench/linter/_base.py b/python/packages/agbench/src/agbench/linter/_base.py index 4f6209b7809c..c59e826d201b 100644 --- a/python/packages/agbench/src/agbench/linter/_base.py +++ b/python/packages/agbench/src/agbench/linter/_base.py @@ -1,7 +1,8 @@ -import json import hashlib +import json import re -from typing import Protocol, List, Set, Optional +from typing import List, Optional, Protocol, Set + from pydantic import BaseModel, Field diff --git a/python/packages/agbench/src/agbench/linter/cli.py b/python/packages/agbench/src/agbench/linter/cli.py index 426890258b69..14f428929b17 100644 --- a/python/packages/agbench/src/agbench/linter/cli.py +++ b/python/packages/agbench/src/agbench/linter/cli.py @@ -1,8 +1,10 @@ -import os import argparse -from typing import List, Sequence, Optional +import os +from typing import List, Optional, Sequence + from openai import OpenAI -from ._base import Document, CodedDocument + +from ._base import CodedDocument, Document from .coders.oai_coder import OAIQualitativeCoder diff --git a/python/packages/agbench/src/agbench/linter/coders/oai_coder.py b/python/packages/agbench/src/agbench/linter/coders/oai_coder.py index 374093d3d81b..01322e0c5ccc 100644 --- a/python/packages/agbench/src/agbench/linter/coders/oai_coder.py +++ b/python/packages/agbench/src/agbench/linter/coders/oai_coder.py @@ -1,13 +1,11 @@ import os import re - -from typing import List, Set, Optional -from pydantic import BaseModel +from typing import List, Optional, Set from openai import OpenAI +from pydantic import BaseModel -from .._base import CodedDocument, Document, Code -from .._base import BaseQualitativeCoder +from .._base import BaseQualitativeCoder, Code, CodedDocument, Document class CodeList(BaseModel): diff --git a/python/packages/autogen-core/src/autogen_core/_agent.py b/python/packages/autogen-core/src/autogen_core/_agent.py index 0f37b822ff8a..144ff21358ff 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_agent.py @@ -17,6 +17,10 @@ def id(self) -> AgentId: """ID of the agent.""" ... + async def init(self, **kwargs: Any) -> None: + """Function for Agents requiring a two-phase initialization process.""" + pass + async def on_message(self, message: Any, ctx: MessageContext) -> Any: """Message handler for the agent. This should only be called by the runtime, not by other agents. diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index 6510d84fbe17..3a3a74d18c7e 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -130,6 +130,12 @@ async def main() -> None: """ ... + async def register_agent_instance( + self, + agent_id: AgentId, + agent_instance: T | Awaitable[T], + ) -> AgentId: ... + # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] """Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases. diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index bffb61b876bb..5efe29f786f0 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -10,7 +10,6 @@ from ._agent import Agent from ._agent_id import AgentId -from ._agent_instantiation import AgentInstantiationContext from ._agent_metadata import AgentMetadata from ._agent_runtime import AgentRuntime from ._agent_type import AgentType @@ -82,20 +81,20 @@ def metadata(self) -> AgentMetadata: return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) def __init__(self, description: str) -> None: - try: - runtime = AgentInstantiationContext.current_runtime() - id = AgentInstantiationContext.current_agent_id() - except LookupError as e: - raise RuntimeError( - "BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated." - ) from e - - self._runtime: AgentRuntime = runtime - self._id: AgentId = id if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description + async def init(self, **kwargs: Any) -> None: + if "runtime" not in kwargs or "agent_id" not in kwargs: + raise ValueError("Agent must be initialized with runtime and agent_id") + if not isinstance(kwargs["runtime"], AgentRuntime): + raise ValueError("Agent must be initialized with runtime of type AgentRuntime") + if not isinstance(kwargs["agent_id"], AgentId): + raise ValueError("Agent must be initialized with agent_id of type AgentId") + self._runtime = kwargs["runtime"] + self._id = kwargs["agent_id"] + @property def type(self) -> str: return self.id.type @@ -155,6 +154,53 @@ async def load_state(self, state: Mapping[str, Any]) -> None: async def close(self) -> None: pass + async def register_instance( + self, + runtime: AgentRuntime, + agent_id: AgentId, + *, + skip_class_subscriptions: bool = False, + skip_direct_message_subscription: bool = False, + ) -> AgentId: + agent_id = await runtime.register_agent_instance(agent_id=agent_id, agent_instance=self) + if not skip_class_subscriptions: + with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)): + subscriptions: List[Subscription] = [] + for unbound_subscription in self._unbound_subscriptions(): + subscriptions_list_result = unbound_subscription() + if inspect.isawaitable(subscriptions_list_result): + subscriptions_list = await subscriptions_list_result + else: + subscriptions_list = subscriptions_list_result + + subscriptions.extend(subscriptions_list) + try: + for subscription in subscriptions: + await runtime.add_subscription(subscription) + except ValueError: + # We don't care if the subscription already exists + pass + + if not skip_direct_message_subscription: + # Additionally adds a special prefix subscription for this agent to receive direct messages + try: + await runtime.add_subscription( + TypePrefixSubscription( + # The prefix MUST include ":" to avoid collisions with other agents + topic_type_prefix=agent_id.type + ":", + agent_type=agent_id.type, + ) + ) + except ValueError: + # We don't care if the subscription already exists + pass + + # TODO: deduplication + for _message_type, serializer in self._handles_types(): + runtime.add_message_serializer(serializer) + + return agent_id + @classmethod async def register( cls, diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 9610e7f54ebc..97b220f0a075 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -265,6 +265,7 @@ def __init__( self._serialization_registry = SerializationRegistry() self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions self._background_exception: BaseException | None = None + self._agent_instance_types: Dict[str, Type[Agent]] = {} @property def unprocessed_messages_count( @@ -830,6 +831,33 @@ async def factory_wrapper() -> T: return type + async def register_agent_instance( + self, + agent_id: AgentId, + agent_instance: T | Awaitable[T], + ) -> AgentId: + def agent_factory() -> T: + raise RuntimeError("Agent factory should not be called when registering an agent instance.") + + if agent_id in self._instantiated_agents: + raise ValueError(f"Agent with id {agent_id} already exists.") + + if inspect.isawaitable(agent_instance): + agent_instance = await agent_instance + + if agent_id.type not in self._agent_factories: + self._agent_factories[agent_id.type] = agent_factory + self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance) + else: + if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__: + raise ValueError("Agent factories and agent instances cannot be registered to the same type.") + if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): + raise ValueError("Agent instances must be the same object type.") + + await agent_instance.init(runtime=self, agent_id=agent_id) + self._instantiated_agents[agent_id] = agent_instance + return agent_id + async def _invoke_agent_factory( self, agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], @@ -851,7 +879,9 @@ async def _invoke_agent_factory( raise ValueError("Agent factory must take 0 or 2 arguments.") if inspect.isawaitable(agent): - return cast(T, await agent) + agent = cast(T, await agent) + + await agent.init(runtime=self, agent_id=agent_id) return agent diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index 64a1cccf4b12..f87e63d8c077 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -82,6 +82,70 @@ def agent_factory() -> NoopAgent: await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent) +@pytest.mark.asyncio +async def test_agent_type_register_instance() -> None: + runtime = SingleThreadedAgentRuntime() + agent1 = NoopAgent() + agent2 = NoopAgent() + agent1_id = AgentId(type="name", key="default") + agentdup_id = AgentId(type="name", key="duplicate") + agent2_id = AgentId(type="name", key="notdefault") + await agent1.register_instance(runtime, agent1_id) + await agent1.register_instance(runtime, agentdup_id) + await agent2.register_instance(runtime, agent2_id) + + assert await runtime.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1 + assert await runtime.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2 + assert await runtime.try_get_underlying_agent_instance(agentdup_id, type=NoopAgent) == agent1 + + +@pytest.mark.asyncio +async def test_agent_type_register_instance_duplicate_ids() -> None: + runtime = SingleThreadedAgentRuntime() + agent_id = AgentId(type="name", key="default") + agent1 = NoopAgent() + agent2 = NoopAgent() + await agent1.register_instance(runtime, agent_id) + with pytest.raises(ValueError): + await agent2.register_instance(runtime, agent_id) + + +@pytest.mark.asyncio +async def test_agent_type_register_instance_different_types() -> None: + runtime = SingleThreadedAgentRuntime() + agent_id1 = AgentId(type="name", key="noop") + agent_id2 = AgentId(type="name", key="loopback") + agent1 = NoopAgent() + agent2 = LoopbackAgent() + await agent1.register_instance(runtime, agent_id1) + with pytest.raises(ValueError): + await agent2.register_instance(runtime, agent_id2) + + +@pytest.mark.asyncio +async def test_agent_type_register_instance_publish_new_source() -> None: + runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) + agent_id = AgentId(type="name", key="default") + agent1 = LoopbackAgent() + await agent1.register_instance(runtime, agent_id) + await runtime.add_subscription(TypeSubscription("notdefault", "name")) + + with pytest.raises(RuntimeError): + runtime.start() + await runtime.publish_message(MessageType(), TopicId("notdefault", "notdefault")) + await runtime.stop_when_idle() + + +@pytest.mark.asyncio +async def test_register_instance_factory() -> None: + runtime = SingleThreadedAgentRuntime() + agent1 = NoopAgent() + agent1_id = AgentId(type="name", key="default") + await agent1.register_instance(runtime, agent1_id) + with pytest.raises(ValueError): + await NoopAgent.register(runtime, "name", lambda: NoopAgent()) + + @pytest.mark.asyncio async def test_register_receives_publish(tracer_provider: TracerProvider) -> None: runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/__init__.py index cd0689b8e01b..e09ff22d3ab7 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/__init__.py @@ -1,8 +1,8 @@ from ._openai_client import ( + AZURE_OPENAI_USER_AGENT, AzureOpenAIChatCompletionClient, BaseOpenAIChatCompletionClient, OpenAIChatCompletionClient, - AZURE_OPENAI_USER_AGENT, ) from .config import ( AzureOpenAIClientConfigurationConfigModel, diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 74a9caa4458d..b61af6f0154e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -8,6 +8,7 @@ import warnings from asyncio import Task from dataclasses import dataclass +from importlib.metadata import PackageNotFoundError, version from typing import ( Any, AsyncGenerator, @@ -87,8 +88,6 @@ OpenAIClientConfiguration, OpenAIClientConfigurationConfigModel, ) -from importlib.metadata import PackageNotFoundError, version - logger = logging.getLogger(EVENT_LOGGER_NAME) trace_logger = logging.getLogger(TRACE_LOGGER_NAME) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 504838740283..8cfb1ad8ec49 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -737,6 +737,13 @@ async def factory_wrapper() -> T: ) return type + async def register_agent_instance( + self, + agent_id: AgentId, + agent_instance: T | Awaitable[T], + ) -> AgentId: + return agent_id + async def _invoke_agent_factory( self, agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], From 693cef88c95cd06bccfe0f4c9da7d3b0297210db Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Thu, 27 Mar 2025 13:21:48 -0400 Subject: [PATCH 02/17] Initialize agent from context if it exists --- .../autogen-core/src/autogen_core/_agent_instantiation.py | 6 ++++++ .../packages/autogen-core/src/autogen_core/_base_agent.py | 4 ++++ .../src/autogen_core/_single_threaded_agent_runtime.py | 2 -- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py index 71921225cfbd..7cfae8d6fed5 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py @@ -118,3 +118,9 @@ def current_agent_id(cls) -> AgentId: raise RuntimeError( "AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so." ) from e + + @classmethod + def is_in_runtime(cls) -> bool: + if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None: + return False + return True diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 5efe29f786f0..b2c144f14370 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -10,6 +10,7 @@ from ._agent import Agent from ._agent_id import AgentId +from ._agent_instantiation import AgentInstantiationContext from ._agent_metadata import AgentMetadata from ._agent_runtime import AgentRuntime from ._agent_type import AgentType @@ -81,6 +82,9 @@ def metadata(self) -> AgentMetadata: return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) def __init__(self, description: str) -> None: + if AgentInstantiationContext.is_in_runtime(): + self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime() + self._id = AgentInstantiationContext.current_agent_id() if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 97b220f0a075..c75554770a75 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -881,8 +881,6 @@ async def _invoke_agent_factory( if inspect.isawaitable(agent): agent = cast(T, await agent) - await agent.init(runtime=self, agent_id=agent_id) - return agent except BaseException as e: From 12d26c76c4b246558d6a4bbfbc0f59d73db41548 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Thu, 27 Mar 2025 16:25:27 -0400 Subject: [PATCH 03/17] BaseAgent initializes with runtime and id --- .../autogen-core/src/autogen_core/_agent.py | 4 - .../src/autogen_core/_agent_instantiation.py | 6 -- .../src/autogen_core/_agent_runtime.py | 52 +++++++++- .../src/autogen_core/_base_agent.py | 48 +++++---- .../src/autogen_core/_routed_agent.py | 9 +- .../_single_threaded_agent_runtime.py | 15 ++- .../autogen-core/tests/test_base_agent.py | 17 +++- .../autogen-core/tests/test_runtime.py | 39 +++----- .../runtimes/grpc/_worker_runtime.py | 44 +++++++-- .../autogen-ext/tests/test_worker_runtime.py | 97 +++++++++++++++++++ .../src/autogen_test_utils/__init__.py | 18 ++-- 11 files changed, 270 insertions(+), 79 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_agent.py b/python/packages/autogen-core/src/autogen_core/_agent.py index 144ff21358ff..0f37b822ff8a 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_agent.py @@ -17,10 +17,6 @@ def id(self) -> AgentId: """ID of the agent.""" ... - async def init(self, **kwargs: Any) -> None: - """Function for Agents requiring a two-phase initialization process.""" - pass - async def on_message(self, message: Any, ctx: MessageContext) -> Any: """Message handler for the agent. This should only be called by the runtime, not by other agents. diff --git a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py index 7cfae8d6fed5..71921225cfbd 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py @@ -118,9 +118,3 @@ def current_agent_id(cls) -> AgentId: raise RuntimeError( "AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so." ) from e - - @classmethod - def is_in_runtime(cls) -> bool: - if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None: - return False - return True diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index 3a3a74d18c7e..c88e9f228ef8 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -132,9 +132,57 @@ async def main() -> None: async def register_agent_instance( self, - agent_id: AgentId, agent_instance: T | Awaitable[T], - ) -> AgentId: ... + ) -> AgentId: + """Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions. + + .. note:: + + This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically. + + Example: + + .. code-block:: python + + from dataclasses import dataclass + + from autogen_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event + from autogen_core.models import UserMessage + + + @dataclass + class MyMessage: + content: str + + + class MyAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("My core agent") + + @event + async def handler(self, message: UserMessage, context: MessageContext) -> None: + print("Event received: ", message.content) + + + async def my_agent_factory(): + return MyAgent() + + + async def main() -> None: + runtime: AgentRuntime = ... # type: ignore + await runtime.register_agent_instance(runtime=runtime, agent_id=AgentId(type="my_agent", key="default")) + + + import asyncio + + asyncio.run(main()) + + + Args: + agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`. + agent_instance (T | Awaitable[T]): A concrete instance of the agent. + """ + ... # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index b2c144f14370..2aeb88f5d003 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -4,7 +4,7 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, final +from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Optional, Tuple, Type, TypeVar, final from typing_extensions import Self @@ -81,24 +81,37 @@ def metadata(self) -> AgentMetadata: assert self._id is not None return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) - def __init__(self, description: str) -> None: - if AgentInstantiationContext.is_in_runtime(): - self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime() - self._id = AgentInstantiationContext.current_agent_id() + def __init__( + self, description: str, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None + ) -> None: + param_count = 0 + if runtime is not None: + param_count += 1 + if agent_id is not None: + param_count += 1 + + if param_count != 0 and param_count != 2: + raise ValueError("BaseAgent must be instantiated with both runtime and agent_id or neither.") + if param_count == 0: + try: + runtime = AgentInstantiationContext.current_runtime() + agent_id = AgentInstantiationContext.current_agent_id() + except LookupError as e: + raise RuntimeError( + "BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated." + ) from e + else: + if not isinstance(runtime, AgentRuntime): + raise ValueError("Agent must be initialized with runtime of type AgentRuntime") + if not isinstance(agent_id, AgentId): + raise ValueError("Agent must be initialized with agent_id of type AgentId") + + self._runtime: AgentRuntime = runtime + self._id: AgentId = agent_id if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description - async def init(self, **kwargs: Any) -> None: - if "runtime" not in kwargs or "agent_id" not in kwargs: - raise ValueError("Agent must be initialized with runtime and agent_id") - if not isinstance(kwargs["runtime"], AgentRuntime): - raise ValueError("Agent must be initialized with runtime of type AgentRuntime") - if not isinstance(kwargs["agent_id"], AgentId): - raise ValueError("Agent must be initialized with agent_id of type AgentId") - self._runtime = kwargs["runtime"] - self._id = kwargs["agent_id"] - @property def type(self) -> str: return self.id.type @@ -160,13 +173,12 @@ async def close(self) -> None: async def register_instance( self, - runtime: AgentRuntime, - agent_id: AgentId, *, skip_class_subscriptions: bool = False, skip_direct_message_subscription: bool = False, ) -> AgentId: - agent_id = await runtime.register_agent_instance(agent_id=agent_id, agent_instance=self) + runtime = self.runtime + agent_id = await runtime.register_agent_instance(agent_instance=self) if not skip_class_subscriptions: with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)): subscriptions: List[Subscription] = [] diff --git a/python/packages/autogen-core/src/autogen_core/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/_routed_agent.py index a5908278cab9..2cf348ff90d2 100644 --- a/python/packages/autogen-core/src/autogen_core/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_routed_agent.py @@ -7,6 +7,7 @@ DefaultDict, List, Literal, + Optional, Protocol, Sequence, Tuple, @@ -18,6 +19,8 @@ runtime_checkable, ) +from ._agent_id import AgentId +from ._agent_runtime import AgentRuntime from ._base_agent import BaseAgent from ._message_context import MessageContext from ._serialization import MessageSerializer, try_get_known_serializers_for_type @@ -457,7 +460,9 @@ async def handle_special_rpc_message(self, message: MessageWithContent, ctx: Mes return Response() """ - def __init__(self, description: str) -> None: + def __init__( + self, description: str, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None + ) -> None: # Self is already bound to the handlers self._handlers: DefaultDict[ Type[Any], @@ -469,7 +474,7 @@ def __init__(self, description: str) -> None: for target_type in message_handler.target_types: self._handlers[target_type].append(message_handler) - super().__init__(description) + super().__init__(description, runtime, agent_id) async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None: """Handle a message by routing it to the appropriate message handler. diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index c75554770a75..a0dfb162036f 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -34,6 +34,7 @@ from ._agent_metadata import AgentMetadata from ._agent_runtime import AgentRuntime from ._agent_type import AgentType +from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken from ._intervention import DropMessage, InterventionHandler from ._message_context import MessageContext @@ -833,18 +834,23 @@ async def factory_wrapper() -> T: async def register_agent_instance( self, - agent_id: AgentId, agent_instance: T | Awaitable[T], ) -> AgentId: def agent_factory() -> T: raise RuntimeError("Agent factory should not be called when registering an agent instance.") - if agent_id in self._instantiated_agents: - raise ValueError(f"Agent with id {agent_id} already exists.") - if inspect.isawaitable(agent_instance): agent_instance = await agent_instance + # Agent type does not have the concept of a runtime + if isinstance(agent_instance, BaseAgent): + if agent_instance.runtime is not self: + raise ValueError("Agent instance is associated with a different runtime.") + agent_id = agent_instance.id + + if agent_id in self._instantiated_agents: + raise ValueError(f"Agent with id {agent_id} already exists.") + if agent_id.type not in self._agent_factories: self._agent_factories[agent_id.type] = agent_factory self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance) @@ -854,7 +860,6 @@ def agent_factory() -> T: if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): raise ValueError("Agent instances must be the same object type.") - await agent_instance.init(runtime=self, agent_id=agent_id) self._instantiated_agents[agent_id] = agent_instance return agent_id diff --git a/python/packages/autogen-core/tests/test_base_agent.py b/python/packages/autogen-core/tests/test_base_agent.py index 64bcf59d1774..0b3a65e321ba 100644 --- a/python/packages/autogen-core/tests/test_base_agent.py +++ b/python/packages/autogen-core/tests/test_base_agent.py @@ -8,8 +8,17 @@ async def test_base_agent_create(mocker: MockerFixture) -> None: runtime = mocker.Mock(spec=AgentRuntime) + agent1 = NoopAgent(runtime=runtime, agent_id=AgentId("name1", "namespace1")) + assert agent1.runtime == runtime + assert agent1.id == AgentId("name1", "namespace1") + + with pytest.raises(ValueError): + NoopAgent(runtime=runtime, agent_id=None) + with pytest.raises(ValueError): + NoopAgent(runtime=None, agent_id=AgentId("name_fail", "namespace_fail")) + # Shows how to set the context for the agent instantiation in a test context - with AgentInstantiationContext.populate_context((runtime, AgentId("name", "namespace"))): - agent = NoopAgent() - assert agent.runtime == runtime - assert agent.id == AgentId("name", "namespace") + with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))): + agent2 = NoopAgent() + assert agent2.runtime == runtime + assert agent2.id == AgentId("name2", "namespace2") diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index f87e63d8c077..c4bc519ce52e 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -85,29 +85,18 @@ def agent_factory() -> NoopAgent: @pytest.mark.asyncio async def test_agent_type_register_instance() -> None: runtime = SingleThreadedAgentRuntime() - agent1 = NoopAgent() - agent2 = NoopAgent() agent1_id = AgentId(type="name", key="default") - agentdup_id = AgentId(type="name", key="duplicate") agent2_id = AgentId(type="name", key="notdefault") - await agent1.register_instance(runtime, agent1_id) - await agent1.register_instance(runtime, agentdup_id) - await agent2.register_instance(runtime, agent2_id) + agent1 = NoopAgent(runtime=runtime, agent_id=agent1_id) + agent1_dup = NoopAgent(runtime=runtime, agent_id=agent1_id) + agent2 = NoopAgent(runtime=runtime, agent_id=agent2_id) + await agent1.register_instance() + await agent2.register_instance() assert await runtime.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1 assert await runtime.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2 - assert await runtime.try_get_underlying_agent_instance(agentdup_id, type=NoopAgent) == agent1 - - -@pytest.mark.asyncio -async def test_agent_type_register_instance_duplicate_ids() -> None: - runtime = SingleThreadedAgentRuntime() - agent_id = AgentId(type="name", key="default") - agent1 = NoopAgent() - agent2 = NoopAgent() - await agent1.register_instance(runtime, agent_id) with pytest.raises(ValueError): - await agent2.register_instance(runtime, agent_id) + await agent1_dup.register_instance() @pytest.mark.asyncio @@ -115,19 +104,19 @@ async def test_agent_type_register_instance_different_types() -> None: runtime = SingleThreadedAgentRuntime() agent_id1 = AgentId(type="name", key="noop") agent_id2 = AgentId(type="name", key="loopback") - agent1 = NoopAgent() - agent2 = LoopbackAgent() - await agent1.register_instance(runtime, agent_id1) + agent1 = NoopAgent(runtime=runtime, agent_id=agent_id1) + agent2 = LoopbackAgent(runtime=runtime, agent_id=agent_id2) + await agent1.register_instance() with pytest.raises(ValueError): - await agent2.register_instance(runtime, agent_id2) + await agent2.register_instance() @pytest.mark.asyncio async def test_agent_type_register_instance_publish_new_source() -> None: runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) agent_id = AgentId(type="name", key="default") - agent1 = LoopbackAgent() - await agent1.register_instance(runtime, agent_id) + agent1 = LoopbackAgent(runtime=runtime, agent_id=agent_id) + await agent1.register_instance() await runtime.add_subscription(TypeSubscription("notdefault", "name")) with pytest.raises(RuntimeError): @@ -139,9 +128,9 @@ async def test_agent_type_register_instance_publish_new_source() -> None: @pytest.mark.asyncio async def test_register_instance_factory() -> None: runtime = SingleThreadedAgentRuntime() - agent1 = NoopAgent() agent1_id = AgentId(type="name", key="default") - await agent1.register_instance(runtime, agent1_id) + agent1 = NoopAgent(runtime=runtime, agent_id=agent1_id) + await agent1.register_instance() with pytest.raises(ValueError): await NoopAgent.register(runtime, "name", lambda: NoopAgent()) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 8cfb1ad8ec49..42ec1a0fbd5d 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -40,6 +40,7 @@ AgentMetadata, AgentRuntime, AgentType, + BaseAgent, CancellationToken, MessageContext, MessageHandlerContext, @@ -251,6 +252,7 @@ def __init__( self._subscription_manager = SubscriptionManager() self._serialization_registry = SerializationRegistry() self._extra_grpc_config = extra_grpc_config or [] + self._agent_instance_types: Dict[str, Type[Agent]] = {} if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}: raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}") @@ -701,6 +703,14 @@ async def send_message(agent: Agent, message_context: MessageContext) -> Any: except BaseException as e: logger.error("Error handling event", exc_info=e) + async def _register_agent_type(self, agent_type: str) -> None: + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + message = agent_worker_pb2.RegisterAgentTypeRequest(type=agent_type) + _response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent( + message, metadata=self._host_connection.metadata + ) + async def register_factory( self, type: str | AgentType, @@ -729,19 +739,41 @@ async def factory_wrapper() -> T: return agent_instance self._agent_factories[type.type] = factory_wrapper - # Send the registration request message to the host. - message = agent_worker_pb2.RegisterAgentTypeRequest(type=type.type) - _response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent( - message, metadata=self._host_connection.metadata - ) + await self._register_agent_type(type.type) + return type async def register_agent_instance( self, - agent_id: AgentId, agent_instance: T | Awaitable[T], ) -> AgentId: + def agent_factory() -> T: + raise RuntimeError("Agent factory should not be called when registering an agent instance.") + + if inspect.isawaitable(agent_instance): + agent_instance = await agent_instance + + # Agent type does not have the concept of a runtime + if isinstance(agent_instance, BaseAgent): + if agent_instance.runtime is not self: + raise ValueError("Agent instance is associated with a different runtime.") + agent_id = agent_instance.id + + if agent_id in self._instantiated_agents: + raise ValueError(f"Agent with id {agent_id} already exists.") + + if agent_id.type not in self._agent_factories: + self._agent_factories[agent_id.type] = agent_factory + await self._register_agent_type(agent_id.type) + self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance) + else: + if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__: + raise ValueError("Agent factories and agent instances cannot be registered to the same type.") + if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): + raise ValueError("Agent instances must be the same object type.") + + self._instantiated_agents[agent_id] = agent_instance return agent_id async def _invoke_agent_factory( diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index dede306853ed..c90ebf282b6d 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -577,6 +577,103 @@ async def test_grpc_max_message_size() -> None: await host.stop() +@pytest.mark.grpc +@pytest.mark.asyncio +async def test_agent_type_register_instance() -> None: + host_address = "localhost:50051" + agent1_id = AgentId(type="name", key="default") + agentdup_id = AgentId(type="name", key="default") + agent2_id = AgentId(type="name", key="notdefault") + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + worker = GrpcWorkerAgentRuntime(host_address=host_address) + agent1 = NoopAgent(runtime=worker, agent_id=agent1_id) + agent2 = NoopAgent(runtime=worker, agent_id=agent2_id) + agentdup = NoopAgent(runtime=worker, agent_id=agentdup_id) + await worker.start() + + await worker.register_agent_instance(agent1) + await worker.register_agent_instance(agent2) + + with pytest.raises(ValueError): + await worker.register_agent_instance(agentdup) + + assert await worker.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1 + assert await worker.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2 + + await worker.stop() + await host.stop() + + +@pytest.mark.grpc +@pytest.mark.asyncio +async def test_agent_type_register_instance_different_types() -> None: + host_address = "localhost:50051" + agent1_id = AgentId(type="name", key="noop") + agent2_id = AgentId(type="name", key="loopback") + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + worker = GrpcWorkerAgentRuntime(host_address=host_address) + agent1 = NoopAgent(runtime=worker, agent_id=agent1_id) + agent2 = LoopbackAgent(runtime=worker, agent_id=agent2_id) + await worker.start() + + await worker.register_agent_instance(agent1) + with pytest.raises(ValueError): + await worker.register_agent_instance(agent2) + + await worker.stop() + await host.stop() + + +@pytest.mark.grpc +@pytest.mark.asyncio +async def test_register_instance_factory() -> None: + host_address = "localhost:50051" + agent1_id = AgentId(type="name", key="default") + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + worker = GrpcWorkerAgentRuntime(host_address=host_address) + agent1 = NoopAgent(runtime=worker, agent_id=agent1_id) + await worker.start() + + await agent1.register_instance() + + with pytest.raises(ValueError): + await NoopAgent.register(runtime=worker, type="name", factory=lambda: NoopAgent()) + + await worker.stop() + await host.stop() + + +# GrpcWorkerAgentRuntimeHost eats exceptions in the main loop +# @pytest.mark.grpc +# @pytest.mark.asyncio +# async def test_agent_type_register_instance_publish_new_source() -> None: +# host_address = "localhost:50056" +# agent_id = AgentId(type="name", key="default") +# agent1 = LoopbackAgent() +# host = GrpcWorkerAgentRuntimeHost(address=host_address) +# host.start() +# worker = GrpcWorkerAgentRuntime(host_address=host_address) +# await worker.start() +# publisher = GrpcWorkerAgentRuntime(host_address=host_address) +# publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) +# await publisher.start() + +# await agent1.register_instance(worker, agent_id=agent_id) +# await worker.add_subscription(TypeSubscription("notdefault", "name")) + +# with pytest.raises(RuntimeError): +# await worker.publish_message(MessageType(), TopicId("notdefault", "notdefault")) +# await asyncio.sleep(2) + +# await worker.stop() +# await host.stop() + if __name__ == "__main__": os.environ["GRPC_VERBOSITY"] = "DEBUG" os.environ["GRPC_TRACE"] = "all" diff --git a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py index d3e5af883605..50cffc3a2242 100644 --- a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py +++ b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py @@ -2,9 +2,11 @@ from asyncio import Event from dataclasses import dataclass -from typing import Any +from typing import Any, Optional from autogen_core import ( + AgentId, + AgentRuntime, BaseAgent, Component, ComponentBase, @@ -33,8 +35,8 @@ class ContentMessage: class LoopbackAgent(RoutedAgent): - def __init__(self) -> None: - super().__init__("A loop back agent.") + def __init__(self, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None) -> None: + super().__init__("A loop back agent.", runtime, agent_id) self.num_calls = 0 self.received_messages: list[Any] = [] self.event = Event() @@ -55,8 +57,10 @@ class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ... @default_subscription class CascadingAgent(RoutedAgent): - def __init__(self, max_rounds: int) -> None: - super().__init__("A cascading agent.") + def __init__( + self, max_rounds: int, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None + ) -> None: + super().__init__("A cascading agent.", runtime, agent_id) self.num_calls = 0 self.max_rounds = max_rounds @@ -69,8 +73,8 @@ async def on_new_message(self, message: CascadingMessageType, ctx: MessageContex class NoopAgent(BaseAgent): - def __init__(self) -> None: - super().__init__("A no op agent") + def __init__(self, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None) -> None: + super().__init__("A no op agent", runtime=runtime, agent_id=agent_id) async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: raise NotImplementedError From 4ebad57cde9182c7b7c0b276493aa0b1241dcb37 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Fri, 28 Mar 2025 10:51:53 -0400 Subject: [PATCH 04/17] Fix example code --- .../autogen-core/src/autogen_core/_agent_runtime.py | 13 +++++-------- .../autogen-ext/tests/tools/http/test_http_tool.py | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index c88e9f228ef8..886e564481a8 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -145,6 +145,7 @@ async def register_agent_instance( .. code-block:: python from dataclasses import dataclass + from typing import Optional from autogen_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event from autogen_core.models import UserMessage @@ -156,21 +157,18 @@ class MyMessage: class MyAgent(RoutedAgent): - def __init__(self) -> None: - super().__init__("My core agent") + def __init__(self, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None) -> None: + super().__init__("My core agent", runtime, agent_id) @event async def handler(self, message: UserMessage, context: MessageContext) -> None: print("Event received: ", message.content) - async def my_agent_factory(): - return MyAgent() - - async def main() -> None: runtime: AgentRuntime = ... # type: ignore - await runtime.register_agent_instance(runtime=runtime, agent_id=AgentId(type="my_agent", key="default")) + agent = MyAgent(runtime=runtime, agent_id=AgentId(type="my_agent", key="default")) + await runtime.register_agent_instance(agent) import asyncio @@ -179,7 +177,6 @@ async def main() -> None: Args: - agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`. agent_instance (T | Awaitable[T]): A concrete instance of the agent. """ ... diff --git a/python/packages/autogen-ext/tests/tools/http/test_http_tool.py b/python/packages/autogen-ext/tests/tools/http/test_http_tool.py index 9d2898bae505..8e50e48ba926 100644 --- a/python/packages/autogen-ext/tests/tools/http/test_http_tool.py +++ b/python/packages/autogen-ext/tests/tools/http/test_http_tool.py @@ -176,7 +176,7 @@ async def test_invalid_request(test_config: ComponentModel, test_server: None) - config.config["host"] = "fake" tool = HttpTool.load_component(config) - with pytest.raises(httpx.ConnectError): + with pytest.raises((httpx.ConnectError, httpx.ConnectTimeout)): await tool.run_json({"query": "test query", "value": 42}, CancellationToken()) From 1076fd5af786c1b3d13bdbd7be5fc7aebb6ec9c2 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Tue, 1 Apr 2025 17:08:12 -0400 Subject: [PATCH 05/17] Revert design change. runtime and agent_id passed in register fn --- .../autogen-core/src/autogen_core/_agent.py | 4 ++ .../src/autogen_core/_agent_instantiation.py | 6 +++ .../src/autogen_core/_agent_runtime.py | 13 +++-- .../src/autogen_core/_base_agent.py | 48 +++++++------------ .../src/autogen_core/_routed_agent.py | 9 +--- .../_single_threaded_agent_runtime.py | 11 ++--- .../autogen-core/tests/test_base_agent.py | 33 +++++-------- .../autogen-core/tests/test_runtime.py | 31 ++++++------ .../runtimes/grpc/_worker_runtime.py | 11 ++--- .../models/test_azure_ai_model_client.py | 4 +- .../autogen-ext/tests/test_worker_runtime.py | 24 +++++----- .../src/autogen_test_utils/__init__.py | 18 +++---- 12 files changed, 93 insertions(+), 119 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_agent.py b/python/packages/autogen-core/src/autogen_core/_agent.py index 0f37b822ff8a..063d8515e5f2 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_agent.py @@ -17,6 +17,10 @@ def id(self) -> AgentId: """ID of the agent.""" ... + async def init(self, **kwargs: Any) -> None: + """Function used for two-phase initialization.""" + ... + async def on_message(self, message: Any, ctx: MessageContext) -> Any: """Message handler for the agent. This should only be called by the runtime, not by other agents. diff --git a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py index 71921225cfbd..7cfae8d6fed5 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py @@ -118,3 +118,9 @@ def current_agent_id(cls) -> AgentId: raise RuntimeError( "AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so." ) from e + + @classmethod + def is_in_runtime(cls) -> bool: + if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None: + return False + return True diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index 886e564481a8..6cb0ec6031d2 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -133,6 +133,7 @@ async def main() -> None: async def register_agent_instance( self, agent_instance: T | Awaitable[T], + agent_id: AgentId, ) -> AgentId: """Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions. @@ -145,7 +146,6 @@ async def register_agent_instance( .. code-block:: python from dataclasses import dataclass - from typing import Optional from autogen_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event from autogen_core.models import UserMessage @@ -157,8 +157,8 @@ class MyMessage: class MyAgent(RoutedAgent): - def __init__(self, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None) -> None: - super().__init__("My core agent", runtime, agent_id) + def __init__(self) -> None: + super().__init__("My core agent") @event async def handler(self, message: UserMessage, context: MessageContext) -> None: @@ -167,8 +167,10 @@ async def handler(self, message: UserMessage, context: MessageContext) -> None: async def main() -> None: runtime: AgentRuntime = ... # type: ignore - agent = MyAgent(runtime=runtime, agent_id=AgentId(type="my_agent", key="default")) - await runtime.register_agent_instance(agent) + agent: Agent = MyAgent() + await runtime.register_agent_instance( + agent_instance=agent, agent_id=AgentId(type="my_agent", key="default") + ) import asyncio @@ -178,6 +180,7 @@ async def main() -> None: Args: agent_instance (T | Awaitable[T]): A concrete instance of the agent. + agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`. """ ... diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 2aeb88f5d003..57a7959c1b6c 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -4,7 +4,7 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Optional, Tuple, Type, TypeVar, final +from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, final from typing_extensions import Self @@ -81,37 +81,24 @@ def metadata(self) -> AgentMetadata: assert self._id is not None return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) - def __init__( - self, description: str, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None - ) -> None: - param_count = 0 - if runtime is not None: - param_count += 1 - if agent_id is not None: - param_count += 1 - - if param_count != 0 and param_count != 2: - raise ValueError("BaseAgent must be instantiated with both runtime and agent_id or neither.") - if param_count == 0: - try: - runtime = AgentInstantiationContext.current_runtime() - agent_id = AgentInstantiationContext.current_agent_id() - except LookupError as e: - raise RuntimeError( - "BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated." - ) from e - else: - if not isinstance(runtime, AgentRuntime): - raise ValueError("Agent must be initialized with runtime of type AgentRuntime") - if not isinstance(agent_id, AgentId): - raise ValueError("Agent must be initialized with agent_id of type AgentId") - - self._runtime: AgentRuntime = runtime - self._id: AgentId = agent_id + def __init__(self, description: str) -> None: + if AgentInstantiationContext.is_in_runtime(): + self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime() + self._id = AgentInstantiationContext.current_agent_id() if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description + async def init(self, **kwargs: Any) -> None: + if "runtime" not in kwargs or "agent_id" not in kwargs: + raise ValueError("Agent must be initialized with runtime and agent_id") + if not isinstance(kwargs["runtime"], AgentRuntime): + raise ValueError("Agent must be initialized with runtime of type AgentRuntime") + if not isinstance(kwargs["agent_id"], AgentId): + raise ValueError("Agent must be initialized with agent_id of type AgentId") + self._runtime = kwargs["runtime"] + self._id = kwargs["agent_id"] + @property def type(self) -> str: return self.id.type @@ -173,12 +160,13 @@ async def close(self) -> None: async def register_instance( self, + runtime: AgentRuntime, + agent_id: AgentId, *, skip_class_subscriptions: bool = False, skip_direct_message_subscription: bool = False, ) -> AgentId: - runtime = self.runtime - agent_id = await runtime.register_agent_instance(agent_instance=self) + agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id) if not skip_class_subscriptions: with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)): subscriptions: List[Subscription] = [] diff --git a/python/packages/autogen-core/src/autogen_core/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/_routed_agent.py index 2cf348ff90d2..a5908278cab9 100644 --- a/python/packages/autogen-core/src/autogen_core/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_routed_agent.py @@ -7,7 +7,6 @@ DefaultDict, List, Literal, - Optional, Protocol, Sequence, Tuple, @@ -19,8 +18,6 @@ runtime_checkable, ) -from ._agent_id import AgentId -from ._agent_runtime import AgentRuntime from ._base_agent import BaseAgent from ._message_context import MessageContext from ._serialization import MessageSerializer, try_get_known_serializers_for_type @@ -460,9 +457,7 @@ async def handle_special_rpc_message(self, message: MessageWithContent, ctx: Mes return Response() """ - def __init__( - self, description: str, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None - ) -> None: + def __init__(self, description: str) -> None: # Self is already bound to the handlers self._handlers: DefaultDict[ Type[Any], @@ -474,7 +469,7 @@ def __init__( for target_type in message_handler.target_types: self._handlers[target_type].append(message_handler) - super().__init__(description, runtime, agent_id) + super().__init__(description) async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None: """Handle a message by routing it to the appropriate message handler. diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index a0dfb162036f..4522da91c8e7 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -34,7 +34,6 @@ from ._agent_metadata import AgentMetadata from ._agent_runtime import AgentRuntime from ._agent_type import AgentType -from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken from ._intervention import DropMessage, InterventionHandler from ._message_context import MessageContext @@ -835,6 +834,7 @@ async def factory_wrapper() -> T: async def register_agent_instance( self, agent_instance: T | Awaitable[T], + agent_id: AgentId, ) -> AgentId: def agent_factory() -> T: raise RuntimeError("Agent factory should not be called when registering an agent instance.") @@ -842,12 +842,6 @@ def agent_factory() -> T: if inspect.isawaitable(agent_instance): agent_instance = await agent_instance - # Agent type does not have the concept of a runtime - if isinstance(agent_instance, BaseAgent): - if agent_instance.runtime is not self: - raise ValueError("Agent instance is associated with a different runtime.") - agent_id = agent_instance.id - if agent_id in self._instantiated_agents: raise ValueError(f"Agent with id {agent_id} already exists.") @@ -860,6 +854,7 @@ def agent_factory() -> T: if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): raise ValueError("Agent instances must be the same object type.") + await agent_instance.init(runtime=self, agent_id=agent_id) self._instantiated_agents[agent_id] = agent_instance return agent_id @@ -885,7 +880,7 @@ async def _invoke_agent_factory( if inspect.isawaitable(agent): agent = cast(T, await agent) - + await agent.init(runtime=self, agent_id=agent_id) return agent except BaseException as e: diff --git a/python/packages/autogen-core/tests/test_base_agent.py b/python/packages/autogen-core/tests/test_base_agent.py index 0b3a65e321ba..2e2e1f91d28b 100644 --- a/python/packages/autogen-core/tests/test_base_agent.py +++ b/python/packages/autogen-core/tests/test_base_agent.py @@ -1,24 +1,15 @@ -import pytest -from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime -from autogen_test_utils import NoopAgent -from pytest_mock import MockerFixture +# import pytest +# from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime +# from autogen_test_utils import NoopAgent +# from pytest_mock import MockerFixture -@pytest.mark.asyncio -async def test_base_agent_create(mocker: MockerFixture) -> None: - runtime = mocker.Mock(spec=AgentRuntime) +# @pytest.mark.asyncio +# async def test_base_agent_create(mocker: MockerFixture) -> None: +# runtime = mocker.Mock(spec=AgentRuntime) - agent1 = NoopAgent(runtime=runtime, agent_id=AgentId("name1", "namespace1")) - assert agent1.runtime == runtime - assert agent1.id == AgentId("name1", "namespace1") - - with pytest.raises(ValueError): - NoopAgent(runtime=runtime, agent_id=None) - with pytest.raises(ValueError): - NoopAgent(runtime=None, agent_id=AgentId("name_fail", "namespace_fail")) - - # Shows how to set the context for the agent instantiation in a test context - with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))): - agent2 = NoopAgent() - assert agent2.runtime == runtime - assert agent2.id == AgentId("name2", "namespace2") +# # Shows how to set the context for the agent instantiation in a test context +# with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))): +# agent2 = NoopAgent() +# assert agent2.runtime == runtime +# assert agent2.id == AgentId("name2", "namespace2") diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index c4bc519ce52e..e93a57a6a291 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -87,16 +87,16 @@ async def test_agent_type_register_instance() -> None: runtime = SingleThreadedAgentRuntime() agent1_id = AgentId(type="name", key="default") agent2_id = AgentId(type="name", key="notdefault") - agent1 = NoopAgent(runtime=runtime, agent_id=agent1_id) - agent1_dup = NoopAgent(runtime=runtime, agent_id=agent1_id) - agent2 = NoopAgent(runtime=runtime, agent_id=agent2_id) - await agent1.register_instance() - await agent2.register_instance() + agent1 = NoopAgent() + agent1_dup = NoopAgent() + agent2 = NoopAgent() + await agent1.register_instance(runtime=runtime, agent_id=agent1_id) + await agent2.register_instance(runtime=runtime, agent_id=agent2_id) assert await runtime.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1 assert await runtime.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2 with pytest.raises(ValueError): - await agent1_dup.register_instance() + await agent1_dup.register_instance(runtime=runtime, agent_id=agent1_id) @pytest.mark.asyncio @@ -104,33 +104,34 @@ async def test_agent_type_register_instance_different_types() -> None: runtime = SingleThreadedAgentRuntime() agent_id1 = AgentId(type="name", key="noop") agent_id2 = AgentId(type="name", key="loopback") - agent1 = NoopAgent(runtime=runtime, agent_id=agent_id1) - agent2 = LoopbackAgent(runtime=runtime, agent_id=agent_id2) - await agent1.register_instance() + agent1 = NoopAgent() + agent2 = LoopbackAgent() + await agent1.register_instance(runtime=runtime, agent_id=agent_id1) with pytest.raises(ValueError): - await agent2.register_instance() + await agent2.register_instance(runtime=runtime, agent_id=agent_id2) @pytest.mark.asyncio async def test_agent_type_register_instance_publish_new_source() -> None: runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) agent_id = AgentId(type="name", key="default") - agent1 = LoopbackAgent(runtime=runtime, agent_id=agent_id) - await agent1.register_instance() + agent1 = LoopbackAgent() + await agent1.register_instance(runtime=runtime, agent_id=agent_id) await runtime.add_subscription(TypeSubscription("notdefault", "name")) + runtime.start() with pytest.raises(RuntimeError): - runtime.start() await runtime.publish_message(MessageType(), TopicId("notdefault", "notdefault")) await runtime.stop_when_idle() + await runtime.close() @pytest.mark.asyncio async def test_register_instance_factory() -> None: runtime = SingleThreadedAgentRuntime() agent1_id = AgentId(type="name", key="default") - agent1 = NoopAgent(runtime=runtime, agent_id=agent1_id) - await agent1.register_instance() + agent1 = NoopAgent() + await agent1.register_instance(runtime=runtime, agent_id=agent1_id) with pytest.raises(ValueError): await NoopAgent.register(runtime, "name", lambda: NoopAgent()) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 42ec1a0fbd5d..06225a75adb0 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -40,7 +40,6 @@ AgentMetadata, AgentRuntime, AgentType, - BaseAgent, CancellationToken, MessageContext, MessageHandlerContext, @@ -747,6 +746,7 @@ async def factory_wrapper() -> T: async def register_agent_instance( self, agent_instance: T | Awaitable[T], + agent_id: AgentId, ) -> AgentId: def agent_factory() -> T: raise RuntimeError("Agent factory should not be called when registering an agent instance.") @@ -754,12 +754,6 @@ def agent_factory() -> T: if inspect.isawaitable(agent_instance): agent_instance = await agent_instance - # Agent type does not have the concept of a runtime - if isinstance(agent_instance, BaseAgent): - if agent_instance.runtime is not self: - raise ValueError("Agent instance is associated with a different runtime.") - agent_id = agent_instance.id - if agent_id in self._instantiated_agents: raise ValueError(f"Agent with id {agent_id} already exists.") @@ -773,6 +767,7 @@ def agent_factory() -> T: if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): raise ValueError("Agent instances must be the same object type.") + await agent_instance.init(runtime=self, agent_id=agent_id) self._instantiated_agents[agent_id] = agent_instance return agent_id @@ -796,7 +791,7 @@ async def _invoke_agent_factory( raise ValueError("Agent factory must take 0 or 2 arguments.") if inspect.isawaitable(agent): - return cast(T, await agent) + agent = cast(T, await agent) return agent diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py index 2f4f02aeaf13..2e15938903f8 100644 --- a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py @@ -3,7 +3,7 @@ import os from datetime import datetime from typing import Any, AsyncGenerator, List, Type, Union -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from autogen_core import CancellationToken, FunctionCall, Image @@ -570,7 +570,7 @@ async def _mock_thought_with_tool_call_stream( ) mock_client = MagicMock() - mock_client.close = MagicMock() + mock_client.close = AsyncMock() async def mock_complete(*args: Any, **kwargs: Any) -> Any: if kwargs.get("stream", False): diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index c90ebf282b6d..a6bf4441cbcc 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -588,16 +588,16 @@ async def test_agent_type_register_instance() -> None: host.start() worker = GrpcWorkerAgentRuntime(host_address=host_address) - agent1 = NoopAgent(runtime=worker, agent_id=agent1_id) - agent2 = NoopAgent(runtime=worker, agent_id=agent2_id) - agentdup = NoopAgent(runtime=worker, agent_id=agentdup_id) + agent1 = NoopAgent() + agent2 = NoopAgent() + agentdup = NoopAgent() await worker.start() - await worker.register_agent_instance(agent1) - await worker.register_agent_instance(agent2) + await worker.register_agent_instance(agent1, agent_id=agent1_id) + await worker.register_agent_instance(agent2, agent_id=agent2_id) with pytest.raises(ValueError): - await worker.register_agent_instance(agentdup) + await worker.register_agent_instance(agentdup, agent_id=agentdup_id) assert await worker.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1 assert await worker.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2 @@ -616,13 +616,13 @@ async def test_agent_type_register_instance_different_types() -> None: host.start() worker = GrpcWorkerAgentRuntime(host_address=host_address) - agent1 = NoopAgent(runtime=worker, agent_id=agent1_id) - agent2 = LoopbackAgent(runtime=worker, agent_id=agent2_id) + agent1 = NoopAgent() + agent2 = LoopbackAgent() await worker.start() - await worker.register_agent_instance(agent1) + await worker.register_agent_instance(agent1, agent_id=agent1_id) with pytest.raises(ValueError): - await worker.register_agent_instance(agent2) + await worker.register_agent_instance(agent2, agent_id=agent2_id) await worker.stop() await host.stop() @@ -637,10 +637,10 @@ async def test_register_instance_factory() -> None: host.start() worker = GrpcWorkerAgentRuntime(host_address=host_address) - agent1 = NoopAgent(runtime=worker, agent_id=agent1_id) + agent1 = NoopAgent() await worker.start() - await agent1.register_instance() + await agent1.register_instance(runtime=worker, agent_id=agent1_id) with pytest.raises(ValueError): await NoopAgent.register(runtime=worker, type="name", factory=lambda: NoopAgent()) diff --git a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py index 50cffc3a2242..d3e5af883605 100644 --- a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py +++ b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py @@ -2,11 +2,9 @@ from asyncio import Event from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from autogen_core import ( - AgentId, - AgentRuntime, BaseAgent, Component, ComponentBase, @@ -35,8 +33,8 @@ class ContentMessage: class LoopbackAgent(RoutedAgent): - def __init__(self, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None) -> None: - super().__init__("A loop back agent.", runtime, agent_id) + def __init__(self) -> None: + super().__init__("A loop back agent.") self.num_calls = 0 self.received_messages: list[Any] = [] self.event = Event() @@ -57,10 +55,8 @@ class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ... @default_subscription class CascadingAgent(RoutedAgent): - def __init__( - self, max_rounds: int, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None - ) -> None: - super().__init__("A cascading agent.", runtime, agent_id) + def __init__(self, max_rounds: int) -> None: + super().__init__("A cascading agent.") self.num_calls = 0 self.max_rounds = max_rounds @@ -73,8 +69,8 @@ async def on_new_message(self, message: CascadingMessageType, ctx: MessageContex class NoopAgent(BaseAgent): - def __init__(self, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None) -> None: - super().__init__("A no op agent", runtime=runtime, agent_id=agent_id) + def __init__(self) -> None: + super().__init__("A no op agent") async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: raise NotImplementedError From 434f2a5c5ddacc9e826c1bd836236dc04139f355 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Wed, 7 May 2025 15:06:20 -0400 Subject: [PATCH 06/17] address PR comments, add test --- .../autogen-core/src/autogen_core/_agent.py | 15 ++++++++-- .../src/autogen_core/_base_agent.py | 14 ++++----- .../_single_threaded_agent_runtime.py | 3 +- .../runtimes/grpc/_worker_runtime.py | 2 +- .../test_docker_commandline_code_executor.py | 2 +- .../autogen-ext/tests/test_worker_runtime.py | 30 +++++++++++++++++++ 6 files changed, 50 insertions(+), 16 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_agent.py b/python/packages/autogen-core/src/autogen_core/_agent.py index 063d8515e5f2..e407fe137394 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_agent.py @@ -1,9 +1,13 @@ -from typing import Any, Mapping, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Mapping, Protocol, runtime_checkable from ._agent_id import AgentId from ._agent_metadata import AgentMetadata from ._message_context import MessageContext +# Forward declaration for type checking only +if TYPE_CHECKING: + from ._agent_runtime import AgentRuntime + @runtime_checkable class Agent(Protocol): @@ -17,8 +21,13 @@ def id(self) -> AgentId: """ID of the agent.""" ... - async def init(self, **kwargs: Any) -> None: - """Function used for two-phase initialization.""" + async def bind_id_and_runtime(self, id: AgentId, runtime: "AgentRuntime") -> None: + """Function used to bind an Agent instance to an `AgentRuntime`. + + Args: + agent_id (AgentId): ID of the agent. + runtime (AgentRuntime): AgentRuntime instance to bind the agent to. + """ ... async def on_message(self, message: Any, ctx: MessageContext) -> Any: diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 57a7959c1b6c..39c17e6a86ba 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -89,15 +89,11 @@ def __init__(self, description: str) -> None: raise ValueError("Agent description must be a string") self._description = description - async def init(self, **kwargs: Any) -> None: - if "runtime" not in kwargs or "agent_id" not in kwargs: - raise ValueError("Agent must be initialized with runtime and agent_id") - if not isinstance(kwargs["runtime"], AgentRuntime): - raise ValueError("Agent must be initialized with runtime of type AgentRuntime") - if not isinstance(kwargs["agent_id"], AgentId): - raise ValueError("Agent must be initialized with agent_id of type AgentId") - self._runtime = kwargs["runtime"] - self._id = kwargs["agent_id"] + async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None: + if hasattr(self, "_id"): + raise RuntimeError("Agent is already bound to an ID and runtime") + self._id = id + self._runtime = runtime @property def type(self) -> str: diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 4522da91c8e7..966504b29b94 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -854,7 +854,7 @@ def agent_factory() -> T: if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): raise ValueError("Agent instances must be the same object type.") - await agent_instance.init(runtime=self, agent_id=agent_id) + await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self) self._instantiated_agents[agent_id] = agent_instance return agent_id @@ -880,7 +880,6 @@ async def _invoke_agent_factory( if inspect.isawaitable(agent): agent = cast(T, await agent) - await agent.init(runtime=self, agent_id=agent_id) return agent except BaseException as e: diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 06225a75adb0..e5c8d87a471b 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -767,7 +767,7 @@ def agent_factory() -> T: if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): raise ValueError("Agent instances must be the same object type.") - await agent_instance.init(runtime=self, agent_id=agent_id) + await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self) self._instantiated_agents[agent_id] = agent_instance return agent_id diff --git a/python/packages/autogen-ext/tests/code_executors/test_docker_commandline_code_executor.py b/python/packages/autogen-ext/tests/code_executors/test_docker_commandline_code_executor.py index b374f5371c93..81c890efa643 100644 --- a/python/packages/autogen-ext/tests/code_executors/test_docker_commandline_code_executor.py +++ b/python/packages/autogen-ext/tests/code_executors/test_docker_commandline_code_executor.py @@ -1,8 +1,8 @@ # mypy: disable-error-code="no-any-unimported" import asyncio import os -import sys import shutil +import sys import tempfile from pathlib import Path from typing import AsyncGenerator, TypeAlias diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index a6bf4441cbcc..a1aa122966dc 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -649,6 +649,36 @@ async def test_register_instance_factory() -> None: await host.stop() +@pytest.mark.grpc +@pytest.mark.asyncio +async def test_instance_factory_messaging() -> None: + host_address = "localhost:50051" + agent1_id = AgentId(type="instance_agent", key="instance_agent") + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + worker = GrpcWorkerAgentRuntime(host_address=host_address) + agent1 = CascadingAgent(max_rounds=5) + await worker.start() + await worker.register_agent_instance(agent1, agent_id=agent1_id) + await worker.add_subscription(TypeSubscription("instance_agent", "instance_agent")) + await CascadingAgent.register(worker, "factory_agent", lambda: CascadingAgent(max_rounds=5)) + + # instance_agent will publish a message that factory_agent will pick up + for i in range(5): + await worker.publish_message( + CascadingMessageType(round=i), TopicId(type="instance_agent", source="instance_agent") + ) + await asyncio.sleep(2) + + agent = await worker.try_get_underlying_agent_instance(AgentId("instance_agent", "instance_agent"), CascadingAgent) + assert agent.num_calls == 5 + assert agent1.num_calls == 5 + + await worker.stop() + await host.stop() + + # GrpcWorkerAgentRuntimeHost eats exceptions in the main loop # @pytest.mark.grpc # @pytest.mark.asyncio From 4b4462ebba02883703d331c398319ba6aa6a6254 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Wed, 7 May 2025 16:45:57 -0400 Subject: [PATCH 07/17] fix test --- python/packages/autogen-ext/tests/test_worker_runtime.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index a1aa122966dc..5212ecb30b6b 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -658,7 +658,7 @@ async def test_instance_factory_messaging() -> None: host.start() worker = GrpcWorkerAgentRuntime(host_address=host_address) - agent1 = CascadingAgent(max_rounds=5) + agent1 = CascadingAgent(max_rounds=4) await worker.start() await worker.register_agent_instance(agent1, agent_id=agent1_id) await worker.add_subscription(TypeSubscription("instance_agent", "instance_agent")) @@ -671,9 +671,9 @@ async def test_instance_factory_messaging() -> None: ) await asyncio.sleep(2) - agent = await worker.try_get_underlying_agent_instance(AgentId("instance_agent", "instance_agent"), CascadingAgent) + agent = await worker.try_get_underlying_agent_instance(AgentId("factory_agent", "default"), CascadingAgent) assert agent.num_calls == 5 - assert agent1.num_calls == 5 + assert agent1.num_calls == 4 await worker.stop() await host.stop() From b6d5ecdda399edcdc5da50525e73bd14c95b862e Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Wed, 7 May 2025 17:25:48 -0400 Subject: [PATCH 08/17] Agent instances can only subscribe to a single subscription. Make that the default behavior --- .../src/autogen_core/_base_agent.py | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 39c17e6a86ba..a86b88c55c7c 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -21,6 +21,7 @@ from ._subscription_context import SubscriptionInstantiationContext from ._topic import TopicId from ._type_prefix_subscription import TypePrefixSubscription +from ._type_subscription import TypeSubscription T = TypeVar("T", bound=Agent) @@ -159,27 +160,18 @@ async def register_instance( runtime: AgentRuntime, agent_id: AgentId, *, - skip_class_subscriptions: bool = False, + skip_broadcast_subscription: bool = False, skip_direct_message_subscription: bool = False, ) -> AgentId: + """ + This function is similar to `register` but is used for registering an instance of an agent. One major difference between the two functions is that using this function will restrict the topics that the agent can subscribe to based on the agent_id. + """ agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id) - if not skip_class_subscriptions: - with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)): - subscriptions: List[Subscription] = [] - for unbound_subscription in self._unbound_subscriptions(): - subscriptions_list_result = unbound_subscription() - if inspect.isawaitable(subscriptions_list_result): - subscriptions_list = await subscriptions_list_result - else: - subscriptions_list = subscriptions_list_result - subscriptions.extend(subscriptions_list) - try: - for subscription in subscriptions: - await runtime.add_subscription(subscription) - except ValueError: - # We don't care if the subscription already exists - pass + if not skip_broadcast_subscription: + # Only add a subscription based on the agent_id. Default should NOT be added in this case. + subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type) + await runtime.add_subscription(subscription) if not skip_direct_message_subscription: # Additionally adds a special prefix subscription for this agent to receive direct messages From 8b4eb109f4dd01767c7f87ca2bc09eecbf6f064d Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Thu, 8 May 2025 16:10:31 -0400 Subject: [PATCH 09/17] fix subscription behavior --- .../src/autogen_core/_base_agent.py | 32 +++++++++++++++---- .../autogen-ext/tests/test_worker_runtime.py | 4 +-- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index a86b88c55c7c..6b8e7d848f73 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -92,7 +92,13 @@ def __init__(self, description: str) -> None: async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None: if hasattr(self, "_id"): - raise RuntimeError("Agent is already bound to an ID and runtime") + if self._id != id: + raise RuntimeError("Agent is already bound to a different ID") + + if hasattr(self, "_runtime"): + if self._runtime != runtime: + raise RuntimeError("Agent is already bound to a different runtime") + self._id = id self._runtime = runtime @@ -160,18 +166,30 @@ async def register_instance( runtime: AgentRuntime, agent_id: AgentId, *, - skip_broadcast_subscription: bool = False, + skip_class_subscriptions: bool = True, skip_direct_message_subscription: bool = False, ) -> AgentId: """ - This function is similar to `register` but is used for registering an instance of an agent. One major difference between the two functions is that using this function will restrict the topics that the agent can subscribe to based on the agent_id. + This function is similar to `register` but is used for registering an instance of an agent. A subscription based on the agent ID is created and added to the runtime. """ agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id) - if not skip_broadcast_subscription: - # Only add a subscription based on the agent_id. Default should NOT be added in this case. - subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type) - await runtime.add_subscription(subscription) + id_subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type) + await runtime.add_subscription(id_subscription) + + if not skip_class_subscriptions: + with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)): + subscriptions: List[Subscription] = [] + for unbound_subscription in self._unbound_subscriptions(): + subscriptions_list_result = unbound_subscription() + if inspect.isawaitable(subscriptions_list_result): + subscriptions_list = await subscriptions_list_result + else: + subscriptions_list = subscriptions_list_result + + subscriptions.extend(subscriptions_list) + for subscription in subscriptions: + await runtime.add_subscription(subscription) if not skip_direct_message_subscription: # Additionally adds a special prefix subscription for this agent to receive direct messages diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index 5212ecb30b6b..081ea4d3d7d3 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -660,8 +660,8 @@ async def test_instance_factory_messaging() -> None: worker = GrpcWorkerAgentRuntime(host_address=host_address) agent1 = CascadingAgent(max_rounds=4) await worker.start() - await worker.register_agent_instance(agent1, agent_id=agent1_id) - await worker.add_subscription(TypeSubscription("instance_agent", "instance_agent")) + + await agent1.register_instance(worker, agent_id=agent1_id) await CascadingAgent.register(worker, "factory_agent", lambda: CascadingAgent(max_rounds=5)) # instance_agent will publish a message that factory_agent will pick up From b45a3beec55f25a7698f61f35225043890e872b3 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Fri, 9 May 2025 10:38:56 -0400 Subject: [PATCH 10/17] add additional test case --- .../autogen-ext/tests/test_worker_runtime.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index 081ea4d3d7d3..73f420820cf5 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -653,27 +653,35 @@ async def test_register_instance_factory() -> None: @pytest.mark.asyncio async def test_instance_factory_messaging() -> None: host_address = "localhost:50051" - agent1_id = AgentId(type="instance_agent", key="instance_agent") + loopback_agent_id = AgentId(type="dm_agent", key="dm_agent") + cascading_agent_id = AgentId(type="instance_agent", key="instance_agent") host = GrpcWorkerAgentRuntimeHost(address=host_address) host.start() worker = GrpcWorkerAgentRuntime(host_address=host_address) - agent1 = CascadingAgent(max_rounds=4) + cascading_agent = CascadingAgent(max_rounds=5) + loopback_agent = LoopbackAgent() await worker.start() - await agent1.register_instance(worker, agent_id=agent1_id) + await loopback_agent.register_instance(worker, agent_id=loopback_agent_id) + resp = await worker.send_message( + message=ContentMessage(content="Hello!"), recipient=loopback_agent_id + ) + assert resp == ContentMessage(content="Hello!") + + await cascading_agent.register_instance(worker, agent_id=cascading_agent_id) await CascadingAgent.register(worker, "factory_agent", lambda: CascadingAgent(max_rounds=5)) # instance_agent will publish a message that factory_agent will pick up for i in range(5): await worker.publish_message( - CascadingMessageType(round=i), TopicId(type="instance_agent", source="instance_agent") + CascadingMessageType(round=i+1), TopicId(type="instance_agent", source="instance_agent") ) await asyncio.sleep(2) agent = await worker.try_get_underlying_agent_instance(AgentId("factory_agent", "default"), CascadingAgent) - assert agent.num_calls == 5 - assert agent1.num_calls == 4 + assert agent.num_calls == 4 + assert cascading_agent.num_calls == 5 await worker.stop() await host.stop() From 8939fac62344219f5638e4d8f5e6c40b39ff14ca Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Fri, 9 May 2025 10:58:06 -0400 Subject: [PATCH 11/17] formatting --- python/packages/autogen-ext/tests/test_worker_runtime.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index 73f420820cf5..ec57f187e821 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -664,9 +664,7 @@ async def test_instance_factory_messaging() -> None: await worker.start() await loopback_agent.register_instance(worker, agent_id=loopback_agent_id) - resp = await worker.send_message( - message=ContentMessage(content="Hello!"), recipient=loopback_agent_id - ) + resp = await worker.send_message(message=ContentMessage(content="Hello!"), recipient=loopback_agent_id) assert resp == ContentMessage(content="Hello!") await cascading_agent.register_instance(worker, agent_id=cascading_agent_id) @@ -675,7 +673,7 @@ async def test_instance_factory_messaging() -> None: # instance_agent will publish a message that factory_agent will pick up for i in range(5): await worker.publish_message( - CascadingMessageType(round=i+1), TopicId(type="instance_agent", source="instance_agent") + CascadingMessageType(round=i + 1), TopicId(type="instance_agent", source="instance_agent") ) await asyncio.sleep(2) From 9bee1d83ce59a5ced744bf89243852d1574ddae5 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Fri, 9 May 2025 15:23:05 -0400 Subject: [PATCH 12/17] address PR comments --- .../src/autogen_core/_agent_instantiation.py | 2 +- .../src/autogen_core/_agent_runtime.py | 2 +- .../src/autogen_core/_base_agent.py | 2 +- .../_single_threaded_agent_runtime.py | 8 ++++--- .../autogen-core/tests/test_base_agent.py | 24 +++++++++---------- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py index 7cfae8d6fed5..a8904a42da56 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py @@ -120,7 +120,7 @@ def current_agent_id(cls) -> AgentId: ) from e @classmethod - def is_in_runtime(cls) -> bool: + def is_in_factory_call(cls) -> bool: if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None: return False return True diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index 875b09e1b491..f4faf6136e91 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -132,7 +132,7 @@ async def main() -> None: async def register_agent_instance( self, - agent_instance: T | Awaitable[T], + agent_instance: Agent, agent_id: AgentId, ) -> AgentId: """Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions. diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 6b8e7d848f73..0ad0bc60776c 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -83,7 +83,7 @@ def metadata(self) -> AgentMetadata: return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) def __init__(self, description: str) -> None: - if AgentInstantiationContext.is_in_runtime(): + if AgentInstantiationContext.is_in_factory_call(): self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime() self._id = AgentInstantiationContext.current_agent_id() if not isinstance(description, str): diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 966504b29b94..0b803f2a72e3 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -833,11 +833,13 @@ async def factory_wrapper() -> T: async def register_agent_instance( self, - agent_instance: T | Awaitable[T], + agent_instance: Agent, agent_id: AgentId, ) -> AgentId: - def agent_factory() -> T: - raise RuntimeError("Agent factory should not be called when registering an agent instance.") + def agent_factory() -> Agent: + raise RuntimeError( + "Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent." + ) if inspect.isawaitable(agent_instance): agent_instance = await agent_instance diff --git a/python/packages/autogen-core/tests/test_base_agent.py b/python/packages/autogen-core/tests/test_base_agent.py index 2e2e1f91d28b..010bd0624478 100644 --- a/python/packages/autogen-core/tests/test_base_agent.py +++ b/python/packages/autogen-core/tests/test_base_agent.py @@ -1,15 +1,15 @@ -# import pytest -# from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime -# from autogen_test_utils import NoopAgent -# from pytest_mock import MockerFixture +import pytest +from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime +from autogen_test_utils import NoopAgent +from pytest_mock import MockerFixture -# @pytest.mark.asyncio -# async def test_base_agent_create(mocker: MockerFixture) -> None: -# runtime = mocker.Mock(spec=AgentRuntime) +@pytest.mark.asyncio +async def test_base_agent_create(mocker: MockerFixture) -> None: + runtime = mocker.Mock(spec=AgentRuntime) -# # Shows how to set the context for the agent instantiation in a test context -# with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))): -# agent2 = NoopAgent() -# assert agent2.runtime == runtime -# assert agent2.id == AgentId("name2", "namespace2") + # Shows how to set the context for the agent instantiation in a test context + with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))): + agent2 = NoopAgent() + assert agent2.runtime == runtime + assert agent2.id == AgentId("name2", "namespace2") From 2bdc4c2386916dd620957ebd6c134eb2eb80d34c Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Fri, 9 May 2025 16:08:47 -0400 Subject: [PATCH 13/17] update documentation --- .../packages/autogen-core/src/autogen_core/_agent_runtime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index f4faf6136e91..d4bac4a9c0ab 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -167,7 +167,7 @@ async def handler(self, message: UserMessage, context: MessageContext) -> None: async def main() -> None: runtime: AgentRuntime = ... # type: ignore - agent: Agent = MyAgent() + agent = MyAgent() await runtime.register_agent_instance( agent_instance=agent, agent_id=AgentId(type="my_agent", key="default") ) @@ -179,7 +179,7 @@ async def main() -> None: Args: - agent_instance (T | Awaitable[T]): A concrete instance of the agent. + agent_instance (Agent): A concrete instance of the agent. agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`. """ ... From c6a1c1da09296498456411a185bfff89ecb362fd Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Fri, 9 May 2025 16:22:50 -0400 Subject: [PATCH 14/17] remove unnecessary code --- .../src/autogen_core/_single_threaded_agent_runtime.py | 3 --- .../src/autogen_ext/runtimes/grpc/_worker_runtime.py | 5 +---- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 0b803f2a72e3..3c3d9afa1b96 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -841,9 +841,6 @@ def agent_factory() -> Agent: "Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent." ) - if inspect.isawaitable(agent_instance): - agent_instance = await agent_instance - if agent_id in self._instantiated_agents: raise ValueError(f"Agent with id {agent_id} already exists.") diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index e5c8d87a471b..92090ab11bbf 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -745,15 +745,12 @@ async def factory_wrapper() -> T: async def register_agent_instance( self, - agent_instance: T | Awaitable[T], + agent_instance: Agent, agent_id: AgentId, ) -> AgentId: def agent_factory() -> T: raise RuntimeError("Agent factory should not be called when registering an agent instance.") - if inspect.isawaitable(agent_instance): - agent_instance = await agent_instance - if agent_id in self._instantiated_agents: raise ValueError(f"Agent with id {agent_id} already exists.") From 5cbb7c2404279b8030395f1da6ad94d54d4b9af8 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Fri, 9 May 2025 16:23:18 -0400 Subject: [PATCH 15/17] missed change --- .../src/autogen_ext/runtimes/grpc/_worker_runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 92090ab11bbf..e7591df401e4 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -748,7 +748,7 @@ async def register_agent_instance( agent_instance: Agent, agent_id: AgentId, ) -> AgentId: - def agent_factory() -> T: + def agent_factory() -> Agent: raise RuntimeError("Agent factory should not be called when registering an agent instance.") if agent_id in self._instantiated_agents: From 771f2d5b00923e6b0f4c1401597faecda95d9a2f Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Fri, 9 May 2025 16:24:07 -0400 Subject: [PATCH 16/17] update error message --- .../src/autogen_ext/runtimes/grpc/_worker_runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index e7591df401e4..db958ff9e2e7 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -749,7 +749,7 @@ async def register_agent_instance( agent_id: AgentId, ) -> AgentId: def agent_factory() -> Agent: - raise RuntimeError("Agent factory should not be called when registering an agent instance.") + raise RuntimeError("Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent.") if agent_id in self._instantiated_agents: raise ValueError(f"Agent with id {agent_id} already exists.") From f4614a8601bad58f4d98dcc9d8f71f9c1b863656 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Mon, 12 May 2025 11:27:42 -0400 Subject: [PATCH 17/17] formatting --- .../src/autogen_ext/runtimes/grpc/_worker_runtime.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index db958ff9e2e7..6a3963586e18 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -749,7 +749,9 @@ async def register_agent_instance( agent_id: AgentId, ) -> AgentId: def agent_factory() -> Agent: - raise RuntimeError("Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent.") + raise RuntimeError( + "Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent." + ) if agent_id in self._instantiated_agents: raise ValueError(f"Agent with id {agent_id} already exists.")