diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 6f8229ad..91b12d1c 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -65,10 +65,19 @@ pypistats pyversions respx resub +sadd +sismember socio +srem sse tagwords taskupdate testuuid typeerror vulnz +zadd +zrem +zremrangebyscorezscore +hdel +hget +hsetex \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6252577e..79e86ef7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ __pycache__ .venv coverage.xml .nox -spec.json \ No newline at end of file +spec.json +.idea \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e73df213..f8de695f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"] mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] sql = ["sqlalchemy[asyncio,postgresql-asyncpg,aiomysql,aiosqlite]>=2.0.0"] +redis = ["redis>=6.2.0"] [project.urls] homepage = "https://a2aproject.github.io/A2A/" @@ -91,6 +92,7 @@ dev = [ "types-protobuf", "types-requests", "pre-commit", + "fakeredis>=2.30.1", ] [[tool.uv.index]] diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 1ce2bd21..76a0b512 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -2,6 +2,10 @@ import logging import sys +from typing import Annotated + +from pydantic import Field + from a2a.types import ( Message, Task, @@ -14,7 +18,10 @@ logger = logging.getLogger(__name__) -Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent +Event = Annotated[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, + Field(discriminator='kind'), +] """Type alias for events that can be enqueued.""" DEFAULT_MAX_QUEUE_SIZE = 1024 diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py new file mode 100644 index 00000000..f5981cc4 --- /dev/null +++ b/src/a2a/server/events/redis_queue_manager.py @@ -0,0 +1,322 @@ +import asyncio +import logging +import uuid + +from asyncio import Task +from typing import Any + +from pydantic import TypeAdapter +from redis.asyncio import Redis + +from a2a.server.events import ( + Event, + EventConsumer, + EventQueue, + NoTaskQueue, + QueueManager, + TaskQueueExists, +) + + +logger = logging.getLogger(__name__) + + +class RedisQueueManager(QueueManager): + """This implements the `QueueManager` interface using Redis for event. + + It will broadcast local events to proxy queues in other processes using redis pubsub, and subscribe event messages from redis pubsub and replay to local proxy queues. + + Args: + redis_client(Redis): asyncio redis connection. + relay_channel_key_prefix(str): prefix for pubsub channel key generation. + task_registry_key(str): key for set data where stores active `task_id`s. + task_id_ttl_in_second: TTL for task id in global registry + node_id: A unique id to be associated with task id in global registry. If node id is not matched, events won't be populated to queues in other `RedisQueueManager`s. + """ + + def __init__( + self, + redis_client: Redis, + relay_channel_key_prefix: str = 'a2a.event.relay.', + task_registry_key: str = 'a2a.event.registry', + task_id_ttl_in_second: int = 60 * 60 * 24, + node_id: str | None = None, + ): + self._redis = redis_client + self._local_queue: dict[str, EventQueue] = {} + self._proxy_queue: dict[str, EventQueue] = {} + self._lock = asyncio.Lock() + self._pubsub = redis_client.pubsub() + self._relay_channel_name = relay_channel_key_prefix + self._background_tasks: dict[str, Task] = {} + self._task_registry_name = task_registry_key + self._pubsub_listener_task: Task | None = None + self._task_id_ttl_in_second = task_id_ttl_in_second + self._node_id = node_id or str(uuid.uuid4()) + + def _task_channel_name(self, task_id: str) -> str: + return self._relay_channel_name + task_id + + async def _has_task_id(self, task_id: str) -> bool: + ret = await self._redis.hget(self._task_registry_name, task_id) # type: ignore [misc] + return ret is not None + + async def _register_task_id(self, task_id: str) -> None: + assert await self._redis.hsetex( + name=self._task_registry_name, + key=task_id, + value=self._node_id, + ex=self._task_id_ttl_in_second, + ) == 1, 'should have registered task id' # type: ignore [misc] + logger.debug( + f'Registered task_id {task_id} to node {self._node_id} in registry.' + ) + task_started_event = asyncio.Event() + + async def _wrapped_listen_and_relay() -> None: + task_started_event.set() + c = EventConsumer(self._local_queue[task_id].tap()) + async for event in c.consume_all(): + logger.debug( + f'Publishing event for task {task_id} in QM {self}: {event}' + ) + expected_node_id = await self._redis.hget( + self._task_registry_name, task_id + ) # type: ignore [misc] + if not expected_node_id: + logger.warning(f'Task {task_id} is expired or not registered yet.') + continue + expected_node_id = ( + expected_node_id.decode('utf-8') + if hasattr(expected_node_id, 'decode') + else expected_node_id + ) + if expected_node_id == self._node_id: + # publish message + await self._redis.publish( + self._task_channel_name(task_id), + event.model_dump_json(exclude_none=True), + ) # type: ignore [misc] + # update TTL for task_id + await self._redis.hsetex( + name=self._task_registry_name, + key=task_id, + value=self._node_id, + ex=self._task_id_ttl_in_second, + ) # type: ignore [misc] + else: + logger.warning( + f'Task {task_id} is not registered on this node. Expected node id: {expected_node_id}' + ) + break + + self._background_tasks[task_id] = asyncio.create_task( + _wrapped_listen_and_relay() + ) + await task_started_event.wait() + logger.debug(f'Started to listen and relay events for task {task_id}') + + async def _remove_task_id(self, task_id: str) -> bool: + if task_id in self._background_tasks: + self._background_tasks[task_id].cancel( + 'task_id is closed: ' + task_id + ) # type: ignore [misc] + return await self._redis.hdel(self._task_registry_name, task_id) == 1 # type: ignore [misc] + + async def _subscribe_remote_task_events(self, task_id: str) -> None: + channel_id = self._task_channel_name(task_id) + await self._pubsub.subscribe(**{channel_id: self._relay_remote_events}) + # this is a global listener to handle incoming pubsub events + if not self._pubsub_listener_task: + logger.debug('Creating pubsub listener task.') + self._pubsub_listener_task = asyncio.create_task( + self._consume_pubsub_messages() + ) + logger.debug(f'Subscribed for remote events for task {task_id}') + + async def _consume_pubsub_messages(self) -> None: + async for _ in self._pubsub.listen(): + pass + + async def _relay_remote_events( + self, subscription_event: dict[str, Any] + ) -> None: + if ( + 'channel' not in subscription_event + or 'data' not in subscription_event + ): + logger.warning( + f'channel or data is absent in subscription event: {subscription_event}' + ) + return + + channel_id: str = subscription_event['channel'].decode('utf-8') + data_string: str = subscription_event['data'].decode('utf-8') + task_id = channel_id.split('.')[-1] + if task_id not in self._proxy_queue: + logger.warning(f'task_id {task_id} not found in proxy queue') + return + + try: + logger.debug( + f'Received event for task_id {task_id} in QM {self}: {data_string}' + ) + event: Event = TypeAdapter(Event).validate_json(data_string) + except Exception as e: + logger.warning( + f'Failed to parse event from subscription event: {subscription_event}: {e}' + ) + return + + logger.debug( + f'Enqueuing event for task_id {task_id} in QM {self}: {event}' + ) + await self._proxy_queue[task_id].enqueue_event(event) + + async def _unsubscribe_remote_task_events(self, task_id: str) -> None: + # unsubscribe channel for given task_id + await self._pubsub.unsubscribe(self._task_channel_name(task_id)) + # release global listener if not channel is subscribed + if not self._pubsub.subscribed and self._pubsub_listener_task: + self._pubsub_listener_task.cancel() + self._pubsub_listener_task = None + + async def add(self, task_id: str, queue: EventQueue) -> None: + """Add a new local event queue for the specified task. + + Args: + task_id (str): The identifier of the task. + queue (EventQueue): The event queue to be added. + + Raises: + TaskQueueExists: If a queue for the task already exists. + """ + logger.debug(f'add {task_id}') + async with self._lock: + if await self._has_task_id(task_id): + raise TaskQueueExists() + self._local_queue[task_id] = queue + await self._register_task_id(task_id) + logger.debug(f'Local queue is created for task {task_id}') + + async def get(self, task_id: str) -> EventQueue | None: + """Get the event queue associated with the given task ID. + + This method first checks if there is a local queue for the task. + If not found, it checks the global registry and creates a proxy queue + if the task exists globally but not locally. + + Args: + task_id (str): The identifier of the task. + + Returns: + EventQueue | None: The event queue if found, otherwise None. + """ + logger.debug(f'get {task_id}') + async with self._lock: + # lookup locally + if task_id in self._local_queue: + logger.debug(f'Got local queue for task_id {task_id}') + return self._local_queue[task_id] + # lookup globally + if await self._has_task_id(task_id): + if task_id not in self._proxy_queue: + logger.debug(f'Creating proxy queue for {task_id}') + queue = EventQueue() + self._proxy_queue[task_id] = queue + await self._subscribe_remote_task_events(task_id) + logger.debug(f'Got proxy queue for task_id {task_id}') + return self._proxy_queue[task_id] + logger.warning( + f'Attempted to get non-existing queue for task {task_id}' + ) + return None + + async def tap(self, task_id: str) -> EventQueue | None: + """Create a duplicate reference to an existing event queue for the task. + + Args: + task_id (str): The identifier of the task. + + Returns: + EventQueue | None: A new reference to the event queue if it exists, otherwise None. + """ + logger.debug(f'tap {task_id}') + event_queue = await self.get(task_id) + if event_queue: + logger.debug(f'Tapping event queue for task: {task_id}') + return event_queue.tap() + return None + + async def close(self, task_id: str) -> None: + """Close the event queue associated with the given task ID. + + If the queue is a local queue, it will be removed from both the local store + and the global registry. If it's a proxy queue, only the proxy will be closed + and unsubscribed from remote events without removing from the global registry. + + Args: + task_id (str): The identifier of the task. + + Raises: + NoTaskQueue: If no queue exists for the given task ID. + """ + logger.debug(f'close {task_id}') + async with self._lock: + if task_id in self._local_queue: + # remove from global registry if a local queue is closed + await self._remove_task_id(task_id) + # close locally + queue = self._local_queue.pop(task_id) + await queue.close() + logger.debug(f'Closing local queue for task {task_id}') + return + + if task_id in self._proxy_queue: + # close proxy queue + queue = self._proxy_queue.pop(task_id) + await queue.close() + # unsubscribe from remote, but don't remove from global registry + await self._unsubscribe_remote_task_events(task_id) + logger.debug(f'Closing proxy queue for task {task_id}') + return + + logger.warning( + f'Attempted to close non-existing queue found for task {task_id}' + ) + raise NoTaskQueue() + + async def create_or_tap(self, task_id: str) -> EventQueue: + """Create a new local queue or return a reference to an existing one. + + If the task already has a queue (either local or proxy), this method returns + a reference to that queue. Otherwise, a new local queue is created and registered. + + Args: + task_id (str): The identifier of the task. + + Returns: + EventQueue: An event queue associated with the given task ID. + """ + logger.debug(f'create_or_tap {task_id}') + async with self._lock: + if await self._has_task_id(task_id): + # if it's a local queue, tap directly + if task_id in self._local_queue: + logger.debug(f'Tapping a local queue for task {task_id}') + return self._local_queue[task_id].tap() + + # if it's a proxy queue, tap the proxy + if task_id not in self._proxy_queue: + # if the proxy is not created, create the proxy + queue = EventQueue() + self._proxy_queue[task_id] = queue + await self._subscribe_remote_task_events(task_id) + logger.debug(f'Tapping a proxy queue for task {task_id}') + return self._proxy_queue[task_id].tap() + # the task doesn't exist before, create a local queue + queue = EventQueue() + self._local_queue[task_id] = queue + await self._register_task_id(task_id) + logger.debug(f'Creating a local queue for task {task_id}') + return queue diff --git a/tests/server/events/test_redis_queue_manager.py b/tests/server/events/test_redis_queue_manager.py new file mode 100644 index 00000000..2826ab75 --- /dev/null +++ b/tests/server/events/test_redis_queue_manager.py @@ -0,0 +1,128 @@ +import asyncio + +from unittest.mock import MagicMock + +import pytest + +from fakeredis import FakeAsyncRedis + +from a2a.server.events import EventQueue, TaskQueueExists +from a2a.server.events.redis_queue_manager import RedisQueueManager + + +class TestRedisQueueManager: + @pytest.fixture + def redis(self): + return FakeAsyncRedis() + + @pytest.fixture + def queue_manager(self, redis): + return RedisQueueManager(redis) + + @pytest.fixture + def event_queue(self): + queue = MagicMock(spec=EventQueue) + # Mock the tap method to return itself + queue.tap.return_value = queue + return queue + + @pytest.mark.asyncio + async def test_init(self, queue_manager): + assert queue_manager._local_queue == {} + assert queue_manager._proxy_queue == {} + assert isinstance(queue_manager._lock, asyncio.Lock) + + @pytest.mark.asyncio + async def test_add_new_queue(self, queue_manager, event_queue): + """Test adding a new queue to the manager.""" + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + assert queue_manager._local_queue[task_id] == event_queue + + @pytest.mark.asyncio + async def test_add_existing_queue(self, queue_manager, event_queue): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + + with pytest.raises(TaskQueueExists): + await queue_manager.add(task_id, event_queue) + + @pytest.mark.asyncio + async def test_get_existing_queue(self, queue_manager, event_queue): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + + result = await queue_manager.get(task_id) + assert result == event_queue + + @pytest.mark.asyncio + async def test_get_nonexistent_queue(self, queue_manager): + result = await queue_manager.get('nonexistent_task_id') + assert result is None + + @pytest.mark.asyncio + async def test_tap_existing_queue(self, queue_manager, event_queue): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + event_queue.tap.assert_called_once() + + result = await queue_manager.tap(task_id) + assert result == event_queue + assert event_queue.tap.call_count == 2 + + @pytest.mark.asyncio + async def test_tap_nonexistent_queue(self, queue_manager): + result = await queue_manager.tap('nonexistent_task_id') + assert result is None + + @pytest.mark.asyncio + async def test_close_existing_queue(self, queue_manager, event_queue): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + + await queue_manager.close(task_id) + assert task_id not in queue_manager._local_queue + + @pytest.mark.asyncio + async def test_create_or_tap_existing_queue( + self, queue_manager, event_queue + ): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + event_queue.tap.assert_called_once() + + result = await queue_manager.create_or_tap(task_id) + + assert result == event_queue + assert event_queue.tap.call_count == 2 + + @pytest.mark.asyncio + async def test_concurrency(self, queue_manager): + async def add_task(task_id): + queue = EventQueue() + await queue_manager.add(task_id, queue) + return task_id + + async def get_task(task_id): + return await queue_manager.get(task_id) + + # Create 10 different task IDs + task_ids = [f'task_{i}' for i in range(10)] + + # Add tasks concurrently + add_tasks = [add_task(task_id) for task_id in task_ids] + added_task_ids = await asyncio.gather(*add_tasks) + + # Verify all tasks were added + assert set(added_task_ids) == set(task_ids) + + # Get tasks concurrently + get_tasks = [get_task(task_id) for task_id in task_ids] + queues = await asyncio.gather(*get_tasks) + + # Verify all queues are not None + assert all(queue is not None for queue in queues) + + # Verify all tasks are in the manager + for task_id in task_ids: + assert task_id in queue_manager._local_queue diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 5581711e..acc1ade6 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -1,10 +1,15 @@ import asyncio +import logging +import os +from asyncio import QueueEmpty from typing import Any from unittest import mock import pytest +import pytest_asyncio +from redis.asyncio import Redis from starlette.authentication import ( AuthCredentials, AuthenticationBackend, @@ -22,6 +27,8 @@ A2AFastAPIApplication, A2AStarletteApplication, ) +from a2a.server.events import EventQueue, NoTaskQueue, TaskQueueExists +from a2a.server.events.redis_queue_manager import RedisQueueManager from a2a.types import ( AgentCapabilities, AgentCard, @@ -44,6 +51,7 @@ TextPart, UnsupportedOperationError, ) +from a2a.utils import new_agent_text_message from a2a.utils.errors import MethodNotImplementedError @@ -884,3 +892,166 @@ def test_non_dict_json(client: TestClient): data = response.json() assert 'error' in data assert data['error']['code'] == InvalidRequestError().code + + +# === RedisQueueManager === +@pytest.mark.asyncio +@pytest_asyncio.fixture(scope='function') +async def asyncio_redis(): + redis_server_url = os.getenv('REDIS_SERVER_URL') + if redis_server_url: + redis = Redis.from_url(redis_server_url) + else: + # use fake redis instead if no redis server url is given + from fakeredis import FakeAsyncRedis + + redis = FakeAsyncRedis() + logging.info('flush redis for next test case') + await redis.flushall(asynchronous=False) + yield redis + await redis.close() + + +@pytest.mark.asyncio +async def test_redis_queue_local_only_queue(asyncio_redis): + queue_manager = RedisQueueManager(asyncio_redis) + + # setup local queues + q1 = EventQueue() + await queue_manager.add('task_1', q1) + q2 = EventQueue() + await queue_manager.add('task_2', q2) + q3 = await queue_manager.tap('task_1') + assert await queue_manager.get('task_1') == q1 + assert await queue_manager.get('task_2') == q2 + + # send and receive locally + msg1 = new_agent_text_message('hello') + await q1.enqueue_event(msg1) + assert await q1.dequeue_event(no_wait=True) == msg1 + assert await q3.dequeue_event(no_wait=True) == msg1 + # raise error if queue is empty + with pytest.raises(asyncio.QueueEmpty): + await q1.dequeue_event(no_wait=True) + # q2 is empty + with pytest.raises(asyncio.QueueEmpty): + await q2.dequeue_event(no_wait=True) + + +@pytest.mark.asyncio +async def test_redis_queue_mixed_queue(asyncio_redis): + qm1 = RedisQueueManager(asyncio_redis) + qm2 = RedisQueueManager(asyncio_redis) + qm3 = RedisQueueManager(asyncio_redis) + qm4 = RedisQueueManager(asyncio_redis) + qm5 = RedisQueueManager(asyncio_redis) + + # create local queue in qm1 + q1 = EventQueue() + await qm1.add('task_1', q1) + assert 'task_1' in qm1._local_queue + assert await qm1.get('task_1') == q1 + + # create proxy queue in qm2 through `get` method + q1_1 = await qm2.get('task_1') + assert 'task_1' in qm2._proxy_queue and 'task_1' not in qm2._local_queue + assert q1_1 != q1 + + # create proxy queue in qm3 through `tap` method + q1_2 = await qm3.tap('task_1') + assert 'task_1' in qm3._proxy_queue and 'task_1' not in qm3._local_queue + + # create proxy queue in qm4 through `create_or_tap` method + q1_3 = await qm4.create_or_tap('task_1') + assert 'task_1' in qm4._proxy_queue and 'task_1' not in qm4._local_queue + + # create local queue in qm5 through `create_or_tap` method + q2 = await qm5.create_or_tap('task_2') + assert 'task_2' in qm5._local_queue and 'task_2' not in qm5._proxy_queue + + # enqueue and dequeue in q1 + msg1 = new_agent_text_message('hello') + await q1.enqueue_event(msg1) + assert await q1.dequeue_event() == msg1 + q1.task_done() + with pytest.raises(asyncio.QueueEmpty): + await q1.dequeue_event(no_wait=True) + + # dequeue in q1_1 + msg1_1: Message = await q1_1.dequeue_event() + assert msg1_1.parts[0].root.text == msg1.parts[0].root.text + q1_1.task_done() + with pytest.raises(asyncio.QueueEmpty): + await q1_1.dequeue_event(no_wait=True) + + # dequeue in q1_2 + msg1_2: Message = await q1_2.dequeue_event() + assert msg1_2.parts[0].root.text == msg1.parts[0].root.text + q1_2.task_done() + with pytest.raises(asyncio.QueueEmpty): + await q1_2.dequeue_event(no_wait=True) + # get proxy queue task_1 and dequeue or close method will block forever. q1_2 is a tapped queue. + _ = await qm3.get('task_1') + await _.dequeue_event() + _.task_done() + + # dequeue in q1_3 + msg1_3: Message = await q1_3.dequeue_event() + assert msg1_3.parts[0].root.text == msg1.parts[0].root.text + q1_3.task_done() + with pytest.raises(asyncio.QueueEmpty): + await q1_3.dequeue_event(no_wait=True) + # get proxy queue task_1 and dequeue or close method will block forever. q1_3 is a tapped queue. + _ = await qm4.get('task_1') + await _.dequeue_event() + _.task_done() + + # enqueue and dequeue in q2 + msg2 = new_agent_text_message('world') + await q2.enqueue_event(msg2) + assert await q2.dequeue_event() == msg2 + q2.task_done() + + # close queues + await qm1.close('task_1') + await qm2.close('task_1') + await qm3.close('task_1') + await qm4.close('task_1') + await qm5.close('task_2') + with pytest.raises(NoTaskQueue): + await qm5.close('task_10000') + + +@pytest.mark.asyncio +async def test_redis_queue_task_id_expiration(asyncio_redis): + qm1 = RedisQueueManager( + asyncio_redis, task_id_ttl_in_second=1 + ) + qm2 = RedisQueueManager( + asyncio_redis, task_id_ttl_in_second=1 + ) + q1 = EventQueue() + await qm1.add('task_1', q1) + # add task_1 again to trigger exception + with pytest.raises(TaskQueueExists): + await qm2.add('task_1', q1) + q2 = await qm2.get('task_1') + assert q2 + # enqueue message to q1, and dequeue in q2 + msg1 = new_agent_text_message('hello') + await q1.enqueue_event(msg1) + assert await q2.dequeue_event() + q2.task_done() + + # sleep for 1 second to expire + await asyncio.sleep(1) + assert await qm1._has_task_id('task_1') is False + assert await qm2._has_task_id('task_1') is False + + # enqueue a message in order to update TTL + msg2 = new_agent_text_message('world') + await q1.enqueue_event(msg2) + assert await qm1._has_task_id('task_1') is False + # after qm1's ownership for task_1 is expired, enqueue in q1 won't broadcast to q2 + with pytest.raises(QueueEmpty): + await q2.dequeue_event(no_wait=True)