Skip to content

Add support for stream pipeline decorator #543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"openai>=1.3.0",
"dashscope>=1.19.0",
"nest_asyncio",
"shortuuid",
]

extra_service_requires = [
Expand Down
10 changes: 9 additions & 1 deletion src/agentscope/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@
import sys
import tempfile
import threading
from typing import Any, Generator, Optional, Union, Tuple, Literal, List
from typing import (
Any,
Generator,
Optional,
Union,
Tuple,
Literal,
List,
)
from urllib.parse import urlparse

import psutil
Expand Down
8 changes: 8 additions & 0 deletions src/agentscope/utils/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
""" Decorators for agent applications """
from .stream_pipeline import pipeline


__all__ = [
"pipeline",
]
8 changes: 8 additions & 0 deletions src/agentscope/utils/decorators/stream_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
""" Stream pipeline decorator """
from .pipeline_utils import pipeline


__all__ = [
"pipeline",
]
90 changes: 90 additions & 0 deletions src/agentscope/utils/decorators/stream_pipeline/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
""" Hooks for stream output """
# pylint: disable=unused-argument
import time
import threading
from collections import defaultdict
from typing import Union, Optional, Generator

from ....agents import AgentBase
from ....message import Msg

_MSG_INSTANCE = defaultdict(list)
_LOCKS = defaultdict(threading.Lock)


def pre_speak_msg_buffer_hook(
self: AgentBase,
x: Msg,
stream: bool,
last: bool,
) -> Union[Msg, None]:
"""Hook for pre speak msg buffer"""
thread_id = threading.current_thread().name
if thread_id.startswith("pipeline"):
with _LOCKS[thread_id]:
_MSG_INSTANCE[thread_id].append(x)
return x


def clear_msg_instances(thread_id: Optional[str] = None) -> None:
"""
Clears all message instances for a specific thread ID.
This function removes all message instances associated with a given
thread ID (`thread_id`). It ensures thread safety through the use of a
threading lock when accessing the shared message instance list. This
prevents race conditions in concurrent environments.
Args:
thread_id (optional): The thread ID for which to clear message
instances. If `None`, the function will do nothing.
Notes:
- It assumes the existence of a global `_LOCKS` for synchronization and
a dictionary `_MSG_INSTANCE` where each thread ID maps to a list of
message instances.
"""
if not thread_id:
return

with _LOCKS[thread_id]:
_MSG_INSTANCE[thread_id].clear()


def get_msg_instances(thread_id: Optional[str] = None) -> Generator:
"""
A generator function that yields message instances for a specific thread ID
This function is designed to continuously monitor and yield new message
instances associated with a given thread ID (`thread_id`). It ensures
thread safety through the use of a threading lock when accessing the shared
message instance list. This prevents race conditions in concurrent
environments.
Args:
thread_id (optional): The thread ID for which to monitor and yield
message instances. If `None`, the function will yield `None` and
terminate.
Yields:
The next available message instance for the specified thread ID. If no
message is available, it will wait and check periodically.
Notes:
- The function uses a small delay (`time.sleep(0.1)`) to prevent busy
waiting. This ensures efficient CPU usage while waiting for new
messages.
- It assumes the existence of a global `_LOCK` for synchronization and
a dictionary `_MSG_INSTANCE` where each thread ID maps to a list of
message instances.
Example:
for msg in get_msg_instances(thread_id=123):
process_message(msg)
"""
if not thread_id:
yield
return

while True:
with _LOCKS[thread_id]:
if _MSG_INSTANCE[thread_id]:
yield _MSG_INSTANCE[thread_id].pop(0), len(
_MSG_INSTANCE[thread_id],
)
else:
yield None, None
time.sleep(0.1) # Avoid busy waiting
83 changes: 83 additions & 0 deletions src/agentscope/utils/decorators/stream_pipeline/pipeline_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
""" Utils for pipeline decorator """
# pylint: disable=redefined-outer-name
import uuid
import threading
from functools import wraps
from typing import Callable, Any, Generator, Type


# mypy: disable-error-code="name-defined"
def pipeline(func: Callable) -> Callable:
"""
A decorator that runs the given function in a separate thread and yields
message instances as they are logged.
This decorator is used to execute a function concurrently while providing a
mechanism to yield messages produced during its execution. It leverages
threading to run the function in parallel, yielding messages until the
function completes.
Args:
func: The function to be executed in a separate thread.
Returns:
A wrapped function that, when called, returns a generator yielding
message instances.
"""

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Generator:
from ....agents import AgentBase
from .hooks import (
get_msg_instances,
clear_msg_instances,
pre_speak_msg_buffer_hook,
)

global AgentBase

def _register_pre_speak_msg_buffer_hook(
cls: Type[AgentBase],
) -> Type[AgentBase]:
original_init = cls.__init__

@wraps(original_init)
def modified_init(self: Any, *args: Any, **kwargs: Any) -> None:
original_init(self, *args, **kwargs)
self.register_pre_speak_hook(
"pre_speak_msg_buffer_hook",
pre_speak_msg_buffer_hook,
)

cls.__init__ = modified_init
return cls

original_agent_base = AgentBase

try:
AgentBase = _register_pre_speak_msg_buffer_hook(AgentBase)

thread_id = "pipeline" + str(uuid.uuid4())

# Run the main function in a separate thread
thread = threading.Thread(
target=func,
name=thread_id,
args=args,
kwargs=kwargs,
)
clear_msg_instances(thread_id=thread_id)
thread.start()

# Yield new Msg instances as they are logged
for msg, msg_len in get_msg_instances(thread_id=thread_id):
if msg:
yield msg
# Break if the thread is dead and no more messages are expected
if not thread.is_alive() and msg_len == 0:
break

# Wait for the function to finish
thread.join()
finally:
AgentBase = original_agent_base

return wrapper