From 9df4cb7da3f3c09e931ff26c79adb23c2821547f Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Fri, 13 Jun 2025 13:44:27 -0400 Subject: [PATCH 1/6] Update Gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index acb015449..092411ded 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite node_modules/ **.idea/ .coverage* +.mypy_cache/ From 619432cfd7c992883ad7f5dadde6279d1ce9a56a Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Fri, 13 Jun 2025 13:50:32 -0400 Subject: [PATCH 2/6] refactor: Update FastA2A to use Google A2A SDK --- fasta2a/fasta2a/__init__.py | 16 +- fasta2a/fasta2a/applications.py | 179 ++++------ fasta2a/fasta2a/broker.py | 98 ----- fasta2a/fasta2a/client.py | 95 ++--- fasta2a/fasta2a/schema.py | 614 +++----------------------------- fasta2a/fasta2a/storage.py | 97 +---- fasta2a/fasta2a/task_manager.py | 169 --------- fasta2a/fasta2a/worker.py | 82 ++--- fasta2a/pyproject.toml | 1 + 9 files changed, 225 insertions(+), 1126 deletions(-) delete mode 100644 fasta2a/fasta2a/broker.py delete mode 100644 fasta2a/fasta2a/task_manager.py diff --git a/fasta2a/fasta2a/__init__.py b/fasta2a/fasta2a/__init__.py index 4a8b10629..fdc8247f4 100644 --- a/fasta2a/fasta2a/__init__.py +++ b/fasta2a/fasta2a/__init__.py @@ -1,7 +1,17 @@ from .applications import FastA2A -from .broker import Broker -from .schema import Skill +from .client import A2AClient +from .schema import Message, Part, Role, Skill, TextPart from .storage import Storage from .worker import Worker -__all__ = ['FastA2A', 'Skill', 'Storage', 'Broker', 'Worker'] +__all__ = [ + "FastA2A", + "A2AClient", + "Worker", + "Storage", + "Skill", + "Message", + "Part", + "Role", + "TextPart", +] diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 61301262b..489d0237e 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -1,57 +1,100 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Sequence -from contextlib import asynccontextmanager from typing import Any -from starlette.applications import Starlette +import httpx +from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler +from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentProvider from starlette.middleware import Middleware -from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Route from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send -from .broker import Broker -from .schema import ( - AgentCard, - Authentication, - Capabilities, - Provider, - Skill, - a2a_request_ta, - a2a_response_ta, - agent_card_ta, -) +from .schema import Skill from .storage import Storage -from .task_manager import TaskManager +from .worker import Worker -class FastA2A(Starlette): - """The main class for the FastA2A library.""" +class FastA2A: + """ + The main class for the FastA2A library. It provides a simple way to create + an A2A server by wrapping the Google A2A SDK. + """ def __init__( self, *, - storage: Storage, - broker: Broker, + worker: Worker, + storage: Storage | None = None, # Agent card - name: str | None = None, - url: str = 'http://localhost:8000', - version: str = '1.0.0', + name: str = "Agent", + url: str = "http://localhost:8000", + version: str = "1.0.0", description: str | None = None, - provider: Provider | None = None, + provider: AgentProvider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, - routes: Sequence[Route] | None = None, - middleware: Sequence[Middleware] | None = None, + routes: list[Route] | None = None, + middleware: list[Middleware] | None = None, exception_handlers: dict[Any, ExceptionHandler] | None = None, - lifespan: Lifespan[FastA2A] | None = None, + lifespan: Lifespan | None = None, ): - if lifespan is None: - lifespan = _default_lifespan + """ + Initializes the FastA2A application. + + Args: + worker: An implementation of `fasta2a.Worker` (which is an `a2a.server.agent_execution.AgentExecutor`). + storage: An implementation of `fasta2a.Storage` (which is an `a2a.server.tasks.TaskStore`). + Defaults to `InMemoryTaskStore`. + name: The human-readable name of the agent. + url: The URL where the agent is hosted. + version: The version of the agent. + description: A human-readable description of the agent. + provider: The service provider of the agent. + skills: A list of skills the agent can perform. + debug: Starlette's debug flag. + routes: A list of additional Starlette routes. + middleware: A list of Starlette middleware. + exception_handlers: A dictionary of Starlette exception handlers. + lifespan: A Starlette lifespan context manager. + """ + self.agent_card = AgentCard( + name=name, + url=url, + version=version, + description=description or "A FastA2A Agent", + provider=provider, + skills=skills or [], + capabilities=AgentCapabilities( + streaming=True, pushNotifications=True, stateTransitionHistory=True + ), + defaultInputModes=["application/json"], + defaultOutputModes=["application/json"], + securitySchemes={}, + ) - super().__init__( + self.storage = storage or InMemoryTaskStore() + self.worker = worker + + # The SDK's DefaultRequestHandler uses httpx to send push notifications + http_client = httpx.AsyncClient() + push_notifier = InMemoryPushNotifier(httpx_client) + + request_handler = DefaultRequestHandler( + agent_executor=self.worker, + task_store=self.storage, + push_notifier=push_notifier, + ) + + a2a_app = A2AStarletteApplication( + agent_card=self.agent_card, + http_handler=request_handler, + ) + + self.app = a2a_app.build( debug=debug, routes=routes, middleware=middleware, @@ -59,77 +102,5 @@ def __init__( lifespan=lifespan, ) - self.name = name or 'Agent' - self.url = url - self.version = version - self.description = description - self.provider = provider - self.skills = skills or [] - # NOTE: For now, I don't think there's any reason to support any other input/output modes. - self.default_input_modes = ['application/json'] - self.default_output_modes = ['application/json'] - - self.task_manager = TaskManager(broker=broker, storage=storage) - - # Setup - self._agent_card_json_schema: bytes | None = None - self.router.add_route('/.well-known/agent.json', self._agent_card_endpoint, methods=['HEAD', 'GET', 'OPTIONS']) - self.router.add_route('/', self._agent_run_endpoint, methods=['POST']) - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope['type'] == 'http' and not self.task_manager.is_running: - raise RuntimeError('TaskManager was not properly initialized.') - await super().__call__(scope, receive, send) - - async def _agent_card_endpoint(self, request: Request) -> Response: - if self._agent_card_json_schema is None: - agent_card = AgentCard( - name=self.name, - url=self.url, - version=self.version, - skills=self.skills, - default_input_modes=self.default_input_modes, - default_output_modes=self.default_output_modes, - capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False), - authentication=Authentication(schemes=[]), - ) - if self.description is not None: - agent_card['description'] = self.description - if self.provider is not None: - agent_card['provider'] = self.provider - self._agent_card_json_schema = agent_card_ta.dump_json(agent_card, by_alias=True) - return Response(content=self._agent_card_json_schema, media_type='application/json') - - async def _agent_run_endpoint(self, request: Request) -> Response: - """This is the main endpoint for the A2A server. - - Although the specification allows freedom of choice and implementation, I'm pretty sure about some decisions. - - 1. The server will always either send a "submitted" or a "failed" on `tasks/send`. - Never a "completed" on the first message. - 2. There are three possible ends for the task: - 2.1. The task was "completed" successfully. - 2.2. The task was "canceled". - 2.3. The task "failed". - 3. The server will send a "working" on the first chunk on `tasks/pushNotification/get`. - """ - data = await request.body() - a2a_request = a2a_request_ta.validate_json(data) - - if a2a_request['method'] == 'tasks/send': - jsonrpc_response = await self.task_manager.send_task(a2a_request) - elif a2a_request['method'] == 'tasks/get': - jsonrpc_response = await self.task_manager.get_task(a2a_request) - elif a2a_request['method'] == 'tasks/cancel': - jsonrpc_response = await self.task_manager.cancel_task(a2a_request) - else: - raise NotImplementedError(f'Method {a2a_request["method"]} not implemented.') - return Response( - content=a2a_response_ta.dump_json(jsonrpc_response, by_alias=True), media_type='application/json' - ) - - -@asynccontextmanager -async def _default_lifespan(app: FastA2A) -> AsyncIterator[None]: - async with app.task_manager: - yield + await self.app(scope, receive, send) diff --git a/fasta2a/fasta2a/broker.py b/fasta2a/fasta2a/broker.py deleted file mode 100644 index c84b73872..000000000 --- a/fasta2a/fasta2a/broker.py +++ /dev/null @@ -1,98 +0,0 @@ -from __future__ import annotations as _annotations - -from abc import ABC, abstractmethod -from collections.abc import AsyncIterator -from contextlib import AsyncExitStack -from dataclasses import dataclass -from typing import Annotated, Any, Generic, Literal, TypeVar - -import anyio -from opentelemetry.trace import Span, get_current_span, get_tracer -from pydantic import Discriminator -from typing_extensions import Self, TypedDict - -from .schema import TaskIdParams, TaskSendParams - -tracer = get_tracer(__name__) - - -@dataclass -class Broker(ABC): - """The broker class is in charge of scheduling the tasks. - - The HTTP server uses the broker to schedule tasks. - - The simple implementation is the `InMemoryBroker`, which is the broker that - runs the tasks in the same process as the HTTP server. That said, this class can be - extended to support remote workers. - """ - - @abstractmethod - async def run_task(self, params: TaskSendParams) -> None: - """Send a task to be executed by the worker.""" - raise NotImplementedError('send_run_task is not implemented yet.') - - @abstractmethod - async def cancel_task(self, params: TaskIdParams) -> None: - """Cancel a task.""" - raise NotImplementedError('send_cancel_task is not implemented yet.') - - @abstractmethod - async def __aenter__(self) -> Self: ... - - @abstractmethod - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): ... - - @abstractmethod - def receive_task_operations(self) -> AsyncIterator[TaskOperation]: - """Receive task operations from the broker. - - On a multi-worker setup, the broker will need to round-robin the task operations - between the workers. - """ - - -OperationT = TypeVar('OperationT') -ParamsT = TypeVar('ParamsT') - - -class _TaskOperation(TypedDict, Generic[OperationT, ParamsT]): - """A task operation.""" - - operation: OperationT - params: ParamsT - _current_span: Span - - -_RunTask = _TaskOperation[Literal['run'], TaskSendParams] -_CancelTask = _TaskOperation[Literal['cancel'], TaskIdParams] - -TaskOperation = Annotated['_RunTask | _CancelTask', Discriminator('operation')] - - -class InMemoryBroker(Broker): - """A broker that schedules tasks in memory.""" - - async def __aenter__(self): - self.aexit_stack = AsyncExitStack() - await self.aexit_stack.__aenter__() - - self._write_stream, self._read_stream = anyio.create_memory_object_stream[TaskOperation]() - await self.aexit_stack.enter_async_context(self._read_stream) - await self.aexit_stack.enter_async_context(self._write_stream) - - return self - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): - await self.aexit_stack.__aexit__(exc_type, exc_value, traceback) - - async def run_task(self, params: TaskSendParams) -> None: - await self._write_stream.send(_RunTask(operation='run', params=params, _current_span=get_current_span())) - - async def cancel_task(self, params: TaskIdParams) -> None: - await self._write_stream.send(_CancelTask(operation='cancel', params=params, _current_span=get_current_span())) - - async def receive_task_operations(self) -> AsyncIterator[TaskOperation]: - """Receive task operations from the broker.""" - async for task_operation in self._read_stream: - yield task_operation diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py index 5c5aabd81..13c787e60 100644 --- a/fasta2a/fasta2a/client.py +++ b/fasta2a/fasta2a/client.py @@ -3,39 +3,44 @@ import uuid from typing import Any -import pydantic - -from .schema import ( +from a2a.client.client import A2AClient as BaseA2AClient +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.types import ( GetTaskRequest, GetTaskResponse, Message, + MessageSendParams, PushNotificationConfig, - SendTaskRequest, - SendTaskResponse, - TaskSendParams, - a2a_request_ta, + SendMessageRequest, + SendMessageResponse, ) -send_task_response_ta = pydantic.TypeAdapter(SendTaskResponse) -get_task_response_ta = pydantic.TypeAdapter(GetTaskResponse) - try: import httpx except ImportError as _import_error: raise ImportError( - 'httpx is required to use the A2AClient. Please install it with `pip install httpx`.', + "httpx is required to use the A2AClient. Please install it with `pip install httpx`.", ) from _import_error +UnexpectedResponseError = A2AClientHTTPError + class A2AClient: - """A client for the A2A protocol.""" + """A client for the A2A protocol, built on the Google A2A SDK.""" - def __init__(self, base_url: str = 'http://localhost:8000', http_client: httpx.AsyncClient | None = None) -> None: + def __init__( + self, + base_url: str = "http://localhost:8000", + http_client: httpx.AsyncClient | None = None, + ) -> None: if http_client is None: - self.http_client = httpx.AsyncClient(base_url=base_url) + self.http_client = httpx.AsyncClient() else: self.http_client = http_client - self.http_client.base_url = base_url + + # The SDK client requires a URL on the agent card, but we can initialize it with a dummy + # and then set the URL directly for the internal calls. + self._sdk_client = BaseA2AClient(httpx_client=self.http_client, url=base_url) async def send_task( self, @@ -43,36 +48,36 @@ async def send_task( history_length: int | None = None, push_notification: PushNotificationConfig | None = None, metadata: dict[str, Any] | None = None, - ) -> SendTaskResponse: - task = TaskSendParams(message=message, id=str(uuid.uuid4())) - if history_length is not None: - task['history_length'] = history_length - if push_notification is not None: - task['push_notification'] = push_notification - if metadata is not None: - task['metadata'] = metadata + ) -> SendMessageResponse: + """Sends a task to the A2A server.""" + if not message.taskId: + message.taskId = str(uuid.uuid4()) - payload = SendTaskRequest(jsonrpc='2.0', id=None, method='tasks/send', params=task) - content = a2a_request_ta.dump_json(payload, by_alias=True) - response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) - self._raise_for_status(response) - return send_task_response_ta.validate_json(response.content) - - async def get_task(self, task_id: str) -> GetTaskResponse: - payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) - content = a2a_request_ta.dump_json(payload, by_alias=True) - response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) - self._raise_for_status(response) - return get_task_response_ta.validate_json(response.content) - - def _raise_for_status(self, response: httpx.Response) -> None: - if response.status_code >= 400: - raise UnexpectedResponseError(response.status_code, response.text) + params = MessageSendParams( + message=message, + configuration={ + "historyLength": history_length, + "pushNotificationConfig": push_notification, + }, + metadata=metadata, + ) + payload = SendMessageRequest( + id=str(uuid.uuid4()), method="message/send", params=params + ) + try: + response = await self._sdk_client.send_message(payload) + return response + except (A2AClientHTTPError, A2AClientJSONError) as e: + raise UnexpectedResponseError(getattr(e, "status_code", 500), str(e)) from e -class UnexpectedResponseError(Exception): - """An error raised when an unexpected response is received from the server.""" - - def __init__(self, status_code: int, content: str) -> None: - self.status_code = status_code - self.content = content + async def get_task(self, task_id: str) -> GetTaskResponse: + """Retrieves a task from the A2A server.""" + payload = GetTaskRequest( + id=str(uuid.uuid4()), method="tasks/get", params={"id": task_id} + ) + try: + response = await self._sdk_client.get_task(payload) + return response + except (A2AClientHTTPError, A2AClientJSONError) as e: + raise UnexpectedResponseError(getattr(e, "status_code", 500), str(e)) from e diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 57193a97d..b5c7e331f 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -1,567 +1,55 @@ -"""This module contains the schema for the agent card.""" - -from __future__ import annotations as _annotations - -from typing import Annotated, Any, Generic, Literal, TypeVar, Union - -import pydantic -from pydantic import Discriminator, TypeAdapter -from pydantic.alias_generators import to_camel -from typing_extensions import NotRequired, TypeAlias, TypedDict - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class AgentCard(TypedDict): - """The card that describes an agent.""" - - name: str - """Human readable name of the agent e.g. "Recipe Agent".""" - - description: NotRequired[str] - """A human-readable description of the agent. - - Used to assist users and other agents in understanding what the agent can do. - (e.g. "Agent that helps users with recipes and cooking.") - """ - - # TODO(Marcelo): The spec makes url required. - url: NotRequired[str] - """A URL to the address the agent is hosted at.""" - - provider: NotRequired[Provider] - """The service provider of the agent.""" - - # TODO(Marcelo): The spec makes version required. - version: NotRequired[str] - """The version of the agent - format is up to the provider. (e.g. "1.0.0")""" - - documentation_url: NotRequired[str] - """A URL to documentation for the agent.""" - - capabilities: Capabilities - """The capabilities of the agent.""" - - authentication: Authentication - """The authentication schemes supported by the agent. - - Intended to match OpenAPI authentication structure. - """ - - default_input_modes: list[str] - """Supported mime types for input data.""" - - default_output_modes: list[str] - """Supported mime types for output data.""" - - skills: list[Skill] - - -agent_card_ta = pydantic.TypeAdapter(AgentCard) - - -class Provider(TypedDict): - """The service provider of the agent.""" - - organization: str - url: str - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Capabilities(TypedDict): - """The capabilities of the agent.""" - - streaming: NotRequired[bool] - """Whether the agent supports streaming.""" - - push_notifications: NotRequired[bool] - """Whether the agent can notify updates to client.""" - - state_transition_history: NotRequired[bool] - """Whether the agent exposes status change history for tasks.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Authentication(TypedDict): - """The authentication schemes supported by the agent.""" - - schemes: list[str] - """The authentication schemes supported by the agent. (e.g. "Basic", "Bearer")""" - - credentials: NotRequired[str] - """The credentials a client should use for private cards.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Skill(TypedDict): - """Skills are a unit of capability that an agent can perform.""" - - id: str - """A unique identifier for the skill.""" - - name: str - """Human readable name of the skill.""" - - description: str - """A human-readable description of the skill. - - It will be used by the client or a human as a hint to understand the skill. - """ - - tags: list[str] - """Set of tag-words describing classes of capabilities for this specific skill. - - Examples: "cooking", "customer support", "billing". - """ - - examples: NotRequired[list[str]] - """The set of example scenarios that the skill can perform. - - Will be used by the client as a hint to understand how the skill can be used. (e.g. "I need a recipe for bread") - """ - - input_modes: list[str] - """Supported mime types for input data.""" - - output_modes: list[str] - """Supported mime types for output data.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Artifact(TypedDict): - """Agents generate Artifacts as an end result of a Task. - - Artifacts are immutable, can be named, and can have multiple parts. A streaming response can append parts to - existing Artifacts. - - A single Task can generate many Artifacts. For example, "create a webpage" could create separate HTML and image - Artifacts. - """ - - name: NotRequired[str] - """The name of the artifact.""" - - description: NotRequired[str] - """A description of the artifact.""" - - parts: list[Part] - """The parts that make up the artifact.""" - - metadata: NotRequired[dict[str, Any]] - """Metadata about the artifact.""" - - index: int - """The index of the artifact.""" - - append: NotRequired[bool] - """Whether to append this artifact to an existing one.""" - - last_chunk: NotRequired[bool] - """Whether this is the last chunk of the artifact.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class PushNotificationConfig(TypedDict): - """Configuration for push notifications. - - A2A supports a secure notification mechanism whereby an agent can notify a client of an update - outside of a connected session via a PushNotificationService. Within and across enterprises, - it is critical that the agent verifies the identity of the notification service, authenticates - itself with the service, and presents an identifier that ties the notification to the executing - Task. - - The target server of the PushNotificationService should be considered a separate service, and - is not guaranteed (or even expected) to be the client directly. This PushNotificationService is - responsible for authenticating and authorizing the agent and for proxying the verified notification - to the appropriate endpoint (which could be anything from a pub/sub queue, to an email inbox or - other service, etc). - - For contrived scenarios with isolated client-agent pairs (e.g. local service mesh in a contained - VPC, etc.) or isolated environments without enterprise security concerns, the client may choose to - simply open a port and act as its own PushNotificationService. Any enterprise implementation will - likely have a centralized service that authenticates the remote agents with trusted notification - credentials and can handle online/offline scenarios. (This should be thought of similarly to a - mobile Push Notification Service). - """ - - url: str - """The URL to send push notifications to.""" - - token: NotRequired[str] - """Token unique to this task/session.""" - - authentication: NotRequired[Authentication] - """Authentication details for push notifications.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskPushNotificationConfig(TypedDict): - """Configuration for task push notifications.""" - - id: str - """The task id.""" - - push_notification_config: PushNotificationConfig - """The push notification configuration.""" - - -class Message(TypedDict): - """A Message contains any content that is not an Artifact. - - This can include things like agent thoughts, user context, instructions, errors, status, or metadata. - - All content from a client comes in the form of a Message. Agents send Messages to communicate status or to provide - instructions (whereas generated results are sent as Artifacts). - - A Message can have multiple parts to denote different pieces of content. For example, a user request could include - a textual description from a user and then multiple files used as context from the client. - """ - - role: Literal['user', 'agent'] - """The role of the message.""" - - parts: list[Part] - """The parts of the message.""" - - metadata: NotRequired[dict[str, Any]] - """Metadata about the message.""" - - -class _BasePart(TypedDict): - """A base class for all parts.""" - - metadata: NotRequired[dict[str, Any]] - - -class TextPart(_BasePart): - """A part that contains text.""" - - type: Literal['text'] - """The type of the part.""" - - text: str - """The text of the part.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class FilePart(_BasePart): - """A part that contains a file.""" - - type: Literal['file'] - """The type of the part.""" - - file: File - """The file of the part.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class _BaseFile(_BasePart): - """A base class for all file types.""" - - name: NotRequired[str] - """The name of the file.""" - - mime_type: str - """The mime type of the file.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class _BinaryFile(_BaseFile): - """A binary file.""" - - data: str - """The base64 encoded bytes of the file.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class _URLFile(_BaseFile): - """A file that is hosted on a remote URL.""" - - url: str - """The URL of the file.""" - - -File: TypeAlias = Union[_BinaryFile, _URLFile] -"""A file is a binary file or a URL file.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class DataPart(_BasePart): - """A part that contains data.""" - - type: Literal['data'] - """The type of the part.""" - - data: dict[str, Any] - """The data of the part.""" - - -Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='type')] -"""A fully formed piece of content exchanged between a client and a remote agent as part of a Message or an Artifact. - -Each Part has its own content type and metadata. +""" +This module re-exports the core A2A schema types from the `google-a2a` SDK. +The types are based on Pydantic `BaseModel` and align with the official A2A JSON specification. """ -TaskState: TypeAlias = Literal['submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'unknown'] -"""The possible states of a task.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskStatus(TypedDict): - """Status and accompanying message for a task.""" - - state: TaskState - """The current state of the task.""" - - message: NotRequired[Message] - """Additional status updates for client.""" - - timestamp: NotRequired[str] - """ISO datetime value of when the status was updated.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Task(TypedDict): - """A Task is a stateful entity that allows Clients and Remote Agents to achieve a specific outcome. - - Clients and Remote Agents exchange Messages within a Task. Remote Agents generate results as Artifacts. - A Task is always created by a Client and the status is always determined by the Remote Agent. - """ - - id: str - """Unique identifier for the task.""" - - session_id: NotRequired[str] - """Client-generated id for the session holding the task.""" - - status: TaskStatus - """Current status of the task.""" - - history: NotRequired[list[Message]] - """Optional history of messages.""" - - artifacts: NotRequired[list[Artifact]] - """Collection of artifacts created by the agent.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskStatusUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" - - id: str - """The id of the task.""" - - status: TaskStatus - """The status of the task.""" - - final: bool - """Indicates the end of the event stream.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskArtifactUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" - - id: str - """The id of the task.""" - - artifact: Artifact - """The artifact that was updated.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskIdParams(TypedDict): - """Parameters for a task id.""" - - id: str - metadata: NotRequired[dict[str, Any]] - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskQueryParams(TaskIdParams): - """Query parameters for a task.""" - - history_length: NotRequired[int] - """Number of recent messages to be retrieved.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskSendParams(TypedDict): - """Sent by the client to the agent to create, continue, or restart a task.""" - - id: str - """The id of the task.""" - - session_id: NotRequired[str] - """The server creates a new sessionId for new tasks if not set.""" - - message: Message - """The message to send to the agent.""" - - history_length: NotRequired[int] - """Number of recent messages to be retrieved.""" - - push_notification: NotRequired[PushNotificationConfig] - """Where the server should send notifications when disconnected.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -class JSONRPCMessage(TypedDict): - """A JSON RPC message.""" - - jsonrpc: Literal['2.0'] - """The JSON RPC version.""" - - id: int | str | None - """The request id.""" - - -Method = TypeVar('Method') -Params = TypeVar('Params') - - -class JSONRPCRequest(JSONRPCMessage, Generic[Method, Params]): - """A JSON RPC request.""" - - method: Method - """The method to call.""" - - params: Params - """The parameters to pass to the method.""" - - -############################################################################################### -####################################### Error codes ####################################### -############################################################################################### - -CodeT = TypeVar('CodeT', bound=int) -MessageT = TypeVar('MessageT', bound=str) - - -class JSONRPCError(TypedDict, Generic[CodeT, MessageT]): - """A JSON RPC error.""" - - code: CodeT - message: MessageT - data: NotRequired[Any] - - -ResultT = TypeVar('ResultT') -ErrorT = TypeVar('ErrorT', bound=JSONRPCError[Any, Any]) - - -class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): - """A JSON RPC response.""" - - result: NotRequired[ResultT] - error: NotRequired[ErrorT] - - -JSONParseError = JSONRPCError[Literal[-32700], Literal['Invalid JSON payload']] -"""A JSON RPC error for a parse error.""" - -InvalidRequestError = JSONRPCError[Literal[-32600], Literal['Request payload validation error']] -"""A JSON RPC error for an invalid request.""" - -MethodNotFoundError = JSONRPCError[Literal[-32601], Literal['Method not found']] -"""A JSON RPC error for a method not found.""" - -InvalidParamsError = JSONRPCError[Literal[-32602], Literal['Invalid parameters']] -"""A JSON RPC error for invalid parameters.""" - -InternalError = JSONRPCError[Literal[-32603], Literal['Internal error']] -"""A JSON RPC error for an internal error.""" - -TaskNotFoundError = JSONRPCError[Literal[-32001], Literal['Task not found']] -"""A JSON RPC error for a task not found.""" - -TaskNotCancelableError = JSONRPCError[Literal[-32002], Literal['Task not cancelable']] -"""A JSON RPC error for a task not cancelable.""" - -PushNotificationNotSupportedError = JSONRPCError[Literal[-32003], Literal['Push notification not supported']] -"""A JSON RPC error for a push notification not supported.""" - -UnsupportedOperationError = JSONRPCError[Literal[-32004], Literal['This operation is not supported']] -"""A JSON RPC error for an unsupported operation.""" - -ContentTypeNotSupportedError = JSONRPCError[Literal[-32005], Literal['Incompatible content types']] -"""A JSON RPC error for incompatible content types.""" - -############################################################################################### -####################################### Requests and responses ############################ -############################################################################################### - -SendTaskRequest = JSONRPCRequest[Literal['tasks/send'], TaskSendParams] -"""A JSON RPC request to send a task.""" - -SendTaskResponse = JSONRPCResponse[Task, JSONRPCError[Any, Any]] -"""A JSON RPC response to send a task.""" - -SendTaskStreamingRequest = JSONRPCRequest[Literal['tasks/sendSubscribe'], TaskSendParams] -"""A JSON RPC request to send a task and receive updates.""" - -SendTaskStreamingResponse = JSONRPCResponse[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent], InternalError] -"""A JSON RPC response to send a task and receive updates.""" - -GetTaskRequest = JSONRPCRequest[Literal['tasks/get'], TaskQueryParams] -"""A JSON RPC request to get a task.""" - -GetTaskResponse = JSONRPCResponse[Task, TaskNotFoundError] -"""A JSON RPC response to get a task.""" - -CancelTaskRequest = JSONRPCRequest[Literal['tasks/cancel'], TaskIdParams] -"""A JSON RPC request to cancel a task.""" - -CancelTaskResponse = JSONRPCResponse[Task, Union[TaskNotCancelableError, TaskNotFoundError]] -"""A JSON RPC response to cancel a task.""" - -SetTaskPushNotificationRequest = JSONRPCRequest[Literal['tasks/pushNotification/set'], TaskPushNotificationConfig] -"""A JSON RPC request to set a task push notification.""" - -SetTaskPushNotificationResponse = JSONRPCResponse[TaskPushNotificationConfig, PushNotificationNotSupportedError] -"""A JSON RPC response to set a task push notification.""" - -GetTaskPushNotificationRequest = JSONRPCRequest[Literal['tasks/pushNotification/get'], TaskIdParams] -"""A JSON RPC request to get a task push notification.""" - -GetTaskPushNotificationResponse = JSONRPCResponse[TaskPushNotificationConfig, PushNotificationNotSupportedError] -"""A JSON RPC response to get a task push notification.""" - -ResubscribeTaskRequest = JSONRPCRequest[Literal['tasks/resubscribe'], TaskIdParams] -"""A JSON RPC request to resubscribe to a task.""" - -A2ARequest = Annotated[ - Union[ - SendTaskRequest, - GetTaskRequest, - CancelTaskRequest, - SetTaskPushNotificationRequest, - GetTaskPushNotificationRequest, - ResubscribeTaskRequest, - ], - Discriminator('method'), -] -"""A JSON RPC request to the A2A server.""" - -A2AResponse: TypeAlias = Union[ - SendTaskResponse, - GetTaskResponse, +from a2a.types import ( + AgentCard, + AgentProvider, + AgentSkill as Skill, + Artifact, + CancelTaskRequest, CancelTaskResponse, - SetTaskPushNotificationResponse, - GetTaskPushNotificationResponse, + DataPart, + FilePart, + GetTaskRequest, + GetTaskResponse, + Message, + MessageSendParams, + Part, + PushNotificationConfig, + Role, + SendMessageRequest, + SendMessageResponse, + Task, + TaskState, + TextPart, +) + +__all__ = [ + # Core Models + "AgentCard", + "AgentProvider", + "Artifact", + "Message", + "Part", + "Skill", + "Task", + # Enums + "Role", + "TaskState", + # Part Types + "TextPart", + "FilePart", + "DataPart", + # Request/Response models + "SendMessageRequest", + "SendMessageResponse", + "GetTaskRequest", + "GetTaskResponse", + "CancelTaskRequest", + "CancelTaskResponse", + # Parameter Models + "MessageSendParams", + "PushNotificationConfig", ] -"""A JSON RPC response from the A2A server.""" - - -a2a_request_ta: TypeAdapter[A2ARequest] = TypeAdapter(A2ARequest) -a2a_response_ta: TypeAdapter[A2AResponse] = TypeAdapter(A2AResponse) diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index c06bc1cb7..4b78cc842 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -1,91 +1,12 @@ -"""This module defines the Storage class, which is responsible for storing and retrieving tasks.""" +""" +This module provides an alias for the TaskStore abstraction from the `google-a2a` SDK. +A TaskStore is responsible for persisting and retrieving task state. +""" -from __future__ import annotations as _annotations +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.server.tasks.task_store import TaskStore -from abc import ABC, abstractmethod -from datetime import datetime +Storage = TaskStore +"""Alias for `a2a.server.tasks.task_store.TaskStore`.""" -from .schema import Artifact, Message, Task, TaskState, TaskStatus - - -class Storage(ABC): - """A storage to retrieve and save tasks. - - The storage is used to update the status of a task and to save the result of a task. - """ - - @abstractmethod - async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: - """Load a task from storage. - - If the task is not found, return None. - """ - - @abstractmethod - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: - """Submit a task to storage.""" - - @abstractmethod - async def update_task( - self, - task_id: str, - state: TaskState, - message: Message | None = None, - artifacts: list[Artifact] | None = None, - ) -> Task: - """Update the state of a task.""" - - -class InMemoryStorage(Storage): - """A storage to retrieve and save tasks in memory.""" - - def __init__(self): - self.tasks: dict[str, Task] = {} - - async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: - """Load a task from memory. - - Args: - task_id: The id of the task to load. - history_length: The number of messages to return in the history. - - Returns: - The task. - """ - if task_id not in self.tasks: - return None - - task = self.tasks[task_id] - if history_length and 'history' in task: - task['history'] = task['history'][-history_length:] - return task - - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: - """Submit a task to storage.""" - if task_id in self.tasks: - raise ValueError(f'Task {task_id} already exists') - - task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) - task = Task(id=task_id, session_id=session_id, status=task_status, history=[message]) - self.tasks[task_id] = task - return task - - async def update_task( - self, - task_id: str, - state: TaskState, - message: Message | None = None, - artifacts: list[Artifact] | None = None, - ) -> Task: - """Save the task as "working".""" - task = self.tasks[task_id] - task['status'] = TaskStatus(state=state, timestamp=datetime.now().isoformat()) - if message: - if 'history' not in task: - task['history'] = [] - task['history'].append(message) - if artifacts: - if 'artifacts' not in task: - task['artifacts'] = [] - task['artifacts'].extend(artifacts) - return task +__all__ = ["Storage", "TaskStore", "InMemoryTaskStore"] diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py deleted file mode 100644 index 0baaeba04..000000000 --- a/fasta2a/fasta2a/task_manager.py +++ /dev/null @@ -1,169 +0,0 @@ -"""This module defines the TaskManager class, which is responsible for managing tasks. - -In our structure, we have the following components: - -- TaskManager: A class that manages tasks. -- Scheduler: A class that schedules tasks to be sent to the worker. -- Worker: A class that executes tasks. -- Runner: A class that defines how tasks run and how history is structured. -- Storage: A class that stores tasks and artifacts. - -Architecture: -``` - +-----------------+ - | HTTP Server | - +-------+---------+ - ^ - | Sends Requests/ - | Receives Results - v - +-------+---------+ - | | - | TaskManager |<-----------------+ - | (coordinates) | | - | | | - +-------+---------+ | - | | - | Schedules Tasks | - v v - +------------------+ +----------------+ - | | | | - | Broker | . | Storage | - | (queues) . | | (persistence) | - | | | | - +------------------+ +----------------+ - ^ ^ - | | - | Delegates Execution | - v | - +------------------+ | - | | | - | Worker | | - | (implementation) |-----------------+ - | | - +------------------+ -``` - -The flow: -1. The HTTP server sends a task to TaskManager -2. TaskManager stores initial task state in Storage -3. TaskManager passes task to Scheduler -4. Scheduler determines when to send tasks to Worker -5. Worker delegates to Runner for task execution -6. Runner defines how tasks run and how history is structured -7. Worker processes task results from Runner -8. Worker reads from and writes to Storage directly -9. Worker updates task status in Storage as execution progresses -10. TaskManager can also read/write from Storage for task management -11. Client queries TaskManager for results, which reads from Storage -""" - -from __future__ import annotations as _annotations - -import uuid -from contextlib import AsyncExitStack -from dataclasses import dataclass, field -from typing import Any - -from .broker import Broker -from .schema import ( - CancelTaskRequest, - CancelTaskResponse, - GetTaskPushNotificationRequest, - GetTaskPushNotificationResponse, - GetTaskRequest, - GetTaskResponse, - ResubscribeTaskRequest, - SendTaskRequest, - SendTaskResponse, - SendTaskStreamingRequest, - SendTaskStreamingResponse, - SetTaskPushNotificationRequest, - SetTaskPushNotificationResponse, - TaskNotFoundError, -) -from .storage import Storage - - -@dataclass -class TaskManager: - """A task manager responsible for managing tasks.""" - - broker: Broker - storage: Storage - - _aexit_stack: AsyncExitStack | None = field(default=None, init=False) - - async def __aenter__(self): - self._aexit_stack = AsyncExitStack() - await self._aexit_stack.__aenter__() - await self._aexit_stack.enter_async_context(self.broker) - - return self - - @property - def is_running(self) -> bool: - return self._aexit_stack is not None - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): - if self._aexit_stack is None: - raise RuntimeError('TaskManager was not properly initialized.') - await self._aexit_stack.__aexit__(exc_type, exc_value, traceback) - self._aexit_stack = None - - async def send_task(self, request: SendTaskRequest) -> SendTaskResponse: - """Send a task to the worker.""" - request_id = str(uuid.uuid4()) - task_id = request['params']['id'] - task = await self.storage.load_task(task_id) - - if task is None: - session_id = request['params'].get('session_id', str(uuid.uuid4())) - message = request['params']['message'] - task = await self.storage.submit_task(task_id, session_id, message) - - await self.broker.run_task(request['params']) - return SendTaskResponse(jsonrpc='2.0', id=request_id, result=task) - - async def get_task(self, request: GetTaskRequest) -> GetTaskResponse: - """Get a task, and return it to the client. - - No further actions are needed here. - """ - task_id = request['params']['id'] - history_length = request['params'].get('history_length') - task = await self.storage.load_task(task_id, history_length) - if task is None: - return GetTaskResponse( - jsonrpc='2.0', - id=request['id'], - error=TaskNotFoundError(code=-32001, message='Task not found'), - ) - return GetTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - - async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: - await self.broker.cancel_task(request['params']) - task = await self.storage.load_task(request['params']['id']) - if task is None: - return CancelTaskResponse( - jsonrpc='2.0', - id=request['id'], - error=TaskNotFoundError(code=-32001, message='Task not found'), - ) - return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - - async def send_task_streaming(self, request: SendTaskStreamingRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('SendTaskStreaming is not implemented yet.') - - async def set_task_push_notification( - self, request: SetTaskPushNotificationRequest - ) -> SetTaskPushNotificationResponse: - raise NotImplementedError('SetTaskPushNotification is not implemented yet.') - - async def get_task_push_notification( - self, request: GetTaskPushNotificationRequest - ) -> GetTaskPushNotificationResponse: - raise NotImplementedError('GetTaskPushNotification is not implemented yet.') - - async def resubscribe_task(self, request: ResubscribeTaskRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('Resubscribe is not implemented yet.') diff --git a/fasta2a/fasta2a/worker.py b/fasta2a/fasta2a/worker.py index 9bbde6b25..e8ba58633 100644 --- a/fasta2a/fasta2a/worker.py +++ b/fasta2a/fasta2a/worker.py @@ -1,68 +1,38 @@ from __future__ import annotations as _annotations -from abc import ABC, abstractmethod -from collections.abc import AsyncIterator +from abc import ABC from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any -import anyio -from opentelemetry.trace import get_tracer, use_span -from typing_extensions import assert_never +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.server.tasks.task_updater import TaskUpdater -if TYPE_CHECKING: - from .broker import Broker, TaskOperation - from .schema import Artifact, Message, TaskIdParams, TaskSendParams - from .storage import Storage -tracer = get_tracer(__name__) +class Worker(AgentExecutor, ABC): + """ + An abstract class for implementing the core logic of an A2A agent. - -@dataclass -class Worker(ABC): - """A worker is responsible for executing tasks.""" - - broker: Broker - storage: Storage + This class inherits from the `a2a.server.agent_execution.AgentExecutor` + and must be subclassed to define the agent's behavior. + """ @asynccontextmanager - async def run(self) -> AsyncIterator[None]: - """Run the worker. - - It connects to the broker, and it makes itself available to receive commands. + async def task_updater( + self, context: RequestContext, event_queue: EventQueue + ) -> TaskUpdater: """ - async with anyio.create_task_group() as tg: - tg.start_soon(self._loop) - yield - tg.cancel_scope.cancel() - - async def _loop(self) -> None: - async for task_operation in self.broker.receive_task_operations(): - await self._handle_task_operation(task_operation) + A convenience context manager to get a `TaskUpdater` for the current task. - async def _handle_task_operation(self, task_operation: TaskOperation) -> None: - try: - with use_span(task_operation['_current_span']): - with tracer.start_as_current_span( - f'{task_operation["operation"]} task', attributes={'logfire.tags': ['fasta2a']} - ): - if task_operation['operation'] == 'run': - await self.run_task(task_operation['params']) - elif task_operation['operation'] == 'cancel': - await self.cancel_task(task_operation['params']) - else: - assert_never(task_operation) - except Exception: - await self.storage.update_task(task_operation['params']['id'], state='failed') + Args: + context: The `RequestContext` for the current execution. + event_queue: The `EventQueue` to publish updates to. - @abstractmethod - async def run_task(self, params: TaskSendParams) -> None: ... - - @abstractmethod - async def cancel_task(self, params: TaskIdParams) -> None: ... - - @abstractmethod - def build_message_history(self, task_history: list[Message]) -> list[Any]: ... - - @abstractmethod - def build_artifacts(self, result: Any) -> list[Artifact]: ... + Yields: + A `TaskUpdater` instance for the current task. + """ + if not context.task_id or not context.context_id: + raise ValueError( + "RequestContext must have a task_id and context_id to create a TaskUpdater." + ) + yield TaskUpdater(event_queue, context.task_id, context.context_id) diff --git a/fasta2a/pyproject.toml b/fasta2a/pyproject.toml index 2abe809aa..651c341b5 100644 --- a/fasta2a/pyproject.toml +++ b/fasta2a/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "starlette>0.29.0", "pydantic>=2.10", "opentelemetry-api>=1.28.0", + "a2a-sdk>=0.2.8", ] [project.optional-dependencies] From f19eea216e2b33c663c121b464a1f9c114417931 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Fri, 13 Jun 2025 13:44:27 -0400 Subject: [PATCH 3/6] Update Gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index acb015449..092411ded 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite node_modules/ **.idea/ .coverage* +.mypy_cache/ From 6b6961b02f533d095971cf9fe2d99c753cbe75d0 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Fri, 13 Jun 2025 13:50:32 -0400 Subject: [PATCH 4/6] refactor: Update FastA2A to use Google A2A SDK --- fasta2a/fasta2a/__init__.py | 16 +- fasta2a/fasta2a/applications.py | 179 ++++------ fasta2a/fasta2a/broker.py | 98 ----- fasta2a/fasta2a/client.py | 95 ++--- fasta2a/fasta2a/schema.py | 614 +++----------------------------- fasta2a/fasta2a/storage.py | 97 +---- fasta2a/fasta2a/task_manager.py | 169 --------- fasta2a/fasta2a/worker.py | 82 ++--- fasta2a/pyproject.toml | 1 + 9 files changed, 225 insertions(+), 1126 deletions(-) delete mode 100644 fasta2a/fasta2a/broker.py delete mode 100644 fasta2a/fasta2a/task_manager.py diff --git a/fasta2a/fasta2a/__init__.py b/fasta2a/fasta2a/__init__.py index 4a8b10629..fdc8247f4 100644 --- a/fasta2a/fasta2a/__init__.py +++ b/fasta2a/fasta2a/__init__.py @@ -1,7 +1,17 @@ from .applications import FastA2A -from .broker import Broker -from .schema import Skill +from .client import A2AClient +from .schema import Message, Part, Role, Skill, TextPart from .storage import Storage from .worker import Worker -__all__ = ['FastA2A', 'Skill', 'Storage', 'Broker', 'Worker'] +__all__ = [ + "FastA2A", + "A2AClient", + "Worker", + "Storage", + "Skill", + "Message", + "Part", + "Role", + "TextPart", +] diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 61301262b..489d0237e 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -1,57 +1,100 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Sequence -from contextlib import asynccontextmanager from typing import Any -from starlette.applications import Starlette +import httpx +from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler +from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentProvider from starlette.middleware import Middleware -from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Route from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send -from .broker import Broker -from .schema import ( - AgentCard, - Authentication, - Capabilities, - Provider, - Skill, - a2a_request_ta, - a2a_response_ta, - agent_card_ta, -) +from .schema import Skill from .storage import Storage -from .task_manager import TaskManager +from .worker import Worker -class FastA2A(Starlette): - """The main class for the FastA2A library.""" +class FastA2A: + """ + The main class for the FastA2A library. It provides a simple way to create + an A2A server by wrapping the Google A2A SDK. + """ def __init__( self, *, - storage: Storage, - broker: Broker, + worker: Worker, + storage: Storage | None = None, # Agent card - name: str | None = None, - url: str = 'http://localhost:8000', - version: str = '1.0.0', + name: str = "Agent", + url: str = "http://localhost:8000", + version: str = "1.0.0", description: str | None = None, - provider: Provider | None = None, + provider: AgentProvider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, - routes: Sequence[Route] | None = None, - middleware: Sequence[Middleware] | None = None, + routes: list[Route] | None = None, + middleware: list[Middleware] | None = None, exception_handlers: dict[Any, ExceptionHandler] | None = None, - lifespan: Lifespan[FastA2A] | None = None, + lifespan: Lifespan | None = None, ): - if lifespan is None: - lifespan = _default_lifespan + """ + Initializes the FastA2A application. + + Args: + worker: An implementation of `fasta2a.Worker` (which is an `a2a.server.agent_execution.AgentExecutor`). + storage: An implementation of `fasta2a.Storage` (which is an `a2a.server.tasks.TaskStore`). + Defaults to `InMemoryTaskStore`. + name: The human-readable name of the agent. + url: The URL where the agent is hosted. + version: The version of the agent. + description: A human-readable description of the agent. + provider: The service provider of the agent. + skills: A list of skills the agent can perform. + debug: Starlette's debug flag. + routes: A list of additional Starlette routes. + middleware: A list of Starlette middleware. + exception_handlers: A dictionary of Starlette exception handlers. + lifespan: A Starlette lifespan context manager. + """ + self.agent_card = AgentCard( + name=name, + url=url, + version=version, + description=description or "A FastA2A Agent", + provider=provider, + skills=skills or [], + capabilities=AgentCapabilities( + streaming=True, pushNotifications=True, stateTransitionHistory=True + ), + defaultInputModes=["application/json"], + defaultOutputModes=["application/json"], + securitySchemes={}, + ) - super().__init__( + self.storage = storage or InMemoryTaskStore() + self.worker = worker + + # The SDK's DefaultRequestHandler uses httpx to send push notifications + http_client = httpx.AsyncClient() + push_notifier = InMemoryPushNotifier(httpx_client) + + request_handler = DefaultRequestHandler( + agent_executor=self.worker, + task_store=self.storage, + push_notifier=push_notifier, + ) + + a2a_app = A2AStarletteApplication( + agent_card=self.agent_card, + http_handler=request_handler, + ) + + self.app = a2a_app.build( debug=debug, routes=routes, middleware=middleware, @@ -59,77 +102,5 @@ def __init__( lifespan=lifespan, ) - self.name = name or 'Agent' - self.url = url - self.version = version - self.description = description - self.provider = provider - self.skills = skills or [] - # NOTE: For now, I don't think there's any reason to support any other input/output modes. - self.default_input_modes = ['application/json'] - self.default_output_modes = ['application/json'] - - self.task_manager = TaskManager(broker=broker, storage=storage) - - # Setup - self._agent_card_json_schema: bytes | None = None - self.router.add_route('/.well-known/agent.json', self._agent_card_endpoint, methods=['HEAD', 'GET', 'OPTIONS']) - self.router.add_route('/', self._agent_run_endpoint, methods=['POST']) - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope['type'] == 'http' and not self.task_manager.is_running: - raise RuntimeError('TaskManager was not properly initialized.') - await super().__call__(scope, receive, send) - - async def _agent_card_endpoint(self, request: Request) -> Response: - if self._agent_card_json_schema is None: - agent_card = AgentCard( - name=self.name, - url=self.url, - version=self.version, - skills=self.skills, - default_input_modes=self.default_input_modes, - default_output_modes=self.default_output_modes, - capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False), - authentication=Authentication(schemes=[]), - ) - if self.description is not None: - agent_card['description'] = self.description - if self.provider is not None: - agent_card['provider'] = self.provider - self._agent_card_json_schema = agent_card_ta.dump_json(agent_card, by_alias=True) - return Response(content=self._agent_card_json_schema, media_type='application/json') - - async def _agent_run_endpoint(self, request: Request) -> Response: - """This is the main endpoint for the A2A server. - - Although the specification allows freedom of choice and implementation, I'm pretty sure about some decisions. - - 1. The server will always either send a "submitted" or a "failed" on `tasks/send`. - Never a "completed" on the first message. - 2. There are three possible ends for the task: - 2.1. The task was "completed" successfully. - 2.2. The task was "canceled". - 2.3. The task "failed". - 3. The server will send a "working" on the first chunk on `tasks/pushNotification/get`. - """ - data = await request.body() - a2a_request = a2a_request_ta.validate_json(data) - - if a2a_request['method'] == 'tasks/send': - jsonrpc_response = await self.task_manager.send_task(a2a_request) - elif a2a_request['method'] == 'tasks/get': - jsonrpc_response = await self.task_manager.get_task(a2a_request) - elif a2a_request['method'] == 'tasks/cancel': - jsonrpc_response = await self.task_manager.cancel_task(a2a_request) - else: - raise NotImplementedError(f'Method {a2a_request["method"]} not implemented.') - return Response( - content=a2a_response_ta.dump_json(jsonrpc_response, by_alias=True), media_type='application/json' - ) - - -@asynccontextmanager -async def _default_lifespan(app: FastA2A) -> AsyncIterator[None]: - async with app.task_manager: - yield + await self.app(scope, receive, send) diff --git a/fasta2a/fasta2a/broker.py b/fasta2a/fasta2a/broker.py deleted file mode 100644 index c84b73872..000000000 --- a/fasta2a/fasta2a/broker.py +++ /dev/null @@ -1,98 +0,0 @@ -from __future__ import annotations as _annotations - -from abc import ABC, abstractmethod -from collections.abc import AsyncIterator -from contextlib import AsyncExitStack -from dataclasses import dataclass -from typing import Annotated, Any, Generic, Literal, TypeVar - -import anyio -from opentelemetry.trace import Span, get_current_span, get_tracer -from pydantic import Discriminator -from typing_extensions import Self, TypedDict - -from .schema import TaskIdParams, TaskSendParams - -tracer = get_tracer(__name__) - - -@dataclass -class Broker(ABC): - """The broker class is in charge of scheduling the tasks. - - The HTTP server uses the broker to schedule tasks. - - The simple implementation is the `InMemoryBroker`, which is the broker that - runs the tasks in the same process as the HTTP server. That said, this class can be - extended to support remote workers. - """ - - @abstractmethod - async def run_task(self, params: TaskSendParams) -> None: - """Send a task to be executed by the worker.""" - raise NotImplementedError('send_run_task is not implemented yet.') - - @abstractmethod - async def cancel_task(self, params: TaskIdParams) -> None: - """Cancel a task.""" - raise NotImplementedError('send_cancel_task is not implemented yet.') - - @abstractmethod - async def __aenter__(self) -> Self: ... - - @abstractmethod - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): ... - - @abstractmethod - def receive_task_operations(self) -> AsyncIterator[TaskOperation]: - """Receive task operations from the broker. - - On a multi-worker setup, the broker will need to round-robin the task operations - between the workers. - """ - - -OperationT = TypeVar('OperationT') -ParamsT = TypeVar('ParamsT') - - -class _TaskOperation(TypedDict, Generic[OperationT, ParamsT]): - """A task operation.""" - - operation: OperationT - params: ParamsT - _current_span: Span - - -_RunTask = _TaskOperation[Literal['run'], TaskSendParams] -_CancelTask = _TaskOperation[Literal['cancel'], TaskIdParams] - -TaskOperation = Annotated['_RunTask | _CancelTask', Discriminator('operation')] - - -class InMemoryBroker(Broker): - """A broker that schedules tasks in memory.""" - - async def __aenter__(self): - self.aexit_stack = AsyncExitStack() - await self.aexit_stack.__aenter__() - - self._write_stream, self._read_stream = anyio.create_memory_object_stream[TaskOperation]() - await self.aexit_stack.enter_async_context(self._read_stream) - await self.aexit_stack.enter_async_context(self._write_stream) - - return self - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): - await self.aexit_stack.__aexit__(exc_type, exc_value, traceback) - - async def run_task(self, params: TaskSendParams) -> None: - await self._write_stream.send(_RunTask(operation='run', params=params, _current_span=get_current_span())) - - async def cancel_task(self, params: TaskIdParams) -> None: - await self._write_stream.send(_CancelTask(operation='cancel', params=params, _current_span=get_current_span())) - - async def receive_task_operations(self) -> AsyncIterator[TaskOperation]: - """Receive task operations from the broker.""" - async for task_operation in self._read_stream: - yield task_operation diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py index 5c5aabd81..13c787e60 100644 --- a/fasta2a/fasta2a/client.py +++ b/fasta2a/fasta2a/client.py @@ -3,39 +3,44 @@ import uuid from typing import Any -import pydantic - -from .schema import ( +from a2a.client.client import A2AClient as BaseA2AClient +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.types import ( GetTaskRequest, GetTaskResponse, Message, + MessageSendParams, PushNotificationConfig, - SendTaskRequest, - SendTaskResponse, - TaskSendParams, - a2a_request_ta, + SendMessageRequest, + SendMessageResponse, ) -send_task_response_ta = pydantic.TypeAdapter(SendTaskResponse) -get_task_response_ta = pydantic.TypeAdapter(GetTaskResponse) - try: import httpx except ImportError as _import_error: raise ImportError( - 'httpx is required to use the A2AClient. Please install it with `pip install httpx`.', + "httpx is required to use the A2AClient. Please install it with `pip install httpx`.", ) from _import_error +UnexpectedResponseError = A2AClientHTTPError + class A2AClient: - """A client for the A2A protocol.""" + """A client for the A2A protocol, built on the Google A2A SDK.""" - def __init__(self, base_url: str = 'http://localhost:8000', http_client: httpx.AsyncClient | None = None) -> None: + def __init__( + self, + base_url: str = "http://localhost:8000", + http_client: httpx.AsyncClient | None = None, + ) -> None: if http_client is None: - self.http_client = httpx.AsyncClient(base_url=base_url) + self.http_client = httpx.AsyncClient() else: self.http_client = http_client - self.http_client.base_url = base_url + + # The SDK client requires a URL on the agent card, but we can initialize it with a dummy + # and then set the URL directly for the internal calls. + self._sdk_client = BaseA2AClient(httpx_client=self.http_client, url=base_url) async def send_task( self, @@ -43,36 +48,36 @@ async def send_task( history_length: int | None = None, push_notification: PushNotificationConfig | None = None, metadata: dict[str, Any] | None = None, - ) -> SendTaskResponse: - task = TaskSendParams(message=message, id=str(uuid.uuid4())) - if history_length is not None: - task['history_length'] = history_length - if push_notification is not None: - task['push_notification'] = push_notification - if metadata is not None: - task['metadata'] = metadata + ) -> SendMessageResponse: + """Sends a task to the A2A server.""" + if not message.taskId: + message.taskId = str(uuid.uuid4()) - payload = SendTaskRequest(jsonrpc='2.0', id=None, method='tasks/send', params=task) - content = a2a_request_ta.dump_json(payload, by_alias=True) - response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) - self._raise_for_status(response) - return send_task_response_ta.validate_json(response.content) - - async def get_task(self, task_id: str) -> GetTaskResponse: - payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) - content = a2a_request_ta.dump_json(payload, by_alias=True) - response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) - self._raise_for_status(response) - return get_task_response_ta.validate_json(response.content) - - def _raise_for_status(self, response: httpx.Response) -> None: - if response.status_code >= 400: - raise UnexpectedResponseError(response.status_code, response.text) + params = MessageSendParams( + message=message, + configuration={ + "historyLength": history_length, + "pushNotificationConfig": push_notification, + }, + metadata=metadata, + ) + payload = SendMessageRequest( + id=str(uuid.uuid4()), method="message/send", params=params + ) + try: + response = await self._sdk_client.send_message(payload) + return response + except (A2AClientHTTPError, A2AClientJSONError) as e: + raise UnexpectedResponseError(getattr(e, "status_code", 500), str(e)) from e -class UnexpectedResponseError(Exception): - """An error raised when an unexpected response is received from the server.""" - - def __init__(self, status_code: int, content: str) -> None: - self.status_code = status_code - self.content = content + async def get_task(self, task_id: str) -> GetTaskResponse: + """Retrieves a task from the A2A server.""" + payload = GetTaskRequest( + id=str(uuid.uuid4()), method="tasks/get", params={"id": task_id} + ) + try: + response = await self._sdk_client.get_task(payload) + return response + except (A2AClientHTTPError, A2AClientJSONError) as e: + raise UnexpectedResponseError(getattr(e, "status_code", 500), str(e)) from e diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 57193a97d..b5c7e331f 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -1,567 +1,55 @@ -"""This module contains the schema for the agent card.""" - -from __future__ import annotations as _annotations - -from typing import Annotated, Any, Generic, Literal, TypeVar, Union - -import pydantic -from pydantic import Discriminator, TypeAdapter -from pydantic.alias_generators import to_camel -from typing_extensions import NotRequired, TypeAlias, TypedDict - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class AgentCard(TypedDict): - """The card that describes an agent.""" - - name: str - """Human readable name of the agent e.g. "Recipe Agent".""" - - description: NotRequired[str] - """A human-readable description of the agent. - - Used to assist users and other agents in understanding what the agent can do. - (e.g. "Agent that helps users with recipes and cooking.") - """ - - # TODO(Marcelo): The spec makes url required. - url: NotRequired[str] - """A URL to the address the agent is hosted at.""" - - provider: NotRequired[Provider] - """The service provider of the agent.""" - - # TODO(Marcelo): The spec makes version required. - version: NotRequired[str] - """The version of the agent - format is up to the provider. (e.g. "1.0.0")""" - - documentation_url: NotRequired[str] - """A URL to documentation for the agent.""" - - capabilities: Capabilities - """The capabilities of the agent.""" - - authentication: Authentication - """The authentication schemes supported by the agent. - - Intended to match OpenAPI authentication structure. - """ - - default_input_modes: list[str] - """Supported mime types for input data.""" - - default_output_modes: list[str] - """Supported mime types for output data.""" - - skills: list[Skill] - - -agent_card_ta = pydantic.TypeAdapter(AgentCard) - - -class Provider(TypedDict): - """The service provider of the agent.""" - - organization: str - url: str - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Capabilities(TypedDict): - """The capabilities of the agent.""" - - streaming: NotRequired[bool] - """Whether the agent supports streaming.""" - - push_notifications: NotRequired[bool] - """Whether the agent can notify updates to client.""" - - state_transition_history: NotRequired[bool] - """Whether the agent exposes status change history for tasks.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Authentication(TypedDict): - """The authentication schemes supported by the agent.""" - - schemes: list[str] - """The authentication schemes supported by the agent. (e.g. "Basic", "Bearer")""" - - credentials: NotRequired[str] - """The credentials a client should use for private cards.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Skill(TypedDict): - """Skills are a unit of capability that an agent can perform.""" - - id: str - """A unique identifier for the skill.""" - - name: str - """Human readable name of the skill.""" - - description: str - """A human-readable description of the skill. - - It will be used by the client or a human as a hint to understand the skill. - """ - - tags: list[str] - """Set of tag-words describing classes of capabilities for this specific skill. - - Examples: "cooking", "customer support", "billing". - """ - - examples: NotRequired[list[str]] - """The set of example scenarios that the skill can perform. - - Will be used by the client as a hint to understand how the skill can be used. (e.g. "I need a recipe for bread") - """ - - input_modes: list[str] - """Supported mime types for input data.""" - - output_modes: list[str] - """Supported mime types for output data.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Artifact(TypedDict): - """Agents generate Artifacts as an end result of a Task. - - Artifacts are immutable, can be named, and can have multiple parts. A streaming response can append parts to - existing Artifacts. - - A single Task can generate many Artifacts. For example, "create a webpage" could create separate HTML and image - Artifacts. - """ - - name: NotRequired[str] - """The name of the artifact.""" - - description: NotRequired[str] - """A description of the artifact.""" - - parts: list[Part] - """The parts that make up the artifact.""" - - metadata: NotRequired[dict[str, Any]] - """Metadata about the artifact.""" - - index: int - """The index of the artifact.""" - - append: NotRequired[bool] - """Whether to append this artifact to an existing one.""" - - last_chunk: NotRequired[bool] - """Whether this is the last chunk of the artifact.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class PushNotificationConfig(TypedDict): - """Configuration for push notifications. - - A2A supports a secure notification mechanism whereby an agent can notify a client of an update - outside of a connected session via a PushNotificationService. Within and across enterprises, - it is critical that the agent verifies the identity of the notification service, authenticates - itself with the service, and presents an identifier that ties the notification to the executing - Task. - - The target server of the PushNotificationService should be considered a separate service, and - is not guaranteed (or even expected) to be the client directly. This PushNotificationService is - responsible for authenticating and authorizing the agent and for proxying the verified notification - to the appropriate endpoint (which could be anything from a pub/sub queue, to an email inbox or - other service, etc). - - For contrived scenarios with isolated client-agent pairs (e.g. local service mesh in a contained - VPC, etc.) or isolated environments without enterprise security concerns, the client may choose to - simply open a port and act as its own PushNotificationService. Any enterprise implementation will - likely have a centralized service that authenticates the remote agents with trusted notification - credentials and can handle online/offline scenarios. (This should be thought of similarly to a - mobile Push Notification Service). - """ - - url: str - """The URL to send push notifications to.""" - - token: NotRequired[str] - """Token unique to this task/session.""" - - authentication: NotRequired[Authentication] - """Authentication details for push notifications.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskPushNotificationConfig(TypedDict): - """Configuration for task push notifications.""" - - id: str - """The task id.""" - - push_notification_config: PushNotificationConfig - """The push notification configuration.""" - - -class Message(TypedDict): - """A Message contains any content that is not an Artifact. - - This can include things like agent thoughts, user context, instructions, errors, status, or metadata. - - All content from a client comes in the form of a Message. Agents send Messages to communicate status or to provide - instructions (whereas generated results are sent as Artifacts). - - A Message can have multiple parts to denote different pieces of content. For example, a user request could include - a textual description from a user and then multiple files used as context from the client. - """ - - role: Literal['user', 'agent'] - """The role of the message.""" - - parts: list[Part] - """The parts of the message.""" - - metadata: NotRequired[dict[str, Any]] - """Metadata about the message.""" - - -class _BasePart(TypedDict): - """A base class for all parts.""" - - metadata: NotRequired[dict[str, Any]] - - -class TextPart(_BasePart): - """A part that contains text.""" - - type: Literal['text'] - """The type of the part.""" - - text: str - """The text of the part.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class FilePart(_BasePart): - """A part that contains a file.""" - - type: Literal['file'] - """The type of the part.""" - - file: File - """The file of the part.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class _BaseFile(_BasePart): - """A base class for all file types.""" - - name: NotRequired[str] - """The name of the file.""" - - mime_type: str - """The mime type of the file.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class _BinaryFile(_BaseFile): - """A binary file.""" - - data: str - """The base64 encoded bytes of the file.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class _URLFile(_BaseFile): - """A file that is hosted on a remote URL.""" - - url: str - """The URL of the file.""" - - -File: TypeAlias = Union[_BinaryFile, _URLFile] -"""A file is a binary file or a URL file.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class DataPart(_BasePart): - """A part that contains data.""" - - type: Literal['data'] - """The type of the part.""" - - data: dict[str, Any] - """The data of the part.""" - - -Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='type')] -"""A fully formed piece of content exchanged between a client and a remote agent as part of a Message or an Artifact. - -Each Part has its own content type and metadata. +""" +This module re-exports the core A2A schema types from the `google-a2a` SDK. +The types are based on Pydantic `BaseModel` and align with the official A2A JSON specification. """ -TaskState: TypeAlias = Literal['submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'unknown'] -"""The possible states of a task.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskStatus(TypedDict): - """Status and accompanying message for a task.""" - - state: TaskState - """The current state of the task.""" - - message: NotRequired[Message] - """Additional status updates for client.""" - - timestamp: NotRequired[str] - """ISO datetime value of when the status was updated.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Task(TypedDict): - """A Task is a stateful entity that allows Clients and Remote Agents to achieve a specific outcome. - - Clients and Remote Agents exchange Messages within a Task. Remote Agents generate results as Artifacts. - A Task is always created by a Client and the status is always determined by the Remote Agent. - """ - - id: str - """Unique identifier for the task.""" - - session_id: NotRequired[str] - """Client-generated id for the session holding the task.""" - - status: TaskStatus - """Current status of the task.""" - - history: NotRequired[list[Message]] - """Optional history of messages.""" - - artifacts: NotRequired[list[Artifact]] - """Collection of artifacts created by the agent.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskStatusUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" - - id: str - """The id of the task.""" - - status: TaskStatus - """The status of the task.""" - - final: bool - """Indicates the end of the event stream.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskArtifactUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" - - id: str - """The id of the task.""" - - artifact: Artifact - """The artifact that was updated.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskIdParams(TypedDict): - """Parameters for a task id.""" - - id: str - metadata: NotRequired[dict[str, Any]] - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskQueryParams(TaskIdParams): - """Query parameters for a task.""" - - history_length: NotRequired[int] - """Number of recent messages to be retrieved.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskSendParams(TypedDict): - """Sent by the client to the agent to create, continue, or restart a task.""" - - id: str - """The id of the task.""" - - session_id: NotRequired[str] - """The server creates a new sessionId for new tasks if not set.""" - - message: Message - """The message to send to the agent.""" - - history_length: NotRequired[int] - """Number of recent messages to be retrieved.""" - - push_notification: NotRequired[PushNotificationConfig] - """Where the server should send notifications when disconnected.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -class JSONRPCMessage(TypedDict): - """A JSON RPC message.""" - - jsonrpc: Literal['2.0'] - """The JSON RPC version.""" - - id: int | str | None - """The request id.""" - - -Method = TypeVar('Method') -Params = TypeVar('Params') - - -class JSONRPCRequest(JSONRPCMessage, Generic[Method, Params]): - """A JSON RPC request.""" - - method: Method - """The method to call.""" - - params: Params - """The parameters to pass to the method.""" - - -############################################################################################### -####################################### Error codes ####################################### -############################################################################################### - -CodeT = TypeVar('CodeT', bound=int) -MessageT = TypeVar('MessageT', bound=str) - - -class JSONRPCError(TypedDict, Generic[CodeT, MessageT]): - """A JSON RPC error.""" - - code: CodeT - message: MessageT - data: NotRequired[Any] - - -ResultT = TypeVar('ResultT') -ErrorT = TypeVar('ErrorT', bound=JSONRPCError[Any, Any]) - - -class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): - """A JSON RPC response.""" - - result: NotRequired[ResultT] - error: NotRequired[ErrorT] - - -JSONParseError = JSONRPCError[Literal[-32700], Literal['Invalid JSON payload']] -"""A JSON RPC error for a parse error.""" - -InvalidRequestError = JSONRPCError[Literal[-32600], Literal['Request payload validation error']] -"""A JSON RPC error for an invalid request.""" - -MethodNotFoundError = JSONRPCError[Literal[-32601], Literal['Method not found']] -"""A JSON RPC error for a method not found.""" - -InvalidParamsError = JSONRPCError[Literal[-32602], Literal['Invalid parameters']] -"""A JSON RPC error for invalid parameters.""" - -InternalError = JSONRPCError[Literal[-32603], Literal['Internal error']] -"""A JSON RPC error for an internal error.""" - -TaskNotFoundError = JSONRPCError[Literal[-32001], Literal['Task not found']] -"""A JSON RPC error for a task not found.""" - -TaskNotCancelableError = JSONRPCError[Literal[-32002], Literal['Task not cancelable']] -"""A JSON RPC error for a task not cancelable.""" - -PushNotificationNotSupportedError = JSONRPCError[Literal[-32003], Literal['Push notification not supported']] -"""A JSON RPC error for a push notification not supported.""" - -UnsupportedOperationError = JSONRPCError[Literal[-32004], Literal['This operation is not supported']] -"""A JSON RPC error for an unsupported operation.""" - -ContentTypeNotSupportedError = JSONRPCError[Literal[-32005], Literal['Incompatible content types']] -"""A JSON RPC error for incompatible content types.""" - -############################################################################################### -####################################### Requests and responses ############################ -############################################################################################### - -SendTaskRequest = JSONRPCRequest[Literal['tasks/send'], TaskSendParams] -"""A JSON RPC request to send a task.""" - -SendTaskResponse = JSONRPCResponse[Task, JSONRPCError[Any, Any]] -"""A JSON RPC response to send a task.""" - -SendTaskStreamingRequest = JSONRPCRequest[Literal['tasks/sendSubscribe'], TaskSendParams] -"""A JSON RPC request to send a task and receive updates.""" - -SendTaskStreamingResponse = JSONRPCResponse[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent], InternalError] -"""A JSON RPC response to send a task and receive updates.""" - -GetTaskRequest = JSONRPCRequest[Literal['tasks/get'], TaskQueryParams] -"""A JSON RPC request to get a task.""" - -GetTaskResponse = JSONRPCResponse[Task, TaskNotFoundError] -"""A JSON RPC response to get a task.""" - -CancelTaskRequest = JSONRPCRequest[Literal['tasks/cancel'], TaskIdParams] -"""A JSON RPC request to cancel a task.""" - -CancelTaskResponse = JSONRPCResponse[Task, Union[TaskNotCancelableError, TaskNotFoundError]] -"""A JSON RPC response to cancel a task.""" - -SetTaskPushNotificationRequest = JSONRPCRequest[Literal['tasks/pushNotification/set'], TaskPushNotificationConfig] -"""A JSON RPC request to set a task push notification.""" - -SetTaskPushNotificationResponse = JSONRPCResponse[TaskPushNotificationConfig, PushNotificationNotSupportedError] -"""A JSON RPC response to set a task push notification.""" - -GetTaskPushNotificationRequest = JSONRPCRequest[Literal['tasks/pushNotification/get'], TaskIdParams] -"""A JSON RPC request to get a task push notification.""" - -GetTaskPushNotificationResponse = JSONRPCResponse[TaskPushNotificationConfig, PushNotificationNotSupportedError] -"""A JSON RPC response to get a task push notification.""" - -ResubscribeTaskRequest = JSONRPCRequest[Literal['tasks/resubscribe'], TaskIdParams] -"""A JSON RPC request to resubscribe to a task.""" - -A2ARequest = Annotated[ - Union[ - SendTaskRequest, - GetTaskRequest, - CancelTaskRequest, - SetTaskPushNotificationRequest, - GetTaskPushNotificationRequest, - ResubscribeTaskRequest, - ], - Discriminator('method'), -] -"""A JSON RPC request to the A2A server.""" - -A2AResponse: TypeAlias = Union[ - SendTaskResponse, - GetTaskResponse, +from a2a.types import ( + AgentCard, + AgentProvider, + AgentSkill as Skill, + Artifact, + CancelTaskRequest, CancelTaskResponse, - SetTaskPushNotificationResponse, - GetTaskPushNotificationResponse, + DataPart, + FilePart, + GetTaskRequest, + GetTaskResponse, + Message, + MessageSendParams, + Part, + PushNotificationConfig, + Role, + SendMessageRequest, + SendMessageResponse, + Task, + TaskState, + TextPart, +) + +__all__ = [ + # Core Models + "AgentCard", + "AgentProvider", + "Artifact", + "Message", + "Part", + "Skill", + "Task", + # Enums + "Role", + "TaskState", + # Part Types + "TextPart", + "FilePart", + "DataPart", + # Request/Response models + "SendMessageRequest", + "SendMessageResponse", + "GetTaskRequest", + "GetTaskResponse", + "CancelTaskRequest", + "CancelTaskResponse", + # Parameter Models + "MessageSendParams", + "PushNotificationConfig", ] -"""A JSON RPC response from the A2A server.""" - - -a2a_request_ta: TypeAdapter[A2ARequest] = TypeAdapter(A2ARequest) -a2a_response_ta: TypeAdapter[A2AResponse] = TypeAdapter(A2AResponse) diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index c06bc1cb7..4b78cc842 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -1,91 +1,12 @@ -"""This module defines the Storage class, which is responsible for storing and retrieving tasks.""" +""" +This module provides an alias for the TaskStore abstraction from the `google-a2a` SDK. +A TaskStore is responsible for persisting and retrieving task state. +""" -from __future__ import annotations as _annotations +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.server.tasks.task_store import TaskStore -from abc import ABC, abstractmethod -from datetime import datetime +Storage = TaskStore +"""Alias for `a2a.server.tasks.task_store.TaskStore`.""" -from .schema import Artifact, Message, Task, TaskState, TaskStatus - - -class Storage(ABC): - """A storage to retrieve and save tasks. - - The storage is used to update the status of a task and to save the result of a task. - """ - - @abstractmethod - async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: - """Load a task from storage. - - If the task is not found, return None. - """ - - @abstractmethod - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: - """Submit a task to storage.""" - - @abstractmethod - async def update_task( - self, - task_id: str, - state: TaskState, - message: Message | None = None, - artifacts: list[Artifact] | None = None, - ) -> Task: - """Update the state of a task.""" - - -class InMemoryStorage(Storage): - """A storage to retrieve and save tasks in memory.""" - - def __init__(self): - self.tasks: dict[str, Task] = {} - - async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: - """Load a task from memory. - - Args: - task_id: The id of the task to load. - history_length: The number of messages to return in the history. - - Returns: - The task. - """ - if task_id not in self.tasks: - return None - - task = self.tasks[task_id] - if history_length and 'history' in task: - task['history'] = task['history'][-history_length:] - return task - - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: - """Submit a task to storage.""" - if task_id in self.tasks: - raise ValueError(f'Task {task_id} already exists') - - task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) - task = Task(id=task_id, session_id=session_id, status=task_status, history=[message]) - self.tasks[task_id] = task - return task - - async def update_task( - self, - task_id: str, - state: TaskState, - message: Message | None = None, - artifacts: list[Artifact] | None = None, - ) -> Task: - """Save the task as "working".""" - task = self.tasks[task_id] - task['status'] = TaskStatus(state=state, timestamp=datetime.now().isoformat()) - if message: - if 'history' not in task: - task['history'] = [] - task['history'].append(message) - if artifacts: - if 'artifacts' not in task: - task['artifacts'] = [] - task['artifacts'].extend(artifacts) - return task +__all__ = ["Storage", "TaskStore", "InMemoryTaskStore"] diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py deleted file mode 100644 index 0baaeba04..000000000 --- a/fasta2a/fasta2a/task_manager.py +++ /dev/null @@ -1,169 +0,0 @@ -"""This module defines the TaskManager class, which is responsible for managing tasks. - -In our structure, we have the following components: - -- TaskManager: A class that manages tasks. -- Scheduler: A class that schedules tasks to be sent to the worker. -- Worker: A class that executes tasks. -- Runner: A class that defines how tasks run and how history is structured. -- Storage: A class that stores tasks and artifacts. - -Architecture: -``` - +-----------------+ - | HTTP Server | - +-------+---------+ - ^ - | Sends Requests/ - | Receives Results - v - +-------+---------+ - | | - | TaskManager |<-----------------+ - | (coordinates) | | - | | | - +-------+---------+ | - | | - | Schedules Tasks | - v v - +------------------+ +----------------+ - | | | | - | Broker | . | Storage | - | (queues) . | | (persistence) | - | | | | - +------------------+ +----------------+ - ^ ^ - | | - | Delegates Execution | - v | - +------------------+ | - | | | - | Worker | | - | (implementation) |-----------------+ - | | - +------------------+ -``` - -The flow: -1. The HTTP server sends a task to TaskManager -2. TaskManager stores initial task state in Storage -3. TaskManager passes task to Scheduler -4. Scheduler determines when to send tasks to Worker -5. Worker delegates to Runner for task execution -6. Runner defines how tasks run and how history is structured -7. Worker processes task results from Runner -8. Worker reads from and writes to Storage directly -9. Worker updates task status in Storage as execution progresses -10. TaskManager can also read/write from Storage for task management -11. Client queries TaskManager for results, which reads from Storage -""" - -from __future__ import annotations as _annotations - -import uuid -from contextlib import AsyncExitStack -from dataclasses import dataclass, field -from typing import Any - -from .broker import Broker -from .schema import ( - CancelTaskRequest, - CancelTaskResponse, - GetTaskPushNotificationRequest, - GetTaskPushNotificationResponse, - GetTaskRequest, - GetTaskResponse, - ResubscribeTaskRequest, - SendTaskRequest, - SendTaskResponse, - SendTaskStreamingRequest, - SendTaskStreamingResponse, - SetTaskPushNotificationRequest, - SetTaskPushNotificationResponse, - TaskNotFoundError, -) -from .storage import Storage - - -@dataclass -class TaskManager: - """A task manager responsible for managing tasks.""" - - broker: Broker - storage: Storage - - _aexit_stack: AsyncExitStack | None = field(default=None, init=False) - - async def __aenter__(self): - self._aexit_stack = AsyncExitStack() - await self._aexit_stack.__aenter__() - await self._aexit_stack.enter_async_context(self.broker) - - return self - - @property - def is_running(self) -> bool: - return self._aexit_stack is not None - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): - if self._aexit_stack is None: - raise RuntimeError('TaskManager was not properly initialized.') - await self._aexit_stack.__aexit__(exc_type, exc_value, traceback) - self._aexit_stack = None - - async def send_task(self, request: SendTaskRequest) -> SendTaskResponse: - """Send a task to the worker.""" - request_id = str(uuid.uuid4()) - task_id = request['params']['id'] - task = await self.storage.load_task(task_id) - - if task is None: - session_id = request['params'].get('session_id', str(uuid.uuid4())) - message = request['params']['message'] - task = await self.storage.submit_task(task_id, session_id, message) - - await self.broker.run_task(request['params']) - return SendTaskResponse(jsonrpc='2.0', id=request_id, result=task) - - async def get_task(self, request: GetTaskRequest) -> GetTaskResponse: - """Get a task, and return it to the client. - - No further actions are needed here. - """ - task_id = request['params']['id'] - history_length = request['params'].get('history_length') - task = await self.storage.load_task(task_id, history_length) - if task is None: - return GetTaskResponse( - jsonrpc='2.0', - id=request['id'], - error=TaskNotFoundError(code=-32001, message='Task not found'), - ) - return GetTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - - async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: - await self.broker.cancel_task(request['params']) - task = await self.storage.load_task(request['params']['id']) - if task is None: - return CancelTaskResponse( - jsonrpc='2.0', - id=request['id'], - error=TaskNotFoundError(code=-32001, message='Task not found'), - ) - return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - - async def send_task_streaming(self, request: SendTaskStreamingRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('SendTaskStreaming is not implemented yet.') - - async def set_task_push_notification( - self, request: SetTaskPushNotificationRequest - ) -> SetTaskPushNotificationResponse: - raise NotImplementedError('SetTaskPushNotification is not implemented yet.') - - async def get_task_push_notification( - self, request: GetTaskPushNotificationRequest - ) -> GetTaskPushNotificationResponse: - raise NotImplementedError('GetTaskPushNotification is not implemented yet.') - - async def resubscribe_task(self, request: ResubscribeTaskRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('Resubscribe is not implemented yet.') diff --git a/fasta2a/fasta2a/worker.py b/fasta2a/fasta2a/worker.py index 9bbde6b25..e8ba58633 100644 --- a/fasta2a/fasta2a/worker.py +++ b/fasta2a/fasta2a/worker.py @@ -1,68 +1,38 @@ from __future__ import annotations as _annotations -from abc import ABC, abstractmethod -from collections.abc import AsyncIterator +from abc import ABC from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any -import anyio -from opentelemetry.trace import get_tracer, use_span -from typing_extensions import assert_never +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.server.tasks.task_updater import TaskUpdater -if TYPE_CHECKING: - from .broker import Broker, TaskOperation - from .schema import Artifact, Message, TaskIdParams, TaskSendParams - from .storage import Storage -tracer = get_tracer(__name__) +class Worker(AgentExecutor, ABC): + """ + An abstract class for implementing the core logic of an A2A agent. - -@dataclass -class Worker(ABC): - """A worker is responsible for executing tasks.""" - - broker: Broker - storage: Storage + This class inherits from the `a2a.server.agent_execution.AgentExecutor` + and must be subclassed to define the agent's behavior. + """ @asynccontextmanager - async def run(self) -> AsyncIterator[None]: - """Run the worker. - - It connects to the broker, and it makes itself available to receive commands. + async def task_updater( + self, context: RequestContext, event_queue: EventQueue + ) -> TaskUpdater: """ - async with anyio.create_task_group() as tg: - tg.start_soon(self._loop) - yield - tg.cancel_scope.cancel() - - async def _loop(self) -> None: - async for task_operation in self.broker.receive_task_operations(): - await self._handle_task_operation(task_operation) + A convenience context manager to get a `TaskUpdater` for the current task. - async def _handle_task_operation(self, task_operation: TaskOperation) -> None: - try: - with use_span(task_operation['_current_span']): - with tracer.start_as_current_span( - f'{task_operation["operation"]} task', attributes={'logfire.tags': ['fasta2a']} - ): - if task_operation['operation'] == 'run': - await self.run_task(task_operation['params']) - elif task_operation['operation'] == 'cancel': - await self.cancel_task(task_operation['params']) - else: - assert_never(task_operation) - except Exception: - await self.storage.update_task(task_operation['params']['id'], state='failed') + Args: + context: The `RequestContext` for the current execution. + event_queue: The `EventQueue` to publish updates to. - @abstractmethod - async def run_task(self, params: TaskSendParams) -> None: ... - - @abstractmethod - async def cancel_task(self, params: TaskIdParams) -> None: ... - - @abstractmethod - def build_message_history(self, task_history: list[Message]) -> list[Any]: ... - - @abstractmethod - def build_artifacts(self, result: Any) -> list[Artifact]: ... + Yields: + A `TaskUpdater` instance for the current task. + """ + if not context.task_id or not context.context_id: + raise ValueError( + "RequestContext must have a task_id and context_id to create a TaskUpdater." + ) + yield TaskUpdater(event_queue, context.task_id, context.context_id) diff --git a/fasta2a/pyproject.toml b/fasta2a/pyproject.toml index 2abe809aa..651c341b5 100644 --- a/fasta2a/pyproject.toml +++ b/fasta2a/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "starlette>0.29.0", "pydantic>=2.10", "opentelemetry-api>=1.28.0", + "a2a-sdk>=0.2.8", ] [project.optional-dependencies] From 6c31b16958e2fb547ea0f6587a16e037a9ef01e8 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Fri, 13 Jun 2025 14:00:04 -0400 Subject: [PATCH 5/6] Re-refactoring --- fasta2a/fasta2a/__init__.py | 18 ++--- fasta2a/fasta2a/applications.py | 115 ++++++++++++-------------------- fasta2a/fasta2a/client.py | 83 ----------------------- fasta2a/fasta2a/schema.py | 47 ++++--------- fasta2a/fasta2a/storage.py | 39 ++++++++--- fasta2a/fasta2a/worker.py | 43 +++--------- 6 files changed, 103 insertions(+), 242 deletions(-) delete mode 100644 fasta2a/fasta2a/client.py diff --git a/fasta2a/fasta2a/__init__.py b/fasta2a/fasta2a/__init__.py index fdc8247f4..17881c198 100644 --- a/fasta2a/fasta2a/__init__.py +++ b/fasta2a/fasta2a/__init__.py @@ -1,17 +1,17 @@ from .applications import FastA2A -from .client import A2AClient -from .schema import Message, Part, Role, Skill, TextPart -from .storage import Storage -from .worker import Worker +from .schema import Artifact, Message, Part, Skill, Task, TaskState +from .storage import InMemoryStorage +from .worker import TaskStore, Worker __all__ = [ "FastA2A", - "A2AClient", - "Worker", - "Storage", "Skill", + "TaskStore", + "InMemoryStorage", + "Worker", + "Task", "Message", + "Artifact", "Part", - "Role", - "TextPart", + "TaskState", ] diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 489d0237e..2b4f097bf 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -2,105 +2,72 @@ from typing import Any -import httpx -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication -from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler -from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier -from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore -from a2a.types import AgentCapabilities, AgentCard, AgentProvider +from a2a.server.agent_execution import AgentExecutor +from a2a.server.apps.jsonrpc import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import TaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentProvider, AgentSkill as Skill from starlette.middleware import Middleware from starlette.routing import Route -from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send - -from .schema import Skill -from .storage import Storage -from .worker import Worker +from starlette.types import Receive, Scope, Send class FastA2A: - """ - The main class for the FastA2A library. It provides a simple way to create - an A2A server by wrapping the Google A2A SDK. - """ + """The main class for the FastA2A library.""" def __init__( self, *, - worker: Worker, - storage: Storage | None = None, + worker: AgentExecutor, + storage: TaskStore, # Agent card - name: str = "Agent", + name: str | None = None, url: str = "http://localhost:8000", version: str = "1.0.0", description: str | None = None, provider: AgentProvider | None = None, skills: list[Skill] | None = None, # Starlette - debug: bool = False, routes: list[Route] | None = None, middleware: list[Middleware] | None = None, - exception_handlers: dict[Any, ExceptionHandler] | None = None, - lifespan: Lifespan | None = None, + **starlette_kwargs: Any, ): - """ - Initializes the FastA2A application. - - Args: - worker: An implementation of `fasta2a.Worker` (which is an `a2a.server.agent_execution.AgentExecutor`). - storage: An implementation of `fasta2a.Storage` (which is an `a2a.server.tasks.TaskStore`). - Defaults to `InMemoryTaskStore`. - name: The human-readable name of the agent. - url: The URL where the agent is hosted. - version: The version of the agent. - description: A human-readable description of the agent. - provider: The service provider of the agent. - skills: A list of skills the agent can perform. - debug: Starlette's debug flag. - routes: A list of additional Starlette routes. - middleware: A list of Starlette middleware. - exception_handlers: A dictionary of Starlette exception handlers. - lifespan: A Starlette lifespan context manager. - """ - self.agent_card = AgentCard( - name=name, - url=url, - version=version, - description=description or "A FastA2A Agent", - provider=provider, - skills=skills or [], - capabilities=AgentCapabilities( - streaming=True, pushNotifications=True, stateTransitionHistory=True - ), - defaultInputModes=["application/json"], - defaultOutputModes=["application/json"], - securitySchemes={}, - ) - - self.storage = storage or InMemoryTaskStore() self.worker = worker + self.storage = storage + self.name = name or "Agent" + self.url = url + self.version = version + self.description = description + self.provider = provider + self.skills = skills or [] + self.default_input_modes = ["application/json"] + self.default_output_modes = ["application/json"] + self.capabilities = AgentCapabilities( + streaming=True, + pushNotifications=False, + stateTransitionHistory=False, + ) - # The SDK's DefaultRequestHandler uses httpx to send push notifications - http_client = httpx.AsyncClient() - push_notifier = InMemoryPushNotifier(httpx_client) + agent_card = AgentCard( + name=self.name, + url=self.url, + version=self.version, + description=self.description or "", + provider=self.provider, + skills=self.skills, + defaultInputModes=self.default_input_modes, + defaultOutputModes=self.default_output_modes, + capabilities=self.capabilities, + ) - request_handler = DefaultRequestHandler( + handler = DefaultRequestHandler( agent_executor=self.worker, task_store=self.storage, - push_notifier=push_notifier, - ) - - a2a_app = A2AStarletteApplication( - agent_card=self.agent_card, - http_handler=request_handler, ) - self.app = a2a_app.build( - debug=debug, - routes=routes, - middleware=middleware, - exception_handlers=exception_handlers, - lifespan=lifespan, - ) + self._app = A2AStarletteApplication( + agent_card=agent_card, http_handler=handler + ).build(routes=routes, middleware=middleware, **starlette_kwargs) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await self.app(scope, receive, send) + await self._app(scope, receive, send) diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py deleted file mode 100644 index 13c787e60..000000000 --- a/fasta2a/fasta2a/client.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import annotations as _annotations - -import uuid -from typing import Any - -from a2a.client.client import A2AClient as BaseA2AClient -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError -from a2a.types import ( - GetTaskRequest, - GetTaskResponse, - Message, - MessageSendParams, - PushNotificationConfig, - SendMessageRequest, - SendMessageResponse, -) - -try: - import httpx -except ImportError as _import_error: - raise ImportError( - "httpx is required to use the A2AClient. Please install it with `pip install httpx`.", - ) from _import_error - -UnexpectedResponseError = A2AClientHTTPError - - -class A2AClient: - """A client for the A2A protocol, built on the Google A2A SDK.""" - - def __init__( - self, - base_url: str = "http://localhost:8000", - http_client: httpx.AsyncClient | None = None, - ) -> None: - if http_client is None: - self.http_client = httpx.AsyncClient() - else: - self.http_client = http_client - - # The SDK client requires a URL on the agent card, but we can initialize it with a dummy - # and then set the URL directly for the internal calls. - self._sdk_client = BaseA2AClient(httpx_client=self.http_client, url=base_url) - - async def send_task( - self, - message: Message, - history_length: int | None = None, - push_notification: PushNotificationConfig | None = None, - metadata: dict[str, Any] | None = None, - ) -> SendMessageResponse: - """Sends a task to the A2A server.""" - if not message.taskId: - message.taskId = str(uuid.uuid4()) - - params = MessageSendParams( - message=message, - configuration={ - "historyLength": history_length, - "pushNotificationConfig": push_notification, - }, - metadata=metadata, - ) - payload = SendMessageRequest( - id=str(uuid.uuid4()), method="message/send", params=params - ) - - try: - response = await self._sdk_client.send_message(payload) - return response - except (A2AClientHTTPError, A2AClientJSONError) as e: - raise UnexpectedResponseError(getattr(e, "status_code", 500), str(e)) from e - - async def get_task(self, task_id: str) -> GetTaskResponse: - """Retrieves a task from the A2A server.""" - payload = GetTaskRequest( - id=str(uuid.uuid4()), method="tasks/get", params={"id": task_id} - ) - try: - response = await self._sdk_client.get_task(payload) - return response - except (A2AClientHTTPError, A2AClientJSONError) as e: - raise UnexpectedResponseError(getattr(e, "status_code", 500), str(e)) from e diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index b5c7e331f..3da693f67 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -1,55 +1,34 @@ """ -This module re-exports the core A2A schema types from the `google-a2a` SDK. -The types are based on Pydantic `BaseModel` and align with the official A2A JSON specification. +This module re-exports the core schema types from the Google A2A SDK. + +By using the SDK's types, FastA2A ensures compliance with the A2A specification +and leverages the robust Pydantic models provided by the SDK. """ +from __future__ import annotations as _annotations + from a2a.types import ( AgentCard, AgentProvider, - AgentSkill as Skill, + AgentSkill, Artifact, - CancelTaskRequest, - CancelTaskResponse, - DataPart, - FilePart, - GetTaskRequest, - GetTaskResponse, Message, - MessageSendParams, Part, - PushNotificationConfig, - Role, - SendMessageRequest, - SendMessageResponse, Task, TaskState, - TextPart, ) +# Alias for backward compatibility +Skill = AgentSkill +Provider = AgentProvider + __all__ = [ - # Core Models "AgentCard", - "AgentProvider", + "Provider", + "Skill", "Artifact", "Message", "Part", - "Skill", "Task", - # Enums - "Role", "TaskState", - # Part Types - "TextPart", - "FilePart", - "DataPart", - # Request/Response models - "SendMessageRequest", - "SendMessageResponse", - "GetTaskRequest", - "GetTaskResponse", - "CancelTaskRequest", - "CancelTaskResponse", - # Parameter Models - "MessageSendParams", - "PushNotificationConfig", ] diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index 4b78cc842..f7a8be4ff 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -1,12 +1,33 @@ -""" -This module provides an alias for the TaskStore abstraction from the `google-a2a` SDK. -A TaskStore is responsible for persisting and retrieving task state. -""" +"""This module defines the Storage class, which is responsible for storing and retrieving tasks.""" -from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore -from a2a.server.tasks.task_store import TaskStore +from __future__ import annotations as _annotations -Storage = TaskStore -"""Alias for `a2a.server.tasks.task_store.TaskStore`.""" +from a2a.server.tasks import TaskStore +from a2a.types import Task -__all__ = ["Storage", "TaskStore", "InMemoryTaskStore"] + +class InMemoryStorage(TaskStore): + """A storage to retrieve and save tasks in memory.""" + + def __init__(self): + self.tasks: dict[str, Task] = {} + + async def get(self, task_id: str) -> Task | None: + """Load a task from memory. + + Args: + task_id: The id of the task to load. + + Returns: + The task. + """ + return self.tasks.get(task_id) + + async def save(self, task: Task) -> None: + """Saves or updates a task in the in-memory store.""" + self.tasks[task.id] = task + + async def delete(self, task_id: str) -> None: + """Deletes a task from the in-memory store by ID.""" + if task_id in self.tasks: + del self.tasks[task_id] diff --git a/fasta2a/fasta2a/worker.py b/fasta2a/fasta2a/worker.py index e8ba58633..ae5bb5381 100644 --- a/fasta2a/fasta2a/worker.py +++ b/fasta2a/fasta2a/worker.py @@ -1,38 +1,15 @@ from __future__ import annotations as _annotations -from abc import ABC -from contextlib import asynccontextmanager +from a2a.server.agent_execution import AgentExecutor +from a2a.server.tasks import TaskStore -from a2a.server.agent_execution.agent_executor import AgentExecutor -from a2a.server.agent_execution.context import RequestContext -from a2a.server.events.event_queue import EventQueue -from a2a.server.tasks.task_updater import TaskUpdater +Worker = AgentExecutor +""" +The `Worker` is the core component where you implement your agent's logic. +It is an alias for the `a2a.server.agent_execution.AgentExecutor` class from the +Google A2A SDK. You should create a class that inherits from `Worker` and +implement the `execute` and `cancel` methods. +""" -class Worker(AgentExecutor, ABC): - """ - An abstract class for implementing the core logic of an A2A agent. - - This class inherits from the `a2a.server.agent_execution.AgentExecutor` - and must be subclassed to define the agent's behavior. - """ - - @asynccontextmanager - async def task_updater( - self, context: RequestContext, event_queue: EventQueue - ) -> TaskUpdater: - """ - A convenience context manager to get a `TaskUpdater` for the current task. - - Args: - context: The `RequestContext` for the current execution. - event_queue: The `EventQueue` to publish updates to. - - Yields: - A `TaskUpdater` instance for the current task. - """ - if not context.task_id or not context.context_id: - raise ValueError( - "RequestContext must have a task_id and context_id to create a TaskUpdater." - ) - yield TaskUpdater(event_queue, context.task_id, context.context_id) +__all__ = ["Worker", "TaskStore"] From 2b0503ce5108ef4685f7824ee45b48707e7caa31 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Mon, 16 Jun 2025 10:18:30 -0500 Subject: [PATCH 6/6] Updates to FastA2A Port --- fasta2a/fasta2a/applications.py | 141 ++++++++++++++++++++++---------- fasta2a/fasta2a/client.py | 65 +++++++++++++++ fasta2a/fasta2a/schema.py | 102 ++++++++++++++++++++--- fasta2a/fasta2a/storage.py | 49 +++++++---- fasta2a/fasta2a/worker.py | 33 +++++--- 5 files changed, 308 insertions(+), 82 deletions(-) create mode 100644 fasta2a/fasta2a/client.py diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 2b4f097bf..a4bd4858f 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -1,15 +1,74 @@ from __future__ import annotations as _annotations -from typing import Any +from typing import TYPE_CHECKING, Any, Sequence -from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps.jsonrpc import A2AStarletteApplication +from a2a.server.events import EventQueue from a2a.server.request_handlers import DefaultRequestHandler -from a2a.server.tasks import TaskStore -from a2a.types import AgentCapabilities, AgentCard, AgentProvider, AgentSkill as Skill +from a2a.types import ( + AgentCard, + Capabilities, + InvalidParamsError, + MessageSendParams, + TaskIdParams, + TaskState, +) +from a2a.utils.errors import ServerError +from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.routing import Route -from starlette.types import Receive, Scope, Send +from starlette.types import ExceptionHandler, Lifespan + +from .storage import Storage +from .worker import Worker + +if TYPE_CHECKING: + from .schema import Provider, Skill + + +class _WorkerExecutor(AgentExecutor): + """An adapter to make a fasta2a.Worker compatible with a2a.AgentExecutor.""" + + def __init__(self, worker: Worker, storage: Storage): + self.worker = worker + self.storage = storage + + async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: + from a2a.server.tasks import TaskUpdater + + self.worker.storage = self.storage + + if not (context.task_id and context.context_id and context.message): + raise ServerError( + InvalidParamsError( + message="task_id, context_id, and message are required for execution" + ) + ) + + params = MessageSendParams( + message=context.message, configuration=context.configuration + ) + + updater = TaskUpdater(event_queue, context.task_id, context.context_id) + await self.worker.run_task(params, updater) + + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + from a2a.server.tasks import TaskUpdater + + self.worker.storage = self.storage + + if not context.task_id or not context.context_id: + raise ServerError( + InvalidParamsError( + message="task_id and context_id are required for cancellation" + ) + ) + + params = TaskIdParams(id=context.task_id) + updater = TaskUpdater(event_queue, context.task_id, context.context_id) + await self.worker.cancel_task(params, updater) + await updater.update_status(TaskState.canceled, final=True) class FastA2A: @@ -18,56 +77,52 @@ class FastA2A: def __init__( self, *, - worker: AgentExecutor, - storage: TaskStore, + storage: Storage, + worker: Worker, # Agent card name: str | None = None, url: str = "http://localhost:8000", version: str = "1.0.0", description: str | None = None, - provider: AgentProvider | None = None, + provider: Provider | None = None, skills: list[Skill] | None = None, # Starlette - routes: list[Route] | None = None, - middleware: list[Middleware] | None = None, - **starlette_kwargs: Any, + debug: bool = False, + routes: Sequence[Route] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: dict[Any, ExceptionHandler] | None = None, + lifespan: Lifespan | None = None, ): - self.worker = worker - self.storage = storage - self.name = name or "Agent" - self.url = url - self.version = version - self.description = description - self.provider = provider - self.skills = skills or [] - self.default_input_modes = ["application/json"] - self.default_output_modes = ["application/json"] - self.capabilities = AgentCapabilities( - streaming=True, - pushNotifications=False, - stateTransitionHistory=False, + agent_executor = _WorkerExecutor(worker, storage) + + request_handler = DefaultRequestHandler( + agent_executor=agent_executor, task_store=storage ) agent_card = AgentCard( - name=self.name, - url=self.url, - version=self.version, - description=self.description or "", - provider=self.provider, - skills=self.skills, - defaultInputModes=self.default_input_modes, - defaultOutputModes=self.default_output_modes, - capabilities=self.capabilities, + name=name or "Agent", + url=url, + version=version, + description=description, + provider=provider, + skills=skills or [], + defaultInputModes=["application/json"], + defaultOutputModes=["application/json"], + capabilities=Capabilities( + streaming=True, pushNotifications=False, stateTransitionHistory=True + ), ) - handler = DefaultRequestHandler( - agent_executor=self.worker, - task_store=self.storage, + app_builder = A2AStarletteApplication( + agent_card=agent_card, http_handler=request_handler + ) + self.app: Starlette = app_builder.build( + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + lifespan=lifespan, ) - self._app = A2AStarletteApplication( - agent_card=agent_card, http_handler=handler - ).build(routes=routes, middleware=middleware, **starlette_kwargs) - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await self._app(scope, receive, send) + async def __call__(self, scope: Any, receive: Any, send: Any) -> None: + await self.app(scope, receive, send) diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py new file mode 100644 index 000000000..fb0f80eab --- /dev/null +++ b/fasta2a/fasta2a/client.py @@ -0,0 +1,65 @@ +from __future__ import annotations as _annotations + +import uuid +from typing import Any + +import httpx +from a2a.client import A2AClient as SDKA2AClient +from a2a.types import ( + GetTaskRequest, + GetTaskResponse, + Message, + MessageSendConfiguration, + SendMessageRequest, + SendMessageResponse, + TaskQueryParams, +) + +from .schema import PushNotificationConfig + + +class A2AClient: + """A client for the A2A protocol.""" + + def __init__( + self, + base_url: str = "http://localhost:8000", + http_client: httpx.AsyncClient | None = None, + ) -> None: + if http_client is None: + http_client = httpx.AsyncClient() + # The SDK's client will be initialized with a URL that points to the JSON-RPC endpoint. + # fasta2a server sets this at root '/', so we append it. + self.sdk_client = SDKA2AClient(http_client, url=f'{base_url.rstrip("/")}/') + + async def send_task( + self, + message: Message, + history_length: int | None = None, + push_notification: PushNotificationConfig | None = None, + metadata: dict[str, Any] | None = None, + ) -> SendMessageResponse: + """Sends a task to the agent. + + This now maps to the 'message/send' A2A method. + """ + if metadata: + message.metadata = (message.metadata or {}) | metadata + + configuration = MessageSendConfiguration( + historyLength=history_length, + pushNotificationConfig=push_notification, + ) + + request = SendMessageRequest( + id=str(uuid.uuid4()), + params={"message": message, "configuration": configuration}, + ) + return await self.sdk_client.send_message(request) + + async def get_task(self, task_id: str) -> GetTaskResponse: + """Retrieves a task from the agent.""" + request = GetTaskRequest( + id=str(uuid.uuid4()), params=TaskQueryParams(id=task_id) + ) + return await self.sdk_client.get_task(request) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 3da693f67..3043c1284 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -1,34 +1,112 @@ -""" -This module re-exports the core schema types from the Google A2A SDK. - -By using the SDK's types, FastA2A ensures compliance with the A2A specification -and leverages the robust Pydantic models provided by the SDK. -""" +"""This module contains the schema for the agent card.""" from __future__ import annotations as _annotations +from typing import Union + from a2a.types import ( + A2ARequest as _A2ARequest, + A2AResponse as _A2AResponse, AgentCard, - AgentProvider, - AgentSkill, + AgentProvider as Provider, + AgentSkill as Skill, Artifact, + AuthenticationInfo as Authentication, + CancelTaskRequest, + CancelTaskResponse, + Capabilities, + ContentTypeNotSupportedError, + GetTaskPushNotificationConfigRequest as GetTaskPushNotificationRequest, + GetTaskPushNotificationConfigResponse as GetTaskPushNotificationResponse, + GetTaskRequest, + GetTaskResponse, + InternalError, + JSONRPCError, Message, + MessageSendParams as TaskSendParams, Part, + PushNotificationConfig, + PushNotificationNotSupportedError, + SendMessageRequest as SendTaskRequest, + SendMessageResponse as SendTaskResponse, + SendStreamingMessageRequest as SendTaskStreamingRequest, + SendStreamingMessageResponse as SendTaskStreamingResponse, + SetTaskPushNotificationConfigRequest as SetTaskPushNotificationRequest, + SetTaskPushNotificationConfigResponse as SetTaskPushNotificationResponse, Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskNotCancelableError, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskQueryParams, + TaskResubscriptionRequest as ResubscribeTaskRequest, TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, + UnsupportedOperationError, ) - -# Alias for backward compatibility -Skill = AgentSkill -Provider = AgentProvider +from pydantic import TypeAdapter __all__ = [ "AgentCard", "Provider", "Skill", "Artifact", + "Authentication", + "Capabilities", "Message", + "TaskSendParams", "Part", + "PushNotificationConfig", "Task", + "TaskIdParams", + "TaskQueryParams", + "TextPart", + "CancelTaskRequest", + "CancelTaskResponse", + "GetTaskPushNotificationRequest", + "GetTaskPushNotificationResponse", + "GetTaskRequest", + "GetTaskResponse", + "SendTaskStreamingRequest", + "SendTaskStreamingResponse", + "SendTaskRequest", + "SendTaskResponse", + "SetTaskPushNotificationRequest", + "SetTaskPushNotificationResponse", + "ResubscribeTaskRequest", "TaskState", + "TaskStatus", + "TaskArtifactUpdateEvent", + "TaskPushNotificationConfig", + "TaskStatusUpdateEvent", + "JSONRPCError", + "TaskNotFoundError", + "TaskNotCancelableError", + "PushNotificationNotSupportedError", + "UnsupportedOperationError", + "ContentTypeNotSupportedError", + "InternalError", + "a2a_request_ta", + "a2a_response_ta", + "A2ARequest", + "A2AResponse", +] + + +A2ARequest = _A2ARequest +"""A JSON RPC request to the A2A server.""" + +A2AResponse = Union[ + SendTaskResponse, + GetTaskResponse, + CancelTaskResponse, + SetTaskPushNotificationResponse, + GetTaskPushNotificationResponse, ] +"""A JSON RPC response from the A2A server.""" + +a2a_request_ta: TypeAdapter[A2ARequest] = TypeAdapter(A2ARequest) +a2a_response_ta: TypeAdapter[A2AResponse] = TypeAdapter(A2AResponse) diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index f7a8be4ff..7ffb3d63a 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -2,32 +2,47 @@ from __future__ import annotations as _annotations -from a2a.server.tasks import TaskStore +import asyncio +from abc import ABC, abstractmethod + from a2a.types import Task -class InMemoryStorage(TaskStore): +class Storage(ABC): + """A storage to retrieve and save tasks.""" + + @abstractmethod + async def get(self, task_id: str) -> Task | None: + """Retrieves a task from the store by its ID.""" + + @abstractmethod + async def save(self, task: Task) -> None: + """Saves or updates a task in the store.""" + + @abstractmethod + async def delete(self, task_id: str) -> None: + """Deletes a task from the store by its ID.""" + + +class InMemoryStorage(Storage): """A storage to retrieve and save tasks in memory.""" - def __init__(self): + def __init__(self) -> None: self.tasks: dict[str, Task] = {} + self.lock = asyncio.Lock() async def get(self, task_id: str) -> Task | None: - """Load a task from memory. - - Args: - task_id: The id of the task to load. - - Returns: - The task. - """ - return self.tasks.get(task_id) + """Retrieves a task from memory.""" + async with self.lock: + return self.tasks.get(task_id) async def save(self, task: Task) -> None: - """Saves or updates a task in the in-memory store.""" - self.tasks[task.id] = task + """Saves or updates a task in memory.""" + async with self.lock: + self.tasks[task.id] = task async def delete(self, task_id: str) -> None: - """Deletes a task from the in-memory store by ID.""" - if task_id in self.tasks: - del self.tasks[task_id] + """Deletes a task from memory.""" + async with self.lock: + if task_id in self.tasks: + del self.tasks[task_id] diff --git a/fasta2a/fasta2a/worker.py b/fasta2a/fasta2a/worker.py index ae5bb5381..7bcc4cb45 100644 --- a/fasta2a/fasta2a/worker.py +++ b/fasta2a/fasta2a/worker.py @@ -1,15 +1,28 @@ from __future__ import annotations as _annotations -from a2a.server.agent_execution import AgentExecutor -from a2a.server.tasks import TaskStore +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any -Worker = AgentExecutor -""" -The `Worker` is the core component where you implement your agent's logic. +if TYPE_CHECKING: + from a2a.server.tasks import TaskUpdater -It is an alias for the `a2a.server.agent_execution.AgentExecutor` class from the -Google A2A SDK. You should create a class that inherits from `Worker` and -implement the `execute` and `cancel` methods. -""" + from .schema import Artifact, Message, TaskIdParams, TaskSendParams + from .storage import Storage -__all__ = ["Worker", "TaskStore"] + +class Worker(ABC): + """A worker is responsible for executing tasks.""" + + storage: Storage + + @abstractmethod + async def run_task(self, params: TaskSendParams, updater: TaskUpdater) -> None: ... + + @abstractmethod + async def cancel_task(self, params: TaskIdParams, updater: TaskUpdater) -> None: ... + + @abstractmethod + def build_message_history(self, task_history: list[Message]) -> list[Any]: ... + + @abstractmethod + def build_artifacts(self, result: Any) -> list[Artifact]: ...