Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/core-llm-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v5
with:
ref: ${{ github.event.pull_request.head.sha }}
ref: ${{ github.event.pull_request.head.sha || github.ref }}
- uses: astral-sh/setup-uv@v6
with:
version: "latest"
Expand Down
6 changes: 5 additions & 1 deletion autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functools import partial
from typing import Any, TypedDict

from ..code_utils import content_str
from ..doc_utils import export_module
from ..events.agent_events import PostCarryoverProcessingEvent
from ..io.base import IOStream
Expand Down Expand Up @@ -132,7 +133,10 @@ def _post_process_carryover_item(carryover_item):
if isinstance(carryover_item, str):
return carryover_item
elif isinstance(carryover_item, dict) and "content" in carryover_item:
return str(carryover_item["content"])
content_value = carryover_item.get("content")
if isinstance(content_value, (str, list)) or content_value is None:
return content_str(content_value)
return str(content_value)
else:
return str(carryover_item)

Expand Down
16 changes: 13 additions & 3 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,12 @@ def _append_oai_message(
oai_message["role"] = message.get("role")
if "tool_responses" in oai_message:
for tool_response in oai_message["tool_responses"]:
tool_response["content"] = str(tool_response["content"])
content_value = tool_response.get("content")
tool_response["content"] = (
content_str(content_value)
if isinstance(content_value, (str, list)) or content_value is None
else str(content_value)
)
elif "override_role" in message:
# If we have a direction to override the role then set the
# role accordingly. Used to customise the role for the
Expand Down Expand Up @@ -1349,9 +1354,10 @@ def _should_terminate_chat(self, recipient: "ConversableAgent", message: dict[st
Returns:
bool: True if the chat should be terminated, False otherwise.
"""
content = message.get("content")
return (
isinstance(recipient, ConversableAgent)
and isinstance(message.get("content"), str)
and content is not None
and hasattr(recipient, "_is_termination_msg")
and recipient._is_termination_msg(message)
)
Expand Down Expand Up @@ -3975,7 +3981,11 @@ def _create_or_get_executor(
if executor_kwargs is None:
executor_kwargs = {}
if "is_termination_msg" not in executor_kwargs:
executor_kwargs["is_termination_msg"] = lambda x: (x["content"] is not None) and "TERMINATE" in x["content"]
executor_kwargs["is_termination_msg"] = lambda x: "TERMINATE" in (
content_str(x.get("content"))
if isinstance(x.get("content"), (str, list)) or x.get("content") is None
else str(x.get("content"))
)

try:
if not self.run_executor:
Expand Down
12 changes: 8 additions & 4 deletions autogen/agentchat/group/group_tool_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copy import deepcopy
from typing import Annotated, Any

from ...code_utils import content_str
from ...oai import OpenAIWrapper
from ...tools import Depends, Tool
from ...tools.dependency_injection import inject_params, on
Expand Down Expand Up @@ -167,9 +168,10 @@ def _generate_group_tool_reply(
agent_name = message.get("name", sender.name if sender else "unknown")
self.set_tool_call_originator(agent_name)

if "tool_calls" in message:
if message.get("tool_calls"):
tool_call_count = len(message["tool_calls"])

tool_message = None
# Loop through tool calls individually (so context can be updated after each function call)
next_target: TransitionTarget | None = None
tool_responses_inner = []
Expand Down Expand Up @@ -203,11 +205,13 @@ def _generate_group_tool_reply(
next_target = content

# Serialize the content to a string
if content is not None:
tool_response["content"] = str(content)
normalized_content = (
content_str(content) if isinstance(content, (str, list)) or content is None else str(content)
)
tool_response["content"] = normalized_content

tool_responses_inner.append(tool_response)
contents.append(str(tool_response["content"]))
contents.append(normalized_content)

self._group_next_target = next_target # type: ignore[attr-defined]

Expand Down
38 changes: 28 additions & 10 deletions autogen/agentchat/group/safeguards/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections.abc import Callable
from typing import Any

from ....code_utils import content_str
from ....io.base import IOStream
from ....llm_config import LLMConfig
from ...conversable_agent import ConversableAgent
Expand All @@ -21,6 +22,15 @@
class SafeguardEnforcer:
"""Main safeguard enforcer - executes safeguard policies"""

@staticmethod
def _stringify_content(value: Any) -> str:
if isinstance(value, (str, list)) or value is None:
try:
return content_str(value)
except (TypeError, ValueError, AssertionError):
pass
return "" if value is None else str(value)

def __init__(
self,
policy: dict[str, Any] | str,
Expand Down Expand Up @@ -695,9 +705,12 @@ def _handle_masked_content(
# Handle tool_responses
if "tool_responses" in masked_content and masked_content["tool_responses"]:
if "content" in masked_content:
masked_content["content"] = mask_func(str(masked_content["content"]))
masked_content["content"] = mask_func(self._stringify_content(masked_content.get("content")))
masked_content["tool_responses"] = [
{**response, "content": mask_func(str(response.get("content", "")))}
{
**response,
"content": mask_func(self._stringify_content(response.get("content"))),
}
for response in masked_content["tool_responses"]
]
# Handle tool_calls
Expand All @@ -707,17 +720,17 @@ def _handle_masked_content(
**tool_call,
"function": {
**tool_call["function"],
"arguments": mask_func(str(tool_call["function"].get("arguments", ""))),
"arguments": mask_func(self._stringify_content(tool_call["function"].get("arguments"))),
},
}
for tool_call in masked_content["tool_calls"]
]
# Handle regular content
elif "content" in masked_content:
masked_content["content"] = mask_func(str(masked_content["content"]))
masked_content["content"] = mask_func(self._stringify_content(masked_content.get("content")))
# Handle arguments
elif "arguments" in masked_content:
masked_content["arguments"] = mask_func(str(masked_content["arguments"]))
masked_content["arguments"] = mask_func(self._stringify_content(masked_content.get("arguments")))

return masked_content

Expand All @@ -728,33 +741,38 @@ def _handle_masked_content(
if isinstance(item, dict):
masked_item = item.copy()
if "content" in masked_item:
masked_item["content"] = mask_func(str(masked_item["content"]))
masked_item["content"] = mask_func(self._stringify_content(masked_item.get("content")))
if "tool_calls" in masked_item:
masked_item["tool_calls"] = [
{
**tool_call,
"function": {
**tool_call["function"],
"arguments": mask_func(str(tool_call["function"].get("arguments", ""))),
"arguments": mask_func(
self._stringify_content(tool_call["function"].get("arguments"))
),
},
}
for tool_call in masked_item["tool_calls"]
]
if "tool_responses" in masked_item:
masked_item["tool_responses"] = [
{**response, "content": mask_func(str(response.get("content", "")))}
{
**response,
"content": mask_func(self._stringify_content(response.get("content"))),
}
for response in masked_item["tool_responses"]
]
masked_list.append(masked_item)
else:
# For non-dict items, wrap the masked content in a dict
masked_item_content: str = mask_func(str(item))
masked_item_content: str = mask_func(self._stringify_content(item))
masked_list.append({"content": masked_item_content, "role": "function"})
return masked_list

else:
# String content
return mask_func(str(content))
return mask_func(self._stringify_content(content))

def _check_inter_agent_communication(
self, sender_name: str, recipient_name: str, message: str | dict[str, Any]
Expand Down
52 changes: 36 additions & 16 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,13 +1303,16 @@ def run_chat(
reply = guardrails_reply

# check for "clear history" phrase in reply and activate clear history function if found
if (
groupchat.enable_clear_history
and isinstance(reply, dict)
and reply["content"]
and "CLEAR HISTORY" in reply["content"].upper()
):
reply["content"] = self.clear_agents_history(reply, groupchat)
if groupchat.enable_clear_history and isinstance(reply, dict) and reply.get("content"):
raw_content = reply.get("content")
normalized_content = (
content_str(raw_content)
if isinstance(raw_content, (str, list)) or raw_content is None
else str(raw_content)
)
if "CLEAR HISTORY" in normalized_content.upper():
reply["content"] = normalized_content
reply["content"] = self.clear_agents_history(reply, groupchat)

# The speaker sends the message without requesting a reply
speaker.send(reply, self, request_reply=False, silent=silent)
Expand Down Expand Up @@ -1420,13 +1423,16 @@ async def a_run_chat(
reply = guardrails_reply

# check for "clear history" phrase in reply and activate clear history function if found
if (
groupchat.enable_clear_history
and isinstance(reply, dict)
and reply["content"]
and "CLEAR HISTORY" in reply["content"].upper()
):
reply["content"] = self.clear_agents_history(reply, groupchat)
if groupchat.enable_clear_history and isinstance(reply, dict) and reply.get("content"):
raw_content = reply.get("content")
normalized_content = (
content_str(raw_content)
if isinstance(raw_content, (str, list)) or raw_content is None
else str(raw_content)
)
if "CLEAR HISTORY" in normalized_content.upper():
reply["content"] = normalized_content
reply["content"] = self.clear_agents_history(reply, groupchat)

# The speaker sends the message without requesting a reply
await speaker.a_send(reply, self, request_reply=False, silent=silent)
Expand Down Expand Up @@ -1701,7 +1707,13 @@ def _remove_termination_string(content: str) -> str:
_remove_termination_string = remove_termination_string

if _remove_termination_string and messages[-1].get("content"):
messages[-1]["content"] = _remove_termination_string(messages[-1]["content"])
content_value = messages[-1]["content"]
if isinstance(content_value, str):
messages[-1]["content"] = _remove_termination_string(content_value)
elif isinstance(content_value, list):
messages[-1]["content"] = _remove_termination_string(content_str(content_value))
else:
messages[-1]["content"] = _remove_termination_string(str(content_value))

# Check if the last message meets termination (if it has one)
if self._is_termination_msg and self._is_termination_msg(last_message):
Expand Down Expand Up @@ -1764,7 +1776,15 @@ def clear_agents_history(self, reply: dict[str, Any], groupchat: GroupChat) -> s
"""
iostream = IOStream.get_default()

reply_content = reply["content"]
raw_reply_content = reply.get("content")
if isinstance(raw_reply_content, str):
reply_content = raw_reply_content
elif isinstance(raw_reply_content, (list, type(None))):
reply_content = content_str(raw_reply_content)
reply["content"] = reply_content
else:
reply_content = str(raw_reply_content)
reply["content"] = reply_content
# Split the reply into words
words = reply_content.split()
# Find the position of "clear" to determine where to start processing
Expand Down
2 changes: 1 addition & 1 deletion autogen/llm_config/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ModelClient(Protocol):
class ModelClientResponseProtocol(Protocol):
class Choice(Protocol):
class Message(Protocol):
content: str | dict[str, Any]
content: str | dict[str, Any] | list[dict[str, Any]]

message: Message

Expand Down
8 changes: 5 additions & 3 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pydantic.type_adapter import TypeAdapter

from ..cache import Cache
from ..code_utils import content_str
from ..doc_utils import export_module
from ..events.client_events import StreamEvent, UsageSummaryEvent
from ..exception_utils import ModelToolNotSupportedError
Expand Down Expand Up @@ -365,11 +366,12 @@ def message_retrieval(self, response: ChatCompletion | Completion) -> list[str]
if isinstance(response, Completion):
return [choice.text for choice in choices] # type: ignore [union-attr]

def _format_content(content: str) -> str:
def _format_content(content: str | list[dict[str, Any]] | None) -> str:
normalized_content = content_str(content)
return (
self.response_format.model_validate_json(content).format()
self.response_format.model_validate_json(normalized_content).format()
if isinstance(self.response_format, FormatterProtocol)
else content
else normalized_content
)

if TOOL_ENABLED:
Expand Down
Loading
Loading