diff --git a/docs/a2a.md b/docs/a2a.md index d1989d1fb..df5cd6efb 100644 --- a/docs/a2a.md +++ b/docs/a2a.md @@ -32,7 +32,7 @@ The library is designed to be used with any agentic framework, and is **not excl Given the nature of the A2A protocol, it's important to understand the design before using it, as a developer you'll need to provide some components: -- [`Storage`][fasta2a.Storage]: to save and load tasks +- [`Storage`][fasta2a.Storage]: to save and load tasks, as well as store context for conversations - [`Broker`][fasta2a.Broker]: to schedule tasks - [`Worker`][fasta2a.Worker]: to execute tasks @@ -55,6 +55,28 @@ flowchart TB FastA2A allows you to bring your own [`Storage`][fasta2a.Storage], [`Broker`][fasta2a.Broker] and [`Worker`][fasta2a.Worker]. +#### Understanding Tasks and Context + +In the A2A protocol: + +- **Task**: Represents one complete execution of an agent. When a client sends a message to the agent, a new task is created. The agent runs until completion (or failure), and this entire execution is considered one task. The final output is stored as a task artifact. + +- **Context**: Represents a conversation thread that can span multiple tasks. The A2A protocol uses a `context_id` to maintain conversation continuity: + - When a new message is sent without a `context_id`, the server generates a new one + - Subsequent messages can include the same `context_id` to continue the conversation + - All tasks sharing the same `context_id` have access to the complete message history + +#### Storage Architecture + +The [`Storage`][fasta2a.Storage] component serves two purposes: + +1. **Task Storage**: Stores tasks in A2A protocol format, including their status, artifacts, and message history +2. **Context Storage**: Stores conversation context in a format optimized for the specific agent implementation + +This design allows for agents to store rich internal state (e.g., tool calls, reasoning traces) as well as store task-specific A2A-formatted messages and artifacts. + +For example, a PydanticAI agent might store its complete internal message format (including tool calls and responses) in the context storage, while storing only the A2A-compliant messages in the task history. + ### Installation @@ -94,3 +116,12 @@ uvicorn agent_to_a2a:app --host 0.0.0.0 --port 8000 ``` Since the goal of `to_a2a` is to be a convenience method, it accepts the same arguments as the [`FastA2A`][fasta2a.FastA2A] constructor. + +When using `to_a2a()`, PydanticAI automatically: + +- Stores the complete conversation history (including tool calls and responses) in the context storage +- Ensures that subsequent messages with the same `context_id` have access to the full conversation history +- Persists agent results as A2A artifacts: + - String results become `TextPart` artifacts and also appear in the message history + - Structured data (Pydantic models, dataclasses, tuples, etc.) become `DataPart` artifacts with the data wrapped as `{"result": }` + - Artifacts include metadata with type information and JSON schema when available diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 61301262b..987d08ece 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -13,10 +13,9 @@ from .broker import Broker from .schema import ( + AgentCapabilities, AgentCard, - Authentication, - Capabilities, - Provider, + AgentProvider, Skill, a2a_request_ta, a2a_response_ta, @@ -39,7 +38,7 @@ def __init__( url: str = 'http://localhost:8000', version: str = '1.0.0', description: str | None = None, - provider: Provider | None = None, + provider: AgentProvider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, @@ -85,16 +84,17 @@ async def _agent_card_endpoint(self, request: Request) -> Response: if self._agent_card_json_schema is None: agent_card = AgentCard( name=self.name, + description=self.description or 'FastA2A Agent', url=self.url, version=self.version, + protocol_version='0.2.5', skills=self.skills, default_input_modes=self.default_input_modes, default_output_modes=self.default_output_modes, - capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False), - authentication=Authentication(schemes=[]), + capabilities=AgentCapabilities( + streaming=False, push_notifications=False, state_transition_history=False + ), ) - if self.description is not None: - agent_card['description'] = self.description if self.provider is not None: agent_card['provider'] = self.provider self._agent_card_json_schema = agent_card_ta.dump_json(agent_card, by_alias=True) @@ -116,8 +116,8 @@ async def _agent_run_endpoint(self, request: Request) -> Response: data = await request.body() a2a_request = a2a_request_ta.validate_json(data) - if a2a_request['method'] == 'tasks/send': - jsonrpc_response = await self.task_manager.send_task(a2a_request) + if a2a_request['method'] == 'message/send': + jsonrpc_response = await self.task_manager.send_message(a2a_request) elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) elif a2a_request['method'] == 'tasks/cancel': diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py index 5c5aabd81..cd8449923 100644 --- a/fasta2a/fasta2a/client.py +++ b/fasta2a/fasta2a/client.py @@ -9,14 +9,15 @@ GetTaskRequest, GetTaskResponse, Message, - PushNotificationConfig, - SendTaskRequest, - SendTaskResponse, - TaskSendParams, + MessageSendConfiguration, + MessageSendParams, + SendMessageRequest, + SendMessageResponse, a2a_request_ta, + send_message_request_ta, + send_message_response_ta, ) -send_task_response_ta = pydantic.TypeAdapter(SendTaskResponse) get_task_response_ta = pydantic.TypeAdapter(GetTaskResponse) try: @@ -37,26 +38,30 @@ def __init__(self, base_url: str = 'http://localhost:8000', http_client: httpx.A self.http_client = http_client self.http_client.base_url = base_url - async def send_task( + async def send_message( self, message: Message, - history_length: int | None = None, - push_notification: PushNotificationConfig | None = None, + *, metadata: dict[str, Any] | None = None, - ) -> SendTaskResponse: - task = TaskSendParams(message=message, id=str(uuid.uuid4())) - if history_length is not None: - task['history_length'] = history_length - if push_notification is not None: - task['push_notification'] = push_notification + configuration: MessageSendConfiguration | None = None, + ) -> SendMessageResponse: + """Send a message using the A2A protocol. + + Returns a JSON-RPC response containing either a result (Task) or an error. + """ + params = MessageSendParams(message=message) if metadata is not None: - task['metadata'] = metadata + params['metadata'] = metadata + if configuration is not None: + params['configuration'] = configuration - payload = SendTaskRequest(jsonrpc='2.0', id=None, method='tasks/send', params=task) - content = a2a_request_ta.dump_json(payload, by_alias=True) + request_id = str(uuid.uuid4()) + payload = SendMessageRequest(jsonrpc='2.0', id=request_id, method='message/send', params=params) + content = send_message_request_ta.dump_json(payload, by_alias=True) response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) self._raise_for_status(response) - return send_task_response_ta.validate_json(response.content) + + return send_message_response_ta.validate_json(response.content) async def get_task(self, task_id: str) -> GetTaskResponse: payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index cab5d2057..80adabcea 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -17,35 +17,45 @@ class AgentCard(TypedDict): name: str """Human readable name of the agent e.g. "Recipe Agent".""" - description: NotRequired[str] + description: str """A human-readable description of the agent. Used to assist users and other agents in understanding what the agent can do. (e.g. "Agent that helps users with recipes and cooking.") """ - # TODO(Marcelo): The spec makes url required. - url: NotRequired[str] + url: str """A URL to the address the agent is hosted at.""" - provider: NotRequired[Provider] - """The service provider of the agent.""" - - # TODO(Marcelo): The spec makes version required. - version: NotRequired[str] + version: str """The version of the agent - format is up to the provider. (e.g. "1.0.0")""" + protocol_version: str + """The version of the A2A protocol this agent supports.""" + + provider: NotRequired[AgentProvider] + """The service provider of the agent.""" + documentation_url: NotRequired[str] """A URL to documentation for the agent.""" - capabilities: Capabilities + icon_url: NotRequired[str] + """A URL to an icon for the agent.""" + + preferred_transport: NotRequired[str] + """The transport of the preferred endpoint. If empty, defaults to JSONRPC.""" + + additional_interfaces: NotRequired[list[AgentInterface]] + """Announcement of additional supported transports.""" + + capabilities: AgentCapabilities """The capabilities of the agent.""" - authentication: Authentication - """The authentication schemes supported by the agent. + security: NotRequired[list[dict[str, list[str]]]] + """Security requirements for contacting the agent.""" - Intended to match OpenAPI authentication structure. - """ + security_schemes: NotRequired[dict[str, SecurityScheme]] + """Security scheme definitions.""" default_input_modes: list[str] """Supported mime types for input data.""" @@ -59,7 +69,7 @@ class AgentCard(TypedDict): agent_card_ta = pydantic.TypeAdapter(AgentCard) -class Provider(TypedDict): +class AgentProvider(TypedDict): """The service provider of the agent.""" organization: str @@ -67,7 +77,7 @@ class Provider(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) -class Capabilities(TypedDict): +class AgentCapabilities(TypedDict): """The capabilities of the agent.""" streaming: NotRequired[bool] @@ -81,14 +91,89 @@ class Capabilities(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) -class Authentication(TypedDict): - """The authentication schemes supported by the agent.""" +class HttpSecurityScheme(TypedDict): + """HTTP security scheme.""" + + type: Literal['http'] + scheme: str + """The name of the HTTP Authorization scheme.""" + bearer_format: NotRequired[str] + """A hint to the client to identify how the bearer token is formatted.""" + description: NotRequired[str] + """Description of this security scheme.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class ApiKeySecurityScheme(TypedDict): + """API Key security scheme.""" + + type: Literal['apiKey'] + name: str + """The name of the header, query or cookie parameter to be used.""" + in_: Literal['query', 'header', 'cookie'] + """The location of the API key.""" + description: NotRequired[str] + """Description of this security scheme.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class OAuth2SecurityScheme(TypedDict): + """OAuth2 security scheme.""" + + type: Literal['oauth2'] + flows: dict[str, Any] + """An object containing configuration information for the flow types supported.""" + description: NotRequired[str] + """Description of this security scheme.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class OpenIdConnectSecurityScheme(TypedDict): + """OpenID Connect security scheme.""" + + type: Literal['openIdConnect'] + open_id_connect_url: str + """OpenId Connect URL to discover OAuth2 configuration values.""" + description: NotRequired[str] + """Description of this security scheme.""" + + +SecurityScheme = Annotated[ + Union[HttpSecurityScheme, ApiKeySecurityScheme, OAuth2SecurityScheme, OpenIdConnectSecurityScheme], + pydantic.Field(discriminator='type'), +] +"""A security scheme for authentication.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class AgentInterface(TypedDict): + """An interface that the agent supports.""" + + transport: str + """The transport protocol (e.g., 'jsonrpc', 'websocket').""" + + url: str + """The URL endpoint for this transport.""" + + description: NotRequired[str] + """Description of this interface.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class AgentExtension(TypedDict): + """A declaration of an extension supported by an Agent.""" - schemes: list[str] - """The authentication schemes supported by the agent. (e.g. "Basic", "Bearer")""" + uri: str + """The URI of the extension.""" - credentials: NotRequired[str] - """The credentials a client should use for private cards.""" + description: NotRequired[str] + """A description of how this agent uses this extension.""" + + required: NotRequired[bool] + """Whether the client must follow specific requirements of the extension.""" + + params: NotRequired[dict[str, Any]] + """Optional configuration for the extension.""" @pydantic.with_config({'alias_generator': to_camel}) @@ -137,6 +222,9 @@ class Artifact(TypedDict): Artifacts. """ + artifact_id: str + """Unique identifier for the artifact.""" + name: NotRequired[str] """The name of the artifact.""" @@ -149,8 +237,8 @@ class Artifact(TypedDict): metadata: NotRequired[dict[str, Any]] """Metadata about the artifact.""" - index: int - """The index of the artifact.""" + extensions: NotRequired[list[str]] + """Array of extensions.""" append: NotRequired[bool] """Whether to append this artifact to an existing one.""" @@ -183,13 +271,16 @@ class PushNotificationConfig(TypedDict): mobile Push Notification Service). """ + id: NotRequired[str] + """Server-assigned identifier.""" + url: str """The URL to send push notifications to.""" token: NotRequired[str] """Token unique to this task/session.""" - authentication: NotRequired[Authentication] + authentication: NotRequired[SecurityScheme] """Authentication details for push notifications.""" @@ -204,6 +295,7 @@ class TaskPushNotificationConfig(TypedDict): """The push notification configuration.""" +@pydantic.with_config({'alias_generator': to_camel}) class Message(TypedDict): """A Message contains any content that is not an Artifact. @@ -222,9 +314,28 @@ class Message(TypedDict): parts: list[Part] """The parts of the message.""" + kind: Literal['message'] + """Event type.""" + metadata: NotRequired[dict[str, Any]] """Metadata about the message.""" + # Additional fields + message_id: str + """Identifier created by the message creator.""" + + context_id: NotRequired[str] + """The context the message is associated with.""" + + task_id: NotRequired[str] + """Identifier of task the message is related to.""" + + reference_task_ids: NotRequired[list[str]] + """Array of task IDs this message references.""" + + extensions: NotRequired[list[str]] + """Array of extensions.""" + class _BasePart(TypedDict): """A base class for all parts.""" @@ -232,76 +343,70 @@ class _BasePart(TypedDict): metadata: NotRequired[dict[str, Any]] +@pydantic.with_config({'alias_generator': to_camel}) class TextPart(_BasePart): """A part that contains text.""" - type: Literal['text'] - """The type of the part.""" + kind: Literal['text'] + """The kind of the part.""" text: str """The text of the part.""" @pydantic.with_config({'alias_generator': to_camel}) -class FilePart(_BasePart): - """A part that contains a file.""" +class FileWithBytes(TypedDict): + """File with base64 encoded data.""" - type: Literal['file'] - """The type of the part.""" + bytes: str + """The base64 encoded content of the file.""" - file: File - """The file of the part.""" + mime_type: NotRequired[str] + """Optional mime type for the file.""" @pydantic.with_config({'alias_generator': to_camel}) -class _BaseFile(_BasePart): - """A base class for all file types.""" +class FileWithUri(TypedDict): + """File with URI reference.""" - name: NotRequired[str] - """The name of the file.""" + uri: str + """The URI of the file.""" - mime_type: str + mime_type: NotRequired[str] """The mime type of the file.""" @pydantic.with_config({'alias_generator': to_camel}) -class _BinaryFile(_BaseFile): - """A binary file.""" - - data: str - """The base64 encoded bytes of the file.""" - - -@pydantic.with_config({'alias_generator': to_camel}) -class _URLFile(_BaseFile): - """A file that is hosted on a remote URL.""" - - url: str - """The URL of the file.""" +class FilePart(_BasePart): + """A part that contains a file.""" + kind: Literal['file'] + """The kind of the part.""" -File: TypeAlias = Union[_BinaryFile, _URLFile] -"""A file is a binary file or a URL file.""" + file: FileWithBytes | FileWithUri + """The file content - either bytes or URI.""" @pydantic.with_config({'alias_generator': to_camel}) class DataPart(_BasePart): - """A part that contains data.""" + """A part that contains structured data.""" - type: Literal['data'] - """The type of the part.""" + kind: Literal['data'] + """The kind of the part.""" data: dict[str, Any] """The data of the part.""" -Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='type')] +Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='kind')] """A fully formed piece of content exchanged between a client and a remote agent as part of a Message or an Artifact. Each Part has its own content type and metadata. """ -TaskState: TypeAlias = Literal['submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'unknown'] +TaskState: TypeAlias = Literal[ + 'submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'rejected', 'auth-required', 'unknown' +] """The possible states of a task.""" @@ -330,8 +435,11 @@ class Task(TypedDict): id: str """Unique identifier for the task.""" - session_id: NotRequired[str] - """Client-generated id for the session holding the task.""" + context_id: str + """The context the task is associated with.""" + + kind: Literal['task'] + """Event type.""" status: TaskStatus """Current status of the task.""" @@ -348,11 +456,17 @@ class Task(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) class TaskStatusUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" + """Sent by server during message/stream requests.""" - id: str + task_id: str """The id of the task.""" + context_id: str + """The context the task is associated with.""" + + kind: Literal['status-update'] + """Event type.""" + status: TaskStatus """The status of the task.""" @@ -365,14 +479,26 @@ class TaskStatusUpdateEvent(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) class TaskArtifactUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" + """Sent by server during message/stream requests.""" - id: str + task_id: str """The id of the task.""" + context_id: str + """The context the task is associated with.""" + + kind: Literal['artifact-update'] + """Event type identification.""" + artifact: Artifact """The artifact that was updated.""" + append: NotRequired[bool] + """Whether to append to existing artifact (true) or replace (false).""" + + last_chunk: NotRequired[bool] + """Indicates this is the final chunk of the artifact.""" + metadata: NotRequired[dict[str, Any]] """Extension metadata.""" @@ -393,24 +519,81 @@ class TaskQueryParams(TaskIdParams): """Number of recent messages to be retrieved.""" +@pydantic.with_config({'alias_generator': to_camel}) +class MessageSendConfiguration(TypedDict): + """Configuration for the send message request.""" + + accepted_output_modes: list[str] + """Accepted output modalities by the client.""" + + blocking: NotRequired[bool] + """If the server should treat the client as a blocking request.""" + + history_length: NotRequired[int] + """Number of recent messages to be retrieved.""" + + push_notification_config: NotRequired[PushNotificationConfig] + """Where the server should send notifications when disconnected.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class MessageSendParams(TypedDict): + """Parameters for message/send method.""" + + configuration: NotRequired[MessageSendConfiguration] + """Send message configuration.""" + + message: Message + """The message being sent to the server.""" + + metadata: NotRequired[dict[str, Any]] + """Extension metadata.""" + + @pydantic.with_config({'alias_generator': to_camel}) class TaskSendParams(TypedDict): - """Sent by the client to the agent to create, continue, or restart a task.""" + """Internal parameters for task execution within the framework. + + Note: This is not part of the A2A protocol - it's used internally + for broker/worker communication. + """ id: str """The id of the task.""" - session_id: NotRequired[str] - """The server creates a new sessionId for new tasks if not set.""" + context_id: str + """The context id for the task.""" message: Message - """The message to send to the agent.""" + """The message to process.""" history_length: NotRequired[int] """Number of recent messages to be retrieved.""" - push_notification: NotRequired[PushNotificationConfig] - """Where the server should send notifications when disconnected.""" + metadata: NotRequired[dict[str, Any]] + """Extension metadata.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class ListTaskPushNotificationConfigParams(TypedDict): + """Parameters for getting list of pushNotificationConfigurations associated with a Task.""" + + id: str + """Task id.""" + + metadata: NotRequired[dict[str, Any]] + """Extension metadata.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class DeleteTaskPushNotificationConfigParams(TypedDict): + """Parameters for removing pushNotificationConfiguration associated with a Task.""" + + id: str + """Task id.""" + + push_notification_config_id: str + """The push notification config id to delete.""" metadata: NotRequired[dict[str, Any]] """Extension metadata.""" @@ -497,21 +680,21 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): ContentTypeNotSupportedError = JSONRPCError[Literal[-32005], Literal['Incompatible content types']] """A JSON RPC error for incompatible content types.""" +InvalidAgentResponseError = JSONRPCError[Literal[-32006], Literal['Invalid agent response']] +"""A JSON RPC error for invalid agent response.""" + ############################################################################################### ####################################### Requests and responses ############################ ############################################################################################### -SendTaskRequest = JSONRPCRequest[Literal['tasks/send'], TaskSendParams] -"""A JSON RPC request to send a task.""" +SendMessageRequest = JSONRPCRequest[Literal['message/send'], MessageSendParams] +"""A JSON RPC request to send a message.""" -SendTaskResponse = JSONRPCResponse[Task, JSONRPCError[Any, Any]] -"""A JSON RPC response to send a task.""" +SendMessageResponse = JSONRPCResponse[Union[Task, Message], JSONRPCError[Any, Any]] +"""A JSON RPC response to send a message.""" -SendTaskStreamingRequest = JSONRPCRequest[Literal['tasks/sendSubscribe'], TaskSendParams] -"""A JSON RPC request to send a task and receive updates.""" - -SendTaskStreamingResponse = JSONRPCResponse[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent], InternalError] -"""A JSON RPC response to send a task and receive updates.""" +StreamMessageRequest = JSONRPCRequest[Literal['message/stream'], MessageSendParams] +"""A JSON RPC request to stream a message.""" GetTaskRequest = JSONRPCRequest[Literal['tasks/get'], TaskQueryParams] """A JSON RPC request to get a task.""" @@ -540,21 +723,34 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): ResubscribeTaskRequest = JSONRPCRequest[Literal['tasks/resubscribe'], TaskIdParams] """A JSON RPC request to resubscribe to a task.""" +ListTaskPushNotificationConfigRequest = JSONRPCRequest[ + Literal['tasks/pushNotificationConfig/list'], ListTaskPushNotificationConfigParams +] +"""A JSON RPC request to list task push notification configs.""" + +DeleteTaskPushNotificationConfigRequest = JSONRPCRequest[ + Literal['tasks/pushNotificationConfig/delete'], DeleteTaskPushNotificationConfigParams +] +"""A JSON RPC request to delete a task push notification config.""" + A2ARequest = Annotated[ Union[ - SendTaskRequest, + SendMessageRequest, + StreamMessageRequest, GetTaskRequest, CancelTaskRequest, SetTaskPushNotificationRequest, GetTaskPushNotificationRequest, ResubscribeTaskRequest, + ListTaskPushNotificationConfigRequest, + DeleteTaskPushNotificationConfigRequest, ], Discriminator('method'), ] """A JSON RPC request to the A2A server.""" A2AResponse: TypeAlias = Union[ - SendTaskResponse, + SendMessageResponse, GetTaskResponse, CancelTaskResponse, SetTaskPushNotificationResponse, @@ -565,3 +761,6 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): a2a_request_ta: TypeAdapter[A2ARequest] = TypeAdapter(A2ARequest) a2a_response_ta: TypeAdapter[A2AResponse] = TypeAdapter(A2AResponse) +send_message_request_ta: TypeAdapter[SendMessageRequest] = TypeAdapter(SendMessageRequest) +send_message_response_ta: TypeAdapter[SendMessageResponse] = TypeAdapter(SendMessageResponse) +stream_message_request_ta: TypeAdapter[StreamMessageRequest] = TypeAdapter(StreamMessageRequest) diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index c06bc1cb7..58934c1c5 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -2,16 +2,24 @@ from __future__ import annotations as _annotations +import uuid from abc import ABC, abstractmethod from datetime import datetime +from typing import Any, Generic + +from typing_extensions import TypeVar from .schema import Artifact, Message, Task, TaskState, TaskStatus +ContextT = TypeVar('ContextT', default=Any) + -class Storage(ABC): - """A storage to retrieve and save tasks. +class Storage(ABC, Generic[ContextT]): + """A storage to retrieve and save tasks, as well as retrieve and save context. - The storage is used to update the status of a task and to save the result of a task. + The storage serves two purposes: + 1. Task storage: Stores tasks in A2A protocol format with their status, artifacts, and message history + 2. Context storage: Stores conversation context in a format optimized for the specific agent implementation """ @abstractmethod @@ -22,7 +30,7 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta """ @abstractmethod - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: + async def submit_task(self, context_id: str, message: Message) -> Task: """Submit a task to storage.""" @abstractmethod @@ -30,17 +38,29 @@ async def update_task( self, task_id: str, state: TaskState, - message: Message | None = None, - artifacts: list[Artifact] | None = None, + new_artifacts: list[Artifact] | None = None, + new_messages: list[Message] | None = None, ) -> Task: - """Update the state of a task.""" + """Update the state of a task. Appends artifacts and messages, if specified.""" + @abstractmethod + async def load_context(self, context_id: str) -> ContextT | None: + """Retrieve the stored context given the `context_id`.""" + + @abstractmethod + async def update_context(self, context_id: str, context: ContextT) -> None: + """Updates the context for a `context_id`. -class InMemoryStorage(Storage): + Implementing agent can decide what to store in context. + """ + + +class InMemoryStorage(Storage[ContextT]): """A storage to retrieve and save tasks in memory.""" def __init__(self): self.tasks: dict[str, Task] = {} + self.contexts: dict[str, ContextT] = {} async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: """Load a task from memory. @@ -60,32 +80,52 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta task['history'] = task['history'][-history_length:] return task - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: + async def submit_task(self, context_id: str, message: Message) -> Task: """Submit a task to storage.""" - if task_id in self.tasks: - raise ValueError(f'Task {task_id} already exists') + # Generate a unique task ID + task_id = str(uuid.uuid4()) + + # Add IDs to the message for A2A protocol + message['task_id'] = task_id + message['context_id'] = context_id task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) - task = Task(id=task_id, session_id=session_id, status=task_status, history=[message]) + task = Task(id=task_id, context_id=context_id, kind='task', status=task_status, history=[message]) self.tasks[task_id] = task + return task async def update_task( self, task_id: str, state: TaskState, - message: Message | None = None, - artifacts: list[Artifact] | None = None, + new_artifacts: list[Artifact] | None = None, + new_messages: list[Message] | None = None, ) -> Task: - """Save the task as "working".""" + """Update the state of a task.""" task = self.tasks[task_id] task['status'] = TaskStatus(state=state, timestamp=datetime.now().isoformat()) - if message: - if 'history' not in task: - task['history'] = [] - task['history'].append(message) - if artifacts: + + if new_artifacts: if 'artifacts' not in task: task['artifacts'] = [] - task['artifacts'].extend(artifacts) + task['artifacts'].extend(new_artifacts) + + if new_messages: + if 'history' not in task: + task['history'] = [] + # Add IDs to messages for consistency + for message in new_messages: + message['task_id'] = task_id + message['context_id'] = task['context_id'] + task['history'].append(message) + return task + + async def update_context(self, context_id: str, context: ContextT) -> None: + """Updates the context given the `context_id`.""" + self.contexts[context_id] = context + + async def load_context(self, context_id: str) -> ContextT | None: + """Retrieve the stored context given the `context_id`.""" + return self.contexts.get(context_id) diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index 0baaeba04..6845d78b1 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -74,13 +74,13 @@ GetTaskRequest, GetTaskResponse, ResubscribeTaskRequest, - SendTaskRequest, - SendTaskResponse, - SendTaskStreamingRequest, - SendTaskStreamingResponse, + SendMessageRequest, + SendMessageResponse, SetTaskPushNotificationRequest, SetTaskPushNotificationResponse, + StreamMessageRequest, TaskNotFoundError, + TaskSendParams, ) from .storage import Storage @@ -90,7 +90,7 @@ class TaskManager: """A task manager responsible for managing tasks.""" broker: Broker - storage: Storage + storage: Storage[Any] _aexit_stack: AsyncExitStack | None = field(default=None, init=False) @@ -111,19 +111,22 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): await self._aexit_stack.__aexit__(exc_type, exc_value, traceback) self._aexit_stack = None - async def send_task(self, request: SendTaskRequest) -> SendTaskResponse: - """Send a task to the worker.""" - request_id = str(uuid.uuid4()) - task_id = request['params']['id'] - task = await self.storage.load_task(task_id) + async def send_message(self, request: SendMessageRequest) -> SendMessageResponse: + """Send a message using the A2A v0.2.3 protocol.""" + request_id = request['id'] + message = request['params']['message'] + context_id = message.get('context_id', str(uuid.uuid4())) - if task is None: - session_id = request['params'].get('session_id', str(uuid.uuid4())) - message = request['params']['message'] - task = await self.storage.submit_task(task_id, session_id, message) + task = await self.storage.submit_task(context_id, message) + + broker_params: TaskSendParams = {'id': task['id'], 'context_id': context_id, 'message': message} + config = request['params'].get('configuration', {}) + history_length = config.get('history_length') + if history_length is not None: + broker_params['history_length'] = history_length - await self.broker.run_task(request['params']) - return SendTaskResponse(jsonrpc='2.0', id=request_id, result=task) + await self.broker.run_task(broker_params) + return SendMessageResponse(jsonrpc='2.0', id=request_id, result=task) async def get_task(self, request: GetTaskRequest) -> GetTaskResponse: """Get a task, and return it to the client. @@ -152,8 +155,9 @@ async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: ) return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - async def send_task_streaming(self, request: SendTaskStreamingRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('SendTaskStreaming is not implemented yet.') + async def stream_message(self, request: StreamMessageRequest) -> None: + """Stream messages using Server-Sent Events.""" + raise NotImplementedError('message/stream method is not implemented yet.') async def set_task_push_notification( self, request: SetTaskPushNotificationRequest @@ -165,5 +169,5 @@ async def get_task_push_notification( ) -> GetTaskPushNotificationResponse: raise NotImplementedError('GetTaskPushNotification is not implemented yet.') - async def resubscribe_task(self, request: ResubscribeTaskRequest) -> SendTaskStreamingResponse: + async def resubscribe_task(self, request: ResubscribeTaskRequest) -> None: raise NotImplementedError('Resubscribe is not implemented yet.') diff --git a/fasta2a/fasta2a/worker.py b/fasta2a/fasta2a/worker.py index 9bbde6b25..bcb017276 100644 --- a/fasta2a/fasta2a/worker.py +++ b/fasta2a/fasta2a/worker.py @@ -4,26 +4,27 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic import anyio from opentelemetry.trace import get_tracer, use_span from typing_extensions import assert_never +from .storage import ContextT, Storage + if TYPE_CHECKING: from .broker import Broker, TaskOperation from .schema import Artifact, Message, TaskIdParams, TaskSendParams - from .storage import Storage tracer = get_tracer(__name__) @dataclass -class Worker(ABC): +class Worker(ABC, Generic[ContextT]): """A worker is responsible for executing tasks.""" broker: Broker - storage: Storage + storage: Storage[ContextT] @asynccontextmanager async def run(self) -> AsyncIterator[None]: @@ -62,7 +63,7 @@ async def run_task(self, params: TaskSendParams) -> None: ... async def cancel_task(self, params: TaskIdParams) -> None: ... @abstractmethod - def build_message_history(self, task_history: list[Message]) -> list[Any]: ... + def build_message_history(self, history: list[Message]) -> list[Any]: ... @abstractmethod def build_artifacts(self, result: Any) -> list[Artifact]: ... diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 99bbe37ad..a46dd69f3 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -1,11 +1,13 @@ from __future__ import annotations, annotations as _annotations +import uuid from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass from functools import partial -from typing import Any, Generic +from typing import Any, Generic, TypeVar +from pydantic import TypeAdapter from typing_extensions import assert_never from pydantic_ai.messages import ( @@ -19,12 +21,17 @@ ModelResponse, ModelResponsePart, TextPart, + ThinkingPart, + ToolCallPart, UserPromptPart, VideoUrl, ) from .agent import Agent, AgentDepsT, OutputDataT +# AgentWorker output type needs to be invariant for use in both parameter and return positions +WorkerOutputT = TypeVar('WorkerOutputT') + try: from starlette.middleware import Middleware from starlette.routing import Route @@ -33,10 +40,11 @@ from fasta2a.applications import FastA2A from fasta2a.broker import Broker, InMemoryBroker from fasta2a.schema import ( + AgentProvider, Artifact, + DataPart, Message, Part, - Provider, Skill, TaskIdParams, TaskSendParams, @@ -72,7 +80,7 @@ def agent_to_a2a( url: str = 'http://localhost:8000', version: str = '1.0.0', description: str | None = None, - provider: Provider | None = None, + provider: AgentProvider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, @@ -106,86 +114,202 @@ def agent_to_a2a( @dataclass -class AgentWorker(Worker, Generic[AgentDepsT, OutputDataT]): +class AgentWorker(Worker[list[ModelMessage]], Generic[WorkerOutputT, AgentDepsT]): """A worker that uses an agent to execute tasks.""" - agent: Agent[AgentDepsT, OutputDataT] + agent: Agent[AgentDepsT, WorkerOutputT] async def run_task(self, params: TaskSendParams) -> None: - task = await self.storage.load_task(params['id'], history_length=params.get('history_length')) - assert task is not None, f'Task {params["id"]} not found' - assert 'session_id' in task, 'Task must have a session_id' + task = await self.storage.load_task(params['id']) + if task is None: + raise ValueError(f'Task {params["id"]} not found') + + # TODO(Marcelo): Should we lock `run_task` on the `context_id`? + # Ensure this task hasn't been run before + if task['status']['state'] != 'submitted': + raise ValueError(f'Task {params["id"]} has already been processed (state: {task["status"]["state"]})') await self.storage.update_task(task['id'], state='working') - # TODO(Marcelo): We need to have a way to communicate when the task is set to `input-required`. Maybe - # a custom `output_type` with a `more_info_required` field, or something like that. + # Load context - contains pydantic-ai message history from previous tasks in this conversation + message_history = await self.storage.load_context(task['context_id']) + if message_history is None: + message_history = [] + + try: + # Tasks start with a user message that triggered this task + # Add the current task's initial message to the history + task_history = task.get('history') + if task_history: + for a2a_msg in task_history: + if a2a_msg['role'] == 'user': + message_history.append(ModelRequest(parts=self._request_parts_from_a2a(a2a_msg['parts']))) - task_history = task.get('history', []) - message_history = self.build_message_history(task_history=task_history) + result = await self.agent.run(message_history=message_history) # type: ignore - # TODO(Marcelo): We need to make this more customizable e.g. pass deps. - result = await self.agent.run(message_history=message_history) # type: ignore + await self.storage.update_context(task['context_id'], result.all_messages()) - artifacts = self.build_artifacts(result.output) - await self.storage.update_task(task['id'], state='completed', artifacts=artifacts) + # Convert new messages to A2A format for task history + a2a_messages: list[Message] = [] + + for message in result.new_messages(): + if isinstance(message, ModelRequest): + # Skip user prompts - they're already in task history + continue + elif isinstance(message, ModelResponse): + # Convert response parts to A2A format + a2a_parts = self._response_parts_to_a2a(message.parts) + if a2a_parts: # Add if there are visible parts (text/thinking) + a2a_messages.append( + Message(role='agent', parts=a2a_parts, kind='message', message_id=str(uuid.uuid4())) + ) + + artifacts = self.build_artifacts(result.output) + except Exception: + await self.storage.update_task(task['id'], state='failed') + raise + else: + await self.storage.update_task( + task['id'], state='completed', new_artifacts=artifacts, new_messages=a2a_messages + ) async def cancel_task(self, params: TaskIdParams) -> None: pass - def build_artifacts(self, result: Any) -> list[Artifact]: - # TODO(Marcelo): We need to send the json schema of the result on the metadata of the message. - return [Artifact(name='result', index=0, parts=[A2ATextPart(type='text', text=str(result))])] + def build_artifacts(self, result: WorkerOutputT) -> list[Artifact]: + """Build artifacts from agent result. + + All agent outputs become artifacts to mark them as durable task outputs. + For string results, we use TextPart. For structured data, we use DataPart. + Metadata is included to preserve type information. + """ + artifact_id = str(uuid.uuid4()) + part = self._convert_result_to_part(result) + return [Artifact(artifact_id=artifact_id, name='result', parts=[part])] - def build_message_history(self, task_history: list[Message]) -> list[ModelMessage]: + def _convert_result_to_part(self, result: WorkerOutputT) -> Part: + """Convert agent result to a Part (TextPart or DataPart). + + For string results, returns a TextPart. + For structured data, returns a DataPart with properly serialized data. + """ + if isinstance(result, str): + return A2ATextPart(kind='text', text=result) + else: + output_type = type(result) + type_adapter = TypeAdapter(output_type) + data = type_adapter.dump_python(result, mode='json') + json_schema = type_adapter.json_schema(mode='serialization') + return DataPart(kind='data', data={'result': data}, metadata={'json_schema': json_schema}) + + def build_message_history(self, history: list[Message]) -> list[ModelMessage]: model_messages: list[ModelMessage] = [] - for message in task_history: + for message in history: if message['role'] == 'user': - model_messages.append(ModelRequest(parts=self._map_request_parts(message['parts']))) + model_messages.append(ModelRequest(parts=self._request_parts_from_a2a(message['parts']))) else: - model_messages.append(ModelResponse(parts=self._map_response_parts(message['parts']))) + model_messages.append(ModelResponse(parts=self._response_parts_from_a2a(message['parts']))) return model_messages - def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: + def _request_parts_from_a2a(self, parts: list[Part]) -> list[ModelRequestPart]: + """Convert A2A Part objects to pydantic-ai ModelRequestPart objects. + + This handles the conversion from A2A protocol parts (text, file, data) to + pydantic-ai's internal request parts (UserPromptPart with various content types). + + Args: + parts: List of A2A Part objects from incoming messages + + Returns: + List of ModelRequestPart objects for the pydantic-ai agent + """ model_parts: list[ModelRequestPart] = [] for part in parts: - if part['type'] == 'text': + if part['kind'] == 'text': model_parts.append(UserPromptPart(content=part['text'])) - elif part['type'] == 'file': - file = part['file'] - if 'data' in file: - data = file['data'].encode('utf-8') - content = BinaryContent(data=data, media_type=file['mime_type']) + elif part['kind'] == 'file': + file_content = part['file'] + if 'bytes' in file_content: + data = file_content['bytes'].encode('utf-8') + mime_type = file_content.get('mime_type', 'application/octet-stream') + content = BinaryContent(data=data, media_type=mime_type) model_parts.append(UserPromptPart(content=[content])) - else: - url = file['url'] - for url_cls in (DocumentUrl, AudioUrl, ImageUrl, VideoUrl): - content = url_cls(url=url) - try: - content.media_type - except ValueError: # pragma: no cover - continue - else: - break + elif 'uri' in file_content: + url = file_content['uri'] + mime_type = file_content.get('mime_type', '') + if mime_type.startswith('image/'): + content = ImageUrl(url=url) + elif mime_type.startswith('audio/'): + content = AudioUrl(url=url) + elif mime_type.startswith('video/'): + content = VideoUrl(url=url) else: - raise ValueError(f'Unknown file type: {file["mime_type"]}') # pragma: no cover + content = DocumentUrl(url=url) model_parts.append(UserPromptPart(content=[content])) - elif part['type'] == 'data': - # TODO(Marcelo): Maybe we should use this for `ToolReturnPart`, and `RetryPromptPart`. + else: + raise ValueError('FilePart.file must have either data or uri') + elif part['kind'] == 'data': raise NotImplementedError('Data parts are not supported yet.') else: assert_never(part) return model_parts - def _map_response_parts(self, parts: list[Part]) -> list[ModelResponsePart]: + def _response_parts_from_a2a(self, parts: list[Part]) -> list[ModelResponsePart]: + """Convert A2A Part objects to pydantic-ai ModelResponsePart objects. + + This handles the conversion from A2A protocol parts (text, file, data) to + pydantic-ai's internal response parts. Currently only supports text parts + as agent responses in A2A are expected to be text-based. + + Args: + parts: List of A2A Part objects from stored agent messages + + Returns: + List of ModelResponsePart objects for message history + """ model_parts: list[ModelResponsePart] = [] for part in parts: - if part['type'] == 'text': + if part['kind'] == 'text': model_parts.append(TextPart(content=part['text'])) - elif part['type'] == 'file': # pragma: no cover + elif part['kind'] == 'file': # pragma: no cover raise NotImplementedError('File parts are not supported yet.') - elif part['type'] == 'data': # pragma: no cover + elif part['kind'] == 'data': # pragma: no cover raise NotImplementedError('Data parts are not supported yet.') else: # pragma: no cover assert_never(part) return model_parts + + def _response_parts_to_a2a(self, parts: list[ModelResponsePart]) -> list[Part]: + """Convert pydantic-ai ModelResponsePart objects to A2A Part objects. + + This handles the conversion from pydantic-ai's internal response parts to + A2A protocol parts. Different part types are handled as follows: + - TextPart: Converted directly to A2A TextPart + - ThinkingPart: Converted to TextPart with metadata indicating it's thinking + - ToolCallPart: Skipped (internal to agent execution) + + Args: + parts: List of ModelResponsePart objects from agent response + + Returns: + List of A2A Part objects suitable for sending via A2A protocol + """ + a2a_parts: list[Part] = [] + for part in parts: + if isinstance(part, TextPart): + if part.content: # Only add non-empty text + a2a_parts.append(A2ATextPart(kind='text', text=part.content)) + elif isinstance(part, ThinkingPart): + if part.content: # Only add non-empty thinking + # Convert thinking to text with metadata + a2a_parts.append( + A2ATextPart( + kind='text', + text=part.content, + metadata={'type': 'thinking', 'thinking_id': part.id, 'signature': part.signature}, + ) + ) + elif isinstance(part, ToolCallPart): + # Skip tool calls - they're internal to agent execution + pass + return a2a_parts diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 6b8a5a5a6..2c9eb3a3d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -63,7 +63,7 @@ from fasta2a.applications import FastA2A from fasta2a.broker import Broker - from fasta2a.schema import Provider, Skill + from fasta2a.schema import AgentProvider, Skill from fasta2a.storage import Storage from pydantic_ai.mcp import MCPServer @@ -1764,7 +1764,7 @@ def to_a2a( url: str = 'http://localhost:8000', version: str = '1.0.0', description: str | None = None, - provider: Provider | None = None, + provider: AgentProvider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, diff --git a/tests/fasta2a/test_applications.py b/tests/fasta2a/test_applications.py index 3b3aa437d..955933151 100644 --- a/tests/fasta2a/test_applications.py +++ b/tests/fasta2a/test_applications.py @@ -34,12 +34,13 @@ async def test_agent_card(): assert response.json() == snapshot( { 'name': 'Agent', + 'description': 'FastA2A Agent', 'url': 'http://localhost:8000', 'version': '1.0.0', + 'protocolVersion': '0.2.5', 'skills': [], 'defaultInputModes': ['application/json'], 'defaultOutputModes': ['application/json'], 'capabilities': {'streaming': False, 'pushNotifications': False, 'stateTransitionHistory': False}, - 'authentication': {'schemes': []}, } ) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index fae117781..a3f28d937 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -1,11 +1,14 @@ +import uuid + import anyio import httpx import pytest from asgi_lifespan import LifespanManager from inline_snapshot import snapshot +from pydantic import BaseModel from pydantic_ai import Agent -from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, ToolCallPart from pydantic_ai.models.function import AgentInfo, FunctionModel from .conftest import IsDatetime, IsStr, try_import @@ -32,6 +35,95 @@ def return_string(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: model = FunctionModel(return_string) +# Define a test Pydantic model +class UserProfile(BaseModel): + name: str + age: int + email: str + + +def return_pydantic_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"name": "John Doe", "age": 30, "email": "john@example.com"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + +pydantic_model = FunctionModel(return_pydantic_model) + + +async def test_a2a_pydantic_model_output(): + """Test that Pydantic model outputs have correct metadata including JSON schema.""" + agent = Agent(model=pydantic_model, output_type=UserProfile) + app = agent.to_a2a() + + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + message = Message( + role='user', + parts=[TextPart(text='Get user profile', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert result['kind'] == 'task' + + task_id = result['id'] + + # Wait for completion + await anyio.sleep(0.1) + task = await a2a_client.get_task(task_id) + + assert 'result' in task + result = task['result'] + assert result['status']['state'] == 'completed' + + # Check artifacts + assert 'artifacts' in result + assert len(result['artifacts']) == 1 + artifact = result['artifacts'][0] + + # Verify the data + assert artifact['parts'][0]['kind'] == 'data' + assert artifact['parts'][0]['data'] == { + 'result': {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'} + } + + metadata = artifact['parts'][0].get('metadata') + assert metadata is not None + + assert metadata['json_schema'] == snapshot( + { + 'properties': { + 'name': {'title': 'Name', 'type': 'string'}, + 'age': {'title': 'Age', 'type': 'integer'}, + 'email': {'title': 'Email', 'type': 'string'}, + }, + 'required': ['name', 'age', 'email'], + 'title': 'UserProfile', + 'type': 'object', + } + ) + + assert result.get('history') == snapshot( + [ + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Get user profile'}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ] + ) + + async def test_a2a_runtime_error_without_lifespan(): agent = Agent(model=model, output_type=tuple[str, str]) app = agent.to_a2a() @@ -40,10 +132,15 @@ async def test_a2a_runtime_error_without_lifespan(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', type='text')]) + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) with pytest.raises(RuntimeError, match='TaskManager was not properly initialized.'): - await a2a_client.send_task(message=message) + await a2a_client.send_message(message=message) async def test_a2a_simple(): @@ -55,23 +152,37 @@ async def test_a2a_simple(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', type='text')]) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert result['kind'] == 'task' + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}], - }, + 'context_id': IsStr(), + 'kind': 'task', + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'completed': @@ -83,11 +194,31 @@ async def test_a2a_simple(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}], + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [ + { + 'metadata': {'json_schema': {'items': {}, 'type': 'array'}}, + 'kind': 'data', + 'data': {'result': ['foo', 'bar']}, + } + ], + } ], }, } @@ -107,37 +238,43 @@ async def test_a2a_file_message_with_file(): role='user', parts=[ FilePart( - type='file', - file={'url': 'https://example.com/file.txt', 'mime_type': 'text/plain'}, + kind='file', + file={'uri': 'https://example.com/file.txt', 'mime_type': 'text/plain'}, ) ], + kind='message', + message_id=str(uuid.uuid4()), ) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert result['kind'] == 'task' + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [ - { - 'role': 'user', - 'parts': [ - { - 'type': 'file', - 'file': {'mime_type': 'text/plain', 'url': 'https://example.com/file.txt'}, - } - ], - } - ], - }, + 'context_id': IsStr(), + 'kind': 'task', + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [ + { + 'kind': 'file', + 'file': {'mime_type': 'text/plain', 'uri': 'https://example.com/file.txt'}, + } + ], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'completed': @@ -149,21 +286,36 @@ async def test_a2a_file_message_with_file(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { 'role': 'user', 'parts': [ { - 'type': 'file', - 'file': {'mime_type': 'text/plain', 'url': 'https://example.com/file.txt'}, + 'kind': 'file', + 'file': {'mime_type': 'text/plain', 'uri': 'https://example.com/file.txt'}, } ], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), } ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [ + { + 'metadata': {'json_schema': {'items': {}, 'type': 'array'}}, + 'kind': 'data', + 'data': {'result': ['foo', 'bar']}, + } + ], + } ], }, } @@ -182,30 +334,36 @@ async def test_a2a_file_message_with_file_content(): message = Message( role='user', parts=[ - FilePart(type='file', file={'data': 'foo', 'mime_type': 'text/plain'}), + FilePart(file={'bytes': 'foo', 'mime_type': 'text/plain'}, kind='file'), ], + kind='message', + message_id=str(uuid.uuid4()), ) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert result['kind'] == 'task' + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [ - { - 'role': 'user', - 'parts': [{'type': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], - } - ], - }, + 'context_id': IsStr(), + 'kind': 'task', + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'file', 'file': {'bytes': 'foo', 'mime_type': 'text/plain'}}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'completed': @@ -217,16 +375,31 @@ async def test_a2a_file_message_with_file_content(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { 'role': 'user', - 'parts': [{'type': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], + 'parts': [{'kind': 'file', 'file': {'bytes': 'foo', 'mime_type': 'text/plain'}}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), } ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [ + { + 'metadata': {'json_schema': {'items': {}, 'type': 'array'}}, + 'kind': 'data', + 'data': {'result': ['foo', 'bar']}, + } + ], + } ], }, } @@ -244,24 +417,35 @@ async def test_a2a_file_message_with_data(): message = Message( role='user', - parts=[DataPart(type='data', data={'foo': 'bar'})], + parts=[DataPart(kind='data', data={'foo': 'bar'})], + kind='message', + message_id=str(uuid.uuid4()), ) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert result['kind'] == 'task' + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'data', 'data': {'foo': 'bar'}}]}], - }, + 'context_id': IsStr(), + 'kind': 'task', + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'failed': @@ -273,14 +457,168 @@ async def test_a2a_file_message_with_data(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'failed', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'data', 'data': {'foo': 'bar'}}]}], + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], }, } ) +async def test_a2a_error_handling(): + """Test that errors during task execution properly update task state.""" + + def raise_error(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + raise RuntimeError('Test error during agent execution') + + error_model = FunctionModel(raise_error) + agent = Agent(model=error_model, output_type=str) + app = agent.to_a2a() + + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert result['kind'] == 'task' + + task_id = result['id'] + + # Wait for task to fail + await anyio.sleep(0.1) + task = await a2a_client.get_task(task_id) + + assert 'result' in task + assert task['result']['status']['state'] == 'failed' + + +async def test_a2a_multiple_tasks_same_context(): + """Test that multiple tasks can share the same context_id with accumulated history.""" + + messages_received: list[list[ModelMessage]] = [] + + def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Store a copy of the messages received by the model + messages_received.append(messages.copy()) + # Return the standard response + assert info.output_tools is not None + args_json = '{"response": ["foo", "bar"]}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + tracking_model = FunctionModel(track_messages) + agent = Agent(model=tracking_model, output_type=tuple[str, str]) + app = agent.to_a2a() + + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + # First message - should create a new context + message1 = Message( + role='user', + parts=[TextPart(text='First message', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + response1 = await a2a_client.send_message(message=message1) + assert 'error' not in response1 + assert 'result' in response1 + result1 = response1['result'] + assert result1['kind'] == 'task' + + task1_id = result1['id'] + context_id = result1['context_id'] + + # Wait for first task to complete + await anyio.sleep(0.1) + task1 = await a2a_client.get_task(task1_id) + assert 'result' in task1 + assert task1['result']['status']['state'] == 'completed' + + # Verify the model received at least one message + assert len(messages_received) == 1 + first_run_history = messages_received[0] + assert len(first_run_history) >= 1 + # Check first message is a ModelRequest with UserPromptPart + first_msg = first_run_history[0] + assert isinstance(first_msg, ModelRequest) + first_part = first_msg.parts[0] + assert hasattr(first_part, 'content') and first_part.content == 'First message' + + # Second message - reuse the same context_id + message2 = Message( + role='user', + parts=[TextPart(text='Second message', kind='text')], + kind='message', + context_id=context_id, + message_id=str(uuid.uuid4()), + ) + response2 = await a2a_client.send_message(message=message2) + assert 'error' not in response2 + assert 'result' in response2 + result2 = response2['result'] + assert result2['kind'] == 'task' + + task2_id = result2['id'] + # Verify we got a new task ID but same context ID + assert task2_id != task1_id + assert result2['context_id'] == context_id + + # Wait for second task to complete + await anyio.sleep(0.1) + task2 = await a2a_client.get_task(task2_id) + assert 'result' in task2 + if task2['result']['status']['state'] == 'failed': + print(f'Task 2 failed: {task2}') + print(f'Messages received so far: {messages_received}') + assert task2['result']['status']['state'] == 'completed' + + # Verify the model received the full history on the second call + assert len(messages_received) == 2 + second_run_history = messages_received[1] + + # Check that history is monotonically increasing - all previous messages should be there + assert len(second_run_history) > len(first_run_history), ( + f'Expected more messages, got {len(second_run_history)} <= {len(first_run_history)}' + ) + + # Check that all messages from first run are still in second run (in same order) + for i, msg in enumerate(first_run_history): + assert second_run_history[i] == msg, f'Message {i} changed between runs' + + # Verify the new message is there + # Find the user message with 'Second message' in the new history + found_second_message = False + for msg in second_run_history: + if isinstance(msg, ModelRequest): + for part in msg.parts: + if hasattr(part, 'content') and part.content == 'Second message': + found_second_message = True + break + assert found_second_message, 'Second message not found in history' + + async def test_a2a_multiple_messages(): agent = Agent(model=model, output_type=tuple[str, str]) storage = InMemoryStorage() @@ -291,27 +629,51 @@ async def test_a2a_multiple_messages(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', type='text')]) - response = await a2a_client.send_task(message=message) + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + response = await a2a_client.send_message(message=message) assert response == snapshot( { 'jsonrpc': '2.0', 'id': IsStr(), 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}], + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], }, } ) # NOTE: We include the agent history before we start working on the task. assert 'result' in response - task_id = response['result']['id'] + result = response['result'] + assert result['kind'] == 'task' + task_id = result['id'] task = storage.tasks[task_id] assert 'history' in task - task['history'].append(Message(role='agent', parts=[TextPart(text='Whats up?', type='text')])) + task['history'].append( + Message( + role='agent', + parts=[TextPart(text='Whats up?', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + ) response = await a2a_client.get_task(task_id) assert response == snapshot( @@ -320,11 +682,24 @@ async def test_a2a_multiple_messages(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}, - {'role': 'agent', 'parts': [{'type': 'text', 'text': 'Whats up?'}]}, + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, + { + 'role': 'agent', + 'parts': [{'kind': 'text', 'text': 'Whats up?'}], + 'kind': 'message', + 'message_id': IsStr(), + }, ], }, } @@ -338,14 +713,37 @@ async def test_a2a_multiple_messages(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}, - {'role': 'agent', 'parts': [{'type': 'text', 'text': 'Whats up?'}]}, + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, + { + 'role': 'agent', + 'parts': [{'kind': 'text', 'text': 'Whats up?'}], + 'kind': 'message', + 'message_id': IsStr(), + }, ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [ + { + 'metadata': {'json_schema': {'items': {}, 'type': 'array'}}, + 'kind': 'data', + 'data': {'result': ['foo', 'bar']}, + } + ], + } ], }, }