From ce19f184bef8377022cbcc12f0a583724f253749 Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Wed, 2 Jul 2025 17:38:25 +0800 Subject: [PATCH 01/11] feat(server): implement Redis-based event queue manager - Add RedisQueueManager class to handle event queues using Redis - Implement core functionalities: add, get, tap, close, create_or_tap - Add unit tests for RedisQueueManager - Update .gitignore to exclude spec.json and .idea - Refactor event type definition in event_queue.py - Add fakeredis dependency in pyproject.toml --- .gitignore | 3 +- pyproject.toml | 1 + src/a2a/server/events/event_queue.py | 5 +- src/a2a/server/events/redis_queue_manager.py | 134 ++++++++++++++++++ .../server/events/test_redis_queue_manager.py | 127 +++++++++++++++++ 5 files changed, 268 insertions(+), 2 deletions(-) create mode 100644 src/a2a/server/events/redis_queue_manager.py create mode 100644 tests/server/events/test_redis_queue_manager.py diff --git a/.gitignore b/.gitignore index 6252577e..79e86ef7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ __pycache__ .venv coverage.xml .nox -spec.json \ No newline at end of file +spec.json +.idea \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e73df213..183bd097 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0", "protobuf==5.29.5", + "fakeredis>=2.30.1", ] classifiers = [ diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 1ce2bd21..d58a214c 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -1,6 +1,9 @@ import asyncio import logging import sys +from typing import Union, Annotated + +from pydantic import Field from a2a.types import ( Message, @@ -14,7 +17,7 @@ logger = logging.getLogger(__name__) -Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent +Event = Annotated[Union[Message, Task ,TaskStatusUpdateEvent, TaskArtifactUpdateEvent], Field(discriminator="kind")] """Type alias for events that can be enqueued.""" DEFAULT_MAX_QUEUE_SIZE = 1024 diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py new file mode 100644 index 00000000..db6240d2 --- /dev/null +++ b/src/a2a/server/events/redis_queue_manager.py @@ -0,0 +1,134 @@ +import asyncio +from asyncio import Task +from functools import partial +from typing import Dict + +from redis.asyncio import Redis + +from a2a.server.events import QueueManager, EventQueue, TaskQueueExists, Event, EventConsumer, NoTaskQueue + + +class RedisQueueManager(QueueManager): + """ + This implements the `QueueManager` interface using Redis for event + queues. Primary jobs: + 1. Broadcast local events to proxy queues in other processes using redis pubsub + 2. Subscribe event messages from redis pubsub and replay to local proxy queues + """ + + def __init__(self, redis_client: Redis, + relay_channel_key_prefix: str = "a2a.event.relay.", + task_registry_key: str = "a2a.event.registry" + ): + self._redis = redis_client + self._local_queue: dict[str, EventQueue] = {} + self._proxy_queue: dict[str, EventQueue] = {} + self._lock = asyncio.Lock() + self._pubsub = redis_client.pubsub() + self._relay_channel_name = relay_channel_key_prefix + self._background_tasks: Dict[str, Task] = {} + self._task_registry_name = task_registry_key + + async def _listen_and_relay(self, task_id: str): + c = EventConsumer(self._local_queue[task_id]) + async for event in c.consume_all(): + await self._redis.publish(self._task_channel_name(task_id), event.model_dump_json(exclude_none=True)) + + def _task_channel_name(self, task_id: str): + return self._relay_channel_name + task_id + + async def _has_task_id(self, task_id: str): + ret = await self._redis.sismember(self._task_registry_name, task_id) + return ret + + async def _register_task_id(self, task_id: str): + await self._redis.sadd(self._task_registry_name, task_id) + self._background_tasks[task_id] = asyncio.create_task(self._listen_and_relay(task_id)) + + async def _remove_task_id(self, task_id: str): + if task_id in self._background_tasks: + self._background_tasks[task_id].cancel("task_id is closed: " + task_id) + return await self._redis.srem(self._task_registry_name, task_id) + + async def _subscribe_remote_task_events(self, task_id: str): + await self._pubsub.subscribe(**{ + self._task_channel_name(task_id): partial(self._relay_remote_events, task_id) + }) + + def _unsubscribe_remote_task_events(self, task_id: str): + self._pubsub.unsubscribe(self._task_channel_name(task_id)) + + def _relay_remote_events(self, task_id: str , event_json: str): + if task_id in self._proxy_queue: + event = Event.model_validate_json(event_json) + self._proxy_queue[task_id].enqueue_event(event) + + async def add(self, task_id: str, queue: EventQueue) -> None: + async with self._lock: + if await self._has_task_id(task_id): + raise TaskQueueExists() + self._local_queue[task_id] = queue + await self._register_task_id(task_id) + + async def get(self, task_id: str) -> EventQueue | None: + async with self._lock: + # lookup locally + if task_id in self._local_queue: + return self._local_queue[task_id] + # lookup globally + if await self._has_task_id(task_id): + if task_id not in self._proxy_queue: + queue = EventQueue() + self._proxy_queue[task_id] = queue + await self._subscribe_remote_task_events(task_id) + return self._proxy_queue[task_id] + return None + + async def tap(self, task_id: str) -> EventQueue | None: + event_queue = await self.get(task_id) + if event_queue: + return event_queue.tap() + return None + + async def close(self, task_id: str) -> None: + async with self._lock: + if task_id in self._local_queue: + # close locally + queue = self._local_queue.pop(task_id) + await queue.close() + # remove from global registry if a local queue is closed + await self._remove_task_id(task_id) + return None + + if task_id in self._proxy_queue: + # close proxy queue + queue = self._proxy_queue.pop(task_id) + await queue.close() + # unsubscribe from remote, but don't remove from global registry + self._unsubscribe_remote_task_events(task_id) + return None + + raise NoTaskQueue() + + async def create_or_tap(self, task_id: str) -> EventQueue: + async with self._lock: + if await self._has_task_id(task_id): + # if it's a local queue, tap directly + if task_id in self._local_queue: + return self._local_queue[task_id].tap() + + # if it's a proxy queue, tap the proxy + if task_id in self._proxy_queue: + return self._proxy_queue[task_id].tap() + + # if the proxy is not created, create the proxy and return + queue = EventQueue() + self._proxy_queue[task_id] = queue + await self._subscribe_remote_task_events(task_id) + return self._proxy_queue[task_id] + else: + # the task doesn't exist before, create a local queue + queue = EventQueue() + self._local_queue[task_id] = queue + await self._register_task_id(task_id) + return queue diff --git a/tests/server/events/test_redis_queue_manager.py b/tests/server/events/test_redis_queue_manager.py new file mode 100644 index 00000000..000dea3e --- /dev/null +++ b/tests/server/events/test_redis_queue_manager.py @@ -0,0 +1,127 @@ +import asyncio +from unittest.mock import MagicMock + +import pytest +from fakeredis import FakeAsyncRedis + +from a2a.server.events import EventQueue, TaskQueueExists +from a2a.server.events.redis_queue_manager import RedisQueueManager + + +class TestRedisQueueManager: + @pytest.fixture + def redis(self): + return FakeAsyncRedis() + + @pytest.fixture + def queue_manager(self, redis): + return RedisQueueManager(redis) + + @pytest.fixture + def event_queue(self): + queue = MagicMock(spec=EventQueue) + # Mock the tap method to return itself + queue.tap.return_value = queue + return queue + + @pytest.mark.asyncio + async def test_init(self, queue_manager): + assert queue_manager._local_queue == {} + assert queue_manager._proxy_queue == {} + assert isinstance(queue_manager._lock, asyncio.Lock) + + + @pytest.mark.asyncio + async def test_add_new_queue(self, queue_manager, event_queue): + """Test adding a new queue to the manager.""" + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + assert queue_manager._local_queue[task_id] == event_queue + + @pytest.mark.asyncio + async def test_add_existing_queue(self, queue_manager, event_queue): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + + with pytest.raises(TaskQueueExists): + await queue_manager.add(task_id, event_queue) + + @pytest.mark.asyncio + async def test_get_existing_queue(self, queue_manager, event_queue): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + + result = await queue_manager.get(task_id) + assert result == event_queue + + @pytest.mark.asyncio + async def test_get_nonexistent_queue(self, queue_manager): + result = await queue_manager.get('nonexistent_task_id') + assert result is None + + + @pytest.mark.asyncio + async def test_tap_existing_queue(self, queue_manager, event_queue): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + + result = await queue_manager.tap(task_id) + assert result == event_queue + event_queue.tap.assert_called_once() + + @pytest.mark.asyncio + async def test_tap_nonexistent_queue(self, queue_manager): + result = await queue_manager.tap('nonexistent_task_id') + assert result is None + + @pytest.mark.asyncio + async def test_close_existing_queue(self, queue_manager, event_queue): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + + await queue_manager.close(task_id) + assert task_id not in queue_manager._local_queue + + + @pytest.mark.asyncio + async def test_create_or_tap_existing_queue( + self, queue_manager, event_queue + ): + task_id = 'test_task_id' + await queue_manager.add(task_id, event_queue) + + result = await queue_manager.create_or_tap(task_id) + + assert result == event_queue + event_queue.tap.assert_called_once() + + @pytest.mark.asyncio + async def test_concurrency(self, queue_manager): + async def add_task(task_id): + queue = EventQueue() + await queue_manager.add(task_id, queue) + return task_id + + async def get_task(task_id): + return await queue_manager.get(task_id) + + # Create 10 different task IDs + task_ids = [f'task_{i}' for i in range(10)] + + # Add tasks concurrently + add_tasks = [add_task(task_id) for task_id in task_ids] + added_task_ids = await asyncio.gather(*add_tasks) + + # Verify all tasks were added + assert set(added_task_ids) == set(task_ids) + + # Get tasks concurrently + get_tasks = [get_task(task_id) for task_id in task_ids] + queues = await asyncio.gather(*get_tasks) + + # Verify all queues are not None + assert all(queue is not None for queue in queues) + + # Verify all tasks are in the manager + for task_id in task_ids: + assert task_id in queue_manager._local_queue \ No newline at end of file From 1a1709192df331e52441479f5b4bec12628bd88f Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Wed, 2 Jul 2025 22:44:21 +0800 Subject: [PATCH 02/11] refactor(server/events): improve code quality and readability using nox --- src/a2a/server/events/event_queue.py | 8 ++- src/a2a/server/events/redis_queue_manager.py | 66 ++++++++++++------- .../server/events/test_redis_queue_manager.py | 7 +- 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index d58a214c..76a0b512 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -1,7 +1,8 @@ import asyncio import logging import sys -from typing import Union, Annotated + +from typing import Annotated from pydantic import Field @@ -17,7 +18,10 @@ logger = logging.getLogger(__name__) -Event = Annotated[Union[Message, Task ,TaskStatusUpdateEvent, TaskArtifactUpdateEvent], Field(discriminator="kind")] +Event = Annotated[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, + Field(discriminator='kind'), +] """Type alias for events that can be enqueued.""" DEFAULT_MAX_QUEUE_SIZE = 1024 diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py index db6240d2..dfa453ad 100644 --- a/src/a2a/server/events/redis_queue_manager.py +++ b/src/a2a/server/events/redis_queue_manager.py @@ -1,38 +1,49 @@ import asyncio + from asyncio import Task from functools import partial -from typing import Dict from redis.asyncio import Redis -from a2a.server.events import QueueManager, EventQueue, TaskQueueExists, Event, EventConsumer, NoTaskQueue +from a2a.server.events import ( + Event, + EventConsumer, + EventQueue, + NoTaskQueue, + QueueManager, + TaskQueueExists, +) class RedisQueueManager(QueueManager): - """ - This implements the `QueueManager` interface using Redis for event + """This implements the `QueueManager` interface using Redis for event queues. Primary jobs: 1. Broadcast local events to proxy queues in other processes using redis pubsub 2. Subscribe event messages from redis pubsub and replay to local proxy queues """ - def __init__(self, redis_client: Redis, - relay_channel_key_prefix: str = "a2a.event.relay.", - task_registry_key: str = "a2a.event.registry" - ): + def __init__( + self, + redis_client: Redis, + relay_channel_key_prefix: str = 'a2a.event.relay.', + task_registry_key: str = 'a2a.event.registry', + ): self._redis = redis_client self._local_queue: dict[str, EventQueue] = {} self._proxy_queue: dict[str, EventQueue] = {} self._lock = asyncio.Lock() self._pubsub = redis_client.pubsub() self._relay_channel_name = relay_channel_key_prefix - self._background_tasks: Dict[str, Task] = {} + self._background_tasks: dict[str, Task] = {} self._task_registry_name = task_registry_key async def _listen_and_relay(self, task_id: str): c = EventConsumer(self._local_queue[task_id]) async for event in c.consume_all(): - await self._redis.publish(self._task_channel_name(task_id), event.model_dump_json(exclude_none=True)) + await self._redis.publish( + self._task_channel_name(task_id), + event.model_dump_json(exclude_none=True), + ) def _task_channel_name(self, task_id: str): return self._relay_channel_name + task_id @@ -43,22 +54,30 @@ async def _has_task_id(self, task_id: str): async def _register_task_id(self, task_id: str): await self._redis.sadd(self._task_registry_name, task_id) - self._background_tasks[task_id] = asyncio.create_task(self._listen_and_relay(task_id)) + self._background_tasks[task_id] = asyncio.create_task( + self._listen_and_relay(task_id) + ) async def _remove_task_id(self, task_id: str): if task_id in self._background_tasks: - self._background_tasks[task_id].cancel("task_id is closed: " + task_id) + self._background_tasks[task_id].cancel( + 'task_id is closed: ' + task_id + ) return await self._redis.srem(self._task_registry_name, task_id) async def _subscribe_remote_task_events(self, task_id: str): - await self._pubsub.subscribe(**{ - self._task_channel_name(task_id): partial(self._relay_remote_events, task_id) - }) + await self._pubsub.subscribe( + **{ + self._task_channel_name(task_id): partial( + self._relay_remote_events, task_id + ) + } + ) def _unsubscribe_remote_task_events(self, task_id: str): self._pubsub.unsubscribe(self._task_channel_name(task_id)) - def _relay_remote_events(self, task_id: str , event_json: str): + def _relay_remote_events(self, task_id: str, event_json: str): if task_id in self._proxy_queue: event = Event.model_validate_json(event_json) self._proxy_queue[task_id].enqueue_event(event) @@ -98,7 +117,7 @@ async def close(self, task_id: str) -> None: await queue.close() # remove from global registry if a local queue is closed await self._remove_task_id(task_id) - return None + return if task_id in self._proxy_queue: # close proxy queue @@ -106,7 +125,7 @@ async def close(self, task_id: str) -> None: await queue.close() # unsubscribe from remote, but don't remove from global registry self._unsubscribe_remote_task_events(task_id) - return None + return raise NoTaskQueue() @@ -126,9 +145,8 @@ async def create_or_tap(self, task_id: str) -> EventQueue: self._proxy_queue[task_id] = queue await self._subscribe_remote_task_events(task_id) return self._proxy_queue[task_id] - else: - # the task doesn't exist before, create a local queue - queue = EventQueue() - self._local_queue[task_id] = queue - await self._register_task_id(task_id) - return queue + # the task doesn't exist before, create a local queue + queue = EventQueue() + self._local_queue[task_id] = queue + await self._register_task_id(task_id) + return queue diff --git a/tests/server/events/test_redis_queue_manager.py b/tests/server/events/test_redis_queue_manager.py index 000dea3e..3d9b4a4d 100644 --- a/tests/server/events/test_redis_queue_manager.py +++ b/tests/server/events/test_redis_queue_manager.py @@ -1,7 +1,9 @@ import asyncio + from unittest.mock import MagicMock import pytest + from fakeredis import FakeAsyncRedis from a2a.server.events import EventQueue, TaskQueueExists @@ -30,7 +32,6 @@ async def test_init(self, queue_manager): assert queue_manager._proxy_queue == {} assert isinstance(queue_manager._lock, asyncio.Lock) - @pytest.mark.asyncio async def test_add_new_queue(self, queue_manager, event_queue): """Test adding a new queue to the manager.""" @@ -59,7 +60,6 @@ async def test_get_nonexistent_queue(self, queue_manager): result = await queue_manager.get('nonexistent_task_id') assert result is None - @pytest.mark.asyncio async def test_tap_existing_queue(self, queue_manager, event_queue): task_id = 'test_task_id' @@ -82,7 +82,6 @@ async def test_close_existing_queue(self, queue_manager, event_queue): await queue_manager.close(task_id) assert task_id not in queue_manager._local_queue - @pytest.mark.asyncio async def test_create_or_tap_existing_queue( self, queue_manager, event_queue @@ -124,4 +123,4 @@ async def get_task(task_id): # Verify all tasks are in the manager for task_id in task_ids: - assert task_id in queue_manager._local_queue \ No newline at end of file + assert task_id in queue_manager._local_queue From 2ad5ff318bad94d575e8a31fda74cdc7e9cca072 Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Thu, 3 Jul 2025 10:37:05 +0800 Subject: [PATCH 03/11] refactor(server): fix ruff errors and spelling check errors --- .github/actions/spelling/allow.txt | 3 + pyproject.toml | 2 + src/a2a/server/events/redis_queue_manager.py | 84 ++++++++++++++++---- 3 files changed, 75 insertions(+), 14 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 6f8229ad..532bdeb8 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -72,3 +72,6 @@ taskupdate testuuid typeerror vulnz +sadd +sismember +srem \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 183bd097..f9b9020f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "grpcio_reflection>=1.7.0", "protobuf==5.29.5", "fakeredis>=2.30.1", + "redis>=6.2.0", ] classifiers = [ @@ -92,6 +93,7 @@ dev = [ "types-protobuf", "types-requests", "pre-commit", + "fakeredis>=2.30.1", ] [[tool.uv.index]] diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py index dfa453ad..9dbb1826 100644 --- a/src/a2a/server/events/redis_queue_manager.py +++ b/src/a2a/server/events/redis_queue_manager.py @@ -16,10 +16,14 @@ class RedisQueueManager(QueueManager): - """This implements the `QueueManager` interface using Redis for event - queues. Primary jobs: - 1. Broadcast local events to proxy queues in other processes using redis pubsub - 2. Subscribe event messages from redis pubsub and replay to local proxy queues + """This implements the `QueueManager` interface using Redis for event. + + It will broadcast local events to proxy queues in other processes using redis pubsub, and subscribe event messages from redis pubsub and replay to local proxy queues. + + Args: + redis_client(Redis): asyncio redis connection. + relay_channel_key_prefix(str): prefix for pubsub channel key generation. + task_registry_key(str): key for set data where stores active `task_id`s. """ def __init__( @@ -37,7 +41,7 @@ def __init__( self._background_tasks: dict[str, Task] = {} self._task_registry_name = task_registry_key - async def _listen_and_relay(self, task_id: str): + async def _listen_and_relay(self, task_id: str) -> None: c = EventConsumer(self._local_queue[task_id]) async for event in c.consume_all(): await self._redis.publish( @@ -45,27 +49,27 @@ async def _listen_and_relay(self, task_id: str): event.model_dump_json(exclude_none=True), ) - def _task_channel_name(self, task_id: str): + def _task_channel_name(self, task_id: str) -> str: return self._relay_channel_name + task_id - async def _has_task_id(self, task_id: str): + async def _has_task_id(self, task_id: str) -> bool: ret = await self._redis.sismember(self._task_registry_name, task_id) - return ret + return ret == 1 - async def _register_task_id(self, task_id: str): + async def _register_task_id(self, task_id: str) -> None: await self._redis.sadd(self._task_registry_name, task_id) self._background_tasks[task_id] = asyncio.create_task( self._listen_and_relay(task_id) ) - async def _remove_task_id(self, task_id: str): + async def _remove_task_id(self, task_id: str) -> bool: if task_id in self._background_tasks: self._background_tasks[task_id].cancel( 'task_id is closed: ' + task_id ) - return await self._redis.srem(self._task_registry_name, task_id) + return await self._redis.srem(self._task_registry_name, task_id) == 1 - async def _subscribe_remote_task_events(self, task_id: str): + async def _subscribe_remote_task_events(self, task_id: str) -> None: await self._pubsub.subscribe( **{ self._task_channel_name(task_id): partial( @@ -74,15 +78,24 @@ async def _subscribe_remote_task_events(self, task_id: str): } ) - def _unsubscribe_remote_task_events(self, task_id: str): + def _unsubscribe_remote_task_events(self, task_id: str) -> None: self._pubsub.unsubscribe(self._task_channel_name(task_id)) - def _relay_remote_events(self, task_id: str, event_json: str): + def _relay_remote_events(self, task_id: str, event_json: str) -> None: if task_id in self._proxy_queue: event = Event.model_validate_json(event_json) self._proxy_queue[task_id].enqueue_event(event) async def add(self, task_id: str, queue: EventQueue) -> None: + """Add a new local event queue for the specified task. + + Args: + task_id (str): The identifier of the task. + queue (EventQueue): The event queue to be added. + + Raises: + TaskQueueExists: If a queue for the task already exists. + """ async with self._lock: if await self._has_task_id(task_id): raise TaskQueueExists() @@ -90,6 +103,18 @@ async def add(self, task_id: str, queue: EventQueue) -> None: await self._register_task_id(task_id) async def get(self, task_id: str) -> EventQueue | None: + """Get the event queue associated with the given task ID. + + This method first checks if there is a local queue for the task. + If not found, it checks the global registry and creates a proxy queue + if the task exists globally but not locally. + + Args: + task_id (str): The identifier of the task. + + Returns: + EventQueue | None: The event queue if found, otherwise None. + """ async with self._lock: # lookup locally if task_id in self._local_queue: @@ -104,12 +129,32 @@ async def get(self, task_id: str) -> EventQueue | None: return None async def tap(self, task_id: str) -> EventQueue | None: + """Create a duplicate reference to an existing event queue for the task. + + Args: + task_id (str): The identifier of the task. + + Returns: + EventQueue | None: A new reference to the event queue if it exists, otherwise None. + """ event_queue = await self.get(task_id) if event_queue: return event_queue.tap() return None async def close(self, task_id: str) -> None: + """Close the event queue associated with the given task ID. + + If the queue is a local queue, it will be removed from both the local store + and the global registry. If it's a proxy queue, only the proxy will be closed + and unsubscribed from remote events without removing from the global registry. + + Args: + task_id (str): The identifier of the task. + + Raises: + NoTaskQueue: If no queue exists for the given task ID. + """ async with self._lock: if task_id in self._local_queue: # close locally @@ -130,6 +175,17 @@ async def close(self, task_id: str) -> None: raise NoTaskQueue() async def create_or_tap(self, task_id: str) -> EventQueue: + """Create a new local queue or return a reference to an existing one. + + If the task already has a queue (either local or proxy), this method returns + a reference to that queue. Otherwise, a new local queue is created and registered. + + Args: + task_id (str): The identifier of the task. + + Returns: + EventQueue: An event queue associated with the given task ID. + """ async with self._lock: if await self._has_task_id(task_id): # if it's a local queue, tap directly From f75bf8b35255a0dac44cd35759a06b51c4b37e8e Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Thu, 3 Jul 2025 18:10:19 +0800 Subject: [PATCH 04/11] feat(server): implement RedisQueueManager for distributed event handling - Add RedisQueueManager class to manage event queues across distributed services - Implement local and proxy queue management using Redis pub/sub - Add logging for better visibility and debugging - Update tests to cover new functionality --- src/a2a/server/events/redis_queue_manager.py | 111 ++++++++++++++----- tests/server/test_integration.py | 92 +++++++++++++++ 2 files changed, 177 insertions(+), 26 deletions(-) diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py index 9dbb1826..fe544076 100644 --- a/src/a2a/server/events/redis_queue_manager.py +++ b/src/a2a/server/events/redis_queue_manager.py @@ -1,8 +1,11 @@ import asyncio +import logging from asyncio import Task from functools import partial +from typing import Any, Dict, Optional +from pydantic import ValidationError, TypeAdapter from redis.asyncio import Redis from a2a.server.events import ( @@ -14,6 +17,8 @@ TaskQueueExists, ) +logger = logging.getLogger(__name__) + class RedisQueueManager(QueueManager): """This implements the `QueueManager` interface using Redis for event. @@ -40,14 +45,7 @@ def __init__( self._relay_channel_name = relay_channel_key_prefix self._background_tasks: dict[str, Task] = {} self._task_registry_name = task_registry_key - - async def _listen_and_relay(self, task_id: str) -> None: - c = EventConsumer(self._local_queue[task_id]) - async for event in c.consume_all(): - await self._redis.publish( - self._task_channel_name(task_id), - event.model_dump_json(exclude_none=True), - ) + self._pubsub_listener_task: Optional[Task] = None def _task_channel_name(self, task_id: str) -> str: return self._relay_channel_name + task_id @@ -57,10 +55,23 @@ async def _has_task_id(self, task_id: str) -> bool: return ret == 1 async def _register_task_id(self, task_id: str) -> None: - await self._redis.sadd(self._task_registry_name, task_id) + task_started_event = asyncio.Event() + async def _wrapped_listen_and_relay() -> None: + task_started_event.set() + c = EventConsumer(self._local_queue[task_id].tap()) + async for event in c.consume_all(): + logger.debug(f'Publishing event for task {task_id} in QM {self}: {event}') + await self._redis.publish( + self._task_channel_name(task_id), + event.model_dump_json(exclude_none=True), + ) + self._background_tasks[task_id] = asyncio.create_task( - self._listen_and_relay(task_id) + _wrapped_listen_and_relay() ) + await task_started_event.wait() + await self._redis.sadd(self._task_registry_name, task_id) + logger.debug(f'Started to listen and relay events for task {task_id}') async def _remove_task_id(self, task_id: str) -> bool: if task_id in self._background_tasks: @@ -70,21 +81,51 @@ async def _remove_task_id(self, task_id: str) -> bool: return await self._redis.srem(self._task_registry_name, task_id) == 1 async def _subscribe_remote_task_events(self, task_id: str) -> None: - await self._pubsub.subscribe( - **{ - self._task_channel_name(task_id): partial( - self._relay_remote_events, task_id - ) - } - ) - - def _unsubscribe_remote_task_events(self, task_id: str) -> None: - self._pubsub.unsubscribe(self._task_channel_name(task_id)) - - def _relay_remote_events(self, task_id: str, event_json: str) -> None: - if task_id in self._proxy_queue: - event = Event.model_validate_json(event_json) - self._proxy_queue[task_id].enqueue_event(event) + channel_id = self._task_channel_name(task_id) + await self._pubsub.subscribe(**{channel_id: self._relay_remote_events}) + + # this is a global listener to handle incoming pubsub events + if not self._pubsub_listener_task: + logger.debug('Creating pubsub listener task.') + self._pubsub_listener_task = asyncio.create_task(self._consume_pubsub_messages()) + + logger.debug(f"Subscribed for remote events for task {task_id}") + + async def _consume_pubsub_messages(self): + async for _ in self._pubsub.listen(): + pass + + async def _relay_remote_events(self, subscription_event) -> None: + if 'channel' not in subscription_event or 'data' not in subscription_event: + logger.warning(f"channel or data is absent in subscription event: {subscription_event}") + return + + channel_id: str = subscription_event['channel'].decode('utf-8') + data_string: str = subscription_event['data'].decode('utf-8') + task_id = channel_id.split('.')[-1] + if task_id not in self._proxy_queue: + logger.warning(f"task_id {task_id} not found in proxy queue") + return + + try: + logger.debug(f"Received event for task_id {task_id} in QM {self}: {data_string}") + event = TypeAdapter(Event).validate_json(data_string) + except Exception as e: + logger.warning(f"Failed to parse event from subscription event: {subscription_event}: {e}") + return + + logger.debug(f"Enqueuing event for task_id {task_id} in QM {self}: {event}") + await self._proxy_queue[task_id].enqueue_event(event) + + + async def _unsubscribe_remote_task_events(self, task_id: str) -> None: + # unsubscribe channel for given task_id + await self._pubsub.unsubscribe(self._task_channel_name(task_id)) + # release global listener if not channel is subscribed + async with self._lock: + if not self._pubsub.subscribed and self._pubsub_listener_task: + self._pubsub_listener_task.cancel() + self._pubsub_listener_task = None async def add(self, task_id: str, queue: EventQueue) -> None: """Add a new local event queue for the specified task. @@ -96,11 +137,13 @@ async def add(self, task_id: str, queue: EventQueue) -> None: Raises: TaskQueueExists: If a queue for the task already exists. """ + logger.debug(f"add {task_id}") async with self._lock: if await self._has_task_id(task_id): raise TaskQueueExists() self._local_queue[task_id] = queue await self._register_task_id(task_id) + logger.debug(f"Local queue is created for task {task_id}") async def get(self, task_id: str) -> EventQueue | None: """Get the event queue associated with the given task ID. @@ -115,17 +158,22 @@ async def get(self, task_id: str) -> EventQueue | None: Returns: EventQueue | None: The event queue if found, otherwise None. """ + logger.debug(f"get {task_id}") async with self._lock: # lookup locally if task_id in self._local_queue: + logger.debug(f"Got local queue for task_id {task_id}") return self._local_queue[task_id] # lookup globally if await self._has_task_id(task_id): if task_id not in self._proxy_queue: + logger.debug(f"Creating proxy queue for {task_id}") queue = EventQueue() self._proxy_queue[task_id] = queue await self._subscribe_remote_task_events(task_id) + logger.debug(f"Got proxy queue for task_id {task_id}") return self._proxy_queue[task_id] + logger.warning(f"Attempted to get non-existing queue for task {task_id}") return None async def tap(self, task_id: str) -> EventQueue | None: @@ -137,8 +185,10 @@ async def tap(self, task_id: str) -> EventQueue | None: Returns: EventQueue | None: A new reference to the event queue if it exists, otherwise None. """ + logger.debug(f"tap {task_id}") event_queue = await self.get(task_id) if event_queue: + logger.debug(f'Tapping event queue for task: {task_id}') return event_queue.tap() return None @@ -155,6 +205,7 @@ async def close(self, task_id: str) -> None: Raises: NoTaskQueue: If no queue exists for the given task ID. """ + logger.debug(f"close {task_id}") async with self._lock: if task_id in self._local_queue: # close locally @@ -162,6 +213,7 @@ async def close(self, task_id: str) -> None: await queue.close() # remove from global registry if a local queue is closed await self._remove_task_id(task_id) + logger.debug(f"Closing local queue for task {task_id}") return if task_id in self._proxy_queue: @@ -169,9 +221,11 @@ async def close(self, task_id: str) -> None: queue = self._proxy_queue.pop(task_id) await queue.close() # unsubscribe from remote, but don't remove from global registry - self._unsubscribe_remote_task_events(task_id) + await self._unsubscribe_remote_task_events(task_id) + logger.debug(f"Closing proxy queue for task {task_id}") return + logger.warning(f"Attempted to close non-existing queue found for task {task_id}") raise NoTaskQueue() async def create_or_tap(self, task_id: str) -> EventQueue: @@ -186,23 +240,28 @@ async def create_or_tap(self, task_id: str) -> EventQueue: Returns: EventQueue: An event queue associated with the given task ID. """ + logger.debug(f"create_or_tap {task_id}") async with self._lock: if await self._has_task_id(task_id): # if it's a local queue, tap directly if task_id in self._local_queue: + logger.debug(f"Tapping a local queue for task {task_id}") return self._local_queue[task_id].tap() # if it's a proxy queue, tap the proxy if task_id in self._proxy_queue: + logger.debug(f"Tapping a proxy queue for task {task_id}") return self._proxy_queue[task_id].tap() # if the proxy is not created, create the proxy and return queue = EventQueue() self._proxy_queue[task_id] = queue await self._subscribe_remote_task_events(task_id) + logger.debug(f"Creating a proxy queue for task {task_id}") return self._proxy_queue[task_id] # the task doesn't exist before, create a local queue queue = EventQueue() self._local_queue[task_id] = queue await self._register_task_id(task_id) + logger.debug(f"Creating a local queue for task {task_id}") return queue diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 5581711e..f736c1e6 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -1,9 +1,13 @@ import asyncio +import logging +import os from typing import Any from unittest import mock import pytest +import pytest_asyncio +from redis.asyncio import Redis from starlette.authentication import ( AuthCredentials, @@ -22,6 +26,8 @@ A2AFastAPIApplication, A2AStarletteApplication, ) +from a2a.server.events import EventQueue +from a2a.server.events.redis_queue_manager import RedisQueueManager from a2a.types import ( AgentCapabilities, AgentCard, @@ -44,6 +50,7 @@ TextPart, UnsupportedOperationError, ) +from a2a.utils import new_agent_text_message from a2a.utils.errors import MethodNotImplementedError @@ -884,3 +891,88 @@ def test_non_dict_json(client: TestClient): data = response.json() assert 'error' in data assert data['error']['code'] == InvalidRequestError().code + + +# === RedisQueueManager === +@pytest.mark.asyncio +@pytest_asyncio.fixture(scope="function") +async def asyncio_redis(): + redis_server_url = os.getenv("REDIS_SERVER_URL") + if redis_server_url: + redis = Redis.from_url(redis_server_url) + else: + # use fake redis instead if no redis server url is given + from fakeredis import FakeAsyncRedis + redis = FakeAsyncRedis() + logging.info("flush redis for next test case") + await redis.flushall(asynchronous=False) + yield redis + await redis.close() + + +@pytest.mark.asyncio +async def test_redis_queue_local_only_queue(asyncio_redis): + queue_manager = RedisQueueManager(asyncio_redis) + + # setup local queues + q1 = EventQueue() + await queue_manager.add('task_1', q1) + q2 = EventQueue() + await queue_manager.add('task_2', q2) + q3 = await queue_manager.tap("task_1") + assert await queue_manager.get('task_1') == q1 + assert await queue_manager.get('task_2') == q2 + + # send and receive locally + msg1 = new_agent_text_message('hello') + await q1.enqueue_event(msg1) + assert await q1.dequeue_event(no_wait=True) == msg1 + assert await q3.dequeue_event(no_wait=True) == msg1 + # raise error if queue is empty + with pytest.raises(asyncio.QueueEmpty): + await q1.dequeue_event(no_wait=True) + # q2 is empty + with pytest.raises(asyncio.QueueEmpty): + await q2.dequeue_event(no_wait=True) + + +@pytest.mark.asyncio +async def test_redis_queue_mixed_queue(asyncio_redis): + qm1 = RedisQueueManager(asyncio_redis) + qm2 = RedisQueueManager(asyncio_redis) + qm3 = RedisQueueManager(asyncio_redis) + + # create local queue in qm1 + q1 = EventQueue() + await qm1.add('task_1', q1) + assert 'task_1' in qm1._local_queue + assert await qm1.get('task_1') == q1 + + # create proxy queue in qm2 through `get` method + q1_1 = await qm2.get('task_1') + assert 'task_1' in qm2._proxy_queue and 'task_1' not in qm2._local_queue + assert q1_1 != q1 + + # create proxy queue in qm3 through `tap` method + q1_2 = await qm3.tap("task_1") + assert 'task_1' in qm3._proxy_queue and 'task_1' not in qm3._local_queue + + # enqueue and dequeue in q1 + msg1 = new_agent_text_message('hello') + await q1.enqueue_event(msg1) + assert await q1.dequeue_event() == msg1 + with pytest.raises(asyncio.QueueEmpty): + await q1.dequeue_event(no_wait=True) + + # dequeue in q1_1 + msg1_1: Message = await q1_1.dequeue_event() + assert msg1_1.parts[0].root.text == msg1.parts[0].root.text + with pytest.raises(asyncio.QueueEmpty): + await q1_1.dequeue_event(no_wait=True) + + # dequeue in q1_2 + msg1_2: Message = await q1_2.dequeue_event() + assert msg1_2.parts[0].root.text == msg1.parts[0].root.text + with pytest.raises(asyncio.QueueEmpty): + await q1_2.dequeue_event(no_wait=True) + From 84c6bef04071418ecee5785f2ff0501f33762f7d Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Thu, 3 Jul 2025 22:26:26 +0800 Subject: [PATCH 05/11] feat(server): implement TTL for task IDs in Redis - Replace Redis set with sorted set to store task IDs with timestamp - Add TTL update mechanism for active task IDs - Implement periodic cleanup of expired task IDs - Update task registration and removal to use new sorted set structure --- .github/actions/spelling/allow.txt | 6 ++- src/a2a/server/events/redis_queue_manager.py | 43 +++++++++++++++----- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 532bdeb8..d786ce34 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -74,4 +74,8 @@ typeerror vulnz sadd sismember -srem \ No newline at end of file +srem +zadd +zscore +zrem +zremrangebyscore \ No newline at end of file diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py index fe544076..d3f57022 100644 --- a/src/a2a/server/events/redis_queue_manager.py +++ b/src/a2a/server/events/redis_queue_manager.py @@ -1,11 +1,12 @@ import asyncio import logging +import random +import time from asyncio import Task -from functools import partial -from typing import Any, Dict, Optional +from typing import Any -from pydantic import ValidationError, TypeAdapter +from pydantic import TypeAdapter from redis.asyncio import Redis from a2a.server.events import ( @@ -17,9 +18,12 @@ TaskQueueExists, ) + logger = logging.getLogger(__name__) +CLEAN_EXPIRED_PROBABILITY = 0.5 + class RedisQueueManager(QueueManager): """This implements the `QueueManager` interface using Redis for event. @@ -36,6 +40,7 @@ def __init__( redis_client: Redis, relay_channel_key_prefix: str = 'a2a.event.relay.', task_registry_key: str = 'a2a.event.registry', + task_id_ttl_in_second: int = 60 * 60 * 24, ): self._redis = redis_client self._local_queue: dict[str, EventQueue] = {} @@ -45,16 +50,18 @@ def __init__( self._relay_channel_name = relay_channel_key_prefix self._background_tasks: dict[str, Task] = {} self._task_registry_name = task_registry_key - self._pubsub_listener_task: Optional[Task] = None + self._pubsub_listener_task: Task | None = None + self._task_id_ttl_in_second = task_id_ttl_in_second def _task_channel_name(self, task_id: str) -> str: return self._relay_channel_name + task_id async def _has_task_id(self, task_id: str) -> bool: - ret = await self._redis.sismember(self._task_registry_name, task_id) - return ret == 1 + ret = await self._redis.zscore(self._task_registry_name, task_id) + return ret is not None async def _register_task_id(self, task_id: str) -> None: + assert await self._redis.zadd(self._task_registry_name, {task_id: time.time()}, nx=True), 'task_id should not exist in global registry: ' + task_id task_started_event = asyncio.Event() async def _wrapped_listen_and_relay() -> None: task_started_event.set() @@ -65,12 +72,16 @@ async def _wrapped_listen_and_relay() -> None: self._task_channel_name(task_id), event.model_dump_json(exclude_none=True), ) + # update TTL for task_id + await self._update_task_id_ttl(task_id) + # clean expired task_ids with certain possibility + if random.random() < CLEAN_EXPIRED_PROBABILITY: + await self._clean_expired_task_ids() self._background_tasks[task_id] = asyncio.create_task( _wrapped_listen_and_relay() ) await task_started_event.wait() - await self._redis.sadd(self._task_registry_name, task_id) logger.debug(f'Started to listen and relay events for task {task_id}') async def _remove_task_id(self, task_id: str) -> bool: @@ -78,7 +89,19 @@ async def _remove_task_id(self, task_id: str) -> bool: self._background_tasks[task_id].cancel( 'task_id is closed: ' + task_id ) - return await self._redis.srem(self._task_registry_name, task_id) == 1 + return await self._redis.zrem(self._task_registry_name, task_id) == 1 + + async def _update_task_id_ttl(self, task_id: str) -> bool: + ret = await self._redis.zadd( + self._task_registry_name, + {task_id: time.time()}, + xx=True + ) + return ret is not None + + async def _clean_expired_task_ids(self) -> None: + count = await self._redis.zremrangebyscore(self._task_registry_name, 0, time.time() - self._task_id_ttl_in_second) + logger.debug(f'Removed {count} expired task ids') async def _subscribe_remote_task_events(self, task_id: str) -> None: channel_id = self._task_channel_name(task_id) @@ -91,11 +114,11 @@ async def _subscribe_remote_task_events(self, task_id: str) -> None: logger.debug(f"Subscribed for remote events for task {task_id}") - async def _consume_pubsub_messages(self): + async def _consume_pubsub_messages(self) -> None: async for _ in self._pubsub.listen(): pass - async def _relay_remote_events(self, subscription_event) -> None: + async def _relay_remote_events(self, subscription_event: dict[str, Any]) -> None: if 'channel' not in subscription_event or 'data' not in subscription_event: logger.warning(f"channel or data is absent in subscription event: {subscription_event}") return From 77863b6dcf995c202b2bd83e1c33f0c51f215377 Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Thu, 3 Jul 2025 22:34:40 +0800 Subject: [PATCH 06/11] refactor(server/events): improve event parsing and update test assertions - Add type annotation for parsed event in RedisQueueManager - Update test assertions to check tap call count in queue manager tests --- src/a2a/server/events/redis_queue_manager.py | 2 +- tests/server/events/test_redis_queue_manager.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py index d3f57022..775f5223 100644 --- a/src/a2a/server/events/redis_queue_manager.py +++ b/src/a2a/server/events/redis_queue_manager.py @@ -132,7 +132,7 @@ async def _relay_remote_events(self, subscription_event: dict[str, Any]) -> None try: logger.debug(f"Received event for task_id {task_id} in QM {self}: {data_string}") - event = TypeAdapter(Event).validate_json(data_string) + event: Event = TypeAdapter(Event).validate_json(data_string) except Exception as e: logger.warning(f"Failed to parse event from subscription event: {subscription_event}: {e}") return diff --git a/tests/server/events/test_redis_queue_manager.py b/tests/server/events/test_redis_queue_manager.py index 3d9b4a4d..2826ab75 100644 --- a/tests/server/events/test_redis_queue_manager.py +++ b/tests/server/events/test_redis_queue_manager.py @@ -64,10 +64,11 @@ async def test_get_nonexistent_queue(self, queue_manager): async def test_tap_existing_queue(self, queue_manager, event_queue): task_id = 'test_task_id' await queue_manager.add(task_id, event_queue) + event_queue.tap.assert_called_once() result = await queue_manager.tap(task_id) assert result == event_queue - event_queue.tap.assert_called_once() + assert event_queue.tap.call_count == 2 @pytest.mark.asyncio async def test_tap_nonexistent_queue(self, queue_manager): @@ -88,11 +89,12 @@ async def test_create_or_tap_existing_queue( ): task_id = 'test_task_id' await queue_manager.add(task_id, event_queue) + event_queue.tap.assert_called_once() result = await queue_manager.create_or_tap(task_id) assert result == event_queue - event_queue.tap.assert_called_once() + assert event_queue.tap.call_count == 2 @pytest.mark.asyncio async def test_concurrency(self, queue_manager): From ad5ca3c69c132b5170f341bbfdf1991e055e55cd Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Fri, 4 Jul 2025 14:01:34 +0800 Subject: [PATCH 07/11] fix(server/events): improve queue management and add test cases - Enhance EventQueue to handle shutdown process more efficiently- Optimize RedisQueueManager for better task event handling - Add comprehensive test cases for queue management scenarios --- src/a2a/server/events/event_queue.py | 4 ++ src/a2a/server/events/redis_queue_manager.py | 13 +++---- tests/server/test_integration.py | 39 +++++++++++++++++++- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 76a0b512..036cd984 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -153,6 +153,10 @@ async def close(self) -> None: await child.close() # Otherwise, join the queue else: + # drain the queue or self.queue.join() would wait forever. This makes this piece of code equivalent to self.queue.shutdown() in python 3.13+ + while not self.queue.empty(): + await self.queue.get() + self.queue.task_done() tasks = [asyncio.create_task(self.queue.join())] for child in self._children: tasks.append(asyncio.create_task(child.close())) diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py index 775f5223..0a02be54 100644 --- a/src/a2a/server/events/redis_queue_manager.py +++ b/src/a2a/server/events/redis_queue_manager.py @@ -140,15 +140,14 @@ async def _relay_remote_events(self, subscription_event: dict[str, Any]) -> None logger.debug(f"Enqueuing event for task_id {task_id} in QM {self}: {event}") await self._proxy_queue[task_id].enqueue_event(event) - async def _unsubscribe_remote_task_events(self, task_id: str) -> None: # unsubscribe channel for given task_id await self._pubsub.unsubscribe(self._task_channel_name(task_id)) # release global listener if not channel is subscribed - async with self._lock: - if not self._pubsub.subscribed and self._pubsub_listener_task: - self._pubsub_listener_task.cancel() - self._pubsub_listener_task = None + if not self._pubsub.subscribed and self._pubsub_listener_task: + self._pubsub_listener_task.cancel() + self._pubsub_listener_task = None + async def add(self, task_id: str, queue: EventQueue) -> None: """Add a new local event queue for the specified task. @@ -231,11 +230,11 @@ async def close(self, task_id: str) -> None: logger.debug(f"close {task_id}") async with self._lock: if task_id in self._local_queue: + # remove from global registry if a local queue is closed + await self._remove_task_id(task_id) # close locally queue = self._local_queue.pop(task_id) await queue.close() - # remove from global registry if a local queue is closed - await self._remove_task_id(task_id) logger.debug(f"Closing local queue for task {task_id}") return diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index f736c1e6..50da1551 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -26,7 +26,7 @@ A2AFastAPIApplication, A2AStarletteApplication, ) -from a2a.server.events import EventQueue +from a2a.server.events import EventQueue, NoTaskQueue from a2a.server.events.redis_queue_manager import RedisQueueManager from a2a.types import ( AgentCapabilities, @@ -941,6 +941,8 @@ async def test_redis_queue_mixed_queue(asyncio_redis): qm1 = RedisQueueManager(asyncio_redis) qm2 = RedisQueueManager(asyncio_redis) qm3 = RedisQueueManager(asyncio_redis) + qm4 = RedisQueueManager(asyncio_redis) + qm5 = RedisQueueManager(asyncio_redis) # create local queue in qm1 q1 = EventQueue() @@ -954,25 +956,58 @@ async def test_redis_queue_mixed_queue(asyncio_redis): assert q1_1 != q1 # create proxy queue in qm3 through `tap` method - q1_2 = await qm3.tap("task_1") + q1_2 = await qm3.tap('task_1') assert 'task_1' in qm3._proxy_queue and 'task_1' not in qm3._local_queue + # create proxy queue in qm4 through `create_or_tap` method + q1_3 = await qm4.create_or_tap('task_1') + assert 'task_1' in qm4._proxy_queue and 'task_1' not in qm4._local_queue + + # create local queue in qm5 through `create_or_tap` method + q2 = await qm5.create_or_tap('task_2') + assert 'task_2' in qm5._local_queue and 'task_2' not in qm5._proxy_queue + # enqueue and dequeue in q1 msg1 = new_agent_text_message('hello') await q1.enqueue_event(msg1) assert await q1.dequeue_event() == msg1 + q1.task_done() with pytest.raises(asyncio.QueueEmpty): await q1.dequeue_event(no_wait=True) # dequeue in q1_1 msg1_1: Message = await q1_1.dequeue_event() assert msg1_1.parts[0].root.text == msg1.parts[0].root.text + q1_1.task_done() with pytest.raises(asyncio.QueueEmpty): await q1_1.dequeue_event(no_wait=True) # dequeue in q1_2 msg1_2: Message = await q1_2.dequeue_event() assert msg1_2.parts[0].root.text == msg1.parts[0].root.text + q1_2.task_done() with pytest.raises(asyncio.QueueEmpty): await q1_2.dequeue_event(no_wait=True) + # dequeue in q1_3 + msg1_3: Message = await q1_3.dequeue_event() + assert msg1_3.parts[0].root.text == msg1.parts[0].root.text + q1_3.task_done() + with pytest.raises(asyncio.QueueEmpty): + await q1_3.dequeue_event(no_wait=True) + + # enqueue and dequeue in q2 + msg2 = new_agent_text_message('world') + await q2.enqueue_event(msg2) + assert await q2.dequeue_event() == msg2 + q2.task_done() + + # close queues + await qm1.close('task_1') + await qm2.close('task_1') + await qm3.close('task_1') + await qm4.close('task_1') + await qm5.close('task_2') + with pytest.raises(NoTaskQueue): + await qm5.close('task_10000') + From 1fe3ccc3f306d681960b601320fa2015a5ef5002 Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Fri, 4 Jul 2025 17:10:11 +0800 Subject: [PATCH 08/11] feat(server): implement task queue with TTL and node affiliation - Add TTL (Time To Live) for task IDs in the global registry - Implement node affiliation for task IDs to prevent message broadcasting to outdated task queues - Update task registration and message relaying logic to support the new features - Add test cases for task ID expiration and node affiliation --- .github/actions/spelling/allow.txt | 9 +- src/a2a/server/events/event_queue.py | 4 - src/a2a/server/events/redis_queue_manager.py | 160 +++++++++++-------- tests/server/test_integration.py | 56 ++++++- 4 files changed, 149 insertions(+), 80 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index d786ce34..b79ed49a 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -65,17 +65,16 @@ pypistats pyversions respx resub +sadd +sismember socio +srem sse tagwords taskupdate testuuid typeerror vulnz -sadd -sismember -srem zadd -zscore zrem -zremrangebyscore \ No newline at end of file +zremrangebyscorezscore diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 036cd984..76a0b512 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -153,10 +153,6 @@ async def close(self) -> None: await child.close() # Otherwise, join the queue else: - # drain the queue or self.queue.join() would wait forever. This makes this piece of code equivalent to self.queue.shutdown() in python 3.13+ - while not self.queue.empty(): - await self.queue.get() - self.queue.task_done() tasks = [asyncio.create_task(self.queue.join())] for child in self._children: tasks.append(asyncio.create_task(child.close())) diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py index 0a02be54..4608e28b 100644 --- a/src/a2a/server/events/redis_queue_manager.py +++ b/src/a2a/server/events/redis_queue_manager.py @@ -1,7 +1,6 @@ import asyncio import logging -import random -import time +import uuid from asyncio import Task from typing import Any @@ -22,8 +21,6 @@ logger = logging.getLogger(__name__) -CLEAN_EXPIRED_PROBABILITY = 0.5 - class RedisQueueManager(QueueManager): """This implements the `QueueManager` interface using Redis for event. @@ -33,6 +30,8 @@ class RedisQueueManager(QueueManager): redis_client(Redis): asyncio redis connection. relay_channel_key_prefix(str): prefix for pubsub channel key generation. task_registry_key(str): key for set data where stores active `task_id`s. + task_id_ttl_in_second: TTL for task id in global registry + node_id: A unique id to be associated with task id in global registry. If node id is not matched, events won't be populated to queues in other `RedisQueueManager`s. """ def __init__( @@ -41,6 +40,7 @@ def __init__( relay_channel_key_prefix: str = 'a2a.event.relay.', task_registry_key: str = 'a2a.event.registry', task_id_ttl_in_second: int = 60 * 60 * 24, + node_id: str = str(uuid.uuid4()), ): self._redis = redis_client self._local_queue: dict[str, EventQueue] = {} @@ -52,31 +52,60 @@ def __init__( self._task_registry_name = task_registry_key self._pubsub_listener_task: Task | None = None self._task_id_ttl_in_second = task_id_ttl_in_second + self._node_id = node_id def _task_channel_name(self, task_id: str) -> str: return self._relay_channel_name + task_id async def _has_task_id(self, task_id: str) -> bool: - ret = await self._redis.zscore(self._task_registry_name, task_id) + ret = await self._redis.hget(self._task_registry_name, task_id) return ret is not None async def _register_task_id(self, task_id: str) -> None: - assert await self._redis.zadd(self._task_registry_name, {task_id: time.time()}, nx=True), 'task_id should not exist in global registry: ' + task_id + await self._redis.hsetex( + name=self._task_registry_name, + key=task_id, + value=self._node_id, + ex=self._task_id_ttl_in_second, + ) + logger.debug( + f'Registered task_id {task_id} to node {self._node_id} in registry.' + ) task_started_event = asyncio.Event() + async def _wrapped_listen_and_relay() -> None: task_started_event.set() c = EventConsumer(self._local_queue[task_id].tap()) async for event in c.consume_all(): - logger.debug(f'Publishing event for task {task_id} in QM {self}: {event}') - await self._redis.publish( - self._task_channel_name(task_id), - event.model_dump_json(exclude_none=True), + logger.debug( + f'Publishing event for task {task_id} in QM {self}: {event}' ) - # update TTL for task_id - await self._update_task_id_ttl(task_id) - # clean expired task_ids with certain possibility - if random.random() < CLEAN_EXPIRED_PROBABILITY: - await self._clean_expired_task_ids() + expected_node_id = await self._redis.hget( + self._task_registry_name, task_id + ) + expected_node_id = ( + expected_node_id.decode('utf-8') + if hasattr(expected_node_id, 'decode') + else expected_node_id + ) + if expected_node_id == self._node_id: + # publish message + await self._redis.publish( + self._task_channel_name(task_id), + event.model_dump_json(exclude_none=True), + ) + # update TTL for task_id + await self._redis.hsetex( + name=self._task_registry_name, + key=task_id, + value=self._node_id, + ex=self._task_id_ttl_in_second, + ) + else: + logger.error( + f'Task {task_id} is not registered on this node. Expected node id: {expected_node_id}' + ) + break self._background_tasks[task_id] = asyncio.create_task( _wrapped_listen_and_relay() @@ -89,55 +118,56 @@ async def _remove_task_id(self, task_id: str) -> bool: self._background_tasks[task_id].cancel( 'task_id is closed: ' + task_id ) - return await self._redis.zrem(self._task_registry_name, task_id) == 1 - - async def _update_task_id_ttl(self, task_id: str) -> bool: - ret = await self._redis.zadd( - self._task_registry_name, - {task_id: time.time()}, - xx=True - ) - return ret is not None - - async def _clean_expired_task_ids(self) -> None: - count = await self._redis.zremrangebyscore(self._task_registry_name, 0, time.time() - self._task_id_ttl_in_second) - logger.debug(f'Removed {count} expired task ids') + return await self._redis.hdel(self._task_registry_name, task_id) == 1 async def _subscribe_remote_task_events(self, task_id: str) -> None: channel_id = self._task_channel_name(task_id) await self._pubsub.subscribe(**{channel_id: self._relay_remote_events}) - # this is a global listener to handle incoming pubsub events if not self._pubsub_listener_task: logger.debug('Creating pubsub listener task.') - self._pubsub_listener_task = asyncio.create_task(self._consume_pubsub_messages()) - - logger.debug(f"Subscribed for remote events for task {task_id}") + self._pubsub_listener_task = asyncio.create_task( + self._consume_pubsub_messages() + ) + logger.debug(f'Subscribed for remote events for task {task_id}') async def _consume_pubsub_messages(self) -> None: async for _ in self._pubsub.listen(): pass - async def _relay_remote_events(self, subscription_event: dict[str, Any]) -> None: - if 'channel' not in subscription_event or 'data' not in subscription_event: - logger.warning(f"channel or data is absent in subscription event: {subscription_event}") + async def _relay_remote_events( + self, subscription_event: dict[str, Any] + ) -> None: + if ( + 'channel' not in subscription_event + or 'data' not in subscription_event + ): + logger.warning( + f'channel or data is absent in subscription event: {subscription_event}' + ) return channel_id: str = subscription_event['channel'].decode('utf-8') data_string: str = subscription_event['data'].decode('utf-8') task_id = channel_id.split('.')[-1] if task_id not in self._proxy_queue: - logger.warning(f"task_id {task_id} not found in proxy queue") + logger.warning(f'task_id {task_id} not found in proxy queue') return try: - logger.debug(f"Received event for task_id {task_id} in QM {self}: {data_string}") + logger.debug( + f'Received event for task_id {task_id} in QM {self}: {data_string}' + ) event: Event = TypeAdapter(Event).validate_json(data_string) except Exception as e: - logger.warning(f"Failed to parse event from subscription event: {subscription_event}: {e}") + logger.warning( + f'Failed to parse event from subscription event: {subscription_event}: {e}' + ) return - logger.debug(f"Enqueuing event for task_id {task_id} in QM {self}: {event}") + logger.debug( + f'Enqueuing event for task_id {task_id} in QM {self}: {event}' + ) await self._proxy_queue[task_id].enqueue_event(event) async def _unsubscribe_remote_task_events(self, task_id: str) -> None: @@ -148,7 +178,6 @@ async def _unsubscribe_remote_task_events(self, task_id: str) -> None: self._pubsub_listener_task.cancel() self._pubsub_listener_task = None - async def add(self, task_id: str, queue: EventQueue) -> None: """Add a new local event queue for the specified task. @@ -159,13 +188,13 @@ async def add(self, task_id: str, queue: EventQueue) -> None: Raises: TaskQueueExists: If a queue for the task already exists. """ - logger.debug(f"add {task_id}") + logger.debug(f'add {task_id}') async with self._lock: if await self._has_task_id(task_id): raise TaskQueueExists() self._local_queue[task_id] = queue await self._register_task_id(task_id) - logger.debug(f"Local queue is created for task {task_id}") + logger.debug(f'Local queue is created for task {task_id}') async def get(self, task_id: str) -> EventQueue | None: """Get the event queue associated with the given task ID. @@ -180,22 +209,24 @@ async def get(self, task_id: str) -> EventQueue | None: Returns: EventQueue | None: The event queue if found, otherwise None. """ - logger.debug(f"get {task_id}") + logger.debug(f'get {task_id}') async with self._lock: # lookup locally if task_id in self._local_queue: - logger.debug(f"Got local queue for task_id {task_id}") + logger.debug(f'Got local queue for task_id {task_id}') return self._local_queue[task_id] # lookup globally if await self._has_task_id(task_id): if task_id not in self._proxy_queue: - logger.debug(f"Creating proxy queue for {task_id}") + logger.debug(f'Creating proxy queue for {task_id}') queue = EventQueue() self._proxy_queue[task_id] = queue await self._subscribe_remote_task_events(task_id) - logger.debug(f"Got proxy queue for task_id {task_id}") + logger.debug(f'Got proxy queue for task_id {task_id}') return self._proxy_queue[task_id] - logger.warning(f"Attempted to get non-existing queue for task {task_id}") + logger.warning( + f'Attempted to get non-existing queue for task {task_id}' + ) return None async def tap(self, task_id: str) -> EventQueue | None: @@ -207,7 +238,7 @@ async def tap(self, task_id: str) -> EventQueue | None: Returns: EventQueue | None: A new reference to the event queue if it exists, otherwise None. """ - logger.debug(f"tap {task_id}") + logger.debug(f'tap {task_id}') event_queue = await self.get(task_id) if event_queue: logger.debug(f'Tapping event queue for task: {task_id}') @@ -227,7 +258,7 @@ async def close(self, task_id: str) -> None: Raises: NoTaskQueue: If no queue exists for the given task ID. """ - logger.debug(f"close {task_id}") + logger.debug(f'close {task_id}') async with self._lock: if task_id in self._local_queue: # remove from global registry if a local queue is closed @@ -235,7 +266,7 @@ async def close(self, task_id: str) -> None: # close locally queue = self._local_queue.pop(task_id) await queue.close() - logger.debug(f"Closing local queue for task {task_id}") + logger.debug(f'Closing local queue for task {task_id}') return if task_id in self._proxy_queue: @@ -244,10 +275,12 @@ async def close(self, task_id: str) -> None: await queue.close() # unsubscribe from remote, but don't remove from global registry await self._unsubscribe_remote_task_events(task_id) - logger.debug(f"Closing proxy queue for task {task_id}") + logger.debug(f'Closing proxy queue for task {task_id}') return - logger.warning(f"Attempted to close non-existing queue found for task {task_id}") + logger.warning( + f'Attempted to close non-existing queue found for task {task_id}' + ) raise NoTaskQueue() async def create_or_tap(self, task_id: str) -> EventQueue: @@ -262,28 +295,25 @@ async def create_or_tap(self, task_id: str) -> EventQueue: Returns: EventQueue: An event queue associated with the given task ID. """ - logger.debug(f"create_or_tap {task_id}") + logger.debug(f'create_or_tap {task_id}') async with self._lock: if await self._has_task_id(task_id): # if it's a local queue, tap directly if task_id in self._local_queue: - logger.debug(f"Tapping a local queue for task {task_id}") + logger.debug(f'Tapping a local queue for task {task_id}') return self._local_queue[task_id].tap() # if it's a proxy queue, tap the proxy - if task_id in self._proxy_queue: - logger.debug(f"Tapping a proxy queue for task {task_id}") - return self._proxy_queue[task_id].tap() - - # if the proxy is not created, create the proxy and return - queue = EventQueue() - self._proxy_queue[task_id] = queue - await self._subscribe_remote_task_events(task_id) - logger.debug(f"Creating a proxy queue for task {task_id}") - return self._proxy_queue[task_id] + if task_id not in self._proxy_queue: + # if the proxy is not created, create the proxy + queue = EventQueue() + self._proxy_queue[task_id] = queue + await self._subscribe_remote_task_events(task_id) + logger.debug(f'Tapping a proxy queue for task {task_id}') + return self._proxy_queue[task_id].tap() # the task doesn't exist before, create a local queue queue = EventQueue() self._local_queue[task_id] = queue await self._register_task_id(task_id) - logger.debug(f"Creating a local queue for task {task_id}") + logger.debug(f'Creating a local queue for task {task_id}') return queue diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 50da1551..e3452cb7 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -2,13 +2,14 @@ import logging import os +from asyncio import QueueEmpty from typing import Any from unittest import mock import pytest import pytest_asyncio -from redis.asyncio import Redis +from redis.asyncio import Redis from starlette.authentication import ( AuthCredentials, AuthenticationBackend, @@ -26,7 +27,7 @@ A2AFastAPIApplication, A2AStarletteApplication, ) -from a2a.server.events import EventQueue, NoTaskQueue +from a2a.server.events import EventQueue, NoTaskQueue, TaskQueueExists from a2a.server.events.redis_queue_manager import RedisQueueManager from a2a.types import ( AgentCapabilities, @@ -895,16 +896,17 @@ def test_non_dict_json(client: TestClient): # === RedisQueueManager === @pytest.mark.asyncio -@pytest_asyncio.fixture(scope="function") +@pytest_asyncio.fixture(scope='function') async def asyncio_redis(): - redis_server_url = os.getenv("REDIS_SERVER_URL") + redis_server_url = os.getenv('REDIS_SERVER_URL') if redis_server_url: redis = Redis.from_url(redis_server_url) else: # use fake redis instead if no redis server url is given from fakeredis import FakeAsyncRedis + redis = FakeAsyncRedis() - logging.info("flush redis for next test case") + logging.info('flush redis for next test case') await redis.flushall(asynchronous=False) yield redis await redis.close() @@ -919,7 +921,7 @@ async def test_redis_queue_local_only_queue(asyncio_redis): await queue_manager.add('task_1', q1) q2 = EventQueue() await queue_manager.add('task_2', q2) - q3 = await queue_manager.tap("task_1") + q3 = await queue_manager.tap('task_1') assert await queue_manager.get('task_1') == q1 assert await queue_manager.get('task_2') == q2 @@ -988,6 +990,10 @@ async def test_redis_queue_mixed_queue(asyncio_redis): q1_2.task_done() with pytest.raises(asyncio.QueueEmpty): await q1_2.dequeue_event(no_wait=True) + # get proxy queue task_1 and dequeue or close method will block forever. q1_2 is a tapped queue. + _ = await qm3.get('task_1') + await _.dequeue_event() + _.task_done() # dequeue in q1_3 msg1_3: Message = await q1_3.dequeue_event() @@ -995,6 +1001,10 @@ async def test_redis_queue_mixed_queue(asyncio_redis): q1_3.task_done() with pytest.raises(asyncio.QueueEmpty): await q1_3.dequeue_event(no_wait=True) + # get proxy queue task_1 and dequeue or close method will block forever. q1_3 is a tapped queue. + _ = await qm4.get('task_1') + await _.dequeue_event() + _.task_done() # enqueue and dequeue in q2 msg2 = new_agent_text_message('world') @@ -1011,3 +1021,37 @@ async def test_redis_queue_mixed_queue(asyncio_redis): with pytest.raises(NoTaskQueue): await qm5.close('task_10000') + +@pytest.mark.asyncio +async def test_redis_queue_task_id_expiration(asyncio_redis): + qm1 = RedisQueueManager( + asyncio_redis, probability_to_clean_expired=1, task_id_ttl_in_second=1 + ) + qm2 = RedisQueueManager( + asyncio_redis, probability_to_clean_expired=1, task_id_ttl_in_second=1 + ) + q1 = EventQueue() + await qm1.add('task_1', q1) + # add task_1 again to trigger exception + with pytest.raises(TaskQueueExists): + await qm2.add('task_1', q1) + q2 = await qm2.get('task_1') + assert q2 + # enqueue message to q1, and dequeue in q2 + msg1 = new_agent_text_message('hello') + await q1.enqueue_event(msg1) + assert await q2.dequeue_event() + q2.task_done() + + # sleep for 1 second to expire + await asyncio.sleep(1) + assert await qm1._has_task_id('task_1') is False + assert await qm2._has_task_id('task_1') is False + + # enqueue a message in order to update TTL + msg2 = new_agent_text_message('world') + await q1.enqueue_event(msg2) + assert await qm1._has_task_id('task_1') is False + # after qm1's ownership for task_1 is expired, enqueue in q1 won't broadcast to q2 + with pytest.raises(QueueEmpty): + await q2.dequeue_event(no_wait=True) From 5bd314ced68bf6153944902bef7d6930c0469647 Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Fri, 4 Jul 2025 17:12:38 +0800 Subject: [PATCH 09/11] test: simplify RedisQueueManager instantiation- Remove redundant 'probability_to_clean_expired' parameter - Update test_integration.py to reflect changes in RedisQueueManager --- .github/actions/spelling/allow.txt | 3 +++ tests/server/test_integration.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index b79ed49a..91b12d1c 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -78,3 +78,6 @@ vulnz zadd zrem zremrangebyscorezscore +hdel +hget +hsetex \ No newline at end of file diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index e3452cb7..acc1ade6 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -1025,10 +1025,10 @@ async def test_redis_queue_mixed_queue(asyncio_redis): @pytest.mark.asyncio async def test_redis_queue_task_id_expiration(asyncio_redis): qm1 = RedisQueueManager( - asyncio_redis, probability_to_clean_expired=1, task_id_ttl_in_second=1 + asyncio_redis, task_id_ttl_in_second=1 ) qm2 = RedisQueueManager( - asyncio_redis, probability_to_clean_expired=1, task_id_ttl_in_second=1 + asyncio_redis, task_id_ttl_in_second=1 ) q1 = EventQueue() await qm1.add('task_1', q1) From 8dab5c761c08174bd3804650c6ce7ff7ebda82c5 Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Fri, 4 Jul 2025 17:25:40 +0800 Subject: [PATCH 10/11] chores: fix lint errors --- src/a2a/server/events/redis_queue_manager.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py index 4608e28b..60b39489 100644 --- a/src/a2a/server/events/redis_queue_manager.py +++ b/src/a2a/server/events/redis_queue_manager.py @@ -58,7 +58,7 @@ def _task_channel_name(self, task_id: str) -> str: return self._relay_channel_name + task_id async def _has_task_id(self, task_id: str) -> bool: - ret = await self._redis.hget(self._task_registry_name, task_id) + ret = await self._redis.hget(self._task_registry_name, task_id) # type: ignore [misc] return ret is not None async def _register_task_id(self, task_id: str) -> None: @@ -67,7 +67,7 @@ async def _register_task_id(self, task_id: str) -> None: key=task_id, value=self._node_id, ex=self._task_id_ttl_in_second, - ) + ) # type: ignore [misc] logger.debug( f'Registered task_id {task_id} to node {self._node_id} in registry.' ) @@ -82,7 +82,9 @@ async def _wrapped_listen_and_relay() -> None: ) expected_node_id = await self._redis.hget( self._task_registry_name, task_id - ) + ) # type: ignore [misc] + if not expected_node_id: + continue expected_node_id = ( expected_node_id.decode('utf-8') if hasattr(expected_node_id, 'decode') @@ -93,14 +95,14 @@ async def _wrapped_listen_and_relay() -> None: await self._redis.publish( self._task_channel_name(task_id), event.model_dump_json(exclude_none=True), - ) + ) # type: ignore [misc] # update TTL for task_id await self._redis.hsetex( name=self._task_registry_name, key=task_id, value=self._node_id, ex=self._task_id_ttl_in_second, - ) + ) # type: ignore [misc] else: logger.error( f'Task {task_id} is not registered on this node. Expected node id: {expected_node_id}' @@ -117,8 +119,8 @@ async def _remove_task_id(self, task_id: str) -> bool: if task_id in self._background_tasks: self._background_tasks[task_id].cancel( 'task_id is closed: ' + task_id - ) - return await self._redis.hdel(self._task_registry_name, task_id) == 1 + ) # type: ignore [misc] + return await self._redis.hdel(self._task_registry_name, task_id) == 1 # type: ignore [misc] async def _subscribe_remote_task_events(self, task_id: str) -> None: channel_id = self._task_channel_name(task_id) From e40519a1914baf3a3718f1ac3865408c07c9f282 Mon Sep 17 00:00:00 2001 From: "long.qul" Date: Sat, 5 Jul 2025 08:27:25 +0800 Subject: [PATCH 11/11] build(deps): remove fakeredis and make redis optional - Remove fakeredis from dependencies - Make redis optional by moving it to the 'redis' extra - Update RedisQueueManager to handle optional node_id - Improve error handling and logging in RedisQueueManager --- pyproject.toml | 3 +-- src/a2a/server/events/redis_queue_manager.py | 11 ++++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f9b9020f..f8de695f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,6 @@ dependencies = [ "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0", "protobuf==5.29.5", - "fakeredis>=2.30.1", - "redis>=6.2.0", ] classifiers = [ @@ -43,6 +41,7 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"] mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] sql = ["sqlalchemy[asyncio,postgresql-asyncpg,aiomysql,aiosqlite]>=2.0.0"] +redis = ["redis>=6.2.0"] [project.urls] homepage = "https://a2aproject.github.io/A2A/" diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py index 60b39489..f5981cc4 100644 --- a/src/a2a/server/events/redis_queue_manager.py +++ b/src/a2a/server/events/redis_queue_manager.py @@ -40,7 +40,7 @@ def __init__( relay_channel_key_prefix: str = 'a2a.event.relay.', task_registry_key: str = 'a2a.event.registry', task_id_ttl_in_second: int = 60 * 60 * 24, - node_id: str = str(uuid.uuid4()), + node_id: str | None = None, ): self._redis = redis_client self._local_queue: dict[str, EventQueue] = {} @@ -52,7 +52,7 @@ def __init__( self._task_registry_name = task_registry_key self._pubsub_listener_task: Task | None = None self._task_id_ttl_in_second = task_id_ttl_in_second - self._node_id = node_id + self._node_id = node_id or str(uuid.uuid4()) def _task_channel_name(self, task_id: str) -> str: return self._relay_channel_name + task_id @@ -62,12 +62,12 @@ async def _has_task_id(self, task_id: str) -> bool: return ret is not None async def _register_task_id(self, task_id: str) -> None: - await self._redis.hsetex( + assert await self._redis.hsetex( name=self._task_registry_name, key=task_id, value=self._node_id, ex=self._task_id_ttl_in_second, - ) # type: ignore [misc] + ) == 1, 'should have registered task id' # type: ignore [misc] logger.debug( f'Registered task_id {task_id} to node {self._node_id} in registry.' ) @@ -84,6 +84,7 @@ async def _wrapped_listen_and_relay() -> None: self._task_registry_name, task_id ) # type: ignore [misc] if not expected_node_id: + logger.warning(f'Task {task_id} is expired or not registered yet.') continue expected_node_id = ( expected_node_id.decode('utf-8') @@ -104,7 +105,7 @@ async def _wrapped_listen_and_relay() -> None: ex=self._task_id_ttl_in_second, ) # type: ignore [misc] else: - logger.error( + logger.warning( f'Task {task_id} is not registered on this node. Expected node id: {expected_node_id}' ) break