diff --git a/python/packages/autogen-core/src/autogen_core/_agent.py b/python/packages/autogen-core/src/autogen_core/_agent.py index 0f37b822ff8a..e407fe137394 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_agent.py @@ -1,9 +1,13 @@ -from typing import Any, Mapping, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Mapping, Protocol, runtime_checkable from ._agent_id import AgentId from ._agent_metadata import AgentMetadata from ._message_context import MessageContext +# Forward declaration for type checking only +if TYPE_CHECKING: + from ._agent_runtime import AgentRuntime + @runtime_checkable class Agent(Protocol): @@ -17,6 +21,15 @@ def id(self) -> AgentId: """ID of the agent.""" ... + async def bind_id_and_runtime(self, id: AgentId, runtime: "AgentRuntime") -> None: + """Function used to bind an Agent instance to an `AgentRuntime`. + + Args: + agent_id (AgentId): ID of the agent. + runtime (AgentRuntime): AgentRuntime instance to bind the agent to. + """ + ... + async def on_message(self, message: Any, ctx: MessageContext) -> Any: """Message handler for the agent. This should only be called by the runtime, not by other agents. diff --git a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py index 71921225cfbd..a8904a42da56 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_instantiation.py @@ -118,3 +118,9 @@ def current_agent_id(cls) -> AgentId: raise RuntimeError( "AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so." ) from e + + @classmethod + def is_in_factory_call(cls) -> bool: + if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None: + return False + return True diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index 22493626f911..d4bac4a9c0ab 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -130,6 +130,60 @@ async def main() -> None: """ ... + async def register_agent_instance( + self, + agent_instance: Agent, + agent_id: AgentId, + ) -> AgentId: + """Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions. + + .. note:: + + This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically. + + Example: + + .. code-block:: python + + from dataclasses import dataclass + + from autogen_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event + from autogen_core.models import UserMessage + + + @dataclass + class MyMessage: + content: str + + + class MyAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("My core agent") + + @event + async def handler(self, message: UserMessage, context: MessageContext) -> None: + print("Event received: ", message.content) + + + async def main() -> None: + runtime: AgentRuntime = ... # type: ignore + agent = MyAgent() + await runtime.register_agent_instance( + agent_instance=agent, agent_id=AgentId(type="my_agent", key="default") + ) + + + import asyncio + + asyncio.run(main()) + + + Args: + agent_instance (Agent): A concrete instance of the agent. + agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`. + """ + ... + # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] """Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases. diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index bffb61b876bb..0ad0bc60776c 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -21,6 +21,7 @@ from ._subscription_context import SubscriptionInstantiationContext from ._topic import TopicId from ._type_prefix_subscription import TypePrefixSubscription +from ._type_subscription import TypeSubscription T = TypeVar("T", bound=Agent) @@ -82,20 +83,25 @@ def metadata(self) -> AgentMetadata: return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) def __init__(self, description: str) -> None: - try: - runtime = AgentInstantiationContext.current_runtime() - id = AgentInstantiationContext.current_agent_id() - except LookupError as e: - raise RuntimeError( - "BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated." - ) from e - - self._runtime: AgentRuntime = runtime - self._id: AgentId = id + if AgentInstantiationContext.is_in_factory_call(): + self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime() + self._id = AgentInstantiationContext.current_agent_id() if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description + async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None: + if hasattr(self, "_id"): + if self._id != id: + raise RuntimeError("Agent is already bound to a different ID") + + if hasattr(self, "_runtime"): + if self._runtime != runtime: + raise RuntimeError("Agent is already bound to a different runtime") + + self._id = id + self._runtime = runtime + @property def type(self) -> str: return self.id.type @@ -155,6 +161,56 @@ async def load_state(self, state: Mapping[str, Any]) -> None: async def close(self) -> None: pass + async def register_instance( + self, + runtime: AgentRuntime, + agent_id: AgentId, + *, + skip_class_subscriptions: bool = True, + skip_direct_message_subscription: bool = False, + ) -> AgentId: + """ + This function is similar to `register` but is used for registering an instance of an agent. A subscription based on the agent ID is created and added to the runtime. + """ + agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id) + + id_subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type) + await runtime.add_subscription(id_subscription) + + if not skip_class_subscriptions: + with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)): + subscriptions: List[Subscription] = [] + for unbound_subscription in self._unbound_subscriptions(): + subscriptions_list_result = unbound_subscription() + if inspect.isawaitable(subscriptions_list_result): + subscriptions_list = await subscriptions_list_result + else: + subscriptions_list = subscriptions_list_result + + subscriptions.extend(subscriptions_list) + for subscription in subscriptions: + await runtime.add_subscription(subscription) + + if not skip_direct_message_subscription: + # Additionally adds a special prefix subscription for this agent to receive direct messages + try: + await runtime.add_subscription( + TypePrefixSubscription( + # The prefix MUST include ":" to avoid collisions with other agents + topic_type_prefix=agent_id.type + ":", + agent_type=agent_id.type, + ) + ) + except ValueError: + # We don't care if the subscription already exists + pass + + # TODO: deduplication + for _message_type, serializer in self._handles_types(): + runtime.add_message_serializer(serializer) + + return agent_id + @classmethod async def register( cls, diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index f806ad31f19b..307cdd863e0b 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -266,6 +266,7 @@ def __init__( self._serialization_registry = SerializationRegistry() self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions self._background_exception: BaseException | None = None + self._agent_instance_types: Dict[str, Type[Agent]] = {} @property def unprocessed_messages_count( @@ -909,6 +910,32 @@ async def factory_wrapper() -> T: return type + async def register_agent_instance( + self, + agent_instance: Agent, + agent_id: AgentId, + ) -> AgentId: + def agent_factory() -> Agent: + raise RuntimeError( + "Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent." + ) + + if agent_id in self._instantiated_agents: + raise ValueError(f"Agent with id {agent_id} already exists.") + + if agent_id.type not in self._agent_factories: + self._agent_factories[agent_id.type] = agent_factory + self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance) + else: + if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__: + raise ValueError("Agent factories and agent instances cannot be registered to the same type.") + if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): + raise ValueError("Agent instances must be the same object type.") + + await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self) + self._instantiated_agents[agent_id] = agent_instance + return agent_id + async def _invoke_agent_factory( self, agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], @@ -930,8 +957,7 @@ async def _invoke_agent_factory( raise ValueError("Agent factory must take 0 or 2 arguments.") if inspect.isawaitable(agent): - return cast(T, await agent) - + agent = cast(T, await agent) return agent except BaseException as e: diff --git a/python/packages/autogen-core/tests/test_base_agent.py b/python/packages/autogen-core/tests/test_base_agent.py index 64bcf59d1774..010bd0624478 100644 --- a/python/packages/autogen-core/tests/test_base_agent.py +++ b/python/packages/autogen-core/tests/test_base_agent.py @@ -9,7 +9,7 @@ async def test_base_agent_create(mocker: MockerFixture) -> None: runtime = mocker.Mock(spec=AgentRuntime) # Shows how to set the context for the agent instantiation in a test context - with AgentInstantiationContext.populate_context((runtime, AgentId("name", "namespace"))): - agent = NoopAgent() - assert agent.runtime == runtime - assert agent.id == AgentId("name", "namespace") + with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))): + agent2 = NoopAgent() + assert agent2.runtime == runtime + assert agent2.id == AgentId("name2", "namespace2") diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index 64a1cccf4b12..e93a57a6a291 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -82,6 +82,60 @@ def agent_factory() -> NoopAgent: await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent) +@pytest.mark.asyncio +async def test_agent_type_register_instance() -> None: + runtime = SingleThreadedAgentRuntime() + agent1_id = AgentId(type="name", key="default") + agent2_id = AgentId(type="name", key="notdefault") + agent1 = NoopAgent() + agent1_dup = NoopAgent() + agent2 = NoopAgent() + await agent1.register_instance(runtime=runtime, agent_id=agent1_id) + await agent2.register_instance(runtime=runtime, agent_id=agent2_id) + + assert await runtime.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1 + assert await runtime.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2 + with pytest.raises(ValueError): + await agent1_dup.register_instance(runtime=runtime, agent_id=agent1_id) + + +@pytest.mark.asyncio +async def test_agent_type_register_instance_different_types() -> None: + runtime = SingleThreadedAgentRuntime() + agent_id1 = AgentId(type="name", key="noop") + agent_id2 = AgentId(type="name", key="loopback") + agent1 = NoopAgent() + agent2 = LoopbackAgent() + await agent1.register_instance(runtime=runtime, agent_id=agent_id1) + with pytest.raises(ValueError): + await agent2.register_instance(runtime=runtime, agent_id=agent_id2) + + +@pytest.mark.asyncio +async def test_agent_type_register_instance_publish_new_source() -> None: + runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) + agent_id = AgentId(type="name", key="default") + agent1 = LoopbackAgent() + await agent1.register_instance(runtime=runtime, agent_id=agent_id) + await runtime.add_subscription(TypeSubscription("notdefault", "name")) + + runtime.start() + with pytest.raises(RuntimeError): + await runtime.publish_message(MessageType(), TopicId("notdefault", "notdefault")) + await runtime.stop_when_idle() + await runtime.close() + + +@pytest.mark.asyncio +async def test_register_instance_factory() -> None: + runtime = SingleThreadedAgentRuntime() + agent1_id = AgentId(type="name", key="default") + agent1 = NoopAgent() + await agent1.register_instance(runtime=runtime, agent_id=agent1_id) + with pytest.raises(ValueError): + await NoopAgent.register(runtime, "name", lambda: NoopAgent()) + + @pytest.mark.asyncio async def test_register_receives_publish(tracer_provider: TracerProvider) -> None: runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 504838740283..6a3963586e18 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -251,6 +251,7 @@ def __init__( self._subscription_manager = SubscriptionManager() self._serialization_registry = SerializationRegistry() self._extra_grpc_config = extra_grpc_config or [] + self._agent_instance_types: Dict[str, Type[Agent]] = {} if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}: raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}") @@ -701,6 +702,14 @@ async def send_message(agent: Agent, message_context: MessageContext) -> Any: except BaseException as e: logger.error("Error handling event", exc_info=e) + async def _register_agent_type(self, agent_type: str) -> None: + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + message = agent_worker_pb2.RegisterAgentTypeRequest(type=agent_type) + _response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent( + message, metadata=self._host_connection.metadata + ) + async def register_factory( self, type: str | AgentType, @@ -729,14 +738,38 @@ async def factory_wrapper() -> T: return agent_instance self._agent_factories[type.type] = factory_wrapper - # Send the registration request message to the host. - message = agent_worker_pb2.RegisterAgentTypeRequest(type=type.type) - _response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent( - message, metadata=self._host_connection.metadata - ) + await self._register_agent_type(type.type) + return type + async def register_agent_instance( + self, + agent_instance: Agent, + agent_id: AgentId, + ) -> AgentId: + def agent_factory() -> Agent: + raise RuntimeError( + "Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent." + ) + + if agent_id in self._instantiated_agents: + raise ValueError(f"Agent with id {agent_id} already exists.") + + if agent_id.type not in self._agent_factories: + self._agent_factories[agent_id.type] = agent_factory + await self._register_agent_type(agent_id.type) + self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance) + else: + if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__: + raise ValueError("Agent factories and agent instances cannot be registered to the same type.") + if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance): + raise ValueError("Agent instances must be the same object type.") + + await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self) + self._instantiated_agents[agent_id] = agent_instance + return agent_id + async def _invoke_agent_factory( self, agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], @@ -757,7 +790,7 @@ async def _invoke_agent_factory( raise ValueError("Agent factory must take 0 or 2 arguments.") if inspect.isawaitable(agent): - return cast(T, await agent) + agent = cast(T, await agent) return agent diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py index 2f4f02aeaf13..2e15938903f8 100644 --- a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py @@ -3,7 +3,7 @@ import os from datetime import datetime from typing import Any, AsyncGenerator, List, Type, Union -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from autogen_core import CancellationToken, FunctionCall, Image @@ -570,7 +570,7 @@ async def _mock_thought_with_tool_call_stream( ) mock_client = MagicMock() - mock_client.close = MagicMock() + mock_client.close = AsyncMock() async def mock_complete(*args: Any, **kwargs: Any) -> Any: if kwargs.get("stream", False): diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index dede306853ed..ec57f187e821 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -577,6 +577,139 @@ async def test_grpc_max_message_size() -> None: await host.stop() +@pytest.mark.grpc +@pytest.mark.asyncio +async def test_agent_type_register_instance() -> None: + host_address = "localhost:50051" + agent1_id = AgentId(type="name", key="default") + agentdup_id = AgentId(type="name", key="default") + agent2_id = AgentId(type="name", key="notdefault") + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + worker = GrpcWorkerAgentRuntime(host_address=host_address) + agent1 = NoopAgent() + agent2 = NoopAgent() + agentdup = NoopAgent() + await worker.start() + + await worker.register_agent_instance(agent1, agent_id=agent1_id) + await worker.register_agent_instance(agent2, agent_id=agent2_id) + + with pytest.raises(ValueError): + await worker.register_agent_instance(agentdup, agent_id=agentdup_id) + + assert await worker.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1 + assert await worker.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2 + + await worker.stop() + await host.stop() + + +@pytest.mark.grpc +@pytest.mark.asyncio +async def test_agent_type_register_instance_different_types() -> None: + host_address = "localhost:50051" + agent1_id = AgentId(type="name", key="noop") + agent2_id = AgentId(type="name", key="loopback") + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + worker = GrpcWorkerAgentRuntime(host_address=host_address) + agent1 = NoopAgent() + agent2 = LoopbackAgent() + await worker.start() + + await worker.register_agent_instance(agent1, agent_id=agent1_id) + with pytest.raises(ValueError): + await worker.register_agent_instance(agent2, agent_id=agent2_id) + + await worker.stop() + await host.stop() + + +@pytest.mark.grpc +@pytest.mark.asyncio +async def test_register_instance_factory() -> None: + host_address = "localhost:50051" + agent1_id = AgentId(type="name", key="default") + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + worker = GrpcWorkerAgentRuntime(host_address=host_address) + agent1 = NoopAgent() + await worker.start() + + await agent1.register_instance(runtime=worker, agent_id=agent1_id) + + with pytest.raises(ValueError): + await NoopAgent.register(runtime=worker, type="name", factory=lambda: NoopAgent()) + + await worker.stop() + await host.stop() + + +@pytest.mark.grpc +@pytest.mark.asyncio +async def test_instance_factory_messaging() -> None: + host_address = "localhost:50051" + loopback_agent_id = AgentId(type="dm_agent", key="dm_agent") + cascading_agent_id = AgentId(type="instance_agent", key="instance_agent") + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + worker = GrpcWorkerAgentRuntime(host_address=host_address) + cascading_agent = CascadingAgent(max_rounds=5) + loopback_agent = LoopbackAgent() + await worker.start() + + await loopback_agent.register_instance(worker, agent_id=loopback_agent_id) + resp = await worker.send_message(message=ContentMessage(content="Hello!"), recipient=loopback_agent_id) + assert resp == ContentMessage(content="Hello!") + + await cascading_agent.register_instance(worker, agent_id=cascading_agent_id) + await CascadingAgent.register(worker, "factory_agent", lambda: CascadingAgent(max_rounds=5)) + + # instance_agent will publish a message that factory_agent will pick up + for i in range(5): + await worker.publish_message( + CascadingMessageType(round=i + 1), TopicId(type="instance_agent", source="instance_agent") + ) + await asyncio.sleep(2) + + agent = await worker.try_get_underlying_agent_instance(AgentId("factory_agent", "default"), CascadingAgent) + assert agent.num_calls == 4 + assert cascading_agent.num_calls == 5 + + await worker.stop() + await host.stop() + + +# GrpcWorkerAgentRuntimeHost eats exceptions in the main loop +# @pytest.mark.grpc +# @pytest.mark.asyncio +# async def test_agent_type_register_instance_publish_new_source() -> None: +# host_address = "localhost:50056" +# agent_id = AgentId(type="name", key="default") +# agent1 = LoopbackAgent() +# host = GrpcWorkerAgentRuntimeHost(address=host_address) +# host.start() +# worker = GrpcWorkerAgentRuntime(host_address=host_address) +# await worker.start() +# publisher = GrpcWorkerAgentRuntime(host_address=host_address) +# publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) +# await publisher.start() + +# await agent1.register_instance(worker, agent_id=agent_id) +# await worker.add_subscription(TypeSubscription("notdefault", "name")) + +# with pytest.raises(RuntimeError): +# await worker.publish_message(MessageType(), TopicId("notdefault", "notdefault")) +# await asyncio.sleep(2) + +# await worker.stop() +# await host.stop() + if __name__ == "__main__": os.environ["GRPC_VERBOSITY"] = "DEBUG" os.environ["GRPC_TRACE"] = "all" diff --git a/python/packages/autogen-ext/tests/tools/http/test_http_tool.py b/python/packages/autogen-ext/tests/tools/http/test_http_tool.py index 9d2898bae505..8e50e48ba926 100644 --- a/python/packages/autogen-ext/tests/tools/http/test_http_tool.py +++ b/python/packages/autogen-ext/tests/tools/http/test_http_tool.py @@ -176,7 +176,7 @@ async def test_invalid_request(test_config: ComponentModel, test_server: None) - config.config["host"] = "fake" tool = HttpTool.load_component(config) - with pytest.raises(httpx.ConnectError): + with pytest.raises((httpx.ConnectError, httpx.ConnectTimeout)): await tool.run_json({"query": "test query", "value": 42}, CancellationToken())