diff --git a/.gitignore b/.gitignore index acb015449..092411ded 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite node_modules/ **.idea/ .coverage* +.mypy_cache/ diff --git a/fasta2a/fasta2a/__init__.py b/fasta2a/fasta2a/__init__.py index 4a8b10629..17881c198 100644 --- a/fasta2a/fasta2a/__init__.py +++ b/fasta2a/fasta2a/__init__.py @@ -1,7 +1,17 @@ from .applications import FastA2A -from .broker import Broker -from .schema import Skill -from .storage import Storage -from .worker import Worker +from .schema import Artifact, Message, Part, Skill, Task, TaskState +from .storage import InMemoryStorage +from .worker import TaskStore, Worker -__all__ = ['FastA2A', 'Skill', 'Storage', 'Broker', 'Worker'] +__all__ = [ + "FastA2A", + "Skill", + "TaskStore", + "InMemoryStorage", + "Worker", + "Task", + "Message", + "Artifact", + "Part", + "TaskState", +] diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 61301262b..a4bd4858f 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -1,43 +1,88 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Sequence -from contextlib import asynccontextmanager -from typing import Any +from typing import TYPE_CHECKING, Any, Sequence +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.apps.jsonrpc import A2AStarletteApplication +from a2a.server.events import EventQueue +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.types import ( + AgentCard, + Capabilities, + InvalidParamsError, + MessageSendParams, + TaskIdParams, + TaskState, +) +from a2a.utils.errors import ServerError from starlette.applications import Starlette from starlette.middleware import Middleware -from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Route -from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send +from starlette.types import ExceptionHandler, Lifespan -from .broker import Broker -from .schema import ( - AgentCard, - Authentication, - Capabilities, - Provider, - Skill, - a2a_request_ta, - a2a_response_ta, - agent_card_ta, -) from .storage import Storage -from .task_manager import TaskManager +from .worker import Worker + +if TYPE_CHECKING: + from .schema import Provider, Skill + + +class _WorkerExecutor(AgentExecutor): + """An adapter to make a fasta2a.Worker compatible with a2a.AgentExecutor.""" + + def __init__(self, worker: Worker, storage: Storage): + self.worker = worker + self.storage = storage + + async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: + from a2a.server.tasks import TaskUpdater + + self.worker.storage = self.storage + + if not (context.task_id and context.context_id and context.message): + raise ServerError( + InvalidParamsError( + message="task_id, context_id, and message are required for execution" + ) + ) + + params = MessageSendParams( + message=context.message, configuration=context.configuration + ) + updater = TaskUpdater(event_queue, context.task_id, context.context_id) + await self.worker.run_task(params, updater) -class FastA2A(Starlette): + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + from a2a.server.tasks import TaskUpdater + + self.worker.storage = self.storage + + if not context.task_id or not context.context_id: + raise ServerError( + InvalidParamsError( + message="task_id and context_id are required for cancellation" + ) + ) + + params = TaskIdParams(id=context.task_id) + updater = TaskUpdater(event_queue, context.task_id, context.context_id) + await self.worker.cancel_task(params, updater) + await updater.update_status(TaskState.canceled, final=True) + + +class FastA2A: """The main class for the FastA2A library.""" def __init__( self, *, storage: Storage, - broker: Broker, + worker: Worker, # Agent card name: str | None = None, - url: str = 'http://localhost:8000', - version: str = '1.0.0', + url: str = "http://localhost:8000", + version: str = "1.0.0", description: str | None = None, provider: Provider | None = None, skills: list[Skill] | None = None, @@ -46,12 +91,32 @@ def __init__( routes: Sequence[Route] | None = None, middleware: Sequence[Middleware] | None = None, exception_handlers: dict[Any, ExceptionHandler] | None = None, - lifespan: Lifespan[FastA2A] | None = None, + lifespan: Lifespan | None = None, ): - if lifespan is None: - lifespan = _default_lifespan + agent_executor = _WorkerExecutor(worker, storage) + + request_handler = DefaultRequestHandler( + agent_executor=agent_executor, task_store=storage + ) - super().__init__( + agent_card = AgentCard( + name=name or "Agent", + url=url, + version=version, + description=description, + provider=provider, + skills=skills or [], + defaultInputModes=["application/json"], + defaultOutputModes=["application/json"], + capabilities=Capabilities( + streaming=True, pushNotifications=False, stateTransitionHistory=True + ), + ) + + app_builder = A2AStarletteApplication( + agent_card=agent_card, http_handler=request_handler + ) + self.app: Starlette = app_builder.build( debug=debug, routes=routes, middleware=middleware, @@ -59,77 +124,5 @@ def __init__( lifespan=lifespan, ) - self.name = name or 'Agent' - self.url = url - self.version = version - self.description = description - self.provider = provider - self.skills = skills or [] - # NOTE: For now, I don't think there's any reason to support any other input/output modes. - self.default_input_modes = ['application/json'] - self.default_output_modes = ['application/json'] - - self.task_manager = TaskManager(broker=broker, storage=storage) - - # Setup - self._agent_card_json_schema: bytes | None = None - self.router.add_route('/.well-known/agent.json', self._agent_card_endpoint, methods=['HEAD', 'GET', 'OPTIONS']) - self.router.add_route('/', self._agent_run_endpoint, methods=['POST']) - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope['type'] == 'http' and not self.task_manager.is_running: - raise RuntimeError('TaskManager was not properly initialized.') - await super().__call__(scope, receive, send) - - async def _agent_card_endpoint(self, request: Request) -> Response: - if self._agent_card_json_schema is None: - agent_card = AgentCard( - name=self.name, - url=self.url, - version=self.version, - skills=self.skills, - default_input_modes=self.default_input_modes, - default_output_modes=self.default_output_modes, - capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False), - authentication=Authentication(schemes=[]), - ) - if self.description is not None: - agent_card['description'] = self.description - if self.provider is not None: - agent_card['provider'] = self.provider - self._agent_card_json_schema = agent_card_ta.dump_json(agent_card, by_alias=True) - return Response(content=self._agent_card_json_schema, media_type='application/json') - - async def _agent_run_endpoint(self, request: Request) -> Response: - """This is the main endpoint for the A2A server. - - Although the specification allows freedom of choice and implementation, I'm pretty sure about some decisions. - - 1. The server will always either send a "submitted" or a "failed" on `tasks/send`. - Never a "completed" on the first message. - 2. There are three possible ends for the task: - 2.1. The task was "completed" successfully. - 2.2. The task was "canceled". - 2.3. The task "failed". - 3. The server will send a "working" on the first chunk on `tasks/pushNotification/get`. - """ - data = await request.body() - a2a_request = a2a_request_ta.validate_json(data) - - if a2a_request['method'] == 'tasks/send': - jsonrpc_response = await self.task_manager.send_task(a2a_request) - elif a2a_request['method'] == 'tasks/get': - jsonrpc_response = await self.task_manager.get_task(a2a_request) - elif a2a_request['method'] == 'tasks/cancel': - jsonrpc_response = await self.task_manager.cancel_task(a2a_request) - else: - raise NotImplementedError(f'Method {a2a_request["method"]} not implemented.') - return Response( - content=a2a_response_ta.dump_json(jsonrpc_response, by_alias=True), media_type='application/json' - ) - - -@asynccontextmanager -async def _default_lifespan(app: FastA2A) -> AsyncIterator[None]: - async with app.task_manager: - yield + async def __call__(self, scope: Any, receive: Any, send: Any) -> None: + await self.app(scope, receive, send) diff --git a/fasta2a/fasta2a/broker.py b/fasta2a/fasta2a/broker.py deleted file mode 100644 index c84b73872..000000000 --- a/fasta2a/fasta2a/broker.py +++ /dev/null @@ -1,98 +0,0 @@ -from __future__ import annotations as _annotations - -from abc import ABC, abstractmethod -from collections.abc import AsyncIterator -from contextlib import AsyncExitStack -from dataclasses import dataclass -from typing import Annotated, Any, Generic, Literal, TypeVar - -import anyio -from opentelemetry.trace import Span, get_current_span, get_tracer -from pydantic import Discriminator -from typing_extensions import Self, TypedDict - -from .schema import TaskIdParams, TaskSendParams - -tracer = get_tracer(__name__) - - -@dataclass -class Broker(ABC): - """The broker class is in charge of scheduling the tasks. - - The HTTP server uses the broker to schedule tasks. - - The simple implementation is the `InMemoryBroker`, which is the broker that - runs the tasks in the same process as the HTTP server. That said, this class can be - extended to support remote workers. - """ - - @abstractmethod - async def run_task(self, params: TaskSendParams) -> None: - """Send a task to be executed by the worker.""" - raise NotImplementedError('send_run_task is not implemented yet.') - - @abstractmethod - async def cancel_task(self, params: TaskIdParams) -> None: - """Cancel a task.""" - raise NotImplementedError('send_cancel_task is not implemented yet.') - - @abstractmethod - async def __aenter__(self) -> Self: ... - - @abstractmethod - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): ... - - @abstractmethod - def receive_task_operations(self) -> AsyncIterator[TaskOperation]: - """Receive task operations from the broker. - - On a multi-worker setup, the broker will need to round-robin the task operations - between the workers. - """ - - -OperationT = TypeVar('OperationT') -ParamsT = TypeVar('ParamsT') - - -class _TaskOperation(TypedDict, Generic[OperationT, ParamsT]): - """A task operation.""" - - operation: OperationT - params: ParamsT - _current_span: Span - - -_RunTask = _TaskOperation[Literal['run'], TaskSendParams] -_CancelTask = _TaskOperation[Literal['cancel'], TaskIdParams] - -TaskOperation = Annotated['_RunTask | _CancelTask', Discriminator('operation')] - - -class InMemoryBroker(Broker): - """A broker that schedules tasks in memory.""" - - async def __aenter__(self): - self.aexit_stack = AsyncExitStack() - await self.aexit_stack.__aenter__() - - self._write_stream, self._read_stream = anyio.create_memory_object_stream[TaskOperation]() - await self.aexit_stack.enter_async_context(self._read_stream) - await self.aexit_stack.enter_async_context(self._write_stream) - - return self - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): - await self.aexit_stack.__aexit__(exc_type, exc_value, traceback) - - async def run_task(self, params: TaskSendParams) -> None: - await self._write_stream.send(_RunTask(operation='run', params=params, _current_span=get_current_span())) - - async def cancel_task(self, params: TaskIdParams) -> None: - await self._write_stream.send(_CancelTask(operation='cancel', params=params, _current_span=get_current_span())) - - async def receive_task_operations(self) -> AsyncIterator[TaskOperation]: - """Receive task operations from the broker.""" - async for task_operation in self._read_stream: - yield task_operation diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py index 5c5aabd81..fb0f80eab 100644 --- a/fasta2a/fasta2a/client.py +++ b/fasta2a/fasta2a/client.py @@ -3,39 +3,34 @@ import uuid from typing import Any -import pydantic - -from .schema import ( +import httpx +from a2a.client import A2AClient as SDKA2AClient +from a2a.types import ( GetTaskRequest, GetTaskResponse, Message, - PushNotificationConfig, - SendTaskRequest, - SendTaskResponse, - TaskSendParams, - a2a_request_ta, + MessageSendConfiguration, + SendMessageRequest, + SendMessageResponse, + TaskQueryParams, ) -send_task_response_ta = pydantic.TypeAdapter(SendTaskResponse) -get_task_response_ta = pydantic.TypeAdapter(GetTaskResponse) - -try: - import httpx -except ImportError as _import_error: - raise ImportError( - 'httpx is required to use the A2AClient. Please install it with `pip install httpx`.', - ) from _import_error +from .schema import PushNotificationConfig class A2AClient: """A client for the A2A protocol.""" - def __init__(self, base_url: str = 'http://localhost:8000', http_client: httpx.AsyncClient | None = None) -> None: + def __init__( + self, + base_url: str = "http://localhost:8000", + http_client: httpx.AsyncClient | None = None, + ) -> None: if http_client is None: - self.http_client = httpx.AsyncClient(base_url=base_url) - else: - self.http_client = http_client - self.http_client.base_url = base_url + http_client = httpx.AsyncClient() + # The SDK's client will be initialized with a URL that points to the JSON-RPC endpoint. + # fasta2a server sets this at root '/', so we append it. + self.sdk_client = SDKA2AClient(http_client, url=f'{base_url.rstrip("/")}/') async def send_task( self, @@ -43,36 +38,28 @@ async def send_task( history_length: int | None = None, push_notification: PushNotificationConfig | None = None, metadata: dict[str, Any] | None = None, - ) -> SendTaskResponse: - task = TaskSendParams(message=message, id=str(uuid.uuid4())) - if history_length is not None: - task['history_length'] = history_length - if push_notification is not None: - task['push_notification'] = push_notification - if metadata is not None: - task['metadata'] = metadata + ) -> SendMessageResponse: + """Sends a task to the agent. - payload = SendTaskRequest(jsonrpc='2.0', id=None, method='tasks/send', params=task) - content = a2a_request_ta.dump_json(payload, by_alias=True) - response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) - self._raise_for_status(response) - return send_task_response_ta.validate_json(response.content) + This now maps to the 'message/send' A2A method. + """ + if metadata: + message.metadata = (message.metadata or {}) | metadata - async def get_task(self, task_id: str) -> GetTaskResponse: - payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) - content = a2a_request_ta.dump_json(payload, by_alias=True) - response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) - self._raise_for_status(response) - return get_task_response_ta.validate_json(response.content) + configuration = MessageSendConfiguration( + historyLength=history_length, + pushNotificationConfig=push_notification, + ) - def _raise_for_status(self, response: httpx.Response) -> None: - if response.status_code >= 400: - raise UnexpectedResponseError(response.status_code, response.text) + request = SendMessageRequest( + id=str(uuid.uuid4()), + params={"message": message, "configuration": configuration}, + ) + return await self.sdk_client.send_message(request) - -class UnexpectedResponseError(Exception): - """An error raised when an unexpected response is received from the server.""" - - def __init__(self, status_code: int, content: str) -> None: - self.status_code = status_code - self.content = content + async def get_task(self, task_id: str) -> GetTaskResponse: + """Retrieves a task from the agent.""" + request = GetTaskRequest( + id=str(uuid.uuid4()), params=TaskQueryParams(id=task_id) + ) + return await self.sdk_client.get_task(request) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 57193a97d..3043c1284 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -2,558 +2,104 @@ from __future__ import annotations as _annotations -from typing import Annotated, Any, Generic, Literal, TypeVar, Union - -import pydantic -from pydantic import Discriminator, TypeAdapter -from pydantic.alias_generators import to_camel -from typing_extensions import NotRequired, TypeAlias, TypedDict - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class AgentCard(TypedDict): - """The card that describes an agent.""" - - name: str - """Human readable name of the agent e.g. "Recipe Agent".""" - - description: NotRequired[str] - """A human-readable description of the agent. - - Used to assist users and other agents in understanding what the agent can do. - (e.g. "Agent that helps users with recipes and cooking.") - """ - - # TODO(Marcelo): The spec makes url required. - url: NotRequired[str] - """A URL to the address the agent is hosted at.""" - - provider: NotRequired[Provider] - """The service provider of the agent.""" - - # TODO(Marcelo): The spec makes version required. - version: NotRequired[str] - """The version of the agent - format is up to the provider. (e.g. "1.0.0")""" - - documentation_url: NotRequired[str] - """A URL to documentation for the agent.""" - - capabilities: Capabilities - """The capabilities of the agent.""" - - authentication: Authentication - """The authentication schemes supported by the agent. - - Intended to match OpenAPI authentication structure. - """ - - default_input_modes: list[str] - """Supported mime types for input data.""" - - default_output_modes: list[str] - """Supported mime types for output data.""" - - skills: list[Skill] - - -agent_card_ta = pydantic.TypeAdapter(AgentCard) - - -class Provider(TypedDict): - """The service provider of the agent.""" - - organization: str - url: str - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Capabilities(TypedDict): - """The capabilities of the agent.""" - - streaming: NotRequired[bool] - """Whether the agent supports streaming.""" - - push_notifications: NotRequired[bool] - """Whether the agent can notify updates to client.""" - - state_transition_history: NotRequired[bool] - """Whether the agent exposes status change history for tasks.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Authentication(TypedDict): - """The authentication schemes supported by the agent.""" - - schemes: list[str] - """The authentication schemes supported by the agent. (e.g. "Basic", "Bearer")""" - - credentials: NotRequired[str] - """The credentials a client should use for private cards.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Skill(TypedDict): - """Skills are a unit of capability that an agent can perform.""" - - id: str - """A unique identifier for the skill.""" - - name: str - """Human readable name of the skill.""" - - description: str - """A human-readable description of the skill. - - It will be used by the client or a human as a hint to understand the skill. - """ - - tags: list[str] - """Set of tag-words describing classes of capabilities for this specific skill. - - Examples: "cooking", "customer support", "billing". - """ - - examples: NotRequired[list[str]] - """The set of example scenarios that the skill can perform. - - Will be used by the client as a hint to understand how the skill can be used. (e.g. "I need a recipe for bread") - """ - - input_modes: list[str] - """Supported mime types for input data.""" - - output_modes: list[str] - """Supported mime types for output data.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Artifact(TypedDict): - """Agents generate Artifacts as an end result of a Task. - - Artifacts are immutable, can be named, and can have multiple parts. A streaming response can append parts to - existing Artifacts. - - A single Task can generate many Artifacts. For example, "create a webpage" could create separate HTML and image - Artifacts. - """ - - name: NotRequired[str] - """The name of the artifact.""" - - description: NotRequired[str] - """A description of the artifact.""" - - parts: list[Part] - """The parts that make up the artifact.""" - - metadata: NotRequired[dict[str, Any]] - """Metadata about the artifact.""" - - index: int - """The index of the artifact.""" - - append: NotRequired[bool] - """Whether to append this artifact to an existing one.""" - - last_chunk: NotRequired[bool] - """Whether this is the last chunk of the artifact.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class PushNotificationConfig(TypedDict): - """Configuration for push notifications. - - A2A supports a secure notification mechanism whereby an agent can notify a client of an update - outside of a connected session via a PushNotificationService. Within and across enterprises, - it is critical that the agent verifies the identity of the notification service, authenticates - itself with the service, and presents an identifier that ties the notification to the executing - Task. - - The target server of the PushNotificationService should be considered a separate service, and - is not guaranteed (or even expected) to be the client directly. This PushNotificationService is - responsible for authenticating and authorizing the agent and for proxying the verified notification - to the appropriate endpoint (which could be anything from a pub/sub queue, to an email inbox or - other service, etc). - - For contrived scenarios with isolated client-agent pairs (e.g. local service mesh in a contained - VPC, etc.) or isolated environments without enterprise security concerns, the client may choose to - simply open a port and act as its own PushNotificationService. Any enterprise implementation will - likely have a centralized service that authenticates the remote agents with trusted notification - credentials and can handle online/offline scenarios. (This should be thought of similarly to a - mobile Push Notification Service). - """ - - url: str - """The URL to send push notifications to.""" - - token: NotRequired[str] - """Token unique to this task/session.""" - - authentication: NotRequired[Authentication] - """Authentication details for push notifications.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskPushNotificationConfig(TypedDict): - """Configuration for task push notifications.""" - - id: str - """The task id.""" - - push_notification_config: PushNotificationConfig - """The push notification configuration.""" - - -class Message(TypedDict): - """A Message contains any content that is not an Artifact. - - This can include things like agent thoughts, user context, instructions, errors, status, or metadata. - - All content from a client comes in the form of a Message. Agents send Messages to communicate status or to provide - instructions (whereas generated results are sent as Artifacts). - - A Message can have multiple parts to denote different pieces of content. For example, a user request could include - a textual description from a user and then multiple files used as context from the client. - """ - - role: Literal['user', 'agent'] - """The role of the message.""" - - parts: list[Part] - """The parts of the message.""" - - metadata: NotRequired[dict[str, Any]] - """Metadata about the message.""" - - -class _BasePart(TypedDict): - """A base class for all parts.""" - - metadata: NotRequired[dict[str, Any]] - - -class TextPart(_BasePart): - """A part that contains text.""" - - type: Literal['text'] - """The type of the part.""" - - text: str - """The text of the part.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class FilePart(_BasePart): - """A part that contains a file.""" - - type: Literal['file'] - """The type of the part.""" - - file: File - """The file of the part.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class _BaseFile(_BasePart): - """A base class for all file types.""" - - name: NotRequired[str] - """The name of the file.""" - - mime_type: str - """The mime type of the file.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class _BinaryFile(_BaseFile): - """A binary file.""" - - data: str - """The base64 encoded bytes of the file.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class _URLFile(_BaseFile): - """A file that is hosted on a remote URL.""" - - url: str - """The URL of the file.""" - - -File: TypeAlias = Union[_BinaryFile, _URLFile] -"""A file is a binary file or a URL file.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class DataPart(_BasePart): - """A part that contains data.""" - - type: Literal['data'] - """The type of the part.""" - - data: dict[str, Any] - """The data of the part.""" - - -Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='type')] -"""A fully formed piece of content exchanged between a client and a remote agent as part of a Message or an Artifact. - -Each Part has its own content type and metadata. -""" - -TaskState: TypeAlias = Literal['submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'unknown'] -"""The possible states of a task.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskStatus(TypedDict): - """Status and accompanying message for a task.""" - - state: TaskState - """The current state of the task.""" - - message: NotRequired[Message] - """Additional status updates for client.""" - - timestamp: NotRequired[str] - """ISO datetime value of when the status was updated.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class Task(TypedDict): - """A Task is a stateful entity that allows Clients and Remote Agents to achieve a specific outcome. - - Clients and Remote Agents exchange Messages within a Task. Remote Agents generate results as Artifacts. - A Task is always created by a Client and the status is always determined by the Remote Agent. - """ - - id: str - """Unique identifier for the task.""" - - session_id: NotRequired[str] - """Client-generated id for the session holding the task.""" - - status: TaskStatus - """Current status of the task.""" - - history: NotRequired[list[Message]] - """Optional history of messages.""" - - artifacts: NotRequired[list[Artifact]] - """Collection of artifacts created by the agent.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskStatusUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" - - id: str - """The id of the task.""" - - status: TaskStatus - """The status of the task.""" - - final: bool - """Indicates the end of the event stream.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskArtifactUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" - - id: str - """The id of the task.""" - - artifact: Artifact - """The artifact that was updated.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskIdParams(TypedDict): - """Parameters for a task id.""" - - id: str - metadata: NotRequired[dict[str, Any]] - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskQueryParams(TaskIdParams): - """Query parameters for a task.""" - - history_length: NotRequired[int] - """Number of recent messages to be retrieved.""" - - -@pydantic.with_config(config={'alias_generator': to_camel}) -class TaskSendParams(TypedDict): - """Sent by the client to the agent to create, continue, or restart a task.""" - - id: str - """The id of the task.""" - - session_id: NotRequired[str] - """The server creates a new sessionId for new tasks if not set.""" - - message: Message - """The message to send to the agent.""" - - history_length: NotRequired[int] - """Number of recent messages to be retrieved.""" - - push_notification: NotRequired[PushNotificationConfig] - """Where the server should send notifications when disconnected.""" - - metadata: NotRequired[dict[str, Any]] - """Extension metadata.""" - - -class JSONRPCMessage(TypedDict): - """A JSON RPC message.""" - - jsonrpc: Literal['2.0'] - """The JSON RPC version.""" - - id: int | str | None - """The request id.""" - - -Method = TypeVar('Method') -Params = TypeVar('Params') - - -class JSONRPCRequest(JSONRPCMessage, Generic[Method, Params]): - """A JSON RPC request.""" - - method: Method - """The method to call.""" - - params: Params - """The parameters to pass to the method.""" - - -############################################################################################### -####################################### Error codes ####################################### -############################################################################################### - -CodeT = TypeVar('CodeT', bound=int) -MessageT = TypeVar('MessageT', bound=str) - - -class JSONRPCError(TypedDict, Generic[CodeT, MessageT]): - """A JSON RPC error.""" - - code: CodeT - message: MessageT - data: NotRequired[Any] - - -ResultT = TypeVar('ResultT') -ErrorT = TypeVar('ErrorT', bound=JSONRPCError[Any, Any]) - - -class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): - """A JSON RPC response.""" - - result: NotRequired[ResultT] - error: NotRequired[ErrorT] - - -JSONParseError = JSONRPCError[Literal[-32700], Literal['Invalid JSON payload']] -"""A JSON RPC error for a parse error.""" - -InvalidRequestError = JSONRPCError[Literal[-32600], Literal['Request payload validation error']] -"""A JSON RPC error for an invalid request.""" - -MethodNotFoundError = JSONRPCError[Literal[-32601], Literal['Method not found']] -"""A JSON RPC error for a method not found.""" - -InvalidParamsError = JSONRPCError[Literal[-32602], Literal['Invalid parameters']] -"""A JSON RPC error for invalid parameters.""" - -InternalError = JSONRPCError[Literal[-32603], Literal['Internal error']] -"""A JSON RPC error for an internal error.""" - -TaskNotFoundError = JSONRPCError[Literal[-32001], Literal['Task not found']] -"""A JSON RPC error for a task not found.""" - -TaskNotCancelableError = JSONRPCError[Literal[-32002], Literal['Task not cancelable']] -"""A JSON RPC error for a task not cancelable.""" - -PushNotificationNotSupportedError = JSONRPCError[Literal[-32003], Literal['Push notification not supported']] -"""A JSON RPC error for a push notification not supported.""" - -UnsupportedOperationError = JSONRPCError[Literal[-32004], Literal['This operation is not supported']] -"""A JSON RPC error for an unsupported operation.""" - -ContentTypeNotSupportedError = JSONRPCError[Literal[-32005], Literal['Incompatible content types']] -"""A JSON RPC error for incompatible content types.""" - -############################################################################################### -####################################### Requests and responses ############################ -############################################################################################### - -SendTaskRequest = JSONRPCRequest[Literal['tasks/send'], TaskSendParams] -"""A JSON RPC request to send a task.""" - -SendTaskResponse = JSONRPCResponse[Task, JSONRPCError[Any, Any]] -"""A JSON RPC response to send a task.""" - -SendTaskStreamingRequest = JSONRPCRequest[Literal['tasks/sendSubscribe'], TaskSendParams] -"""A JSON RPC request to send a task and receive updates.""" - -SendTaskStreamingResponse = JSONRPCResponse[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent], InternalError] -"""A JSON RPC response to send a task and receive updates.""" - -GetTaskRequest = JSONRPCRequest[Literal['tasks/get'], TaskQueryParams] -"""A JSON RPC request to get a task.""" - -GetTaskResponse = JSONRPCResponse[Task, TaskNotFoundError] -"""A JSON RPC response to get a task.""" - -CancelTaskRequest = JSONRPCRequest[Literal['tasks/cancel'], TaskIdParams] -"""A JSON RPC request to cancel a task.""" - -CancelTaskResponse = JSONRPCResponse[Task, Union[TaskNotCancelableError, TaskNotFoundError]] -"""A JSON RPC response to cancel a task.""" - -SetTaskPushNotificationRequest = JSONRPCRequest[Literal['tasks/pushNotification/set'], TaskPushNotificationConfig] -"""A JSON RPC request to set a task push notification.""" - -SetTaskPushNotificationResponse = JSONRPCResponse[TaskPushNotificationConfig, PushNotificationNotSupportedError] -"""A JSON RPC response to set a task push notification.""" - -GetTaskPushNotificationRequest = JSONRPCRequest[Literal['tasks/pushNotification/get'], TaskIdParams] -"""A JSON RPC request to get a task push notification.""" - -GetTaskPushNotificationResponse = JSONRPCResponse[TaskPushNotificationConfig, PushNotificationNotSupportedError] -"""A JSON RPC response to get a task push notification.""" +from typing import Union + +from a2a.types import ( + A2ARequest as _A2ARequest, + A2AResponse as _A2AResponse, + AgentCard, + AgentProvider as Provider, + AgentSkill as Skill, + Artifact, + AuthenticationInfo as Authentication, + CancelTaskRequest, + CancelTaskResponse, + Capabilities, + ContentTypeNotSupportedError, + GetTaskPushNotificationConfigRequest as GetTaskPushNotificationRequest, + GetTaskPushNotificationConfigResponse as GetTaskPushNotificationResponse, + GetTaskRequest, + GetTaskResponse, + InternalError, + JSONRPCError, + Message, + MessageSendParams as TaskSendParams, + Part, + PushNotificationConfig, + PushNotificationNotSupportedError, + SendMessageRequest as SendTaskRequest, + SendMessageResponse as SendTaskResponse, + SendStreamingMessageRequest as SendTaskStreamingRequest, + SendStreamingMessageResponse as SendTaskStreamingResponse, + SetTaskPushNotificationConfigRequest as SetTaskPushNotificationRequest, + SetTaskPushNotificationConfigResponse as SetTaskPushNotificationResponse, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskNotCancelableError, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskQueryParams, + TaskResubscriptionRequest as ResubscribeTaskRequest, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, + UnsupportedOperationError, +) +from pydantic import TypeAdapter + +__all__ = [ + "AgentCard", + "Provider", + "Skill", + "Artifact", + "Authentication", + "Capabilities", + "Message", + "TaskSendParams", + "Part", + "PushNotificationConfig", + "Task", + "TaskIdParams", + "TaskQueryParams", + "TextPart", + "CancelTaskRequest", + "CancelTaskResponse", + "GetTaskPushNotificationRequest", + "GetTaskPushNotificationResponse", + "GetTaskRequest", + "GetTaskResponse", + "SendTaskStreamingRequest", + "SendTaskStreamingResponse", + "SendTaskRequest", + "SendTaskResponse", + "SetTaskPushNotificationRequest", + "SetTaskPushNotificationResponse", + "ResubscribeTaskRequest", + "TaskState", + "TaskStatus", + "TaskArtifactUpdateEvent", + "TaskPushNotificationConfig", + "TaskStatusUpdateEvent", + "JSONRPCError", + "TaskNotFoundError", + "TaskNotCancelableError", + "PushNotificationNotSupportedError", + "UnsupportedOperationError", + "ContentTypeNotSupportedError", + "InternalError", + "a2a_request_ta", + "a2a_response_ta", + "A2ARequest", + "A2AResponse", +] -ResubscribeTaskRequest = JSONRPCRequest[Literal['tasks/resubscribe'], TaskIdParams] -"""A JSON RPC request to resubscribe to a task.""" -A2ARequest = Annotated[ - Union[ - SendTaskRequest, - GetTaskRequest, - CancelTaskRequest, - SetTaskPushNotificationRequest, - GetTaskPushNotificationRequest, - ResubscribeTaskRequest, - ], - Discriminator('method'), -] +A2ARequest = _A2ARequest """A JSON RPC request to the A2A server.""" -A2AResponse: TypeAlias = Union[ +A2AResponse = Union[ SendTaskResponse, GetTaskResponse, CancelTaskResponse, @@ -562,6 +108,5 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): ] """A JSON RPC response from the A2A server.""" - a2a_request_ta: TypeAdapter[A2ARequest] = TypeAdapter(A2ARequest) a2a_response_ta: TypeAdapter[A2AResponse] = TypeAdapter(A2AResponse) diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index c06bc1cb7..7ffb3d63a 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -2,90 +2,47 @@ from __future__ import annotations as _annotations +import asyncio from abc import ABC, abstractmethod -from datetime import datetime -from .schema import Artifact, Message, Task, TaskState, TaskStatus +from a2a.types import Task class Storage(ABC): - """A storage to retrieve and save tasks. - - The storage is used to update the status of a task and to save the result of a task. - """ + """A storage to retrieve and save tasks.""" @abstractmethod - async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: - """Load a task from storage. - - If the task is not found, return None. - """ + async def get(self, task_id: str) -> Task | None: + """Retrieves a task from the store by its ID.""" @abstractmethod - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: - """Submit a task to storage.""" + async def save(self, task: Task) -> None: + """Saves or updates a task in the store.""" @abstractmethod - async def update_task( - self, - task_id: str, - state: TaskState, - message: Message | None = None, - artifacts: list[Artifact] | None = None, - ) -> Task: - """Update the state of a task.""" + async def delete(self, task_id: str) -> None: + """Deletes a task from the store by its ID.""" class InMemoryStorage(Storage): """A storage to retrieve and save tasks in memory.""" - def __init__(self): + def __init__(self) -> None: self.tasks: dict[str, Task] = {} - - async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: - """Load a task from memory. - - Args: - task_id: The id of the task to load. - history_length: The number of messages to return in the history. - - Returns: - The task. - """ - if task_id not in self.tasks: - return None - - task = self.tasks[task_id] - if history_length and 'history' in task: - task['history'] = task['history'][-history_length:] - return task - - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: - """Submit a task to storage.""" - if task_id in self.tasks: - raise ValueError(f'Task {task_id} already exists') - - task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) - task = Task(id=task_id, session_id=session_id, status=task_status, history=[message]) - self.tasks[task_id] = task - return task - - async def update_task( - self, - task_id: str, - state: TaskState, - message: Message | None = None, - artifacts: list[Artifact] | None = None, - ) -> Task: - """Save the task as "working".""" - task = self.tasks[task_id] - task['status'] = TaskStatus(state=state, timestamp=datetime.now().isoformat()) - if message: - if 'history' not in task: - task['history'] = [] - task['history'].append(message) - if artifacts: - if 'artifacts' not in task: - task['artifacts'] = [] - task['artifacts'].extend(artifacts) - return task + self.lock = asyncio.Lock() + + async def get(self, task_id: str) -> Task | None: + """Retrieves a task from memory.""" + async with self.lock: + return self.tasks.get(task_id) + + async def save(self, task: Task) -> None: + """Saves or updates a task in memory.""" + async with self.lock: + self.tasks[task.id] = task + + async def delete(self, task_id: str) -> None: + """Deletes a task from memory.""" + async with self.lock: + if task_id in self.tasks: + del self.tasks[task_id] diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py deleted file mode 100644 index 0baaeba04..000000000 --- a/fasta2a/fasta2a/task_manager.py +++ /dev/null @@ -1,169 +0,0 @@ -"""This module defines the TaskManager class, which is responsible for managing tasks. - -In our structure, we have the following components: - -- TaskManager: A class that manages tasks. -- Scheduler: A class that schedules tasks to be sent to the worker. -- Worker: A class that executes tasks. -- Runner: A class that defines how tasks run and how history is structured. -- Storage: A class that stores tasks and artifacts. - -Architecture: -``` - +-----------------+ - | HTTP Server | - +-------+---------+ - ^ - | Sends Requests/ - | Receives Results - v - +-------+---------+ - | | - | TaskManager |<-----------------+ - | (coordinates) | | - | | | - +-------+---------+ | - | | - | Schedules Tasks | - v v - +------------------+ +----------------+ - | | | | - | Broker | . | Storage | - | (queues) . | | (persistence) | - | | | | - +------------------+ +----------------+ - ^ ^ - | | - | Delegates Execution | - v | - +------------------+ | - | | | - | Worker | | - | (implementation) |-----------------+ - | | - +------------------+ -``` - -The flow: -1. The HTTP server sends a task to TaskManager -2. TaskManager stores initial task state in Storage -3. TaskManager passes task to Scheduler -4. Scheduler determines when to send tasks to Worker -5. Worker delegates to Runner for task execution -6. Runner defines how tasks run and how history is structured -7. Worker processes task results from Runner -8. Worker reads from and writes to Storage directly -9. Worker updates task status in Storage as execution progresses -10. TaskManager can also read/write from Storage for task management -11. Client queries TaskManager for results, which reads from Storage -""" - -from __future__ import annotations as _annotations - -import uuid -from contextlib import AsyncExitStack -from dataclasses import dataclass, field -from typing import Any - -from .broker import Broker -from .schema import ( - CancelTaskRequest, - CancelTaskResponse, - GetTaskPushNotificationRequest, - GetTaskPushNotificationResponse, - GetTaskRequest, - GetTaskResponse, - ResubscribeTaskRequest, - SendTaskRequest, - SendTaskResponse, - SendTaskStreamingRequest, - SendTaskStreamingResponse, - SetTaskPushNotificationRequest, - SetTaskPushNotificationResponse, - TaskNotFoundError, -) -from .storage import Storage - - -@dataclass -class TaskManager: - """A task manager responsible for managing tasks.""" - - broker: Broker - storage: Storage - - _aexit_stack: AsyncExitStack | None = field(default=None, init=False) - - async def __aenter__(self): - self._aexit_stack = AsyncExitStack() - await self._aexit_stack.__aenter__() - await self._aexit_stack.enter_async_context(self.broker) - - return self - - @property - def is_running(self) -> bool: - return self._aexit_stack is not None - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): - if self._aexit_stack is None: - raise RuntimeError('TaskManager was not properly initialized.') - await self._aexit_stack.__aexit__(exc_type, exc_value, traceback) - self._aexit_stack = None - - async def send_task(self, request: SendTaskRequest) -> SendTaskResponse: - """Send a task to the worker.""" - request_id = str(uuid.uuid4()) - task_id = request['params']['id'] - task = await self.storage.load_task(task_id) - - if task is None: - session_id = request['params'].get('session_id', str(uuid.uuid4())) - message = request['params']['message'] - task = await self.storage.submit_task(task_id, session_id, message) - - await self.broker.run_task(request['params']) - return SendTaskResponse(jsonrpc='2.0', id=request_id, result=task) - - async def get_task(self, request: GetTaskRequest) -> GetTaskResponse: - """Get a task, and return it to the client. - - No further actions are needed here. - """ - task_id = request['params']['id'] - history_length = request['params'].get('history_length') - task = await self.storage.load_task(task_id, history_length) - if task is None: - return GetTaskResponse( - jsonrpc='2.0', - id=request['id'], - error=TaskNotFoundError(code=-32001, message='Task not found'), - ) - return GetTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - - async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: - await self.broker.cancel_task(request['params']) - task = await self.storage.load_task(request['params']['id']) - if task is None: - return CancelTaskResponse( - jsonrpc='2.0', - id=request['id'], - error=TaskNotFoundError(code=-32001, message='Task not found'), - ) - return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - - async def send_task_streaming(self, request: SendTaskStreamingRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('SendTaskStreaming is not implemented yet.') - - async def set_task_push_notification( - self, request: SetTaskPushNotificationRequest - ) -> SetTaskPushNotificationResponse: - raise NotImplementedError('SetTaskPushNotification is not implemented yet.') - - async def get_task_push_notification( - self, request: GetTaskPushNotificationRequest - ) -> GetTaskPushNotificationResponse: - raise NotImplementedError('GetTaskPushNotification is not implemented yet.') - - async def resubscribe_task(self, request: ResubscribeTaskRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('Resubscribe is not implemented yet.') diff --git a/fasta2a/fasta2a/worker.py b/fasta2a/fasta2a/worker.py index 9bbde6b25..7bcc4cb45 100644 --- a/fasta2a/fasta2a/worker.py +++ b/fasta2a/fasta2a/worker.py @@ -1,65 +1,25 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass from typing import TYPE_CHECKING, Any -import anyio -from opentelemetry.trace import get_tracer, use_span -from typing_extensions import assert_never - if TYPE_CHECKING: - from .broker import Broker, TaskOperation + from a2a.server.tasks import TaskUpdater + from .schema import Artifact, Message, TaskIdParams, TaskSendParams from .storage import Storage -tracer = get_tracer(__name__) - -@dataclass class Worker(ABC): """A worker is responsible for executing tasks.""" - broker: Broker storage: Storage - @asynccontextmanager - async def run(self) -> AsyncIterator[None]: - """Run the worker. - - It connects to the broker, and it makes itself available to receive commands. - """ - async with anyio.create_task_group() as tg: - tg.start_soon(self._loop) - yield - tg.cancel_scope.cancel() - - async def _loop(self) -> None: - async for task_operation in self.broker.receive_task_operations(): - await self._handle_task_operation(task_operation) - - async def _handle_task_operation(self, task_operation: TaskOperation) -> None: - try: - with use_span(task_operation['_current_span']): - with tracer.start_as_current_span( - f'{task_operation["operation"]} task', attributes={'logfire.tags': ['fasta2a']} - ): - if task_operation['operation'] == 'run': - await self.run_task(task_operation['params']) - elif task_operation['operation'] == 'cancel': - await self.cancel_task(task_operation['params']) - else: - assert_never(task_operation) - except Exception: - await self.storage.update_task(task_operation['params']['id'], state='failed') - @abstractmethod - async def run_task(self, params: TaskSendParams) -> None: ... + async def run_task(self, params: TaskSendParams, updater: TaskUpdater) -> None: ... @abstractmethod - async def cancel_task(self, params: TaskIdParams) -> None: ... + async def cancel_task(self, params: TaskIdParams, updater: TaskUpdater) -> None: ... @abstractmethod def build_message_history(self, task_history: list[Message]) -> list[Any]: ... diff --git a/fasta2a/pyproject.toml b/fasta2a/pyproject.toml index 2abe809aa..651c341b5 100644 --- a/fasta2a/pyproject.toml +++ b/fasta2a/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "starlette>0.29.0", "pydantic>=2.10", "opentelemetry-api>=1.28.0", + "a2a-sdk>=0.2.8", ] [project.optional-dependencies]