diff --git a/setup.py b/setup.py index 890fb79c8..653c4d500 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "openai>=1.3.0", "dashscope>=1.19.0", "nest_asyncio", + "shortuuid", ] extra_service_requires = [ diff --git a/src/agentscope/utils/common.py b/src/agentscope/utils/common.py index 009eea358..315ad0851 100644 --- a/src/agentscope/utils/common.py +++ b/src/agentscope/utils/common.py @@ -15,7 +15,15 @@ import sys import tempfile import threading -from typing import Any, Generator, Optional, Union, Tuple, Literal, List +from typing import ( + Any, + Generator, + Optional, + Union, + Tuple, + Literal, + List, +) from urllib.parse import urlparse import psutil diff --git a/src/agentscope/utils/decorators/__init__.py b/src/agentscope/utils/decorators/__init__.py new file mode 100644 index 000000000..225b424a6 --- /dev/null +++ b/src/agentscope/utils/decorators/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +""" Decorators for agent applications """ +from .stream_pipeline import pipeline + + +__all__ = [ + "pipeline", +] diff --git a/src/agentscope/utils/decorators/stream_pipeline/__init__.py b/src/agentscope/utils/decorators/stream_pipeline/__init__.py new file mode 100644 index 000000000..c07cfdb1a --- /dev/null +++ b/src/agentscope/utils/decorators/stream_pipeline/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +""" Stream pipeline decorator """ +from .pipeline_utils import pipeline + + +__all__ = [ + "pipeline", +] diff --git a/src/agentscope/utils/decorators/stream_pipeline/hooks.py b/src/agentscope/utils/decorators/stream_pipeline/hooks.py new file mode 100644 index 000000000..7a89c7741 --- /dev/null +++ b/src/agentscope/utils/decorators/stream_pipeline/hooks.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +""" Hooks for stream output """ +# pylint: disable=unused-argument +import time +import threading +from collections import defaultdict +from typing import Union, Optional, Generator + +from ....agents import AgentBase +from ....message import Msg + +_MSG_INSTANCE = defaultdict(list) +_LOCKS = defaultdict(threading.Lock) + + +def pre_speak_msg_buffer_hook( + self: AgentBase, + x: Msg, + stream: bool, + last: bool, +) -> Union[Msg, None]: + """Hook for pre speak msg buffer""" + thread_id = threading.current_thread().name + if thread_id.startswith("pipeline"): + with _LOCKS[thread_id]: + _MSG_INSTANCE[thread_id].append(x) + return x + + +def clear_msg_instances(thread_id: Optional[str] = None) -> None: + """ + Clears all message instances for a specific thread ID. + This function removes all message instances associated with a given + thread ID (`thread_id`). It ensures thread safety through the use of a + threading lock when accessing the shared message instance list. This + prevents race conditions in concurrent environments. + Args: + thread_id (optional): The thread ID for which to clear message + instances. If `None`, the function will do nothing. + Notes: + - It assumes the existence of a global `_LOCKS` for synchronization and + a dictionary `_MSG_INSTANCE` where each thread ID maps to a list of + message instances. + """ + if not thread_id: + return + + with _LOCKS[thread_id]: + _MSG_INSTANCE[thread_id].clear() + + +def get_msg_instances(thread_id: Optional[str] = None) -> Generator: + """ + A generator function that yields message instances for a specific thread ID + This function is designed to continuously monitor and yield new message + instances associated with a given thread ID (`thread_id`). It ensures + thread safety through the use of a threading lock when accessing the shared + message instance list. This prevents race conditions in concurrent + environments. + Args: + thread_id (optional): The thread ID for which to monitor and yield + message instances. If `None`, the function will yield `None` and + terminate. + Yields: + The next available message instance for the specified thread ID. If no + message is available, it will wait and check periodically. + Notes: + - The function uses a small delay (`time.sleep(0.1)`) to prevent busy + waiting. This ensures efficient CPU usage while waiting for new + messages. + - It assumes the existence of a global `_LOCK` for synchronization and + a dictionary `_MSG_INSTANCE` where each thread ID maps to a list of + message instances. + Example: + for msg in get_msg_instances(thread_id=123): + process_message(msg) + """ + if not thread_id: + yield + return + + while True: + with _LOCKS[thread_id]: + if _MSG_INSTANCE[thread_id]: + yield _MSG_INSTANCE[thread_id].pop(0), len( + _MSG_INSTANCE[thread_id], + ) + else: + yield None, None + time.sleep(0.1) # Avoid busy waiting diff --git a/src/agentscope/utils/decorators/stream_pipeline/pipeline_utils.py b/src/agentscope/utils/decorators/stream_pipeline/pipeline_utils.py new file mode 100644 index 000000000..cd23a8c0e --- /dev/null +++ b/src/agentscope/utils/decorators/stream_pipeline/pipeline_utils.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +""" Utils for pipeline decorator """ +# pylint: disable=redefined-outer-name +import uuid +import threading +from functools import wraps +from typing import Callable, Any, Generator, Type + + +# mypy: disable-error-code="name-defined" +def pipeline(func: Callable) -> Callable: + """ + A decorator that runs the given function in a separate thread and yields + message instances as they are logged. + This decorator is used to execute a function concurrently while providing a + mechanism to yield messages produced during its execution. It leverages + threading to run the function in parallel, yielding messages until the + function completes. + Args: + func: The function to be executed in a separate thread. + Returns: + A wrapped function that, when called, returns a generator yielding + message instances. + """ + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Generator: + from ....agents import AgentBase + from .hooks import ( + get_msg_instances, + clear_msg_instances, + pre_speak_msg_buffer_hook, + ) + + global AgentBase + + def _register_pre_speak_msg_buffer_hook( + cls: Type[AgentBase], + ) -> Type[AgentBase]: + original_init = cls.__init__ + + @wraps(original_init) + def modified_init(self: Any, *args: Any, **kwargs: Any) -> None: + original_init(self, *args, **kwargs) + self.register_pre_speak_hook( + "pre_speak_msg_buffer_hook", + pre_speak_msg_buffer_hook, + ) + + cls.__init__ = modified_init + return cls + + original_agent_base = AgentBase + + try: + AgentBase = _register_pre_speak_msg_buffer_hook(AgentBase) + + thread_id = "pipeline" + str(uuid.uuid4()) + + # Run the main function in a separate thread + thread = threading.Thread( + target=func, + name=thread_id, + args=args, + kwargs=kwargs, + ) + clear_msg_instances(thread_id=thread_id) + thread.start() + + # Yield new Msg instances as they are logged + for msg, msg_len in get_msg_instances(thread_id=thread_id): + if msg: + yield msg + # Break if the thread is dead and no more messages are expected + if not thread.is_alive() and msg_len == 0: + break + + # Wait for the function to finish + thread.join() + finally: + AgentBase = original_agent_base + + return wrapper