Skip to content

refactor: Update FastA2A to use Google A2A SDK #1980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite
node_modules/
**.idea/
.coverage*
.mypy_cache/
20 changes: 15 additions & 5 deletions fasta2a/fasta2a/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from .applications import FastA2A
from .broker import Broker
from .schema import Skill
from .storage import Storage
from .worker import Worker
from .schema import Artifact, Message, Part, Skill, Task, TaskState
from .storage import InMemoryStorage
from .worker import TaskStore, Worker

__all__ = ['FastA2A', 'Skill', 'Storage', 'Broker', 'Worker']
__all__ = [
"FastA2A",
"Skill",
"TaskStore",
"InMemoryStorage",
"Worker",
"Task",
"Message",
"Artifact",
"Part",
"TaskState",
]
193 changes: 93 additions & 100 deletions fasta2a/fasta2a/applications.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,88 @@
from __future__ import annotations as _annotations

from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from typing import Any
from typing import TYPE_CHECKING, Any, Sequence

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.apps.jsonrpc import A2AStarletteApplication
from a2a.server.events import EventQueue
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.types import (
AgentCard,
Capabilities,
InvalidParamsError,
MessageSendParams,
TaskIdParams,
TaskState,
)
from a2a.utils.errors import ServerError
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send
from starlette.types import ExceptionHandler, Lifespan

from .broker import Broker
from .schema import (
AgentCard,
Authentication,
Capabilities,
Provider,
Skill,
a2a_request_ta,
a2a_response_ta,
agent_card_ta,
)
from .storage import Storage
from .task_manager import TaskManager
from .worker import Worker

if TYPE_CHECKING:
from .schema import Provider, Skill


class _WorkerExecutor(AgentExecutor):
"""An adapter to make a fasta2a.Worker compatible with a2a.AgentExecutor."""

def __init__(self, worker: Worker, storage: Storage):
self.worker = worker
self.storage = storage

async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
from a2a.server.tasks import TaskUpdater

self.worker.storage = self.storage

if not (context.task_id and context.context_id and context.message):
raise ServerError(
InvalidParamsError(
message="task_id, context_id, and message are required for execution"
)
)

params = MessageSendParams(
message=context.message, configuration=context.configuration
)

updater = TaskUpdater(event_queue, context.task_id, context.context_id)
await self.worker.run_task(params, updater)

class FastA2A(Starlette):
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
from a2a.server.tasks import TaskUpdater

self.worker.storage = self.storage

if not context.task_id or not context.context_id:
raise ServerError(
InvalidParamsError(
message="task_id and context_id are required for cancellation"
)
)

params = TaskIdParams(id=context.task_id)
updater = TaskUpdater(event_queue, context.task_id, context.context_id)
await self.worker.cancel_task(params, updater)
await updater.update_status(TaskState.canceled, final=True)


class FastA2A:
"""The main class for the FastA2A library."""

def __init__(
self,
*,
storage: Storage,
broker: Broker,
worker: Worker,
# Agent card
name: str | None = None,
url: str = 'http://localhost:8000',
version: str = '1.0.0',
url: str = "http://localhost:8000",
version: str = "1.0.0",
description: str | None = None,
provider: Provider | None = None,
skills: list[Skill] | None = None,
Expand All @@ -46,90 +91,38 @@ def __init__(
routes: Sequence[Route] | None = None,
middleware: Sequence[Middleware] | None = None,
exception_handlers: dict[Any, ExceptionHandler] | None = None,
lifespan: Lifespan[FastA2A] | None = None,
lifespan: Lifespan | None = None,
):
if lifespan is None:
lifespan = _default_lifespan
agent_executor = _WorkerExecutor(worker, storage)

request_handler = DefaultRequestHandler(
agent_executor=agent_executor, task_store=storage
)

super().__init__(
agent_card = AgentCard(
name=name or "Agent",
url=url,
version=version,
description=description,
provider=provider,
skills=skills or [],
defaultInputModes=["application/json"],
defaultOutputModes=["application/json"],
capabilities=Capabilities(
streaming=True, pushNotifications=False, stateTransitionHistory=True
),
)

app_builder = A2AStarletteApplication(
agent_card=agent_card, http_handler=request_handler
)
self.app: Starlette = app_builder.build(
debug=debug,
routes=routes,
middleware=middleware,
exception_handlers=exception_handlers,
lifespan=lifespan,
)

self.name = name or 'Agent'
self.url = url
self.version = version
self.description = description
self.provider = provider
self.skills = skills or []
# NOTE: For now, I don't think there's any reason to support any other input/output modes.
self.default_input_modes = ['application/json']
self.default_output_modes = ['application/json']

self.task_manager = TaskManager(broker=broker, storage=storage)

# Setup
self._agent_card_json_schema: bytes | None = None
self.router.add_route('/.well-known/agent.json', self._agent_card_endpoint, methods=['HEAD', 'GET', 'OPTIONS'])
self.router.add_route('/', self._agent_run_endpoint, methods=['POST'])

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope['type'] == 'http' and not self.task_manager.is_running:
raise RuntimeError('TaskManager was not properly initialized.')
await super().__call__(scope, receive, send)

async def _agent_card_endpoint(self, request: Request) -> Response:
if self._agent_card_json_schema is None:
agent_card = AgentCard(
name=self.name,
url=self.url,
version=self.version,
skills=self.skills,
default_input_modes=self.default_input_modes,
default_output_modes=self.default_output_modes,
capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False),
authentication=Authentication(schemes=[]),
)
if self.description is not None:
agent_card['description'] = self.description
if self.provider is not None:
agent_card['provider'] = self.provider
self._agent_card_json_schema = agent_card_ta.dump_json(agent_card, by_alias=True)
return Response(content=self._agent_card_json_schema, media_type='application/json')

async def _agent_run_endpoint(self, request: Request) -> Response:
"""This is the main endpoint for the A2A server.

Although the specification allows freedom of choice and implementation, I'm pretty sure about some decisions.

1. The server will always either send a "submitted" or a "failed" on `tasks/send`.
Never a "completed" on the first message.
2. There are three possible ends for the task:
2.1. The task was "completed" successfully.
2.2. The task was "canceled".
2.3. The task "failed".
3. The server will send a "working" on the first chunk on `tasks/pushNotification/get`.
"""
data = await request.body()
a2a_request = a2a_request_ta.validate_json(data)

if a2a_request['method'] == 'tasks/send':
jsonrpc_response = await self.task_manager.send_task(a2a_request)
elif a2a_request['method'] == 'tasks/get':
jsonrpc_response = await self.task_manager.get_task(a2a_request)
elif a2a_request['method'] == 'tasks/cancel':
jsonrpc_response = await self.task_manager.cancel_task(a2a_request)
else:
raise NotImplementedError(f'Method {a2a_request["method"]} not implemented.')
return Response(
content=a2a_response_ta.dump_json(jsonrpc_response, by_alias=True), media_type='application/json'
)


@asynccontextmanager
async def _default_lifespan(app: FastA2A) -> AsyncIterator[None]:
async with app.task_manager:
yield
async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
await self.app(scope, receive, send)
98 changes: 0 additions & 98 deletions fasta2a/fasta2a/broker.py

This file was deleted.

Loading
Loading