Skip to content

Commit 2606a72

Browse files
feat: Implement new FunctionTarget (#2031)
* implement new FunctionTarget feature * Discussion WIP * Discussion WIP * refactored function target logic, new FunctionTargetResult and FunctionTargetMessage classes * pre-commit tidy * Headers and init order * Refactor, consolidate into one file, type hints * implement support for extra fn parameters --------- Co-authored-by: Mark Sze <mark@sze.family> Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com>
1 parent bf75e0c commit 2606a72

File tree

3 files changed

+346
-0
lines changed

3 files changed

+346
-0
lines changed

autogen/agentchat/group/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
GroupManagerTarget,
2525
)
2626
"""
27+
from .targets.function_target import FunctionTarget, FunctionTargetResult
2728
from .targets.transition_target import (
2829
AgentNameTarget,
2930
AgentTarget,
@@ -44,6 +45,8 @@
4445
"ContextVariables",
4546
"ExpressionAvailableCondition",
4647
"ExpressionContextCondition",
48+
"FunctionTarget",
49+
"FunctionTargetResult",
4750
"GroupChatConfig",
4851
"GroupChatTarget",
4952
# "GroupManagerSelectionMessageContextStr",
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from __future__ import annotations
6+
7+
import inspect
8+
from collections.abc import Callable
9+
from typing import TYPE_CHECKING, Any
10+
11+
from pydantic import BaseModel, Field
12+
13+
from ...agent import Agent
14+
from ..context_variables import ContextVariables
15+
from ..speaker_selection_result import SpeakerSelectionResult
16+
from .transition_target import AgentNameTarget, AgentTarget, RevertToUserTarget, StayTarget, TransitionTarget
17+
18+
if TYPE_CHECKING:
19+
from ...conversable_agent import ConversableAgent
20+
from ...groupchat import GroupChat
21+
22+
__all__ = ["FunctionTarget", "FunctionTargetMessage", "FunctionTargetResult", "broadcast"]
23+
24+
25+
class FunctionTargetMessage(BaseModel):
26+
"""Message and target that can be sent as part of the FunctionTargetResult.
27+
28+
Attributes:
29+
content: The content of the message to be sent.
30+
msg_target: The agent to whom the message is to be sent.
31+
"""
32+
33+
content: str
34+
msg_target: Agent
35+
36+
class Config:
37+
arbitrary_types_allowed = True
38+
39+
40+
class FunctionTargetResult(BaseModel):
41+
"""Result of a function handoff that is used to provide the return message and the target to transition to.
42+
43+
Attributes:
44+
messages: Optional list of messages to be broadcast to specific agents, or a single string message.
45+
context_variables: Optional updated context variables that will be applied to the group chat context variables.
46+
target: The next target to transition to.
47+
"""
48+
49+
messages: list[FunctionTargetMessage] | str | None = None
50+
context_variables: ContextVariables | None = None
51+
target: TransitionTarget
52+
53+
54+
def construct_broadcast_messages_list(
55+
messages: list[FunctionTargetMessage] | str,
56+
group_chat: GroupChat,
57+
current_agent: ConversableAgent,
58+
target: TransitionTarget,
59+
user_agent: ConversableAgent | None = None,
60+
) -> list[FunctionTargetMessage]:
61+
"""Construct a list of FunctionTargetMessage from input messages and target."""
62+
if isinstance(messages, str):
63+
if isinstance(target, (AgentTarget, AgentNameTarget)):
64+
next_target = target.agent_name
65+
for agent in group_chat.agents:
66+
if agent.name == next_target:
67+
messages = [FunctionTargetMessage(content=messages, msg_target=agent)]
68+
break
69+
elif isinstance(target, RevertToUserTarget) and user_agent is not None:
70+
messages_list = [FunctionTargetMessage(content=messages, msg_target=user_agent)]
71+
elif isinstance(target, StayTarget):
72+
messages_list = [FunctionTargetMessage(content=messages, msg_target=current_agent)]
73+
else:
74+
# Default to current agent if no target is not agent-based is found
75+
messages_list = [FunctionTargetMessage(content=messages, msg_target=current_agent)]
76+
else:
77+
messages_list = messages
78+
return messages_list
79+
80+
81+
def broadcast(
82+
messages: list[FunctionTargetMessage] | str,
83+
group_chat: GroupChat,
84+
current_agent: ConversableAgent,
85+
fn_name: str,
86+
target: TransitionTarget,
87+
user_agent: ConversableAgent | None = None,
88+
) -> None:
89+
"""Broadcast message(s) to their target agent."""
90+
messages_list = construct_broadcast_messages_list(messages, group_chat, current_agent, target, user_agent)
91+
92+
for message in messages_list:
93+
content = message.content
94+
broadcast = {
95+
"role": "system",
96+
"name": f"{fn_name}",
97+
"content": f"[FUNCTION_HANDOFF] - Reply from function {fn_name}: \n\n {content}",
98+
}
99+
if hasattr(current_agent, "_group_manager") and current_agent._group_manager is not None:
100+
current_agent._group_manager.send(
101+
broadcast,
102+
message.msg_target,
103+
request_reply=False,
104+
silent=False,
105+
)
106+
else:
107+
raise ValueError("Current agent must have a group manager to broadcast messages.")
108+
109+
110+
def validate_fn_sig(
111+
incoming_fn: Callable[..., FunctionTargetResult],
112+
extra_args: dict[str, Any],
113+
) -> None:
114+
"""
115+
Validate a user-defined afterwork_function signature.
116+
117+
Rules:
118+
1. Must have at least two positional parameters (whatever their names).
119+
2. All provided extra_args must exist in the function signature (unless **kwargs is present).
120+
3. All additional required (non-default) params beyond the first two must be satisfied via extra_args.
121+
"""
122+
sig = inspect.signature(incoming_fn)
123+
params = list(sig.parameters.values())
124+
125+
# 1️⃣ Must have at least two positional parameters (whatever names)
126+
if len(params) < 2:
127+
raise ValueError(
128+
f"Function '{incoming_fn.__name__}' must accept at least two positional parameters: "
129+
f"(output, ctx). Current: {[p.name for p in params]}"
130+
)
131+
132+
# 2️⃣ Detect **kwargs
133+
has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params)
134+
135+
# Build list of extra (non-core) params after the first two
136+
extra_params = [
137+
p for p in params[2:] if p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
138+
]
139+
140+
# 3️⃣ If **kwargs not present, check for invalid extra_args
141+
if not has_kwargs:
142+
valid_names = {p.name for p in params}
143+
invalid_keys = [k for k in extra_args if k not in valid_names]
144+
if invalid_keys:
145+
raise ValueError(
146+
f"Invalid extra_args for function '{incoming_fn.__name__}': {invalid_keys}. "
147+
f"Allowed parameters are: {[p.name for p in params]}."
148+
)
149+
150+
# 4️⃣ Check for missing required params (no defaults) among the extra ones
151+
missing_keys = [p.name for p in extra_params if p.default is p.empty and p.name not in extra_args]
152+
153+
if missing_keys:
154+
raise ValueError(
155+
f"Missing required extra_args for function '{incoming_fn.__name__}': {missing_keys}. "
156+
f"You must supply them via `extra_args`."
157+
)
158+
159+
160+
class FunctionTarget(TransitionTarget):
161+
"""Transition target that invokes a tool function with (prev_output, context).
162+
163+
The function must return a FunctionTargetResult object that includes the next target to transition to.
164+
"""
165+
166+
fn_name: str = Field(...)
167+
fn: Callable[..., FunctionTargetResult] = Field(..., repr=False)
168+
extra_args: dict[str, Any] = Field(default_factory=dict)
169+
170+
def __init__(
171+
self,
172+
incoming_fn: Callable[..., FunctionTargetResult],
173+
*,
174+
extra_args: dict[str, Any] | None = None,
175+
**kwargs: Any,
176+
) -> None:
177+
if callable(incoming_fn):
178+
extra_args = extra_args or {}
179+
180+
validate_fn_sig(incoming_fn, extra_args)
181+
182+
super().__init__(fn_name=incoming_fn.__name__, fn=incoming_fn, extra_args=extra_args, **kwargs)
183+
else:
184+
raise ValueError(
185+
"FunctionTarget must be initialized with a callable function as the first argument or 'fn' keyword argument."
186+
)
187+
188+
def can_resolve_for_speaker_selection(self) -> bool:
189+
return False
190+
191+
def resolve(
192+
self,
193+
groupchat: GroupChat,
194+
current_agent: ConversableAgent,
195+
user_agent: ConversableAgent | None,
196+
) -> SpeakerSelectionResult:
197+
"""Invoke the function, update context variables (optional), broadcast messages (optional), and return the next target to transition to."""
198+
last_message = (
199+
groupchat.messages[-1]["content"] if groupchat.messages and "content" in groupchat.messages[-1] else ""
200+
)
201+
202+
# Run the function to get the FunctionTargetResult
203+
function_target_result = self.fn(
204+
last_message,
205+
current_agent.context_variables,
206+
**self.extra_args,
207+
)
208+
209+
if not isinstance(function_target_result, FunctionTargetResult):
210+
raise ValueError("FunctionTarget function must return a FunctionTargetResult object.")
211+
212+
if function_target_result.context_variables:
213+
# Update the group's Context Variables if the function returned any
214+
current_agent.context_variables.update(function_target_result.context_variables.to_dict())
215+
216+
if function_target_result.messages:
217+
# If we have messages, we need to broadcast them to the appropriate agent based on the target
218+
broadcast(
219+
function_target_result.messages,
220+
groupchat,
221+
current_agent,
222+
self.fn_name,
223+
function_target_result.target,
224+
user_agent,
225+
)
226+
227+
# Resolve and return the next target
228+
return function_target_result.target.resolve(groupchat, current_agent, user_agent)
229+
230+
def display_name(self) -> str:
231+
return self.fn_name
232+
233+
def normalized_name(self) -> str:
234+
return self.fn_name.replace(" ", "_")
235+
236+
def __str__(self) -> str:
237+
return f"Transfer to tool {self.fn_name}"
238+
239+
def needs_agent_wrapper(self) -> bool:
240+
return False
241+
242+
def create_wrapper_agent(self, parent_agent: ConversableAgent, index: int) -> ConversableAgent:
243+
raise NotImplementedError("FunctionTarget is executed inline and needs no wrapper")
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
"""
5+
Minimal FunctionTarget test wiring for a two-agent group chat.
6+
"""
7+
8+
from typing import Any
9+
10+
from dotenv import load_dotenv
11+
12+
from autogen import ConversableAgent, LLMConfig
13+
from autogen.agentchat import initiate_group_chat
14+
from autogen.agentchat.group import AgentTarget, ContextVariables, FunctionTarget
15+
from autogen.agentchat.group.patterns import DefaultPattern
16+
from autogen.agentchat.group.targets.function_target import FunctionTargetMessage, FunctionTargetResult
17+
from autogen.agentchat.group.targets.transition_target import StayTarget
18+
19+
load_dotenv()
20+
21+
import logging
22+
23+
logging.basicConfig(level=logging.INFO)
24+
logger = logging.getLogger(__name__)
25+
26+
27+
def main(session_id: str | None = None) -> dict:
28+
# LLM config
29+
cfg = LLMConfig(api_type="openai", model="gpt-4o-mini")
30+
31+
# Shared context
32+
ctx = ContextVariables(data={"application": "<empty>"})
33+
34+
# Agents
35+
first_agent = ConversableAgent(
36+
name="first_agent",
37+
llm_config=cfg,
38+
system_message="Output a sample email you would send to apply to a job in tech. "
39+
"Listen to the specifics of the instructions.",
40+
)
41+
42+
second_agent = ConversableAgent(
43+
name="second_agent",
44+
llm_config=cfg,
45+
system_message="Do whatever the message sent to you tells you to do.",
46+
)
47+
48+
user_agent = ConversableAgent(
49+
name="user",
50+
human_input_mode="ALWAYS",
51+
)
52+
53+
# After-work hook
54+
def afterwork_function(output: str, context_variables: Any, next_agent: ConversableAgent) -> FunctionTargetResult:
55+
"""
56+
Switches a context variable and routes the next turn.
57+
"""
58+
logger.info(f"After-work function called. Random param: {next_agent}")
59+
if context_variables.get("application") == "<empty>":
60+
context_variables["application"] = output
61+
return FunctionTargetResult(
62+
messages="apply for a job in gpu optimization",
63+
target=StayTarget(),
64+
context_variables=context_variables,
65+
)
66+
67+
return FunctionTargetResult(
68+
messages=[
69+
FunctionTargetMessage(
70+
content=f"Revise the draft written by the first agent: {output}", msg_target=next_agent
71+
)
72+
],
73+
target=AgentTarget(next_agent),
74+
context_variables=context_variables,
75+
)
76+
77+
# Conversation pattern
78+
pattern = DefaultPattern(
79+
initial_agent=first_agent,
80+
agents=[first_agent, second_agent],
81+
user_agent=user_agent,
82+
context_variables=ctx,
83+
group_manager_args={"llm_config": cfg},
84+
)
85+
86+
# Register after-work handoff
87+
first_agent.handoffs.set_after_work(FunctionTarget(afterwork_function, extra_args={"next_agent": second_agent}))
88+
89+
# Run
90+
initiate_group_chat(
91+
pattern=pattern,
92+
messages="the job you are applying to is specifically in machine learning",
93+
max_rounds=20,
94+
)
95+
96+
return {"session_id": session_id}
97+
98+
99+
if __name__ == "__main__":
100+
main()

0 commit comments

Comments
 (0)