Skip to content

feat(server): implement Redis-based event queue manager #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ __pycache__
.venv
coverage.xml
.nox
spec.json
spec.json
.idea
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"grpcio-tools>=1.60",
"grpcio_reflection>=1.7.0",
"protobuf==5.29.5",
"fakeredis>=2.30.1",
]

classifiers = [
Expand Down
9 changes: 8 additions & 1 deletion src/a2a/server/events/event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import logging
import sys

from typing import Annotated

from pydantic import Field

from a2a.types import (
Message,
Task,
Expand All @@ -14,7 +18,10 @@
logger = logging.getLogger(__name__)


Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
Event = Annotated[
Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
Field(discriminator='kind'),
]
"""Type alias for events that can be enqueued."""

DEFAULT_MAX_QUEUE_SIZE = 1024
Expand Down
152 changes: 152 additions & 0 deletions src/a2a/server/events/redis_queue_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio

from asyncio import Task
from functools import partial

from redis.asyncio import Redis

from a2a.server.events import (
Event,
EventConsumer,
EventQueue,
NoTaskQueue,
QueueManager,
TaskQueueExists,
)


class RedisQueueManager(QueueManager):
"""This implements the `QueueManager` interface using Redis for event
queues. Primary jobs:
1. Broadcast local events to proxy queues in other processes using redis pubsub
2. Subscribe event messages from redis pubsub and replay to local proxy queues
"""

def __init__(
self,
redis_client: Redis,
relay_channel_key_prefix: str = 'a2a.event.relay.',
task_registry_key: str = 'a2a.event.registry',
):
self._redis = redis_client
self._local_queue: dict[str, EventQueue] = {}
self._proxy_queue: dict[str, EventQueue] = {}
self._lock = asyncio.Lock()
self._pubsub = redis_client.pubsub()
self._relay_channel_name = relay_channel_key_prefix
self._background_tasks: dict[str, Task] = {}
self._task_registry_name = task_registry_key

async def _listen_and_relay(self, task_id: str):
c = EventConsumer(self._local_queue[task_id])
async for event in c.consume_all():
await self._redis.publish(
self._task_channel_name(task_id),
event.model_dump_json(exclude_none=True),
)

def _task_channel_name(self, task_id: str):
return self._relay_channel_name + task_id

async def _has_task_id(self, task_id: str):
ret = await self._redis.sismember(self._task_registry_name, task_id)

Check failure on line 52 in src/a2a/server/events/redis_queue_manager.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`sismember` is not a recognized word. (unrecognized-spelling)
return ret

async def _register_task_id(self, task_id: str):
await self._redis.sadd(self._task_registry_name, task_id)

Check failure on line 56 in src/a2a/server/events/redis_queue_manager.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`sadd` is not a recognized word. (unrecognized-spelling)
self._background_tasks[task_id] = asyncio.create_task(
self._listen_and_relay(task_id)
)

async def _remove_task_id(self, task_id: str):
if task_id in self._background_tasks:
self._background_tasks[task_id].cancel(
'task_id is closed: ' + task_id
)
return await self._redis.srem(self._task_registry_name, task_id)

Check failure on line 66 in src/a2a/server/events/redis_queue_manager.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`srem` is not a recognized word. (unrecognized-spelling)

async def _subscribe_remote_task_events(self, task_id: str):
await self._pubsub.subscribe(
**{
self._task_channel_name(task_id): partial(
self._relay_remote_events, task_id
)
}
)

def _unsubscribe_remote_task_events(self, task_id: str):
self._pubsub.unsubscribe(self._task_channel_name(task_id))

def _relay_remote_events(self, task_id: str, event_json: str):
if task_id in self._proxy_queue:
event = Event.model_validate_json(event_json)
self._proxy_queue[task_id].enqueue_event(event)

async def add(self, task_id: str, queue: EventQueue) -> None:
async with self._lock:
if await self._has_task_id(task_id):
raise TaskQueueExists()
self._local_queue[task_id] = queue
await self._register_task_id(task_id)

async def get(self, task_id: str) -> EventQueue | None:
async with self._lock:
# lookup locally
if task_id in self._local_queue:
return self._local_queue[task_id]
# lookup globally
if await self._has_task_id(task_id):
if task_id not in self._proxy_queue:
queue = EventQueue()
self._proxy_queue[task_id] = queue
await self._subscribe_remote_task_events(task_id)
return self._proxy_queue[task_id]
return None

async def tap(self, task_id: str) -> EventQueue | None:
event_queue = await self.get(task_id)
if event_queue:
return event_queue.tap()
return None

async def close(self, task_id: str) -> None:
async with self._lock:
if task_id in self._local_queue:
# close locally
queue = self._local_queue.pop(task_id)
await queue.close()
# remove from global registry if a local queue is closed
await self._remove_task_id(task_id)
return

if task_id in self._proxy_queue:
# close proxy queue
queue = self._proxy_queue.pop(task_id)
await queue.close()
# unsubscribe from remote, but don't remove from global registry
self._unsubscribe_remote_task_events(task_id)
return

raise NoTaskQueue()

async def create_or_tap(self, task_id: str) -> EventQueue:
async with self._lock:
if await self._has_task_id(task_id):
# if it's a local queue, tap directly
if task_id in self._local_queue:
return self._local_queue[task_id].tap()

# if it's a proxy queue, tap the proxy
if task_id in self._proxy_queue:
return self._proxy_queue[task_id].tap()

# if the proxy is not created, create the proxy and return
queue = EventQueue()
self._proxy_queue[task_id] = queue
await self._subscribe_remote_task_events(task_id)
return self._proxy_queue[task_id]
# the task doesn't exist before, create a local queue
queue = EventQueue()
self._local_queue[task_id] = queue
await self._register_task_id(task_id)
return queue
126 changes: 126 additions & 0 deletions tests/server/events/test_redis_queue_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import asyncio

from unittest.mock import MagicMock

import pytest

from fakeredis import FakeAsyncRedis

from a2a.server.events import EventQueue, TaskQueueExists
from a2a.server.events.redis_queue_manager import RedisQueueManager


class TestRedisQueueManager:
@pytest.fixture
def redis(self):
return FakeAsyncRedis()

@pytest.fixture
def queue_manager(self, redis):
return RedisQueueManager(redis)

@pytest.fixture
def event_queue(self):
queue = MagicMock(spec=EventQueue)
# Mock the tap method to return itself
queue.tap.return_value = queue
return queue

@pytest.mark.asyncio
async def test_init(self, queue_manager):
assert queue_manager._local_queue == {}
assert queue_manager._proxy_queue == {}
assert isinstance(queue_manager._lock, asyncio.Lock)

@pytest.mark.asyncio
async def test_add_new_queue(self, queue_manager, event_queue):
"""Test adding a new queue to the manager."""
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)
assert queue_manager._local_queue[task_id] == event_queue

@pytest.mark.asyncio
async def test_add_existing_queue(self, queue_manager, event_queue):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

with pytest.raises(TaskQueueExists):
await queue_manager.add(task_id, event_queue)

@pytest.mark.asyncio
async def test_get_existing_queue(self, queue_manager, event_queue):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

result = await queue_manager.get(task_id)
assert result == event_queue

@pytest.mark.asyncio
async def test_get_nonexistent_queue(self, queue_manager):
result = await queue_manager.get('nonexistent_task_id')
assert result is None

@pytest.mark.asyncio
async def test_tap_existing_queue(self, queue_manager, event_queue):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

result = await queue_manager.tap(task_id)
assert result == event_queue
event_queue.tap.assert_called_once()

@pytest.mark.asyncio
async def test_tap_nonexistent_queue(self, queue_manager):
result = await queue_manager.tap('nonexistent_task_id')
assert result is None

@pytest.mark.asyncio
async def test_close_existing_queue(self, queue_manager, event_queue):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

await queue_manager.close(task_id)
assert task_id not in queue_manager._local_queue

@pytest.mark.asyncio
async def test_create_or_tap_existing_queue(
self, queue_manager, event_queue
):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

result = await queue_manager.create_or_tap(task_id)

assert result == event_queue
event_queue.tap.assert_called_once()

@pytest.mark.asyncio
async def test_concurrency(self, queue_manager):
async def add_task(task_id):
queue = EventQueue()
await queue_manager.add(task_id, queue)
return task_id

async def get_task(task_id):
return await queue_manager.get(task_id)

# Create 10 different task IDs
task_ids = [f'task_{i}' for i in range(10)]

# Add tasks concurrently
add_tasks = [add_task(task_id) for task_id in task_ids]
added_task_ids = await asyncio.gather(*add_tasks)

# Verify all tasks were added
assert set(added_task_ids) == set(task_ids)

# Get tasks concurrently
get_tasks = [get_task(task_id) for task_id in task_ids]
queues = await asyncio.gather(*get_tasks)

# Verify all queues are not None
assert all(queue is not None for queue in queues)

# Verify all tasks are in the manager
for task_id in task_ids:
assert task_id in queue_manager._local_queue
Loading