Skip to content

Add ability to register Agent instances #6131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
aff1d01
Add ability to register Agent instances
peterychang Mar 27, 2025
693cef8
Initialize agent from context if it exists
peterychang Mar 27, 2025
12d26c7
BaseAgent initializes with runtime and id
peterychang Mar 27, 2025
4ebad57
Fix example code
peterychang Mar 28, 2025
e0c6c5c
Merge branch 'main' into register_agent_instances
ekzhu Mar 31, 2025
1076fd5
Revert design change. runtime and agent_id passed in register fn
peterychang Apr 1, 2025
edf4414
Merge branch 'register_agent_instances' of github.com:peterychang/aut…
peterychang Apr 1, 2025
4c67138
Merge branch 'main' into register_agent_instances
peterychang May 5, 2025
434f2a5
address PR comments, add test
peterychang May 7, 2025
4b4462e
fix test
peterychang May 7, 2025
b6d5ecd
Agent instances can only subscribe to a single subscription. Make tha…
peterychang May 7, 2025
8b4eb10
fix subscription behavior
peterychang May 8, 2025
b45a3be
add additional test case
peterychang May 9, 2025
8939fac
formatting
peterychang May 9, 2025
9bee1d8
address PR comments
peterychang May 9, 2025
2bdc4c2
update documentation
peterychang May 9, 2025
096790c
Merge branch 'main' into register_agent_instances
peterychang May 9, 2025
c6a1c1d
remove unnecessary code
peterychang May 9, 2025
5cbb7c2
missed change
peterychang May 9, 2025
771f2d5
update error message
peterychang May 9, 2025
eeb3488
Merge branch 'main' into register_agent_instances
peterychang May 9, 2025
78672c4
Merge branch 'main' into register_agent_instances
peterychang May 12, 2025
f4614a8
formatting
peterychang May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/packages/autogen-core/src/autogen_core/_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any, Mapping, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Mapping, Protocol, runtime_checkable

from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._message_context import MessageContext

# Forward declaration for type checking only
if TYPE_CHECKING:
from ._agent_runtime import AgentRuntime


@runtime_checkable
class Agent(Protocol):
Expand All @@ -17,6 +21,15 @@ def id(self) -> AgentId:
"""ID of the agent."""
...

async def bind_id_and_runtime(self, id: AgentId, runtime: "AgentRuntime") -> None:
"""Function used to bind an Agent instance to an `AgentRuntime`.

Args:
agent_id (AgentId): ID of the agent.
runtime (AgentRuntime): AgentRuntime instance to bind the agent to.
"""
...

async def on_message(self, message: Any, ctx: MessageContext) -> Any:
"""Message handler for the agent. This should only be called by the runtime, not by other agents.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,9 @@ def current_agent_id(cls) -> AgentId:
raise RuntimeError(
"AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
) from e

@classmethod
def is_in_factory_call(cls) -> bool:
if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None:
return False
return True
54 changes: 54 additions & 0 deletions python/packages/autogen-core/src/autogen_core/_agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,60 @@ async def main() -> None:
"""
...

async def register_agent_instance(
self,
agent_instance: Agent,
agent_id: AgentId,
) -> AgentId:
"""Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions.

.. note::

This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically.

Example:

.. code-block:: python

from dataclasses import dataclass

from autogen_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event
from autogen_core.models import UserMessage


@dataclass
class MyMessage:
content: str


class MyAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("My core agent")

@event
async def handler(self, message: UserMessage, context: MessageContext) -> None:
print("Event received: ", message.content)


async def main() -> None:
runtime: AgentRuntime = ... # type: ignore
agent = MyAgent()
await runtime.register_agent_instance(
agent_instance=agent, agent_id=AgentId(type="my_agent", key="default")
)


import asyncio

asyncio.run(main())


Args:
agent_instance (Agent): A concrete instance of the agent.
agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`.
"""
...

# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
Expand Down
76 changes: 66 additions & 10 deletions python/packages/autogen-core/src/autogen_core/_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
from ._type_prefix_subscription import TypePrefixSubscription
from ._type_subscription import TypeSubscription

T = TypeVar("T", bound=Agent)

Expand Down Expand Up @@ -82,20 +83,25 @@ def metadata(self) -> AgentMetadata:
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)

def __init__(self, description: str) -> None:
try:
runtime = AgentInstantiationContext.current_runtime()
id = AgentInstantiationContext.current_agent_id()
except LookupError as e:
raise RuntimeError(
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
) from e

self._runtime: AgentRuntime = runtime
self._id: AgentId = id
if AgentInstantiationContext.is_in_factory_call():
self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime()
self._id = AgentInstantiationContext.current_agent_id()
if not isinstance(description, str):
raise ValueError("Agent description must be a string")
self._description = description

async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None:
if hasattr(self, "_id"):
if self._id != id:
raise RuntimeError("Agent is already bound to a different ID")

if hasattr(self, "_runtime"):
if self._runtime != runtime:
raise RuntimeError("Agent is already bound to a different runtime")

self._id = id
self._runtime = runtime

@property
def type(self) -> str:
return self.id.type
Expand Down Expand Up @@ -155,6 +161,56 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
async def close(self) -> None:
pass

async def register_instance(
self,
runtime: AgentRuntime,
agent_id: AgentId,
*,
skip_class_subscriptions: bool = True,
skip_direct_message_subscription: bool = False,
) -> AgentId:
"""
This function is similar to `register` but is used for registering an instance of an agent. A subscription based on the agent ID is created and added to the runtime.
"""
agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id)

id_subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type)
await runtime.add_subscription(id_subscription)

if not skip_class_subscriptions:
with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)):
subscriptions: List[Subscription] = []
for unbound_subscription in self._unbound_subscriptions():
subscriptions_list_result = unbound_subscription()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result
else:
subscriptions_list = subscriptions_list_result

subscriptions.extend(subscriptions_list)
for subscription in subscriptions:
await runtime.add_subscription(subscription)

if not skip_direct_message_subscription:
# Additionally adds a special prefix subscription for this agent to receive direct messages
try:
await runtime.add_subscription(
TypePrefixSubscription(
# The prefix MUST include ":" to avoid collisions with other agents
topic_type_prefix=agent_id.type + ":",
agent_type=agent_id.type,
)
)
except ValueError:
# We don't care if the subscription already exists
pass

# TODO: deduplication
for _message_type, serializer in self._handles_types():
runtime.add_message_serializer(serializer)

return agent_id

@classmethod
async def register(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(
self._serialization_registry = SerializationRegistry()
self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions
self._background_exception: BaseException | None = None
self._agent_instance_types: Dict[str, Type[Agent]] = {}

@property
def unprocessed_messages_count(
Expand Down Expand Up @@ -909,6 +910,32 @@ async def factory_wrapper() -> T:

return type

async def register_agent_instance(
self,
agent_instance: Agent,
agent_id: AgentId,
) -> AgentId:
def agent_factory() -> Agent:
raise RuntimeError(
"Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent."
)

if agent_id in self._instantiated_agents:
raise ValueError(f"Agent with id {agent_id} already exists.")

if agent_id.type not in self._agent_factories:
self._agent_factories[agent_id.type] = agent_factory
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
else:
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
raise ValueError("Agent instances must be the same object type.")

await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self)
self._instantiated_agents[agent_id] = agent_instance
return agent_id

async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
Expand All @@ -930,8 +957,7 @@ async def _invoke_agent_factory(
raise ValueError("Agent factory must take 0 or 2 arguments.")

if inspect.isawaitable(agent):
return cast(T, await agent)

agent = cast(T, await agent)
return agent

except BaseException as e:
Expand Down
8 changes: 4 additions & 4 deletions python/packages/autogen-core/tests/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ async def test_base_agent_create(mocker: MockerFixture) -> None:
runtime = mocker.Mock(spec=AgentRuntime)

# Shows how to set the context for the agent instantiation in a test context
with AgentInstantiationContext.populate_context((runtime, AgentId("name", "namespace"))):
agent = NoopAgent()
assert agent.runtime == runtime
assert agent.id == AgentId("name", "namespace")
with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))):
agent2 = NoopAgent()
assert agent2.runtime == runtime
assert agent2.id == AgentId("name2", "namespace2")
54 changes: 54 additions & 0 deletions python/packages/autogen-core/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,60 @@ def agent_factory() -> NoopAgent:
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent)


@pytest.mark.asyncio
async def test_agent_type_register_instance() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = AgentId(type="name", key="default")
agent2_id = AgentId(type="name", key="notdefault")
agent1 = NoopAgent()
agent1_dup = NoopAgent()
agent2 = NoopAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent1_id)
await agent2.register_instance(runtime=runtime, agent_id=agent2_id)

assert await runtime.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1
assert await runtime.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2
with pytest.raises(ValueError):
await agent1_dup.register_instance(runtime=runtime, agent_id=agent1_id)


@pytest.mark.asyncio
async def test_agent_type_register_instance_different_types() -> None:
runtime = SingleThreadedAgentRuntime()
agent_id1 = AgentId(type="name", key="noop")
agent_id2 = AgentId(type="name", key="loopback")
agent1 = NoopAgent()
agent2 = LoopbackAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent_id1)
with pytest.raises(ValueError):
await agent2.register_instance(runtime=runtime, agent_id=agent_id2)


@pytest.mark.asyncio
async def test_agent_type_register_instance_publish_new_source() -> None:
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
agent_id = AgentId(type="name", key="default")
agent1 = LoopbackAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent_id)
await runtime.add_subscription(TypeSubscription("notdefault", "name"))

runtime.start()
with pytest.raises(RuntimeError):
await runtime.publish_message(MessageType(), TopicId("notdefault", "notdefault"))
await runtime.stop_when_idle()
await runtime.close()


@pytest.mark.asyncio
async def test_register_instance_factory() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = AgentId(type="name", key="default")
agent1 = NoopAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent1_id)
with pytest.raises(ValueError):
await NoopAgent.register(runtime, "name", lambda: NoopAgent())


@pytest.mark.asyncio
async def test_register_receives_publish(tracer_provider: TracerProvider) -> None:
runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
Expand Down
Loading
Loading