From be6e37fd72c682cea4d91aed4dceb686006502cf Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Sun, 6 Jul 2025 01:51:05 +0000 Subject: [PATCH 01/24] feat: upgrade A2A protocol from v0.1 to v0.2.3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update protocol methods: tasks/send → message/send - Replace 'type' with 'kind' throughout schema - Replace 'session_id' with 'context_id' for conversation tracking - Add Message and Part types (TextPart, FilePart, DataPart) - Implement dual message/artifact approach for agent outputs - Add metadata to artifacts including type info and JSON schema - Add proper error handling with task state updates - Add NotImplementedError stubs for streaming methods - Rename test directory to avoid import conflicts --- fasta2a/fasta2a/applications.py | 14 +- fasta2a/fasta2a/client.py | 41 +- fasta2a/fasta2a/schema.py | 212 +++++++--- fasta2a/fasta2a/storage.py | 60 ++- fasta2a/fasta2a/task_manager.py | 49 ++- fasta2a/fasta2a/worker.py | 2 +- pydantic_ai_slim/pydantic_ai/_a2a.py | 154 ++++++-- tests/test_a2a.py | 362 +++++++++++++----- tests/{fasta2a => test_fasta2a}/__init__.py | 0 .../test_applications.py | 0 10 files changed, 660 insertions(+), 234 deletions(-) rename tests/{fasta2a => test_fasta2a}/__init__.py (100%) rename tests/{fasta2a => test_fasta2a}/test_applications.py (100%) diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 61301262b..9de5f9433 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -21,6 +21,7 @@ a2a_request_ta, a2a_response_ta, agent_card_ta, + send_message_request_ta, ) from .storage import Storage from .task_manager import TaskManager @@ -116,8 +117,17 @@ async def _agent_run_endpoint(self, request: Request) -> Response: data = await request.body() a2a_request = a2a_request_ta.validate_json(data) - if a2a_request['method'] == 'tasks/send': - jsonrpc_response = await self.task_manager.send_task(a2a_request) + if a2a_request['method'] == 'message/send': + message_request = send_message_request_ta.validate_json(data) + jsonrpc_response = await self.task_manager.send_message(message_request) + elif a2a_request['method'] == 'message/stream': + # Streaming support not yet implemented + raise NotImplementedError( + 'message/stream method is not implemented yet. Streaming support will be added in a future update.' + ) + elif a2a_request['method'] == 'tasks/send': # type: ignore[comparison-overlap] + # Legacy method - no longer supported + raise NotImplementedError('tasks/send is deprecated. Use message/send instead.') elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) elif a2a_request['method'] == 'tasks/cancel': diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py index 5c5aabd81..dc3449623 100644 --- a/fasta2a/fasta2a/client.py +++ b/fasta2a/fasta2a/client.py @@ -9,14 +9,15 @@ GetTaskRequest, GetTaskResponse, Message, - PushNotificationConfig, - SendTaskRequest, - SendTaskResponse, - TaskSendParams, + MessageSendConfiguration, + MessageSendParams, + SendMessageRequest, + SendMessageResponse, a2a_request_ta, + send_message_request_ta, + send_message_response_ta, ) -send_task_response_ta = pydantic.TypeAdapter(SendTaskResponse) get_task_response_ta = pydantic.TypeAdapter(GetTaskResponse) try: @@ -37,26 +38,30 @@ def __init__(self, base_url: str = 'http://localhost:8000', http_client: httpx.A self.http_client = http_client self.http_client.base_url = base_url - async def send_task( + async def send_message( self, message: Message, - history_length: int | None = None, - push_notification: PushNotificationConfig | None = None, + *, metadata: dict[str, Any] | None = None, - ) -> SendTaskResponse: - task = TaskSendParams(message=message, id=str(uuid.uuid4())) - if history_length is not None: - task['history_length'] = history_length - if push_notification is not None: - task['push_notification'] = push_notification + configuration: MessageSendConfiguration | None = None, + ) -> SendMessageResponse: + """Send a message using the A2A protocol. + + Returns a JSON-RPC response containing either a result (Task | Message) or an error. + """ + params = MessageSendParams(message=message) if metadata is not None: - task['metadata'] = metadata + params['metadata'] = metadata + if configuration is not None: + params['configuration'] = configuration - payload = SendTaskRequest(jsonrpc='2.0', id=None, method='tasks/send', params=task) - content = a2a_request_ta.dump_json(payload, by_alias=True) + request_id = str(uuid.uuid4()) + payload = SendMessageRequest(jsonrpc='2.0', id=request_id, method='message/send', params=params) + content = send_message_request_ta.dump_json(payload, by_alias=True) response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) self._raise_for_status(response) - return send_task_response_ta.validate_json(response.content) + + return send_message_response_ta.validate_json(response.content) async def get_task(self, task_id: str) -> GetTaskResponse: payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index cab5d2057..1ccf98055 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -7,7 +7,7 @@ import pydantic from pydantic import Discriminator, TypeAdapter from pydantic.alias_generators import to_camel -from typing_extensions import NotRequired, TypeAlias, TypedDict +from typing_extensions import NotRequired, TypeAlias, TypedDict, TypeGuard @pydantic.with_config({'alias_generator': to_camel}) @@ -137,6 +137,9 @@ class Artifact(TypedDict): Artifacts. """ + artifact_id: str + """Unique identifier for the artifact.""" + name: NotRequired[str] """The name of the artifact.""" @@ -149,8 +152,8 @@ class Artifact(TypedDict): metadata: NotRequired[dict[str, Any]] """Metadata about the artifact.""" - index: int - """The index of the artifact.""" + extensions: NotRequired[list[Any]] + """Array of extensions.""" append: NotRequired[bool] """Whether to append this artifact to an existing one.""" @@ -183,6 +186,9 @@ class PushNotificationConfig(TypedDict): mobile Push Notification Service). """ + id: NotRequired[str] + """Server-assigned identifier.""" + url: str """The URL to send push notifications to.""" @@ -204,6 +210,7 @@ class TaskPushNotificationConfig(TypedDict): """The push notification configuration.""" +@pydantic.with_config({'alias_generator': to_camel}) class Message(TypedDict): """A Message contains any content that is not an Artifact. @@ -222,9 +229,28 @@ class Message(TypedDict): parts: list[Part] """The parts of the message.""" + kind: Literal['message'] + """Event type.""" + metadata: NotRequired[dict[str, Any]] """Metadata about the message.""" + # Additional fields + message_id: NotRequired[str] + """Identifier created by the message creator.""" + + context_id: NotRequired[str] + """The context the message is associated with.""" + + task_id: NotRequired[str] + """Identifier of task the message is related to.""" + + reference_task_ids: NotRequired[list[str]] + """Array of task IDs this message references.""" + + extensions: NotRequired[list[Any]] + """Array of extensions.""" + class _BasePart(TypedDict): """A base class for all parts.""" @@ -232,76 +258,73 @@ class _BasePart(TypedDict): metadata: NotRequired[dict[str, Any]] +@pydantic.with_config({'alias_generator': to_camel}) class TextPart(_BasePart): """A part that contains text.""" - type: Literal['text'] - """The type of the part.""" + kind: Literal['text'] + """The kind of the part.""" text: str """The text of the part.""" @pydantic.with_config({'alias_generator': to_camel}) -class FilePart(_BasePart): - """A part that contains a file.""" - - type: Literal['file'] - """The type of the part.""" - - file: File - """The file of the part.""" +class FileWithBytes(TypedDict): + """File with base64 encoded data.""" - -@pydantic.with_config({'alias_generator': to_camel}) -class _BaseFile(_BasePart): - """A base class for all file types.""" - - name: NotRequired[str] - """The name of the file.""" + data: str + """The base64 encoded data.""" mime_type: str """The mime type of the file.""" @pydantic.with_config({'alias_generator': to_camel}) -class _BinaryFile(_BaseFile): - """A binary file.""" +class FileWithUri(TypedDict): + """File with URI reference.""" - data: str - """The base64 encoded bytes of the file.""" + uri: str + """The URI of the file.""" + mime_type: NotRequired[str] + """The mime type of the file.""" -@pydantic.with_config({'alias_generator': to_camel}) -class _URLFile(_BaseFile): - """A file that is hosted on a remote URL.""" - url: str - """The URL of the file.""" +@pydantic.with_config({'alias_generator': to_camel}) +class FilePart(_BasePart): + """A part that contains a file.""" + kind: Literal['file'] + """The kind of the part.""" -File: TypeAlias = Union[_BinaryFile, _URLFile] -"""A file is a binary file or a URL file.""" + file: FileWithBytes | FileWithUri + """The file content - either bytes or URI.""" @pydantic.with_config({'alias_generator': to_camel}) class DataPart(_BasePart): - """A part that contains data.""" + """A part that contains structured data.""" - type: Literal['data'] - """The type of the part.""" + kind: Literal['data'] + """The kind of the part.""" - data: dict[str, Any] + data: Any """The data of the part.""" + description: NotRequired[str] + """A description of the data.""" -Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='type')] + +Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='kind')] """A fully formed piece of content exchanged between a client and a remote agent as part of a Message or an Artifact. Each Part has its own content type and metadata. """ -TaskState: TypeAlias = Literal['submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'unknown'] +TaskState: TypeAlias = Literal[ + 'submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'rejected', 'auth-required' +] """The possible states of a task.""" @@ -330,8 +353,11 @@ class Task(TypedDict): id: str """Unique identifier for the task.""" - session_id: NotRequired[str] - """Client-generated id for the session holding the task.""" + context_id: str + """The context the task is associated with.""" + + kind: Literal['task'] + """Event type.""" status: TaskStatus """Current status of the task.""" @@ -348,11 +374,17 @@ class Task(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) class TaskStatusUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" + """Sent by server during message/stream requests.""" - id: str + task_id: str """The id of the task.""" + context_id: str + """The context the task is associated with.""" + + kind: Literal['status-update'] + """Event type.""" + status: TaskStatus """The status of the task.""" @@ -365,14 +397,26 @@ class TaskStatusUpdateEvent(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) class TaskArtifactUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" + """Sent by server during message/stream requests.""" - id: str + task_id: str """The id of the task.""" + context_id: str + """The context the task is associated with.""" + + kind: Literal['artifact-update'] + """Event type identification.""" + artifact: Artifact """The artifact that was updated.""" + append: NotRequired[bool] + """Whether to append to existing artifact (true) or replace (false).""" + + last_chunk: NotRequired[bool] + """Indicates this is the final chunk of the artifact.""" + metadata: NotRequired[dict[str, Any]] """Extension metadata.""" @@ -393,25 +437,57 @@ class TaskQueryParams(TaskIdParams): """Number of recent messages to be retrieved.""" +@pydantic.with_config({'alias_generator': to_camel}) +class MessageSendConfiguration(TypedDict): + """Configuration for the send message request.""" + + accepted_output_modes: list[str] + """Accepted output modalities by the client.""" + + blocking: NotRequired[bool] + """If the server should treat the client as a blocking request.""" + + history_length: NotRequired[int] + """Number of recent messages to be retrieved.""" + + push_notification_config: NotRequired[PushNotificationConfig] + """Where the server should send notifications when disconnected.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class MessageSendParams(TypedDict): + """Parameters for message/send method.""" + + configuration: NotRequired[MessageSendConfiguration] + """Send message configuration.""" + + message: Message + """The message being sent to the server.""" + + metadata: NotRequired[dict[str, Any]] + """Extension metadata.""" + + @pydantic.with_config({'alias_generator': to_camel}) class TaskSendParams(TypedDict): - """Sent by the client to the agent to create, continue, or restart a task.""" + """Internal parameters for task execution within the framework. + + Note: This is not part of the A2A protocol - it's used internally + for broker/worker communication. + """ id: str """The id of the task.""" - session_id: NotRequired[str] - """The server creates a new sessionId for new tasks if not set.""" + context_id: str + """The context id for the task.""" message: Message - """The message to send to the agent.""" + """The message to process.""" history_length: NotRequired[int] """Number of recent messages to be retrieved.""" - push_notification: NotRequired[PushNotificationConfig] - """Where the server should send notifications when disconnected.""" - metadata: NotRequired[dict[str, Any]] """Extension metadata.""" @@ -497,21 +573,21 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): ContentTypeNotSupportedError = JSONRPCError[Literal[-32005], Literal['Incompatible content types']] """A JSON RPC error for incompatible content types.""" +InvalidAgentResponseError = JSONRPCError[Literal[-32006], Literal['Invalid agent response']] +"""A JSON RPC error for invalid agent response.""" + ############################################################################################### ####################################### Requests and responses ############################ ############################################################################################### -SendTaskRequest = JSONRPCRequest[Literal['tasks/send'], TaskSendParams] -"""A JSON RPC request to send a task.""" - -SendTaskResponse = JSONRPCResponse[Task, JSONRPCError[Any, Any]] -"""A JSON RPC response to send a task.""" +SendMessageRequest = JSONRPCRequest[Literal['message/send'], MessageSendParams] +"""A JSON RPC request to send a message.""" -SendTaskStreamingRequest = JSONRPCRequest[Literal['tasks/sendSubscribe'], TaskSendParams] -"""A JSON RPC request to send a task and receive updates.""" +SendMessageResponse = JSONRPCResponse[Union[Task, Message], JSONRPCError[Any, Any]] +"""A JSON RPC response to send a message.""" -SendTaskStreamingResponse = JSONRPCResponse[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent], InternalError] -"""A JSON RPC response to send a task and receive updates.""" +StreamMessageRequest = JSONRPCRequest[Literal['message/stream'], MessageSendParams] +"""A JSON RPC request to stream a message.""" GetTaskRequest = JSONRPCRequest[Literal['tasks/get'], TaskQueryParams] """A JSON RPC request to get a task.""" @@ -542,7 +618,8 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): A2ARequest = Annotated[ Union[ - SendTaskRequest, + SendMessageRequest, + StreamMessageRequest, GetTaskRequest, CancelTaskRequest, SetTaskPushNotificationRequest, @@ -554,7 +631,7 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): """A JSON RPC request to the A2A server.""" A2AResponse: TypeAlias = Union[ - SendTaskResponse, + SendMessageResponse, GetTaskResponse, CancelTaskResponse, SetTaskPushNotificationResponse, @@ -565,3 +642,16 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): a2a_request_ta: TypeAdapter[A2ARequest] = TypeAdapter(A2ARequest) a2a_response_ta: TypeAdapter[A2AResponse] = TypeAdapter(A2AResponse) +send_message_request_ta: TypeAdapter[SendMessageRequest] = TypeAdapter(SendMessageRequest) +send_message_response_ta: TypeAdapter[SendMessageResponse] = TypeAdapter(SendMessageResponse) +stream_message_request_ta: TypeAdapter[StreamMessageRequest] = TypeAdapter(StreamMessageRequest) + + +def is_task(response: Task | Message) -> TypeGuard[Task]: + """Type guard to check if a response is a Task.""" + return 'id' in response and 'status' in response and 'context_id' in response and response.get('kind') == 'task' + + +def is_message(response: Task | Message) -> TypeGuard[Message]: + """Type guard to check if a response is a Message.""" + return 'role' in response and 'parts' in response and response.get('kind') == 'message' diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index c06bc1cb7..03f04c34e 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -22,7 +22,7 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta """ @abstractmethod - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: + async def submit_task(self, task_id: str, context_id: str, message: Message) -> Task: """Submit a task to storage.""" @abstractmethod @@ -30,17 +30,29 @@ async def update_task( self, task_id: str, state: TaskState, - message: Message | None = None, artifacts: list[Artifact] | None = None, ) -> Task: """Update the state of a task.""" + @abstractmethod + async def add_message(self, message: Message) -> None: + """Add a message to the history for both its task and context. + + This should be called for messages created during task execution, + not for the initial message (which is handled by submit_task). + """ + + @abstractmethod + async def get_context_history(self, context_id: str, history_length: int | None = None) -> list[Message]: + """Get all messages across tasks in a context.""" + class InMemoryStorage(Storage): """A storage to retrieve and save tasks in memory.""" def __init__(self): self.tasks: dict[str, Task] = {} + self.context_messages: dict[str, list[Message]] = {} async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: """Load a task from memory. @@ -60,32 +72,60 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta task['history'] = task['history'][-history_length:] return task - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: + async def submit_task(self, task_id: str, context_id: str, message: Message) -> Task: """Submit a task to storage.""" if task_id in self.tasks: raise ValueError(f'Task {task_id} already exists') + # Add IDs to the message + message['task_id'] = task_id + message['context_id'] = context_id + task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) - task = Task(id=task_id, session_id=session_id, status=task_status, history=[message]) + task = Task(id=task_id, context_id=context_id, kind='task', status=task_status, history=[message]) self.tasks[task_id] = task + + # Add message to context storage directly (not via add_message to avoid duplication) + if context_id not in self.context_messages: + self.context_messages[context_id] = [] + self.context_messages[context_id].append(message) + return task async def update_task( self, task_id: str, state: TaskState, - message: Message | None = None, artifacts: list[Artifact] | None = None, ) -> Task: - """Save the task as "working".""" + """Update the state of a task.""" task = self.tasks[task_id] task['status'] = TaskStatus(state=state, timestamp=datetime.now().isoformat()) - if message: - if 'history' not in task: - task['history'] = [] - task['history'].append(message) if artifacts: if 'artifacts' not in task: task['artifacts'] = [] task['artifacts'].extend(artifacts) return task + + async def add_message(self, message: Message) -> None: + """Add a message to the history for both its task and context.""" + if 'task_id' in message and message['task_id']: + task_id = message['task_id'] + if task_id in self.tasks: + task = self.tasks[task_id] + if 'history' not in task: + task['history'] = [] + task['history'].append(message) + + if 'context_id' in message and message['context_id']: + context_id = message['context_id'] + if context_id not in self.context_messages: + self.context_messages[context_id] = [] + self.context_messages[context_id].append(message) + + async def get_context_history(self, context_id: str, history_length: int | None = None) -> list[Message]: + """Get all messages across tasks in a context.""" + messages = self.context_messages.get(context_id, []) + if history_length: + return messages[-history_length:] + return messages diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index 0baaeba04..637a439f9 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -74,13 +74,13 @@ GetTaskRequest, GetTaskResponse, ResubscribeTaskRequest, - SendTaskRequest, - SendTaskResponse, - SendTaskStreamingRequest, - SendTaskStreamingResponse, + SendMessageRequest, + SendMessageResponse, SetTaskPushNotificationRequest, SetTaskPushNotificationResponse, + StreamMessageRequest, TaskNotFoundError, + TaskSendParams, ) from .storage import Storage @@ -111,19 +111,31 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): await self._aexit_stack.__aexit__(exc_type, exc_value, traceback) self._aexit_stack = None - async def send_task(self, request: SendTaskRequest) -> SendTaskResponse: - """Send a task to the worker.""" - request_id = str(uuid.uuid4()) - task_id = request['params']['id'] - task = await self.storage.load_task(task_id) + async def send_message(self, request: SendMessageRequest) -> SendMessageResponse: + """Send a message using the A2A v0.2.3 protocol.""" + request_id = request['id'] + task_id = str(uuid.uuid4()) + message = request['params']['message'] - if task is None: - session_id = request['params'].get('session_id', str(uuid.uuid4())) - message = request['params']['message'] - task = await self.storage.submit_task(task_id, session_id, message) + # Use provided context_id or create new one + context_id = message.get('context_id') or str(uuid.uuid4()) + + # Create a new task + task = await self.storage.submit_task(task_id, context_id, message) + + # Prepare params for broker + broker_params: TaskSendParams = { + 'id': task_id, + 'context_id': context_id, + 'message': message, + } + config = request['params'].get('configuration', {}) + history_length = config.get('history_length') + if history_length is not None: + broker_params['history_length'] = history_length - await self.broker.run_task(request['params']) - return SendTaskResponse(jsonrpc='2.0', id=request_id, result=task) + await self.broker.run_task(broker_params) + return SendMessageResponse(jsonrpc='2.0', id=request_id, result=task) async def get_task(self, request: GetTaskRequest) -> GetTaskResponse: """Get a task, and return it to the client. @@ -152,8 +164,9 @@ async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: ) return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - async def send_task_streaming(self, request: SendTaskStreamingRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('SendTaskStreaming is not implemented yet.') + async def stream_message(self, request: StreamMessageRequest) -> None: + """Stream messages using Server-Sent Events. Not implemented.""" + raise NotImplementedError('message/stream method is not implemented yet.') async def set_task_push_notification( self, request: SetTaskPushNotificationRequest @@ -165,5 +178,5 @@ async def get_task_push_notification( ) -> GetTaskPushNotificationResponse: raise NotImplementedError('GetTaskPushNotification is not implemented yet.') - async def resubscribe_task(self, request: ResubscribeTaskRequest) -> SendTaskStreamingResponse: + async def resubscribe_task(self, request: ResubscribeTaskRequest) -> None: raise NotImplementedError('Resubscribe is not implemented yet.') diff --git a/fasta2a/fasta2a/worker.py b/fasta2a/fasta2a/worker.py index 9bbde6b25..34a0cf565 100644 --- a/fasta2a/fasta2a/worker.py +++ b/fasta2a/fasta2a/worker.py @@ -62,7 +62,7 @@ async def run_task(self, params: TaskSendParams) -> None: ... async def cancel_task(self, params: TaskIdParams) -> None: ... @abstractmethod - def build_message_history(self, task_history: list[Message]) -> list[Any]: ... + def build_message_history(self, history: list[Message]) -> list[Any]: ... @abstractmethod def build_artifacts(self, result: Any) -> list[Artifact]: ... diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 99bbe37ad..b186bf78e 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -1,11 +1,13 @@ from __future__ import annotations, annotations as _annotations +import uuid from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager -from dataclasses import dataclass +from dataclasses import asdict, dataclass, is_dataclass from functools import partial -from typing import Any, Generic +from typing import Any, Generic, TypeVar +from pydantic import TypeAdapter from typing_extensions import assert_never from pydantic_ai.messages import ( @@ -25,6 +27,9 @@ from .agent import Agent, AgentDepsT, OutputDataT +# AgentWorker output type needs to be invariant for use in both parameter and return positions +WorkerOutputT = TypeVar('WorkerOutputT') + try: from starlette.middleware import Middleware from starlette.routing import Route @@ -34,6 +39,7 @@ from fasta2a.broker import Broker, InMemoryBroker from fasta2a.schema import ( Artifact, + DataPart, Message, Part, Provider, @@ -106,40 +112,127 @@ def agent_to_a2a( @dataclass -class AgentWorker(Worker, Generic[AgentDepsT, OutputDataT]): +# Generic parameters are reversed compared to Agent because AgentDepsT has a default +class AgentWorker(Worker, Generic[WorkerOutputT, AgentDepsT]): """A worker that uses an agent to execute tasks.""" - agent: Agent[AgentDepsT, OutputDataT] + agent: Agent[AgentDepsT, WorkerOutputT] async def run_task(self, params: TaskSendParams) -> None: task = await self.storage.load_task(params['id'], history_length=params.get('history_length')) - assert task is not None, f'Task {params["id"]} not found' - assert 'session_id' in task, 'Task must have a session_id' + if task is None: + raise ValueError(f'Task {params["id"]} not found') + if 'context_id' not in task: + raise ValueError('Task must have a context_id') + + task_id = task['id'] + context_id = task['context_id'] + + try: + await self.storage.update_task(task_id, state='working') + + # TODO(Marcelo): We need to have a way to communicate when the task is set to `input-required`. Maybe + # a custom `output_type` with a `more_info_required` field, or something like that. - await self.storage.update_task(task['id'], state='working') + history = task.get('history', []) + message_history = self.build_message_history(history) - # TODO(Marcelo): We need to have a way to communicate when the task is set to `input-required`. Maybe - # a custom `output_type` with a `more_info_required` field, or something like that. + # TODO(Marcelo): We need to make this more customizable e.g. pass deps. + result = await self.agent.run(message_history=message_history) # type: ignore - task_history = task.get('history', []) - message_history = self.build_message_history(task_history=task_history) + # Create both a message and artifact for the result + # This ensures the complete conversation is preserved in history while + # also marking the output as a durable artifact + message_id = str(uuid.uuid4()) - # TODO(Marcelo): We need to make this more customizable e.g. pass deps. - result = await self.agent.run(message_history=message_history) # type: ignore + # Create message parts based on output type + message_part = self._convert_result_to_part(result.output) + message_parts: list[Part] = [message_part] - artifacts = self.build_artifacts(result.output) - await self.storage.update_task(task['id'], state='completed', artifacts=artifacts) + # Add result as a message to preserve conversation flow + result_message = Message( + role='agent', + parts=message_parts, + kind='message', + message_id=message_id, + task_id=task_id, + context_id=context_id, + ) + await self.storage.add_message(result_message) + + # Also create artifacts for durable outputs + artifacts = self.build_artifacts(result.output) + await self.storage.update_task(task_id, state='completed', artifacts=artifacts) + + except Exception: + # Ensure task is marked as failed on any error + await self.storage.update_task(task_id, state='failed') + raise # Re-raise to maintain error visibility async def cancel_task(self, params: TaskIdParams) -> None: pass - def build_artifacts(self, result: Any) -> list[Artifact]: - # TODO(Marcelo): We need to send the json schema of the result on the metadata of the message. - return [Artifact(name='result', index=0, parts=[A2ATextPart(type='text', text=str(result))])] + def build_artifacts(self, result: WorkerOutputT) -> list[Artifact]: + """Build artifacts from agent result. + + All agent outputs become artifacts to mark them as durable task outputs. + For string results, we use TextPart. For structured data, we use DataPart. + Metadata is included to preserve type information. + """ + artifact_id = str(uuid.uuid4()) + part = self._convert_result_to_part(result) + metadata = self._build_result_metadata(result) + return [Artifact(artifact_id=artifact_id, name='result', parts=[part], metadata=metadata)] + + def _convert_result_to_part(self, result: WorkerOutputT) -> Part: + """Convert agent result to a Part (TextPart or DataPart). - def build_message_history(self, task_history: list[Message]) -> list[ModelMessage]: + For string results, returns a TextPart. + For structured data, returns a DataPart with properly serialized data. + """ + if isinstance(result, str): + return A2ATextPart(kind='text', text=result) + else: + # For structured data, create a DataPart + try: + # Try using TypeAdapter for proper serialization + output_type = type(result) + type_adapter: TypeAdapter[WorkerOutputT] = TypeAdapter(output_type) + data = type_adapter.dump_python(result, mode='json') + except Exception: + # Fallback for types that TypeAdapter can't handle + if is_dataclass(result) and not isinstance(result, type): + data = asdict(result) + else: + # Last resort - convert to string + data = str(result) + + return DataPart(kind='data', data=data) + + def _build_result_metadata(self, result: WorkerOutputT) -> dict[str, Any]: + """Build metadata for the result artifact. + + Captures type information and JSON schema when available. + """ + metadata: dict[str, Any] = { + 'type': type(result).__name__, + } + + # For non-string types, attempt to capture JSON schema + if not isinstance(result, str): + output_type = type(result) + type_adapter: TypeAdapter[WorkerOutputT] = TypeAdapter(output_type) + try: + metadata['json_schema'] = type_adapter.json_schema() + except Exception: + # Some types don't support JSON schema generation + pass + + return metadata + + def build_message_history(self, history: list[Message]) -> list[ModelMessage]: model_messages: list[ModelMessage] = [] - for message in task_history: + for message in history: if message['role'] == 'user': model_messages.append(ModelRequest(parts=self._map_request_parts(message['parts']))) else: @@ -149,18 +242,19 @@ def build_message_history(self, task_history: list[Message]) -> list[ModelMessag def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: model_parts: list[ModelRequestPart] = [] for part in parts: - if part['type'] == 'text': + if part['kind'] == 'text': model_parts.append(UserPromptPart(content=part['text'])) - elif part['type'] == 'file': + elif part['kind'] == 'file': file = part['file'] if 'data' in file: data = file['data'].encode('utf-8') content = BinaryContent(data=data, media_type=file['mime_type']) model_parts.append(UserPromptPart(content=[content])) - else: - url = file['url'] + elif 'uri' in file: + uri = file['uri'] + mime_type = file.get('mime_type', 'application/octet-stream') for url_cls in (DocumentUrl, AudioUrl, ImageUrl, VideoUrl): - content = url_cls(url=url) + content = url_cls(url=uri) try: content.media_type except ValueError: # pragma: no cover @@ -168,9 +262,9 @@ def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: else: break else: - raise ValueError(f'Unknown file type: {file["mime_type"]}') # pragma: no cover + raise ValueError(f'Unknown file type: {mime_type}') # pragma: no cover model_parts.append(UserPromptPart(content=[content])) - elif part['type'] == 'data': + elif part['kind'] == 'data': # TODO(Marcelo): Maybe we should use this for `ToolReturnPart`, and `RetryPromptPart`. raise NotImplementedError('Data parts are not supported yet.') else: @@ -180,11 +274,11 @@ def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: def _map_response_parts(self, parts: list[Part]) -> list[ModelResponsePart]: model_parts: list[ModelResponsePart] = [] for part in parts: - if part['type'] == 'text': + if part['kind'] == 'text': model_parts.append(TextPart(content=part['text'])) - elif part['type'] == 'file': # pragma: no cover + elif part['kind'] == 'file': # pragma: no cover raise NotImplementedError('File parts are not supported yet.') - elif part['type'] == 'data': # pragma: no cover + elif part['kind'] == 'data': # pragma: no cover raise NotImplementedError('Data parts are not supported yet.') else: # pragma: no cover assert_never(part) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index fae117781..3c7bcf4f7 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -12,7 +12,7 @@ with try_import() as imports_successful: from fasta2a.client import A2AClient - from fasta2a.schema import DataPart, FilePart, Message, TextPart + from fasta2a.schema import DataPart, FilePart, Message, TextPart, is_task from fasta2a.storage import InMemoryStorage @@ -40,10 +40,10 @@ async def test_a2a_runtime_error_without_lifespan(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', type='text')]) + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') with pytest.raises(RuntimeError, match='TaskManager was not properly initialized.'): - await a2a_client.send_task(message=message) + await a2a_client.send_message(message=message) async def test_a2a_simple(): @@ -55,23 +55,31 @@ async def test_a2a_simple(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', type='text')]) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}], - }, + 'context_id': IsStr(), + 'kind': 'task', + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'completed': @@ -83,11 +91,36 @@ async def test_a2a_simple(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}], + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + }, + { + 'role': 'agent', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, + ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'metadata': { + 'type': 'tuple', + 'json_schema': {'items': {}, 'type': 'array'}, + }, + } ], }, } @@ -107,37 +140,41 @@ async def test_a2a_file_message_with_file(): role='user', parts=[ FilePart( - type='file', - file={'url': 'https://example.com/file.txt', 'mime_type': 'text/plain'}, + kind='file', + file={'uri': 'https://example.com/file.txt', 'mime_type': 'text/plain'}, ) ], + kind='message', ) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [ - { - 'role': 'user', - 'parts': [ - { - 'type': 'file', - 'file': {'mime_type': 'text/plain', 'url': 'https://example.com/file.txt'}, - } - ], - } - ], - }, + 'context_id': IsStr(), + 'kind': 'task', + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [ + { + 'kind': 'file', + 'file': {'mime_type': 'text/plain', 'uri': 'https://example.com/file.txt'}, + } + ], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'completed': @@ -149,21 +186,41 @@ async def test_a2a_file_message_with_file(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { 'role': 'user', 'parts': [ { - 'type': 'file', - 'file': {'mime_type': 'text/plain', 'url': 'https://example.com/file.txt'}, + 'kind': 'file', + 'file': {'mime_type': 'text/plain', 'uri': 'https://example.com/file.txt'}, } ], - } + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + }, + { + 'role': 'agent', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'metadata': { + 'type': 'tuple', + 'json_schema': {'items': {}, 'type': 'array'}, + }, + } ], }, } @@ -182,30 +239,34 @@ async def test_a2a_file_message_with_file_content(): message = Message( role='user', parts=[ - FilePart(type='file', file={'data': 'foo', 'mime_type': 'text/plain'}), + FilePart(file={'data': 'foo', 'mime_type': 'text/plain'}, kind='file'), ], + kind='message', ) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [ - { - 'role': 'user', - 'parts': [{'type': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], - } - ], - }, + 'context_id': IsStr(), + 'kind': 'task', + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'completed': @@ -217,16 +278,36 @@ async def test_a2a_file_message_with_file_content(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { 'role': 'user', - 'parts': [{'type': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], - } + 'parts': [{'kind': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + }, + { + 'role': 'agent', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'metadata': { + 'type': 'tuple', + 'json_schema': {'items': {}, 'type': 'array'}, + }, + } ], }, } @@ -244,24 +325,33 @@ async def test_a2a_file_message_with_data(): message = Message( role='user', - parts=[DataPart(type='data', data={'foo': 'bar'})], + parts=[DataPart(kind='data', data={'foo': 'bar'})], + kind='message', ) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'data', 'data': {'foo': 'bar'}}]}], - }, + 'context_id': IsStr(), + 'kind': 'task', + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'failed': @@ -273,14 +363,55 @@ async def test_a2a_file_message_with_data(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'failed', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'data', 'data': {'foo': 'bar'}}]}], + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], }, } ) +async def test_a2a_error_handling(): + """Test that errors during task execution properly update task state.""" + + def raise_error(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + raise RuntimeError('Test error during agent execution') + + error_model = FunctionModel(raise_error) + agent = Agent(model=error_model, output_type=str) + app = agent.to_a2a() + + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + + task_id = result['id'] + + # Wait for task to fail + await anyio.sleep(0.1) + task = await a2a_client.get_task(task_id) + + assert 'result' in task + assert task['result']['status']['state'] == 'failed' + + async def test_a2a_multiple_messages(): agent = Agent(model=model, output_type=tuple[str, str]) storage = InMemoryStorage() @@ -291,27 +422,40 @@ async def test_a2a_multiple_messages(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', type='text')]) - response = await a2a_client.send_task(message=message) + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') + response = await a2a_client.send_message(message=message) assert response == snapshot( { 'jsonrpc': '2.0', 'id': IsStr(), 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}], + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ], }, } ) # NOTE: We include the agent history before we start working on the task. assert 'result' in response - task_id = response['result']['id'] + result = response['result'] + assert is_task(result) + task_id = result['id'] task = storage.tasks[task_id] assert 'history' in task - task['history'].append(Message(role='agent', parts=[TextPart(text='Whats up?', type='text')])) + task['history'].append( + Message(role='agent', parts=[TextPart(text='Whats up?', kind='text')], kind='message') + ) response = await a2a_client.get_task(task_id) assert response == snapshot( @@ -320,11 +464,18 @@ async def test_a2a_multiple_messages(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}, - {'role': 'agent', 'parts': [{'type': 'text', 'text': 'Whats up?'}]}, + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + }, + {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, ], }, } @@ -338,14 +489,37 @@ async def test_a2a_multiple_messages(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}, - {'role': 'agent', 'parts': [{'type': 'text', 'text': 'Whats up?'}]}, + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + }, + {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, + { + 'role': 'agent', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'metadata': { + 'type': 'tuple', + 'json_schema': {'items': {}, 'type': 'array'}, + }, + } ], }, } diff --git a/tests/fasta2a/__init__.py b/tests/test_fasta2a/__init__.py similarity index 100% rename from tests/fasta2a/__init__.py rename to tests/test_fasta2a/__init__.py diff --git a/tests/fasta2a/test_applications.py b/tests/test_fasta2a/test_applications.py similarity index 100% rename from tests/fasta2a/test_applications.py rename to tests/test_fasta2a/test_applications.py From 22e825f8b847c6e5ecf1239c08ed5f153d28fa8e Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Sun, 6 Jul 2025 05:59:02 +0000 Subject: [PATCH 02/24] test: add comprehensive test for Pydantic model outputs with metadata - Test that Pydantic model outputs are correctly serialized as DataPart - Verify metadata includes type name and JSON schema - Ensure dual message/artifact approach works for complex types - Confirm that both message history and artifacts contain the data --- tests/test_a2a.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 3c7bcf4f7..946d55dcb 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -3,6 +3,7 @@ import pytest from asgi_lifespan import LifespanManager from inline_snapshot import snapshot +from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart @@ -32,6 +33,83 @@ def return_string(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: model = FunctionModel(return_string) +# Define a test Pydantic model +class UserProfile(BaseModel): + name: str + age: int + email: str + + +def return_pydantic_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"name": "John Doe", "age": 30, "email": "john@example.com"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + +pydantic_model = FunctionModel(return_pydantic_model) + + +async def test_a2a_pydantic_model_output(): + """Test that Pydantic model outputs have correct metadata including JSON schema.""" + agent = Agent(model=pydantic_model, output_type=UserProfile) + app = agent.to_a2a() + + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + message = Message(role='user', parts=[TextPart(text='Get user profile', kind='text')], kind='message') + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + + task_id = result['id'] + + # Wait for completion + await anyio.sleep(0.1) + task = await a2a_client.get_task(task_id) + + assert 'result' in task + result = task['result'] + assert result['status']['state'] == 'completed' + + # Check artifacts + assert 'artifacts' in result + assert len(result['artifacts']) == 1 + artifact = result['artifacts'][0] + + # Verify the data + assert artifact['parts'][0]['kind'] == 'data' + assert artifact['parts'][0]['data'] == {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'} + + # Verify metadata + assert 'metadata' in artifact + metadata = artifact['metadata'] + assert metadata['type'] == 'UserProfile' + + # Verify JSON schema is present and correct + assert 'json_schema' in metadata + json_schema = metadata['json_schema'] + assert json_schema['type'] == 'object' + assert 'properties' in json_schema + assert set(json_schema['properties'].keys()) == {'name', 'age', 'email'} + assert json_schema['properties']['name']['type'] == 'string' + assert json_schema['properties']['age']['type'] == 'integer' + assert json_schema['properties']['email']['type'] == 'string' + assert json_schema['required'] == ['name', 'age', 'email'] + + # Check the message history also has the data + assert 'history' in result + assert len(result['history']) == 2 + agent_message = result['history'][1] + assert agent_message['role'] == 'agent' + assert agent_message['parts'][0]['kind'] == 'data' + assert agent_message['parts'][0]['data'] == {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'} + + async def test_a2a_runtime_error_without_lifespan(): agent = Agent(model=model, output_type=tuple[str, str]) app = agent.to_a2a() From a6d958b47b0c31cc1dd8f3476ecc295eb11840df Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 7 Jul 2025 12:19:13 +0000 Subject: [PATCH 03/24] feat: add context storage for conversation continuity - Add update_context() and get_context() methods to Storage - Store full pydantic-ai message history (including tool calls) in context - Preserve conversation state across multiple tasks with same context_id - Update docs to explain task vs context distinction - Add test for monotonic message history growth - Clean up run_task: remove history_length, add state check, fix comments --- docs/a2a.md | 28 +++- fasta2a/fasta2a/storage.py | 77 ++++++----- pydantic_ai_slim/pydantic_ai/_a2a.py | 185 +++++++++++++++++++++------ tests/test_a2a.py | 96 +++++++++++++- 4 files changed, 303 insertions(+), 83 deletions(-) diff --git a/docs/a2a.md b/docs/a2a.md index 28f7093fd..ca32e8419 100644 --- a/docs/a2a.md +++ b/docs/a2a.md @@ -32,7 +32,7 @@ The library is designed to be used with any agentic framework, and is **not excl Given the nature of the A2A protocol, it's important to understand the design before using it, as a developer you'll need to provide some components: -- [`Storage`][fasta2a.Storage]: to save and load tasks +- [`Storage`][fasta2a.Storage]: to save and load tasks, as well as store context for conversations - [`Broker`][fasta2a.Broker]: to schedule tasks - [`Worker`][fasta2a.Worker]: to execute tasks @@ -55,6 +55,28 @@ flowchart TB FastA2A allows you to bring your own [`Storage`][fasta2a.Storage], [`Broker`][fasta2a.Broker] and [`Worker`][fasta2a.Worker]. +#### Understanding Tasks and Context + +In the A2A protocol: + +- **Task**: Represents one complete execution of an agent. When a client sends a message to the agent, a new task is created. The agent runs until completion (or failure), and this entire execution is considered one task. The final output is stored as a task artifact. + +- **Context**: Represents a conversation thread that can span multiple tasks. The A2A protocol uses a `context_id` to maintain conversation continuity: + - When a new message is sent without a `context_id`, the server generates a new one + - Subsequent messages can include the same `context_id` to continue the conversation + - All tasks sharing the same `context_id` have access to the complete message history + +#### Storage Architecture + +The [`Storage`][fasta2a.Storage] component serves two purposes: + +1. **Task Storage**: Stores tasks in A2A protocol format, including their status, artifacts, and message history +2. **Context Storage**: Stores conversation context in a format optimized for the specific agent implementation + +This dual-purpose design allows flexibility for agents to store rich internal state (e.g., tool calls, reasoning traces) while maintaining efficient conversation continuity across multiple task executions. + +For example, a PydanticAI agent might store its complete internal message format (including tool calls and responses) in the context storage, while storing only the A2A-compliant messages in the task history. + ### Installation @@ -94,3 +116,7 @@ uvicorn agent_to_a2a:app --host 0.0.0.0 --port 8000 ``` Since the goal of `to_a2a` is to be a convenience method, it accepts the same arguments as the [`FastA2A`][fasta2a.FastA2A] constructor. + +When using `to_a2a()`, PydanticAI automatically: +- Stores the complete conversation history (including tool calls and responses) in the context storage +- Ensures that subsequent messages with the same `context_id` have access to the full conversation history diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index 03f04c34e..6fcb8205e 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -4,14 +4,17 @@ from abc import ABC, abstractmethod from datetime import datetime +from typing import Any from .schema import Artifact, Message, Task, TaskState, TaskStatus class Storage(ABC): - """A storage to retrieve and save tasks. + """A storage to retrieve and save tasks, as well as retrieve and save context. - The storage is used to update the status of a task and to save the result of a task. + The storage serves two purposes: + 1. Task storage: Stores tasks in A2A protocol format with their status, artifacts, and message history + 2. Context storage: Stores conversation context in a format optimized for the specific agent implementation """ @abstractmethod @@ -30,21 +33,21 @@ async def update_task( self, task_id: str, state: TaskState, - artifacts: list[Artifact] | None = None, + new_artifacts: list[Artifact] | None = None, + new_messages: list[Message] | None = None, ) -> Task: - """Update the state of a task.""" + """Update the state of a task. Appends artifacts and messages, if specified.""" @abstractmethod - async def add_message(self, message: Message) -> None: - """Add a message to the history for both its task and context. + async def update_context(self, context_id: str, context: Any) -> None: + """Updates the context for a context_id. - This should be called for messages created during task execution, - not for the initial message (which is handled by submit_task). + Implementing agent can decide what to store in context. """ @abstractmethod - async def get_context_history(self, context_id: str, history_length: int | None = None) -> list[Message]: - """Get all messages across tasks in a context.""" + async def get_context(self, context_id: str) -> Any: + """Retrieve the stored context for a context_id.""" class InMemoryStorage(Storage): @@ -52,7 +55,7 @@ class InMemoryStorage(Storage): def __init__(self): self.tasks: dict[str, Task] = {} - self.context_messages: dict[str, list[Message]] = {} + self.contexts: dict[str, Any] = {} async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: """Load a task from memory. @@ -77,7 +80,7 @@ async def submit_task(self, task_id: str, context_id: str, message: Message) -> if task_id in self.tasks: raise ValueError(f'Task {task_id} already exists') - # Add IDs to the message + # Add IDs to the message for A2A protocol message['task_id'] = task_id message['context_id'] = context_id @@ -85,47 +88,39 @@ async def submit_task(self, task_id: str, context_id: str, message: Message) -> task = Task(id=task_id, context_id=context_id, kind='task', status=task_status, history=[message]) self.tasks[task_id] = task - # Add message to context storage directly (not via add_message to avoid duplication) - if context_id not in self.context_messages: - self.context_messages[context_id] = [] - self.context_messages[context_id].append(message) - return task async def update_task( self, task_id: str, state: TaskState, - artifacts: list[Artifact] | None = None, + new_artifacts: list[Artifact] | None = None, + new_messages: list[Message] | None = None, ) -> Task: """Update the state of a task.""" task = self.tasks[task_id] task['status'] = TaskStatus(state=state, timestamp=datetime.now().isoformat()) - if artifacts: + + if new_artifacts: if 'artifacts' not in task: task['artifacts'] = [] - task['artifacts'].extend(artifacts) + task['artifacts'].extend(new_artifacts) + + if new_messages: + if 'history' not in task: + task['history'] = [] + # Add IDs to messages for consistency + for message in new_messages: + message['task_id'] = task_id + message['context_id'] = task['context_id'] + task['history'].append(message) + return task - async def add_message(self, message: Message) -> None: - """Add a message to the history for both its task and context.""" - if 'task_id' in message and message['task_id']: - task_id = message['task_id'] - if task_id in self.tasks: - task = self.tasks[task_id] - if 'history' not in task: - task['history'] = [] - task['history'].append(message) + async def update_context(self, context_id: str, context: Any) -> None: + """Updates the context for a context_id.""" + self.contexts[context_id] = context - if 'context_id' in message and message['context_id']: - context_id = message['context_id'] - if context_id not in self.context_messages: - self.context_messages[context_id] = [] - self.context_messages[context_id].append(message) - - async def get_context_history(self, context_id: str, history_length: int | None = None) -> list[Message]: - """Get all messages across tasks in a context.""" - messages = self.context_messages.get(context_id, []) - if history_length: - return messages[-history_length:] - return messages + async def get_context(self, context_id: str) -> Any: + """Retrieve the stored context for a context_id.""" + return self.contexts.get(context_id) diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index b186bf78e..21aa0cae2 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -21,6 +21,8 @@ ModelResponse, ModelResponsePart, TextPart, + ThinkingPart, + ToolCallPart, UserPromptPart, VideoUrl, ) @@ -119,12 +121,16 @@ class AgentWorker(Worker, Generic[WorkerOutputT, AgentDepsT]): agent: Agent[AgentDepsT, WorkerOutputT] async def run_task(self, params: TaskSendParams) -> None: - task = await self.storage.load_task(params['id'], history_length=params.get('history_length')) + task = await self.storage.load_task(params['id']) if task is None: raise ValueError(f'Task {params["id"]} not found') if 'context_id' not in task: raise ValueError('Task must have a context_id') + # Ensure this task hasn't been run before + if task['status']['state'] != 'submitted': + raise ValueError(f'Task {params["id"]} has already been processed (state: {task["status"]["state"]})') + task_id = task['id'] context_id = task['context_id'] @@ -134,8 +140,17 @@ async def run_task(self, params: TaskSendParams) -> None: # TODO(Marcelo): We need to have a way to communicate when the task is set to `input-required`. Maybe # a custom `output_type` with a `more_info_required` field, or something like that. - history = task.get('history', []) - message_history = self.build_message_history(history) + # Load context - contains pydantic-ai message history from previous tasks in this conversation + context = await self.storage.get_context(context_id) + message_history: list[ModelMessage] = context if context else [] + + # Add the current task's initial message to the history + # Tasks start with a user message that triggered this task + if task.get('history'): + for a2a_msg in task['history']: + if a2a_msg['role'] == 'user': + # Convert user message to pydantic-ai format + message_history.append(ModelRequest(parts=self._request_parts_from_a2a(a2a_msg['parts']))) # TODO(Marcelo): We need to make this more customizable e.g. pass deps. result = await self.agent.run(message_history=message_history) # type: ignore @@ -145,24 +160,55 @@ async def run_task(self, params: TaskSendParams) -> None: # also marking the output as a durable artifact message_id = str(uuid.uuid4()) - # Create message parts based on output type - message_part = self._convert_result_to_part(result.output) - message_parts: list[Part] = [message_part] - - # Add result as a message to preserve conversation flow - result_message = Message( - role='agent', - parts=message_parts, - kind='message', - message_id=message_id, - task_id=task_id, - context_id=context_id, - ) - await self.storage.add_message(result_message) - - # Also create artifacts for durable outputs + # Update context with complete message history including new messages + # This preserves tool calls, thinking, and all internal state + all_messages = result.all_messages() + await self.storage.update_context(context_id, all_messages) + + # Convert new messages to A2A format for task history + new_messages = result.new_messages() + a2a_messages: list[Message] = [] + + for msg in new_messages: + if isinstance(msg, ModelRequest): + # Skip user prompts - they're already in task history + continue + elif isinstance(msg, ModelResponse): + # Convert response parts to A2A format + a2a_parts = self._response_parts_to_a2a(msg.parts) + if a2a_parts: # Add if there are visible parts (text/thinking) + a2a_messages.append( + Message( + role='agent', + parts=a2a_parts, + kind='message', + message_id=str(uuid.uuid4()), + ) + ) + + # Also add the final output as a message if it's not just text + # This ensures structured outputs appear in the message history + if result.output and not isinstance(result.output, str): + output_part = self._convert_result_to_part(result.output) + a2a_messages.append( + Message( + role='agent', + parts=[output_part], + kind='message', + message_id=message_id, + ) + ) + + # Create artifacts for durable outputs artifacts = self.build_artifacts(result.output) - await self.storage.update_task(task_id, state='completed', artifacts=artifacts) + + # Update task with completion status, new A2A messages, and artifacts + await self.storage.update_task( + task_id, + state='completed', + new_artifacts=artifacts, + new_messages=a2a_messages if a2a_messages else None, + ) except Exception: # Ensure task is marked as failed on any error @@ -234,36 +280,48 @@ def build_message_history(self, history: list[Message]) -> list[ModelMessage]: model_messages: list[ModelMessage] = [] for message in history: if message['role'] == 'user': - model_messages.append(ModelRequest(parts=self._map_request_parts(message['parts']))) + model_messages.append(ModelRequest(parts=self._request_parts_from_a2a(message['parts']))) else: - model_messages.append(ModelResponse(parts=self._map_response_parts(message['parts']))) + model_messages.append(ModelResponse(parts=self._response_parts_from_a2a(message['parts']))) return model_messages - def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: + def _request_parts_from_a2a(self, parts: list[Part]) -> list[ModelRequestPart]: + """Convert A2A Part objects to pydantic-ai ModelRequestPart objects. + + This handles the conversion from A2A protocol parts (text, file, data) to + pydantic-ai's internal request parts (UserPromptPart with various content types). + + Args: + parts: List of A2A Part objects from incoming messages + + Returns: + List of ModelRequestPart objects for the pydantic-ai agent + """ model_parts: list[ModelRequestPart] = [] for part in parts: if part['kind'] == 'text': model_parts.append(UserPromptPart(content=part['text'])) elif part['kind'] == 'file': - file = part['file'] - if 'data' in file: - data = file['data'].encode('utf-8') - content = BinaryContent(data=data, media_type=file['mime_type']) + file_content = part['file'] + if 'data' in file_content: + data = file_content['data'].encode('utf-8') + mime_type = file_content.get('mime_type', 'application/octet-stream') + content = BinaryContent(data=data, media_type=mime_type) model_parts.append(UserPromptPart(content=[content])) - elif 'uri' in file: - uri = file['uri'] - mime_type = file.get('mime_type', 'application/octet-stream') - for url_cls in (DocumentUrl, AudioUrl, ImageUrl, VideoUrl): - content = url_cls(url=uri) - try: - content.media_type - except ValueError: # pragma: no cover - continue - else: - break + elif 'uri' in file_content: + url = file_content['uri'] + mime_type = file_content.get('mime_type', '') + if mime_type.startswith('image/'): + content = ImageUrl(url=url) + elif mime_type.startswith('audio/'): + content = AudioUrl(url=url) + elif mime_type.startswith('video/'): + content = VideoUrl(url=url) else: - raise ValueError(f'Unknown file type: {mime_type}') # pragma: no cover + content = DocumentUrl(url=url) model_parts.append(UserPromptPart(content=[content])) + else: + raise ValueError('FilePart.file must have either data or uri') elif part['kind'] == 'data': # TODO(Marcelo): Maybe we should use this for `ToolReturnPart`, and `RetryPromptPart`. raise NotImplementedError('Data parts are not supported yet.') @@ -271,7 +329,19 @@ def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: assert_never(part) return model_parts - def _map_response_parts(self, parts: list[Part]) -> list[ModelResponsePart]: + def _response_parts_from_a2a(self, parts: list[Part]) -> list[ModelResponsePart]: + """Convert A2A Part objects to pydantic-ai ModelResponsePart objects. + + This handles the conversion from A2A protocol parts (text, file, data) to + pydantic-ai's internal response parts. Currently only supports text parts + as agent responses in A2A are expected to be text-based. + + Args: + parts: List of A2A Part objects from stored agent messages + + Returns: + List of ModelResponsePart objects for message history + """ model_parts: list[ModelResponsePart] = [] for part in parts: if part['kind'] == 'text': @@ -283,3 +353,38 @@ def _map_response_parts(self, parts: list[Part]) -> list[ModelResponsePart]: else: # pragma: no cover assert_never(part) return model_parts + + def _response_parts_to_a2a(self, parts: list[ModelResponsePart]) -> list[Part]: + """Convert pydantic-ai ModelResponsePart objects to A2A Part objects. + + This handles the conversion from pydantic-ai's internal response parts to + A2A protocol parts. Different part types are handled as follows: + - TextPart: Converted directly to A2A TextPart + - ThinkingPart: Converted to TextPart with metadata indicating it's thinking + - ToolCallPart: Skipped (internal to agent execution) + + Args: + parts: List of ModelResponsePart objects from agent response + + Returns: + List of A2A Part objects suitable for sending via A2A protocol + """ + a2a_parts: list[Part] = [] + for part in parts: + if isinstance(part, TextPart): + if part.content: # Only add non-empty text + a2a_parts.append(A2ATextPart(kind='text', text=part.content)) + elif isinstance(part, ThinkingPart): + if part.content: # Only add non-empty thinking + # Convert thinking to text with metadata + a2a_parts.append( + A2ATextPart( + kind='text', + text=part.content, + metadata={'type': 'thinking', 'thinking_id': part.id, 'signature': part.signature}, + ) + ) + elif isinstance(part, ToolCallPart): + # Skip tool calls - they're internal to agent execution + pass + return a2a_parts diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 946d55dcb..2df179f85 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from pydantic_ai import Agent -from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, ToolCallPart from pydantic_ai.models.function import AgentInfo, FunctionModel from .conftest import IsDatetime, IsStr, try_import @@ -490,6 +490,100 @@ def raise_error(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert task['result']['status']['state'] == 'failed' +async def test_a2a_multiple_tasks_same_context(): + """Test that multiple tasks can share the same context_id with accumulated history.""" + + messages_received = [] + + def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Store a copy of the messages received by the model + messages_received.append(messages.copy()) + # Return the standard response + assert info.output_tools is not None + args_json = '{"response": ["foo", "bar"]}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + tracking_model = FunctionModel(track_messages) + agent = Agent(model=tracking_model, output_type=tuple[str, str]) + app = agent.to_a2a() + + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + # First message - should create a new context + message1 = Message(role='user', parts=[TextPart(text='First message', kind='text')], kind='message') + response1 = await a2a_client.send_message(message=message1) + assert 'error' not in response1 + assert 'result' in response1 + result1 = response1['result'] + assert is_task(result1) + + task1_id = result1['id'] + context_id = result1['context_id'] + + # Wait for first task to complete + await anyio.sleep(0.1) + task1 = await a2a_client.get_task(task1_id) + assert 'result' in task1 + assert task1['result']['status']['state'] == 'completed' + + # Verify the model received at least one message + assert len(messages_received) == 1 + first_run_history = messages_received[0] + assert len(first_run_history) >= 1 + assert first_run_history[0].parts[0].content == 'First message' + + # Second message - reuse the same context_id + message2 = Message( + role='user', parts=[TextPart(text='Second message', kind='text')], kind='message', context_id=context_id + ) + response2 = await a2a_client.send_message(message=message2) + assert 'error' not in response2 + assert 'result' in response2 + result2 = response2['result'] + assert is_task(result2) + + task2_id = result2['id'] + # Verify we got a new task ID but same context ID + assert task2_id != task1_id + assert result2['context_id'] == context_id + + # Wait for second task to complete + await anyio.sleep(0.1) + task2 = await a2a_client.get_task(task2_id) + assert 'result' in task2 + if task2['result']['status']['state'] == 'failed': + print(f'Task 2 failed: {task2}') + print(f'Messages received so far: {messages_received}') + assert task2['result']['status']['state'] == 'completed' + + # Verify the model received the full history on the second call + assert len(messages_received) == 2 + second_run_history = messages_received[1] + + # Check that history is monotonically increasing - all previous messages should be there + assert len(second_run_history) > len(first_run_history), ( + f'Expected more messages, got {len(second_run_history)} <= {len(first_run_history)}' + ) + + # Check that all messages from first run are still in second run (in same order) + for i, msg in enumerate(first_run_history): + assert second_run_history[i] == msg, f'Message {i} changed between runs' + + # Verify the new message is there + # Find the user message with 'Second message' in the new history + found_second_message = False + for msg in second_run_history: + if isinstance(msg, ModelRequest): + for part in msg.parts: + if hasattr(part, 'content') and part.content == 'Second message': + found_second_message = True + break + assert found_second_message, 'Second message not found in history' + + async def test_a2a_multiple_messages(): agent = Agent(model=model, output_type=tuple[str, str]) storage = InMemoryStorage() From 39019542c67ec3cd56e05143dc58048ea3d04c9c Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 7 Jul 2025 12:23:21 +0000 Subject: [PATCH 04/24] Fix types --- pydantic_ai_slim/pydantic_ai/_a2a.py | 5 +++-- tests/test_a2a.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 21aa0cae2..a8a3b254f 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -146,8 +146,9 @@ async def run_task(self, params: TaskSendParams) -> None: # Add the current task's initial message to the history # Tasks start with a user message that triggered this task - if task.get('history'): - for a2a_msg in task['history']: + task_history = task.get('history') + if task_history: + for a2a_msg in task_history: if a2a_msg['role'] == 'user': # Convert user message to pydantic-ai format message_history.append(ModelRequest(parts=self._request_parts_from_a2a(a2a_msg['parts']))) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 2df179f85..70a4b55c6 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -493,7 +493,7 @@ def raise_error(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: async def test_a2a_multiple_tasks_same_context(): """Test that multiple tasks can share the same context_id with accumulated history.""" - messages_received = [] + messages_received: list[list[ModelMessage]] = [] def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # Store a copy of the messages received by the model @@ -533,7 +533,11 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon assert len(messages_received) == 1 first_run_history = messages_received[0] assert len(first_run_history) >= 1 - assert first_run_history[0].parts[0].content == 'First message' + # Check first message is a ModelRequest with UserPromptPart + first_msg = first_run_history[0] + assert isinstance(first_msg, ModelRequest) + first_part = first_msg.parts[0] + assert hasattr(first_part, 'content') and first_part.content == 'First message' # Second message - reuse the same context_id message2 = Message( From 55f3cedaa13463651a9d120743199bc2afd1612b Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 7 Jul 2025 12:42:17 +0000 Subject: [PATCH 05/24] Fix misattributed TODOs -- oops! --- pydantic_ai_slim/pydantic_ai/_a2a.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index a8a3b254f..5cc6be9bd 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -137,9 +137,6 @@ async def run_task(self, params: TaskSendParams) -> None: try: await self.storage.update_task(task_id, state='working') - # TODO(Marcelo): We need to have a way to communicate when the task is set to `input-required`. Maybe - # a custom `output_type` with a `more_info_required` field, or something like that. - # Load context - contains pydantic-ai message history from previous tasks in this conversation context = await self.storage.get_context(context_id) message_history: list[ModelMessage] = context if context else [] @@ -153,7 +150,6 @@ async def run_task(self, params: TaskSendParams) -> None: # Convert user message to pydantic-ai format message_history.append(ModelRequest(parts=self._request_parts_from_a2a(a2a_msg['parts']))) - # TODO(Marcelo): We need to make this more customizable e.g. pass deps. result = await self.agent.run(message_history=message_history) # type: ignore # Create both a message and artifact for the result @@ -324,7 +320,6 @@ def _request_parts_from_a2a(self, parts: list[Part]) -> list[ModelRequestPart]: else: raise ValueError('FilePart.file must have either data or uri') elif part['kind'] == 'data': - # TODO(Marcelo): Maybe we should use this for `ToolReturnPart`, and `RetryPromptPart`. raise NotImplementedError('Data parts are not supported yet.') else: assert_never(part) From 0a95bf9b090f8edf010b2bf0905d0b31f1a6946a Mon Sep 17 00:00:00 2001 From: Rob Porter Date: Mon, 7 Jul 2025 21:15:21 -0700 Subject: [PATCH 06/24] Update docs/a2a.md Co-authored-by: Marcelo Trylesinski --- docs/a2a.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/a2a.md b/docs/a2a.md index ca32e8419..8646864d7 100644 --- a/docs/a2a.md +++ b/docs/a2a.md @@ -118,5 +118,6 @@ uvicorn agent_to_a2a:app --host 0.0.0.0 --port 8000 Since the goal of `to_a2a` is to be a convenience method, it accepts the same arguments as the [`FastA2A`][fasta2a.FastA2A] constructor. When using `to_a2a()`, PydanticAI automatically: + - Stores the complete conversation history (including tool calls and responses) in the context storage - Ensures that subsequent messages with the same `context_id` have access to the full conversation history From 578590158a21a042157844a1b43aad4475289c8f Mon Sep 17 00:00:00 2001 From: Rob Porter Date: Mon, 7 Jul 2025 21:18:42 -0700 Subject: [PATCH 07/24] Update fasta2a/fasta2a/applications.py Co-authored-by: Marcelo Trylesinski --- fasta2a/fasta2a/applications.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 9de5f9433..8ad1e379b 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -118,8 +118,7 @@ async def _agent_run_endpoint(self, request: Request) -> Response: a2a_request = a2a_request_ta.validate_json(data) if a2a_request['method'] == 'message/send': - message_request = send_message_request_ta.validate_json(data) - jsonrpc_response = await self.task_manager.send_message(message_request) + jsonrpc_response = await self.task_manager.send_message(a2a_request) elif a2a_request['method'] == 'message/stream': # Streaming support not yet implemented raise NotImplementedError( From d1ff90a3d2e58f1405abc85cd68373e3f7a80fb7 Mon Sep 17 00:00:00 2001 From: Rob Porter Date: Mon, 7 Jul 2025 21:19:10 -0700 Subject: [PATCH 08/24] Update fasta2a/fasta2a/applications.py Co-authored-by: Marcelo Trylesinski --- fasta2a/fasta2a/applications.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 8ad1e379b..fe996de75 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -120,7 +120,6 @@ async def _agent_run_endpoint(self, request: Request) -> Response: if a2a_request['method'] == 'message/send': jsonrpc_response = await self.task_manager.send_message(a2a_request) elif a2a_request['method'] == 'message/stream': - # Streaming support not yet implemented raise NotImplementedError( 'message/stream method is not implemented yet. Streaming support will be added in a future update.' ) From 15df5242643919425c0a98b03afb11a6826c4511 Mon Sep 17 00:00:00 2001 From: Rob Porter Date: Mon, 7 Jul 2025 21:29:36 -0700 Subject: [PATCH 09/24] Update fasta2a/fasta2a/schema.py Co-authored-by: Marcelo Trylesinski --- fasta2a/fasta2a/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 1ccf98055..fbd6a4a70 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -152,7 +152,7 @@ class Artifact(TypedDict): metadata: NotRequired[dict[str, Any]] """Metadata about the artifact.""" - extensions: NotRequired[list[Any]] + extensions: NotRequired[list[str]] """Array of extensions.""" append: NotRequired[bool] From 1a245a77af393ce2bfc7f8429add9e05a3072eb3 Mon Sep 17 00:00:00 2001 From: Rob Porter Date: Mon, 7 Jul 2025 21:31:11 -0700 Subject: [PATCH 10/24] Update fasta2a/fasta2a/schema.py Co-authored-by: Marcelo Trylesinski --- fasta2a/fasta2a/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index fbd6a4a70..5bfe19daa 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -236,7 +236,7 @@ class Message(TypedDict): """Metadata about the message.""" # Additional fields - message_id: NotRequired[str] + message_id: str """Identifier created by the message creator.""" context_id: NotRequired[str] From ffe8a6e4b2b6dce50221f9f50cf0b78f46fdd24e Mon Sep 17 00:00:00 2001 From: Rob Porter Date: Mon, 7 Jul 2025 21:31:28 -0700 Subject: [PATCH 11/24] Update fasta2a/fasta2a/schema.py Co-authored-by: Marcelo Trylesinski --- fasta2a/fasta2a/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 5bfe19daa..048c5b1fc 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -248,7 +248,7 @@ class Message(TypedDict): reference_task_ids: NotRequired[list[str]] """Array of task IDs this message references.""" - extensions: NotRequired[list[Any]] + extensions: NotRequired[list[str]] """Array of extensions.""" From 5fd0dde17b8f4ab138e693e3132d86fcfe314b91 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 8 Jul 2025 04:17:15 +0000 Subject: [PATCH 12/24] Update docs --- docs/a2a.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/a2a.md b/docs/a2a.md index 8646864d7..6ec1be4d3 100644 --- a/docs/a2a.md +++ b/docs/a2a.md @@ -73,7 +73,7 @@ The [`Storage`][fasta2a.Storage] component serves two purposes: 1. **Task Storage**: Stores tasks in A2A protocol format, including their status, artifacts, and message history 2. **Context Storage**: Stores conversation context in a format optimized for the specific agent implementation -This dual-purpose design allows flexibility for agents to store rich internal state (e.g., tool calls, reasoning traces) while maintaining efficient conversation continuity across multiple task executions. +This design allows for agents to store rich internal state (e.g., tool calls, reasoning traces) as well as store task-specific A2A-formatted messages and artifacts. For example, a PydanticAI agent might store its complete internal message format (including tool calls and responses) in the context storage, while storing only the A2A-compliant messages in the task history. From 872714e53a9a607a5223fe7cf302edba1e5e4199 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 8 Jul 2025 04:30:22 +0000 Subject: [PATCH 13/24] Remove deprecation exception --- fasta2a/fasta2a/applications.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index fe996de75..318a16b8e 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -123,9 +123,6 @@ async def _agent_run_endpoint(self, request: Request) -> Response: raise NotImplementedError( 'message/stream method is not implemented yet. Streaming support will be added in a future update.' ) - elif a2a_request['method'] == 'tasks/send': # type: ignore[comparison-overlap] - # Legacy method - no longer supported - raise NotImplementedError('tasks/send is deprecated. Use message/send instead.') elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) elif a2a_request['method'] == 'tasks/cancel': From 9372d3c00e1d633942c3f88cacb15436723cfd0e Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 8 Jul 2025 04:30:44 +0000 Subject: [PATCH 14/24] Remove deprecation exception --- fasta2a/fasta2a/applications.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 318a16b8e..7b64a40d4 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -21,7 +21,6 @@ a2a_request_ta, a2a_response_ta, agent_card_ta, - send_message_request_ta, ) from .storage import Storage from .task_manager import TaskManager From c5dd525339ca0c421ce27cd5a4083a595e579ee4 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 8 Jul 2025 04:47:29 +0000 Subject: [PATCH 15/24] Change DataPart data type back to dict[str, Any]; Update result artifacts to conform --- fasta2a/fasta2a/schema.py | 2 +- pydantic_ai_slim/pydantic_ai/_a2a.py | 2 +- tests/test_a2a.py | 24 ++++++++++++++---------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 048c5b1fc..debfa6659 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -309,7 +309,7 @@ class DataPart(_BasePart): kind: Literal['data'] """The kind of the part.""" - data: Any + data: dict[str, Any] """The data of the part.""" description: NotRequired[str] diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 5cc6be9bd..987c699e7 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -250,7 +250,7 @@ def _convert_result_to_part(self, result: WorkerOutputT) -> Part: # Last resort - convert to string data = str(result) - return DataPart(kind='data', data=data) + return DataPart(kind='data', data={'result': data}) def _build_result_metadata(self, result: WorkerOutputT) -> dict[str, Any]: """Build metadata for the result artifact. diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 70a4b55c6..532ecfe53 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -83,7 +83,9 @@ async def test_a2a_pydantic_model_output(): # Verify the data assert artifact['parts'][0]['kind'] == 'data' - assert artifact['parts'][0]['data'] == {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'} + assert artifact['parts'][0]['data'] == { + 'result': {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'} + } # Verify metadata assert 'metadata' in artifact @@ -107,7 +109,9 @@ async def test_a2a_pydantic_model_output(): agent_message = result['history'][1] assert agent_message['role'] == 'agent' assert agent_message['parts'][0]['kind'] == 'data' - assert agent_message['parts'][0]['data'] == {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'} + assert agent_message['parts'][0]['data'] == { + 'result': {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'} + } async def test_a2a_runtime_error_without_lifespan(): @@ -182,7 +186,7 @@ async def test_a2a_simple(): }, { 'role': 'agent', - 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], 'kind': 'message', 'message_id': IsStr(), 'context_id': IsStr(), @@ -193,7 +197,7 @@ async def test_a2a_simple(): { 'artifact_id': IsStr(), 'name': 'result', - 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], 'metadata': { 'type': 'tuple', 'json_schema': {'items': {}, 'type': 'array'}, @@ -282,7 +286,7 @@ async def test_a2a_file_message_with_file(): }, { 'role': 'agent', - 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], 'kind': 'message', 'message_id': IsStr(), 'context_id': IsStr(), @@ -293,7 +297,7 @@ async def test_a2a_file_message_with_file(): { 'artifact_id': IsStr(), 'name': 'result', - 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], 'metadata': { 'type': 'tuple', 'json_schema': {'items': {}, 'type': 'array'}, @@ -369,7 +373,7 @@ async def test_a2a_file_message_with_file_content(): }, { 'role': 'agent', - 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], 'kind': 'message', 'message_id': IsStr(), 'context_id': IsStr(), @@ -380,7 +384,7 @@ async def test_a2a_file_message_with_file_content(): { 'artifact_id': IsStr(), 'name': 'result', - 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], 'metadata': { 'type': 'tuple', 'json_schema': {'items': {}, 'type': 'array'}, @@ -679,7 +683,7 @@ async def test_a2a_multiple_messages(): {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, { 'role': 'agent', - 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], 'kind': 'message', 'message_id': IsStr(), 'context_id': IsStr(), @@ -690,7 +694,7 @@ async def test_a2a_multiple_messages(): { 'artifact_id': IsStr(), 'name': 'result', - 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], 'metadata': { 'type': 'tuple', 'json_schema': {'items': {}, 'type': 'array'}, From a1314829a48adc3f458e303cc681b888c80c4989 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 8 Jul 2025 04:51:16 +0000 Subject: [PATCH 16/24] More PR feedback on spec --- fasta2a/fasta2a/schema.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index debfa6659..0daf20599 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -312,9 +312,6 @@ class DataPart(_BasePart): data: dict[str, Any] """The data of the part.""" - description: NotRequired[str] - """A description of the data.""" - Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='kind')] """A fully formed piece of content exchanged between a client and a remote agent as part of a Message or an Artifact. @@ -323,7 +320,7 @@ class DataPart(_BasePart): """ TaskState: TypeAlias = Literal[ - 'submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'rejected', 'auth-required' + 'submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'rejected', 'auth-required', 'unknown' ] """The possible states of a task.""" From bfd305fa0a4981a20fa09262a8907eb51827bb97 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 8 Jul 2025 05:10:53 +0000 Subject: [PATCH 17/24] fix: make DataPart spec-compliant and improve message/artifact separation - Change DataPart.data type from Any to dict[str, Any] per A2A spec - Wrap non-dict agent results as {"result": } for consistency - Remove DataPart.description field (not in spec) - Improve message vs artifact separation: - String outputs appear in both messages and artifacts - Structured data only appears as artifacts (not duplicated in messages) - Update tests to reflect new behavior - Update docs to clarify artifact handling --- docs/a2a.md | 4 +++ pydantic_ai_slim/pydantic_ai/_a2a.py | 17 ++++------ tests/test_a2a.py | 49 ++++------------------------ 3 files changed, 16 insertions(+), 54 deletions(-) diff --git a/docs/a2a.md b/docs/a2a.md index 6ec1be4d3..b96e84d87 100644 --- a/docs/a2a.md +++ b/docs/a2a.md @@ -121,3 +121,7 @@ When using `to_a2a()`, PydanticAI automatically: - Stores the complete conversation history (including tool calls and responses) in the context storage - Ensures that subsequent messages with the same `context_id` have access to the full conversation history +- Persists agent results as A2A artifacts: + - String results become `TextPart` artifacts and also appear in the message history + - Structured data (Pydantic models, dataclasses, tuples, etc.) become `DataPart` artifacts with the data wrapped as `{"result": }` + - Artifacts include metadata with type information and JSON schema when available diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 987c699e7..0d400efeb 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -147,16 +147,11 @@ async def run_task(self, params: TaskSendParams) -> None: if task_history: for a2a_msg in task_history: if a2a_msg['role'] == 'user': - # Convert user message to pydantic-ai format + # Convert user message from A2A format to pydantic-ai format message_history.append(ModelRequest(parts=self._request_parts_from_a2a(a2a_msg['parts']))) result = await self.agent.run(message_history=message_history) # type: ignore - # Create both a message and artifact for the result - # This ensures the complete conversation is preserved in history while - # also marking the output as a durable artifact - message_id = str(uuid.uuid4()) - # Update context with complete message history including new messages # This preserves tool calls, thinking, and all internal state all_messages = result.all_messages() @@ -183,14 +178,14 @@ async def run_task(self, params: TaskSendParams) -> None: ) ) - # Also add the final output as a message if it's not just text - # This ensures structured outputs appear in the message history - if result.output and not isinstance(result.output, str): - output_part = self._convert_result_to_part(result.output) + # Also add the final output as a message if it's a string + # This ensures string outputs appear in the message history for a task + if result.output and isinstance(result.output, str): + message_id = str(uuid.uuid4()) a2a_messages.append( Message( role='agent', - parts=[output_part], + parts=[A2ATextPart(kind='text', text=result.output)], kind='message', message_id=message_id, ) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 532ecfe53..ceb9947d8 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -103,15 +103,10 @@ async def test_a2a_pydantic_model_output(): assert json_schema['properties']['email']['type'] == 'string' assert json_schema['required'] == ['name', 'age', 'email'] - # Check the message history also has the data + # Check the message history - structured outputs don't appear as messages assert 'history' in result - assert len(result['history']) == 2 - agent_message = result['history'][1] - assert agent_message['role'] == 'agent' - assert agent_message['parts'][0]['kind'] == 'data' - assert agent_message['parts'][0]['data'] == { - 'result': {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'} - } + assert len(result['history']) == 1 # Only the user message + assert result['history'][0]['role'] == 'user' async def test_a2a_runtime_error_without_lifespan(): @@ -183,15 +178,7 @@ async def test_a2a_simple(): 'kind': 'message', 'context_id': IsStr(), 'task_id': IsStr(), - }, - { - 'role': 'agent', - 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], - 'kind': 'message', - 'message_id': IsStr(), - 'context_id': IsStr(), - 'task_id': IsStr(), - }, + } ], 'artifacts': [ { @@ -283,15 +270,7 @@ async def test_a2a_file_message_with_file(): 'kind': 'message', 'context_id': IsStr(), 'task_id': IsStr(), - }, - { - 'role': 'agent', - 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], - 'kind': 'message', - 'message_id': IsStr(), - 'context_id': IsStr(), - 'task_id': IsStr(), - }, + } ], 'artifacts': [ { @@ -370,15 +349,7 @@ async def test_a2a_file_message_with_file_content(): 'kind': 'message', 'context_id': IsStr(), 'task_id': IsStr(), - }, - { - 'role': 'agent', - 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], - 'kind': 'message', - 'message_id': IsStr(), - 'context_id': IsStr(), - 'task_id': IsStr(), - }, + } ], 'artifacts': [ { @@ -681,14 +652,6 @@ async def test_a2a_multiple_messages(): 'task_id': IsStr(), }, {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, - { - 'role': 'agent', - 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], - 'kind': 'message', - 'message_id': IsStr(), - 'context_id': IsStr(), - 'task_id': IsStr(), - }, ], 'artifacts': [ { From 2f05b6c59e11767c3ea9e12622acedd5adba71eb Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 8 Jul 2025 05:17:13 +0000 Subject: [PATCH 18/24] Remove is_task/is_message --- fasta2a/fasta2a/schema.py | 12 +----------- tests/test_a2a.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 0daf20599..34ac1bf82 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -7,7 +7,7 @@ import pydantic from pydantic import Discriminator, TypeAdapter from pydantic.alias_generators import to_camel -from typing_extensions import NotRequired, TypeAlias, TypedDict, TypeGuard +from typing_extensions import NotRequired, TypeAlias, TypedDict @pydantic.with_config({'alias_generator': to_camel}) @@ -642,13 +642,3 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): send_message_request_ta: TypeAdapter[SendMessageRequest] = TypeAdapter(SendMessageRequest) send_message_response_ta: TypeAdapter[SendMessageResponse] = TypeAdapter(SendMessageResponse) stream_message_request_ta: TypeAdapter[StreamMessageRequest] = TypeAdapter(StreamMessageRequest) - - -def is_task(response: Task | Message) -> TypeGuard[Task]: - """Type guard to check if a response is a Task.""" - return 'id' in response and 'status' in response and 'context_id' in response and response.get('kind') == 'task' - - -def is_message(response: Task | Message) -> TypeGuard[Message]: - """Type guard to check if a response is a Message.""" - return 'role' in response and 'parts' in response and response.get('kind') == 'message' diff --git a/tests/test_a2a.py b/tests/test_a2a.py index ceb9947d8..417bb71be 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -13,7 +13,7 @@ with try_import() as imports_successful: from fasta2a.client import A2AClient - from fasta2a.schema import DataPart, FilePart, Message, TextPart, is_task + from fasta2a.schema import DataPart, FilePart, Message, TextPart from fasta2a.storage import InMemoryStorage @@ -64,7 +64,7 @@ async def test_a2a_pydantic_model_output(): assert 'error' not in response assert 'result' in response result = response['result'] - assert is_task(result) + assert result['kind'] == 'task' task_id = result['id'] @@ -137,7 +137,7 @@ async def test_a2a_simple(): assert 'error' not in response assert 'result' in response result = response['result'] - assert is_task(result) + assert result['kind'] == 'task' assert result == snapshot( { 'id': IsStr(), @@ -219,7 +219,7 @@ async def test_a2a_file_message_with_file(): assert 'error' not in response assert 'result' in response result = response['result'] - assert is_task(result) + assert result['kind'] == 'task' assert result == snapshot( { 'id': IsStr(), @@ -308,7 +308,7 @@ async def test_a2a_file_message_with_file_content(): assert 'error' not in response assert 'result' in response result = response['result'] - assert is_task(result) + assert result['kind'] == 'task' assert result == snapshot( { 'id': IsStr(), @@ -385,7 +385,7 @@ async def test_a2a_file_message_with_data(): assert 'error' not in response assert 'result' in response result = response['result'] - assert is_task(result) + assert result['kind'] == 'task' assert result == snapshot( { 'id': IsStr(), @@ -453,7 +453,7 @@ def raise_error(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert 'error' not in response assert 'result' in response result = response['result'] - assert is_task(result) + assert result['kind'] == 'task' task_id = result['id'] @@ -493,7 +493,7 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon assert 'error' not in response1 assert 'result' in response1 result1 = response1['result'] - assert is_task(result1) + assert result1['kind'] == 'task' task1_id = result1['id'] context_id = result1['context_id'] @@ -522,7 +522,7 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon assert 'error' not in response2 assert 'result' in response2 result2 = response2['result'] - assert is_task(result2) + assert result2['kind'] == 'task' task2_id = result2['id'] # Verify we got a new task ID but same context ID @@ -600,7 +600,7 @@ async def test_a2a_multiple_messages(): # NOTE: We include the agent history before we start working on the task. assert 'result' in response result = response['result'] - assert is_task(result) + assert result['kind'] == 'task' task_id = result['id'] task = storage.tasks[task_id] assert 'history' in task From a23582a4cd9ec27c905b6d08ced7a8aa8ca6d96b Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 8 Jul 2025 05:40:44 +0000 Subject: [PATCH 19/24] Address PR feedback on task_manager --- fasta2a/fasta2a/storage.py | 9 +++++---- fasta2a/fasta2a/task_manager.py | 7 +++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index 6fcb8205e..5a917efcc 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -2,6 +2,7 @@ from __future__ import annotations as _annotations +import uuid from abc import ABC, abstractmethod from datetime import datetime from typing import Any @@ -25,7 +26,7 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta """ @abstractmethod - async def submit_task(self, task_id: str, context_id: str, message: Message) -> Task: + async def submit_task(self, context_id: str, message: Message) -> Task: """Submit a task to storage.""" @abstractmethod @@ -75,10 +76,10 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta task['history'] = task['history'][-history_length:] return task - async def submit_task(self, task_id: str, context_id: str, message: Message) -> Task: + async def submit_task(self, context_id: str, message: Message) -> Task: """Submit a task to storage.""" - if task_id in self.tasks: - raise ValueError(f'Task {task_id} already exists') + # Generate a unique task ID + task_id = str(uuid.uuid4()) # Add IDs to the message for A2A protocol message['task_id'] = task_id diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index 637a439f9..0c50c43ce 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -114,18 +114,17 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): async def send_message(self, request: SendMessageRequest) -> SendMessageResponse: """Send a message using the A2A v0.2.3 protocol.""" request_id = request['id'] - task_id = str(uuid.uuid4()) message = request['params']['message'] # Use provided context_id or create new one context_id = message.get('context_id') or str(uuid.uuid4()) # Create a new task - task = await self.storage.submit_task(task_id, context_id, message) + task = await self.storage.submit_task(context_id, message) # Prepare params for broker broker_params: TaskSendParams = { - 'id': task_id, + 'id': task['id'], 'context_id': context_id, 'message': message, } @@ -165,7 +164,7 @@ async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) async def stream_message(self, request: StreamMessageRequest) -> None: - """Stream messages using Server-Sent Events. Not implemented.""" + """Stream messages using Server-Sent Events.""" raise NotImplementedError('message/stream method is not implemented yet.') async def set_task_push_notification( From 06b2d88306913afd0a2c3bf799a56fbe9441ab92 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 8 Jul 2025 05:56:35 +0000 Subject: [PATCH 20/24] Update tests for requiring message_id --- tests/test_a2a.py | 85 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 75 insertions(+), 10 deletions(-) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 417bb71be..b8af35594 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -1,3 +1,5 @@ +import uuid + import anyio import httpx import pytest @@ -59,7 +61,12 @@ async def test_a2a_pydantic_model_output(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Get user profile', kind='text')], kind='message') + message = Message( + role='user', + parts=[TextPart(text='Get user profile', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) response = await a2a_client.send_message(message=message) assert 'error' not in response assert 'result' in response @@ -117,7 +124,12 @@ async def test_a2a_runtime_error_without_lifespan(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) with pytest.raises(RuntimeError, match='TaskManager was not properly initialized.'): await a2a_client.send_message(message=message) @@ -132,7 +144,12 @@ async def test_a2a_simple(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) response = await a2a_client.send_message(message=message) assert 'error' not in response assert 'result' in response @@ -149,6 +166,7 @@ async def test_a2a_simple(): 'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), } @@ -176,6 +194,7 @@ async def test_a2a_simple(): 'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), } @@ -214,6 +233,7 @@ async def test_a2a_file_message_with_file(): ) ], kind='message', + message_id=str(uuid.uuid4()), ) response = await a2a_client.send_message(message=message) assert 'error' not in response @@ -236,6 +256,7 @@ async def test_a2a_file_message_with_file(): } ], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), } @@ -268,6 +289,7 @@ async def test_a2a_file_message_with_file(): } ], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), } @@ -303,6 +325,7 @@ async def test_a2a_file_message_with_file_content(): FilePart(file={'data': 'foo', 'mime_type': 'text/plain'}, kind='file'), ], kind='message', + message_id=str(uuid.uuid4()), ) response = await a2a_client.send_message(message=message) assert 'error' not in response @@ -320,6 +343,7 @@ async def test_a2a_file_message_with_file_content(): 'role': 'user', 'parts': [{'kind': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), } @@ -347,6 +371,7 @@ async def test_a2a_file_message_with_file_content(): 'role': 'user', 'parts': [{'kind': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), } @@ -380,6 +405,7 @@ async def test_a2a_file_message_with_data(): role='user', parts=[DataPart(kind='data', data={'foo': 'bar'})], kind='message', + message_id=str(uuid.uuid4()), ) response = await a2a_client.send_message(message=message) assert 'error' not in response @@ -397,6 +423,7 @@ async def test_a2a_file_message_with_data(): 'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), } @@ -424,6 +451,7 @@ async def test_a2a_file_message_with_data(): 'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), } @@ -448,7 +476,12 @@ def raise_error(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) response = await a2a_client.send_message(message=message) assert 'error' not in response assert 'result' in response @@ -488,7 +521,12 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon a2a_client = A2AClient(http_client=http_client) # First message - should create a new context - message1 = Message(role='user', parts=[TextPart(text='First message', kind='text')], kind='message') + message1 = Message( + role='user', + parts=[TextPart(text='First message', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) response1 = await a2a_client.send_message(message=message1) assert 'error' not in response1 assert 'result' in response1 @@ -516,7 +554,11 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon # Second message - reuse the same context_id message2 = Message( - role='user', parts=[TextPart(text='Second message', kind='text')], kind='message', context_id=context_id + role='user', + parts=[TextPart(text='Second message', kind='text')], + kind='message', + context_id=context_id, + message_id=str(uuid.uuid4()), ) response2 = await a2a_client.send_message(message=message2) assert 'error' not in response2 @@ -573,7 +615,12 @@ async def test_a2a_multiple_messages(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) response = await a2a_client.send_message(message=message) assert response == snapshot( { @@ -589,6 +636,7 @@ async def test_a2a_multiple_messages(): 'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), } @@ -605,7 +653,12 @@ async def test_a2a_multiple_messages(): task = storage.tasks[task_id] assert 'history' in task task['history'].append( - Message(role='agent', parts=[TextPart(text='Whats up?', kind='text')], kind='message') + Message( + role='agent', + parts=[TextPart(text='Whats up?', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) ) response = await a2a_client.get_task(task_id) @@ -623,10 +676,16 @@ async def test_a2a_multiple_messages(): 'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), }, - {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, + { + 'role': 'agent', + 'parts': [{'kind': 'text', 'text': 'Whats up?'}], + 'kind': 'message', + 'message_id': IsStr(), + }, ], }, } @@ -648,10 +707,16 @@ async def test_a2a_multiple_messages(): 'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message', + 'message_id': IsStr(), 'context_id': IsStr(), 'task_id': IsStr(), }, - {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, + { + 'role': 'agent', + 'parts': [{'kind': 'text', 'text': 'Whats up?'}], + 'kind': 'message', + 'message_id': IsStr(), + }, ], 'artifacts': [ { From 36fc9e85a20cc2b338f1f51646565da17db0ea96 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 8 Jul 2025 09:06:51 +0200 Subject: [PATCH 21/24] apply my comments --- fasta2a/fasta2a/applications.py | 4 ---- fasta2a/fasta2a/client.py | 2 +- pydantic_ai_slim/pydantic_ai/_a2a.py | 3 +-- tests/{test_fasta2a => fasta2a}/__init__.py | 0 tests/{test_fasta2a => fasta2a}/test_applications.py | 0 5 files changed, 2 insertions(+), 7 deletions(-) rename tests/{test_fasta2a => fasta2a}/__init__.py (100%) rename tests/{test_fasta2a => fasta2a}/test_applications.py (100%) diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 7b64a40d4..9f473efa5 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -118,10 +118,6 @@ async def _agent_run_endpoint(self, request: Request) -> Response: if a2a_request['method'] == 'message/send': jsonrpc_response = await self.task_manager.send_message(a2a_request) - elif a2a_request['method'] == 'message/stream': - raise NotImplementedError( - 'message/stream method is not implemented yet. Streaming support will be added in a future update.' - ) elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) elif a2a_request['method'] == 'tasks/cancel': diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py index dc3449623..cd8449923 100644 --- a/fasta2a/fasta2a/client.py +++ b/fasta2a/fasta2a/client.py @@ -47,7 +47,7 @@ async def send_message( ) -> SendMessageResponse: """Send a message using the A2A protocol. - Returns a JSON-RPC response containing either a result (Task | Message) or an error. + Returns a JSON-RPC response containing either a result (Task) or an error. """ params = MessageSendParams(message=message) if metadata is not None: diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 0d400efeb..810a31b62 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -124,9 +124,8 @@ async def run_task(self, params: TaskSendParams) -> None: task = await self.storage.load_task(params['id']) if task is None: raise ValueError(f'Task {params["id"]} not found') - if 'context_id' not in task: - raise ValueError('Task must have a context_id') + # TODO(Marcelo): Should we lock `run_task` on the `context_id`? # Ensure this task hasn't been run before if task['status']['state'] != 'submitted': raise ValueError(f'Task {params["id"]} has already been processed (state: {task["status"]["state"]})') diff --git a/tests/test_fasta2a/__init__.py b/tests/fasta2a/__init__.py similarity index 100% rename from tests/test_fasta2a/__init__.py rename to tests/fasta2a/__init__.py diff --git a/tests/test_fasta2a/test_applications.py b/tests/fasta2a/test_applications.py similarity index 100% rename from tests/test_fasta2a/test_applications.py rename to tests/fasta2a/test_applications.py From d50cd8e9d21d538877abab27d65da28bbebe443b Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 8 Jul 2025 09:27:53 +0200 Subject: [PATCH 22/24] update types with claude --- fasta2a/fasta2a/applications.py | 16 +-- fasta2a/fasta2a/schema.py | 174 ++++++++++++++++++++++---- pydantic_ai_slim/pydantic_ai/_a2a.py | 8 +- pydantic_ai_slim/pydantic_ai/agent.py | 4 +- tests/test_a2a.py | 2 +- 5 files changed, 163 insertions(+), 41 deletions(-) diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 9f473efa5..987d08ece 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -13,10 +13,9 @@ from .broker import Broker from .schema import ( + AgentCapabilities, AgentCard, - Authentication, - Capabilities, - Provider, + AgentProvider, Skill, a2a_request_ta, a2a_response_ta, @@ -39,7 +38,7 @@ def __init__( url: str = 'http://localhost:8000', version: str = '1.0.0', description: str | None = None, - provider: Provider | None = None, + provider: AgentProvider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, @@ -85,16 +84,17 @@ async def _agent_card_endpoint(self, request: Request) -> Response: if self._agent_card_json_schema is None: agent_card = AgentCard( name=self.name, + description=self.description or 'FastA2A Agent', url=self.url, version=self.version, + protocol_version='0.2.5', skills=self.skills, default_input_modes=self.default_input_modes, default_output_modes=self.default_output_modes, - capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False), - authentication=Authentication(schemes=[]), + capabilities=AgentCapabilities( + streaming=False, push_notifications=False, state_transition_history=False + ), ) - if self.description is not None: - agent_card['description'] = self.description if self.provider is not None: agent_card['provider'] = self.provider self._agent_card_json_schema = agent_card_ta.dump_json(agent_card, by_alias=True) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 34ac1bf82..80adabcea 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -17,35 +17,45 @@ class AgentCard(TypedDict): name: str """Human readable name of the agent e.g. "Recipe Agent".""" - description: NotRequired[str] + description: str """A human-readable description of the agent. Used to assist users and other agents in understanding what the agent can do. (e.g. "Agent that helps users with recipes and cooking.") """ - # TODO(Marcelo): The spec makes url required. - url: NotRequired[str] + url: str """A URL to the address the agent is hosted at.""" - provider: NotRequired[Provider] - """The service provider of the agent.""" - - # TODO(Marcelo): The spec makes version required. - version: NotRequired[str] + version: str """The version of the agent - format is up to the provider. (e.g. "1.0.0")""" + protocol_version: str + """The version of the A2A protocol this agent supports.""" + + provider: NotRequired[AgentProvider] + """The service provider of the agent.""" + documentation_url: NotRequired[str] """A URL to documentation for the agent.""" - capabilities: Capabilities + icon_url: NotRequired[str] + """A URL to an icon for the agent.""" + + preferred_transport: NotRequired[str] + """The transport of the preferred endpoint. If empty, defaults to JSONRPC.""" + + additional_interfaces: NotRequired[list[AgentInterface]] + """Announcement of additional supported transports.""" + + capabilities: AgentCapabilities """The capabilities of the agent.""" - authentication: Authentication - """The authentication schemes supported by the agent. + security: NotRequired[list[dict[str, list[str]]]] + """Security requirements for contacting the agent.""" - Intended to match OpenAPI authentication structure. - """ + security_schemes: NotRequired[dict[str, SecurityScheme]] + """Security scheme definitions.""" default_input_modes: list[str] """Supported mime types for input data.""" @@ -59,7 +69,7 @@ class AgentCard(TypedDict): agent_card_ta = pydantic.TypeAdapter(AgentCard) -class Provider(TypedDict): +class AgentProvider(TypedDict): """The service provider of the agent.""" organization: str @@ -67,7 +77,7 @@ class Provider(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) -class Capabilities(TypedDict): +class AgentCapabilities(TypedDict): """The capabilities of the agent.""" streaming: NotRequired[bool] @@ -81,14 +91,89 @@ class Capabilities(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) -class Authentication(TypedDict): - """The authentication schemes supported by the agent.""" +class HttpSecurityScheme(TypedDict): + """HTTP security scheme.""" + + type: Literal['http'] + scheme: str + """The name of the HTTP Authorization scheme.""" + bearer_format: NotRequired[str] + """A hint to the client to identify how the bearer token is formatted.""" + description: NotRequired[str] + """Description of this security scheme.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class ApiKeySecurityScheme(TypedDict): + """API Key security scheme.""" + + type: Literal['apiKey'] + name: str + """The name of the header, query or cookie parameter to be used.""" + in_: Literal['query', 'header', 'cookie'] + """The location of the API key.""" + description: NotRequired[str] + """Description of this security scheme.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class OAuth2SecurityScheme(TypedDict): + """OAuth2 security scheme.""" + + type: Literal['oauth2'] + flows: dict[str, Any] + """An object containing configuration information for the flow types supported.""" + description: NotRequired[str] + """Description of this security scheme.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class OpenIdConnectSecurityScheme(TypedDict): + """OpenID Connect security scheme.""" + + type: Literal['openIdConnect'] + open_id_connect_url: str + """OpenId Connect URL to discover OAuth2 configuration values.""" + description: NotRequired[str] + """Description of this security scheme.""" + + +SecurityScheme = Annotated[ + Union[HttpSecurityScheme, ApiKeySecurityScheme, OAuth2SecurityScheme, OpenIdConnectSecurityScheme], + pydantic.Field(discriminator='type'), +] +"""A security scheme for authentication.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class AgentInterface(TypedDict): + """An interface that the agent supports.""" - schemes: list[str] - """The authentication schemes supported by the agent. (e.g. "Basic", "Bearer")""" + transport: str + """The transport protocol (e.g., 'jsonrpc', 'websocket').""" - credentials: NotRequired[str] - """The credentials a client should use for private cards.""" + url: str + """The URL endpoint for this transport.""" + + description: NotRequired[str] + """Description of this interface.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class AgentExtension(TypedDict): + """A declaration of an extension supported by an Agent.""" + + uri: str + """The URI of the extension.""" + + description: NotRequired[str] + """A description of how this agent uses this extension.""" + + required: NotRequired[bool] + """Whether the client must follow specific requirements of the extension.""" + + params: NotRequired[dict[str, Any]] + """Optional configuration for the extension.""" @pydantic.with_config({'alias_generator': to_camel}) @@ -195,7 +280,7 @@ class PushNotificationConfig(TypedDict): token: NotRequired[str] """Token unique to this task/session.""" - authentication: NotRequired[Authentication] + authentication: NotRequired[SecurityScheme] """Authentication details for push notifications.""" @@ -273,11 +358,11 @@ class TextPart(_BasePart): class FileWithBytes(TypedDict): """File with base64 encoded data.""" - data: str - """The base64 encoded data.""" + bytes: str + """The base64 encoded content of the file.""" - mime_type: str - """The mime type of the file.""" + mime_type: NotRequired[str] + """Optional mime type for the file.""" @pydantic.with_config({'alias_generator': to_camel}) @@ -489,6 +574,31 @@ class TaskSendParams(TypedDict): """Extension metadata.""" +@pydantic.with_config({'alias_generator': to_camel}) +class ListTaskPushNotificationConfigParams(TypedDict): + """Parameters for getting list of pushNotificationConfigurations associated with a Task.""" + + id: str + """Task id.""" + + metadata: NotRequired[dict[str, Any]] + """Extension metadata.""" + + +@pydantic.with_config({'alias_generator': to_camel}) +class DeleteTaskPushNotificationConfigParams(TypedDict): + """Parameters for removing pushNotificationConfiguration associated with a Task.""" + + id: str + """Task id.""" + + push_notification_config_id: str + """The push notification config id to delete.""" + + metadata: NotRequired[dict[str, Any]] + """Extension metadata.""" + + class JSONRPCMessage(TypedDict): """A JSON RPC message.""" @@ -613,6 +723,16 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): ResubscribeTaskRequest = JSONRPCRequest[Literal['tasks/resubscribe'], TaskIdParams] """A JSON RPC request to resubscribe to a task.""" +ListTaskPushNotificationConfigRequest = JSONRPCRequest[ + Literal['tasks/pushNotificationConfig/list'], ListTaskPushNotificationConfigParams +] +"""A JSON RPC request to list task push notification configs.""" + +DeleteTaskPushNotificationConfigRequest = JSONRPCRequest[ + Literal['tasks/pushNotificationConfig/delete'], DeleteTaskPushNotificationConfigParams +] +"""A JSON RPC request to delete a task push notification config.""" + A2ARequest = Annotated[ Union[ SendMessageRequest, @@ -622,6 +742,8 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): SetTaskPushNotificationRequest, GetTaskPushNotificationRequest, ResubscribeTaskRequest, + ListTaskPushNotificationConfigRequest, + DeleteTaskPushNotificationConfigRequest, ], Discriminator('method'), ] diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 810a31b62..58d057b10 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -40,11 +40,11 @@ from fasta2a.applications import FastA2A from fasta2a.broker import Broker, InMemoryBroker from fasta2a.schema import ( + AgentProvider, Artifact, DataPart, Message, Part, - Provider, Skill, TaskIdParams, TaskSendParams, @@ -80,7 +80,7 @@ def agent_to_a2a( url: str = 'http://localhost:8000', version: str = '1.0.0', description: str | None = None, - provider: Provider | None = None, + provider: AgentProvider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, @@ -294,8 +294,8 @@ def _request_parts_from_a2a(self, parts: list[Part]) -> list[ModelRequestPart]: model_parts.append(UserPromptPart(content=part['text'])) elif part['kind'] == 'file': file_content = part['file'] - if 'data' in file_content: - data = file_content['data'].encode('utf-8') + if 'bytes' in file_content: + data = file_content['bytes'].encode('utf-8') mime_type = file_content.get('mime_type', 'application/octet-stream') content = BinaryContent(data=data, media_type=mime_type) model_parts.append(UserPromptPart(content=[content])) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 6b8a5a5a6..2c9eb3a3d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -63,7 +63,7 @@ from fasta2a.applications import FastA2A from fasta2a.broker import Broker - from fasta2a.schema import Provider, Skill + from fasta2a.schema import AgentProvider, Skill from fasta2a.storage import Storage from pydantic_ai.mcp import MCPServer @@ -1764,7 +1764,7 @@ def to_a2a( url: str = 'http://localhost:8000', version: str = '1.0.0', description: str | None = None, - provider: Provider | None = None, + provider: AgentProvider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, diff --git a/tests/test_a2a.py b/tests/test_a2a.py index b8af35594..8126aab0c 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -322,7 +322,7 @@ async def test_a2a_file_message_with_file_content(): message = Message( role='user', parts=[ - FilePart(file={'data': 'foo', 'mime_type': 'text/plain'}, kind='file'), + FilePart(file={'bytes': 'foo', 'mime_type': 'text/plain'}, kind='file'), ], kind='message', message_id=str(uuid.uuid4()), From d24fccac3e6cf8b31db498a52a7846397f55df25 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 8 Jul 2025 09:50:13 +0200 Subject: [PATCH 23/24] update tests --- tests/fasta2a/test_applications.py | 3 ++- tests/test_a2a.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/fasta2a/test_applications.py b/tests/fasta2a/test_applications.py index 3b3aa437d..955933151 100644 --- a/tests/fasta2a/test_applications.py +++ b/tests/fasta2a/test_applications.py @@ -34,12 +34,13 @@ async def test_agent_card(): assert response.json() == snapshot( { 'name': 'Agent', + 'description': 'FastA2A Agent', 'url': 'http://localhost:8000', 'version': '1.0.0', + 'protocolVersion': '0.2.5', 'skills': [], 'defaultInputModes': ['application/json'], 'defaultOutputModes': ['application/json'], 'capabilities': {'streaming': False, 'pushNotifications': False, 'stateTransitionHistory': False}, - 'authentication': {'schemes': []}, } ) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 8126aab0c..32ddb6b09 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -341,7 +341,7 @@ async def test_a2a_file_message_with_file_content(): 'history': [ { 'role': 'user', - 'parts': [{'kind': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], + 'parts': [{'kind': 'file', 'file': {'bytes': 'foo', 'mime_type': 'text/plain'}}], 'kind': 'message', 'message_id': IsStr(), 'context_id': IsStr(), @@ -369,7 +369,7 @@ async def test_a2a_file_message_with_file_content(): 'history': [ { 'role': 'user', - 'parts': [{'kind': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], + 'parts': [{'kind': 'file', 'file': {'bytes': 'foo', 'mime_type': 'text/plain'}}], 'kind': 'message', 'message_id': IsStr(), 'context_id': IsStr(), From 270872c2d794915e601a11a5e4e4eaaf6bccc2bb Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 8 Jul 2025 12:47:30 +0200 Subject: [PATCH 24/24] update code --- fasta2a/fasta2a/storage.py | 32 ++++---- fasta2a/fasta2a/task_manager.py | 14 +--- fasta2a/fasta2a/worker.py | 9 ++- pydantic_ai_slim/pydantic_ai/_a2a.py | 115 ++++++--------------------- tests/test_a2a.py | 96 ++++++++++++---------- 5 files changed, 107 insertions(+), 159 deletions(-) diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index 5a917efcc..58934c1c5 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -5,12 +5,16 @@ import uuid from abc import ABC, abstractmethod from datetime import datetime -from typing import Any +from typing import Any, Generic + +from typing_extensions import TypeVar from .schema import Artifact, Message, Task, TaskState, TaskStatus +ContextT = TypeVar('ContextT', default=Any) + -class Storage(ABC): +class Storage(ABC, Generic[ContextT]): """A storage to retrieve and save tasks, as well as retrieve and save context. The storage serves two purposes: @@ -40,23 +44,23 @@ async def update_task( """Update the state of a task. Appends artifacts and messages, if specified.""" @abstractmethod - async def update_context(self, context_id: str, context: Any) -> None: - """Updates the context for a context_id. + async def load_context(self, context_id: str) -> ContextT | None: + """Retrieve the stored context given the `context_id`.""" + + @abstractmethod + async def update_context(self, context_id: str, context: ContextT) -> None: + """Updates the context for a `context_id`. Implementing agent can decide what to store in context. """ - @abstractmethod - async def get_context(self, context_id: str) -> Any: - """Retrieve the stored context for a context_id.""" - -class InMemoryStorage(Storage): +class InMemoryStorage(Storage[ContextT]): """A storage to retrieve and save tasks in memory.""" def __init__(self): self.tasks: dict[str, Task] = {} - self.contexts: dict[str, Any] = {} + self.contexts: dict[str, ContextT] = {} async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: """Load a task from memory. @@ -118,10 +122,10 @@ async def update_task( return task - async def update_context(self, context_id: str, context: Any) -> None: - """Updates the context for a context_id.""" + async def update_context(self, context_id: str, context: ContextT) -> None: + """Updates the context given the `context_id`.""" self.contexts[context_id] = context - async def get_context(self, context_id: str) -> Any: - """Retrieve the stored context for a context_id.""" + async def load_context(self, context_id: str) -> ContextT | None: + """Retrieve the stored context given the `context_id`.""" return self.contexts.get(context_id) diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index 0c50c43ce..6845d78b1 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -90,7 +90,7 @@ class TaskManager: """A task manager responsible for managing tasks.""" broker: Broker - storage: Storage + storage: Storage[Any] _aexit_stack: AsyncExitStack | None = field(default=None, init=False) @@ -115,19 +115,11 @@ async def send_message(self, request: SendMessageRequest) -> SendMessageResponse """Send a message using the A2A v0.2.3 protocol.""" request_id = request['id'] message = request['params']['message'] + context_id = message.get('context_id', str(uuid.uuid4())) - # Use provided context_id or create new one - context_id = message.get('context_id') or str(uuid.uuid4()) - - # Create a new task task = await self.storage.submit_task(context_id, message) - # Prepare params for broker - broker_params: TaskSendParams = { - 'id': task['id'], - 'context_id': context_id, - 'message': message, - } + broker_params: TaskSendParams = {'id': task['id'], 'context_id': context_id, 'message': message} config = request['params'].get('configuration', {}) history_length = config.get('history_length') if history_length is not None: diff --git a/fasta2a/fasta2a/worker.py b/fasta2a/fasta2a/worker.py index 34a0cf565..bcb017276 100644 --- a/fasta2a/fasta2a/worker.py +++ b/fasta2a/fasta2a/worker.py @@ -4,26 +4,27 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic import anyio from opentelemetry.trace import get_tracer, use_span from typing_extensions import assert_never +from .storage import ContextT, Storage + if TYPE_CHECKING: from .broker import Broker, TaskOperation from .schema import Artifact, Message, TaskIdParams, TaskSendParams - from .storage import Storage tracer = get_tracer(__name__) @dataclass -class Worker(ABC): +class Worker(ABC, Generic[ContextT]): """A worker is responsible for executing tasks.""" broker: Broker - storage: Storage + storage: Storage[ContextT] @asynccontextmanager async def run(self) -> AsyncIterator[None]: diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 58d057b10..a46dd69f3 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -3,7 +3,7 @@ import uuid from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager -from dataclasses import asdict, dataclass, is_dataclass +from dataclasses import dataclass from functools import partial from typing import Any, Generic, TypeVar @@ -114,8 +114,7 @@ def agent_to_a2a( @dataclass -# Generic parameters are reversed compared to Agent because AgentDepsT has a default -class AgentWorker(Worker, Generic[WorkerOutputT, AgentDepsT]): +class AgentWorker(Worker[list[ModelMessage]], Generic[WorkerOutputT, AgentDepsT]): """A worker that uses an agent to execute tasks.""" agent: Agent[AgentDepsT, WorkerOutputT] @@ -130,82 +129,50 @@ async def run_task(self, params: TaskSendParams) -> None: if task['status']['state'] != 'submitted': raise ValueError(f'Task {params["id"]} has already been processed (state: {task["status"]["state"]})') - task_id = task['id'] - context_id = task['context_id'] + await self.storage.update_task(task['id'], state='working') - try: - await self.storage.update_task(task_id, state='working') - - # Load context - contains pydantic-ai message history from previous tasks in this conversation - context = await self.storage.get_context(context_id) - message_history: list[ModelMessage] = context if context else [] + # Load context - contains pydantic-ai message history from previous tasks in this conversation + message_history = await self.storage.load_context(task['context_id']) + if message_history is None: + message_history = [] - # Add the current task's initial message to the history + try: # Tasks start with a user message that triggered this task + # Add the current task's initial message to the history task_history = task.get('history') if task_history: for a2a_msg in task_history: if a2a_msg['role'] == 'user': - # Convert user message from A2A format to pydantic-ai format message_history.append(ModelRequest(parts=self._request_parts_from_a2a(a2a_msg['parts']))) result = await self.agent.run(message_history=message_history) # type: ignore - # Update context with complete message history including new messages - # This preserves tool calls, thinking, and all internal state - all_messages = result.all_messages() - await self.storage.update_context(context_id, all_messages) + await self.storage.update_context(task['context_id'], result.all_messages()) # Convert new messages to A2A format for task history - new_messages = result.new_messages() a2a_messages: list[Message] = [] - for msg in new_messages: - if isinstance(msg, ModelRequest): + for message in result.new_messages(): + if isinstance(message, ModelRequest): # Skip user prompts - they're already in task history continue - elif isinstance(msg, ModelResponse): + elif isinstance(message, ModelResponse): # Convert response parts to A2A format - a2a_parts = self._response_parts_to_a2a(msg.parts) + a2a_parts = self._response_parts_to_a2a(message.parts) if a2a_parts: # Add if there are visible parts (text/thinking) a2a_messages.append( - Message( - role='agent', - parts=a2a_parts, - kind='message', - message_id=str(uuid.uuid4()), - ) + Message(role='agent', parts=a2a_parts, kind='message', message_id=str(uuid.uuid4())) ) - # Also add the final output as a message if it's a string - # This ensures string outputs appear in the message history for a task - if result.output and isinstance(result.output, str): - message_id = str(uuid.uuid4()) - a2a_messages.append( - Message( - role='agent', - parts=[A2ATextPart(kind='text', text=result.output)], - kind='message', - message_id=message_id, - ) - ) - - # Create artifacts for durable outputs artifacts = self.build_artifacts(result.output) - - # Update task with completion status, new A2A messages, and artifacts + except Exception: + await self.storage.update_task(task['id'], state='failed') + raise + else: await self.storage.update_task( - task_id, - state='completed', - new_artifacts=artifacts, - new_messages=a2a_messages if a2a_messages else None, + task['id'], state='completed', new_artifacts=artifacts, new_messages=a2a_messages ) - except Exception: - # Ensure task is marked as failed on any error - await self.storage.update_task(task_id, state='failed') - raise # Re-raise to maintain error visibility - async def cancel_task(self, params: TaskIdParams) -> None: pass @@ -218,8 +185,7 @@ def build_artifacts(self, result: WorkerOutputT) -> list[Artifact]: """ artifact_id = str(uuid.uuid4()) part = self._convert_result_to_part(result) - metadata = self._build_result_metadata(result) - return [Artifact(artifact_id=artifact_id, name='result', parts=[part], metadata=metadata)] + return [Artifact(artifact_id=artifact_id, name='result', parts=[part])] def _convert_result_to_part(self, result: WorkerOutputT) -> Part: """Convert agent result to a Part (TextPart or DataPart). @@ -230,42 +196,11 @@ def _convert_result_to_part(self, result: WorkerOutputT) -> Part: if isinstance(result, str): return A2ATextPart(kind='text', text=result) else: - # For structured data, create a DataPart - try: - # Try using TypeAdapter for proper serialization - output_type = type(result) - type_adapter: TypeAdapter[WorkerOutputT] = TypeAdapter(output_type) - data = type_adapter.dump_python(result, mode='json') - except Exception: - # Fallback for types that TypeAdapter can't handle - if is_dataclass(result) and not isinstance(result, type): - data = asdict(result) - else: - # Last resort - convert to string - data = str(result) - - return DataPart(kind='data', data={'result': data}) - - def _build_result_metadata(self, result: WorkerOutputT) -> dict[str, Any]: - """Build metadata for the result artifact. - - Captures type information and JSON schema when available. - """ - metadata: dict[str, Any] = { - 'type': type(result).__name__, - } - - # For non-string types, attempt to capture JSON schema - if not isinstance(result, str): output_type = type(result) - type_adapter: TypeAdapter[WorkerOutputT] = TypeAdapter(output_type) - try: - metadata['json_schema'] = type_adapter.json_schema() - except Exception: - # Some types don't support JSON schema generation - pass - - return metadata + type_adapter = TypeAdapter(output_type) + data = type_adapter.dump_python(result, mode='json') + json_schema = type_adapter.json_schema(mode='serialization') + return DataPart(kind='data', data={'result': data}, metadata={'json_schema': json_schema}) def build_message_history(self, history: list[Message]) -> list[ModelMessage]: model_messages: list[ModelMessage] = [] diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 32ddb6b09..a3f28d937 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -94,26 +94,34 @@ async def test_a2a_pydantic_model_output(): 'result': {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'} } - # Verify metadata - assert 'metadata' in artifact - metadata = artifact['metadata'] - assert metadata['type'] == 'UserProfile' - - # Verify JSON schema is present and correct - assert 'json_schema' in metadata - json_schema = metadata['json_schema'] - assert json_schema['type'] == 'object' - assert 'properties' in json_schema - assert set(json_schema['properties'].keys()) == {'name', 'age', 'email'} - assert json_schema['properties']['name']['type'] == 'string' - assert json_schema['properties']['age']['type'] == 'integer' - assert json_schema['properties']['email']['type'] == 'string' - assert json_schema['required'] == ['name', 'age', 'email'] - - # Check the message history - structured outputs don't appear as messages - assert 'history' in result - assert len(result['history']) == 1 # Only the user message - assert result['history'][0]['role'] == 'user' + metadata = artifact['parts'][0].get('metadata') + assert metadata is not None + + assert metadata['json_schema'] == snapshot( + { + 'properties': { + 'name': {'title': 'Name', 'type': 'string'}, + 'age': {'title': 'Age', 'type': 'integer'}, + 'email': {'title': 'Email', 'type': 'string'}, + }, + 'required': ['name', 'age', 'email'], + 'title': 'UserProfile', + 'type': 'object', + } + ) + + assert result.get('history') == snapshot( + [ + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Get user profile'}], + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + } + ] + ) async def test_a2a_runtime_error_without_lifespan(): @@ -203,11 +211,13 @@ async def test_a2a_simple(): { 'artifact_id': IsStr(), 'name': 'result', - 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], - 'metadata': { - 'type': 'tuple', - 'json_schema': {'items': {}, 'type': 'array'}, - }, + 'parts': [ + { + 'metadata': {'json_schema': {'items': {}, 'type': 'array'}}, + 'kind': 'data', + 'data': {'result': ['foo', 'bar']}, + } + ], } ], }, @@ -298,11 +308,13 @@ async def test_a2a_file_message_with_file(): { 'artifact_id': IsStr(), 'name': 'result', - 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], - 'metadata': { - 'type': 'tuple', - 'json_schema': {'items': {}, 'type': 'array'}, - }, + 'parts': [ + { + 'metadata': {'json_schema': {'items': {}, 'type': 'array'}}, + 'kind': 'data', + 'data': {'result': ['foo', 'bar']}, + } + ], } ], }, @@ -380,11 +392,13 @@ async def test_a2a_file_message_with_file_content(): { 'artifact_id': IsStr(), 'name': 'result', - 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], - 'metadata': { - 'type': 'tuple', - 'json_schema': {'items': {}, 'type': 'array'}, - }, + 'parts': [ + { + 'metadata': {'json_schema': {'items': {}, 'type': 'array'}}, + 'kind': 'data', + 'data': {'result': ['foo', 'bar']}, + } + ], } ], }, @@ -722,11 +736,13 @@ async def test_a2a_multiple_messages(): { 'artifact_id': IsStr(), 'name': 'result', - 'parts': [{'kind': 'data', 'data': {'result': ['foo', 'bar']}}], - 'metadata': { - 'type': 'tuple', - 'json_schema': {'items': {}, 'type': 'array'}, - }, + 'parts': [ + { + 'metadata': {'json_schema': {'items': {}, 'type': 'array'}}, + 'kind': 'data', + 'data': {'result': ['foo', 'bar']}, + } + ], } ], },