From eee8dc111a7d596dd332ce8c0a2dccfa250c51dc Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Sun, 29 Jun 2025 06:09:37 +0000 Subject: [PATCH 01/15] add metadata support to tasks --- fasta2a/fasta2a/storage.py | 11 +++++++++-- fasta2a/fasta2a/task_manager.py | 3 ++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index c06bc1cb7..b9b1ebc9a 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from datetime import datetime +from typing import Any from .schema import Artifact, Message, Task, TaskState, TaskStatus @@ -22,7 +23,9 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta """ @abstractmethod - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: + async def submit_task( + self, task_id: str, session_id: str, message: Message, metadata: dict[str, Any] | None = None + ) -> Task: """Submit a task to storage.""" @abstractmethod @@ -60,13 +63,17 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta task['history'] = task['history'][-history_length:] return task - async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: + async def submit_task( + self, task_id: str, session_id: str, message: Message, metadata: dict[str, Any] | None = None + ) -> Task: """Submit a task to storage.""" if task_id in self.tasks: raise ValueError(f'Task {task_id} already exists') task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) task = Task(id=task_id, session_id=session_id, status=task_status, history=[message]) + if metadata is not None: + task['metadata'] = metadata self.tasks[task_id] = task return task diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index 0baaeba04..46c398a99 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -120,7 +120,8 @@ async def send_task(self, request: SendTaskRequest) -> SendTaskResponse: if task is None: session_id = request['params'].get('session_id', str(uuid.uuid4())) message = request['params']['message'] - task = await self.storage.submit_task(task_id, session_id, message) + metadata = request['params'].get('metadata') + task = await self.storage.submit_task(task_id, session_id, message, metadata) await self.broker.run_task(request['params']) return SendTaskResponse(jsonrpc='2.0', id=request_id, result=task) From 93684eae7343b1fea6335cf69450e51fa71a0de1 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Sun, 29 Jun 2025 06:10:24 +0000 Subject: [PATCH 02/15] feat: add deps_factory support to Agent.to_a2a() Enable pydantic-ai agents that use dependency injection to work with the A2A protocol by adding a deps_factory parameter to to_a2a(). Changes: - Add deps_factory parameter to Agent.to_a2a() method - Update AgentWorker to use deps_factory when creating dependencies - Add comprehensive tests for the feature - Add bank_support_a2a example showing real-world usage - Update documentation with examples The deps_factory receives the A2A Task object and returns dependencies matching the agent's deps_type. This enables agents to get context from task metadata (e.g., user authentication, database connections). --- docs/a2a.md | 41 ++++++ docs/examples/bank-support-a2a.md | 39 ++++++ .../pydantic_ai_examples/bank_support_a2a.py | 73 ++++++++++ mkdocs.yml | 1 + pydantic_ai_slim/pydantic_ai/_a2a.py | 15 +- pydantic_ai_slim/pydantic_ai/agent.py | 23 +++- tests/test_a2a_deps.py | 129 ++++++++++++++++++ 7 files changed, 316 insertions(+), 5 deletions(-) create mode 100644 docs/examples/bank-support-a2a.md create mode 100644 examples/pydantic_ai_examples/bank_support_a2a.py create mode 100644 tests/test_a2a_deps.py diff --git a/docs/a2a.md b/docs/a2a.md index 28f7093fd..6e90fdbc3 100644 --- a/docs/a2a.md +++ b/docs/a2a.md @@ -93,4 +93,45 @@ Since `app` is an ASGI application, it can be used with any ASGI server. uvicorn agent_to_a2a:app --host 0.0.0.0 --port 8000 ``` +#### Using Agents with Dependencies + +If your agent uses [dependencies](../agents.md#dependencies), you can provide a `deps_factory` function that creates dependencies from the A2A task metadata: + +```python {title="agent_with_deps_to_a2a.py"} +from dataclasses import dataclass +from pydantic_ai import Agent, RunContext + +@dataclass +class SupportDeps: + customer_id: int + +support_agent = Agent( + 'openai:gpt-4.1', + deps_type=SupportDeps, + instructions='You are a support agent.', +) + +@support_agent.system_prompt +def add_customer_info(ctx: RunContext[SupportDeps]) -> str: + return f'The customer ID is {ctx.deps.customer_id}' + +def create_deps(task): + """Create dependencies from task metadata.""" + metadata = task.get('metadata', {}) + return SupportDeps(customer_id=metadata.get('customer_id', 0)) + +# Create A2A app with deps_factory +app = support_agent.to_a2a(deps_factory=create_deps) +``` + +Now when clients send tasks with metadata, the agent will have access to the dependencies: + +```python {title="client_example.py"} +# Client sends a task with metadata +response = await a2a_client.send_task( + message=message, + metadata={'customer_id': 12345} +) +``` + Since the goal of `to_a2a` is to be a convenience method, it accepts the same arguments as the [`FastA2A`][fasta2a.FastA2A] constructor. diff --git a/docs/examples/bank-support-a2a.md b/docs/examples/bank-support-a2a.md new file mode 100644 index 000000000..46b81d6e9 --- /dev/null +++ b/docs/examples/bank-support-a2a.md @@ -0,0 +1,39 @@ +Example showing how to expose the [bank support agent](bank-support.md) as an A2A server with dependency injection. + +Demonstrates: + +* Converting an existing agent to A2A +* Using `deps_factory` to provide customer context +* Passing metadata through A2A protocol + +## Running the Example + +With [dependencies installed and environment variables set](./index.md#usage), run: + +```bash +# Start the A2A server +uvicorn pydantic_ai_examples.bank_support_a2a:app --reload + +# In another terminal, send a request +curl -X POST http://localhost:8000/tasks.send \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "method": "tasks.send", + "params": { + "id": "test-task-1", + "message": { + "role": "user", + "parts": [{"type": "text", "text": "What is my balance?"}] + }, + "metadata": {"customer_id": 123} + }, + "id": "1" + }' +``` + +## Example Code + +```python {title="bank_support_a2a.py"} +#! examples/pydantic_ai_examples/bank_support_a2a.py +``` \ No newline at end of file diff --git a/examples/pydantic_ai_examples/bank_support_a2a.py b/examples/pydantic_ai_examples/bank_support_a2a.py new file mode 100644 index 000000000..c6558c418 --- /dev/null +++ b/examples/pydantic_ai_examples/bank_support_a2a.py @@ -0,0 +1,73 @@ +"""Bank support agent exposed as an A2A server. + +Shows how to use deps_factory to provide customer context from task metadata. + +Run the server: + python -m pydantic_ai_examples.bank_support_a2a + # or + uvicorn pydantic_ai_examples.bank_support_a2a:app --reload + +Test with curl: + curl -X POST http://localhost:8000/ \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "method": "tasks/send", + "params": { + "id": "test-task-1", + "message": { + "role": "user", + "parts": [{"type": "text", "text": "What is my balance?"}] + }, + "metadata": {"customer_id": 123} + }, + "id": "1" + }' + +Then get the result: + curl -X POST http://localhost:8000/ \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "method": "tasks/get", + "params": {"id": "test-task-1"}, + "id": "2" + }' +""" + +from pydantic_ai_examples.bank_support import ( + DatabaseConn, + SupportDependencies, + support_agent, +) + + +def create_deps(task): + """Create dependencies from A2A task metadata. + + In a real application, you might: + - Validate the customer_id + - Look up authentication from a session token + - Connect to a real database with connection pooling + """ + metadata = task.get('metadata', {}) + customer_id = metadata.get('customer_id', 0) + + # In production, you'd validate the customer exists + # and the request is authorized + return SupportDependencies(customer_id=customer_id, db=DatabaseConn()) + + +# Create the A2A application +app = support_agent.to_a2a( + deps_factory=create_deps, + name='Bank Support Agent', + description='AI support agent for banking customers', +) + + +if __name__ == '__main__': + # For development convenience + import uvicorn + + uvicorn.run(app, host='0.0.0.0', port=8000) diff --git a/mkdocs.yml b/mkdocs.yml index d750c29bb..9e2910266 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -53,6 +53,7 @@ nav: - examples/pydantic-model.md - examples/weather-agent.md - examples/bank-support.md + - examples/bank-support-a2a.md - examples/sql-gen.md - examples/flight-booking.md - examples/rag.md diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 99bbe37ad..45f79afe8 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from functools import partial -from typing import Any, Generic +from typing import Any, Callable, Generic from typing_extensions import assert_never @@ -38,6 +38,7 @@ Part, Provider, Skill, + Task, TaskIdParams, TaskSendParams, TextPart as A2ATextPart, @@ -65,6 +66,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker) -> AsyncIterator[None]: def agent_to_a2a( agent: Agent[AgentDepsT, OutputDataT], *, + deps_factory: Callable[[Task], AgentDepsT] | None = None, storage: Storage | None = None, broker: Broker | None = None, # Agent card @@ -84,7 +86,7 @@ def agent_to_a2a( """Create a FastA2A server from an agent.""" storage = storage or InMemoryStorage() broker = broker or InMemoryBroker() - worker = AgentWorker(agent=agent, broker=broker, storage=storage) + worker = AgentWorker(agent=agent, broker=broker, storage=storage, deps_factory=deps_factory) lifespan = lifespan or partial(worker_lifespan, worker=worker) @@ -110,6 +112,7 @@ class AgentWorker(Worker, Generic[AgentDepsT, OutputDataT]): """A worker that uses an agent to execute tasks.""" agent: Agent[AgentDepsT, OutputDataT] + deps_factory: Callable[[Task], AgentDepsT] | None = None async def run_task(self, params: TaskSendParams) -> None: task = await self.storage.load_task(params['id'], history_length=params.get('history_length')) @@ -124,8 +127,12 @@ async def run_task(self, params: TaskSendParams) -> None: task_history = task.get('history', []) message_history = self.build_message_history(task_history=task_history) - # TODO(Marcelo): We need to make this more customizable e.g. pass deps. - result = await self.agent.run(message_history=message_history) # type: ignore + # Initialize dependencies if factory provided + deps = None + if self.deps_factory is not None: + deps = self.deps_factory(task) + + result = await self.agent.run(message_history=message_history, deps=deps) artifacts = self.build_artifacts(result.output) await self.storage.update_task(task['id'], state='completed', artifacts=artifacts) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index a04ae8646..15cc2ca59 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -62,7 +62,7 @@ from fasta2a.applications import FastA2A from fasta2a.broker import Broker - from fasta2a.schema import Provider, Skill + from fasta2a.schema import Provider, Skill, Task from fasta2a.storage import Storage from pydantic_ai.mcp import MCPServer @@ -1739,6 +1739,7 @@ async def run_mcp_servers( def to_a2a( self, *, + deps_factory: Callable[[Task], AgentDepsT] | None = None, storage: Storage | None = None, broker: Broker | None = None, # Agent card @@ -1772,11 +1773,31 @@ def to_a2a( ```bash uvicorn app:app --host 0.0.0.0 --port 8000 ``` + + Args: + deps_factory: Function that creates agent dependencies from task metadata. + storage: Backend for persisting task state and history. Defaults to in-memory storage. + broker: Message broker for distributing tasks to workers. Defaults to in-memory broker. + name: Display name for this agent in the A2A protocol. + url: Base URL where this agent will be hosted. + version: Version string for this agent. + description: Human-readable description of what this agent does. + provider: Provider metadata for the A2A agent card. + skills: List of capabilities this agent exposes via A2A. + debug: Enable Starlette debug mode with detailed error pages. + routes: Additional Starlette routes to include in the application. + middleware: Additional Starlette middleware to include. + exception_handlers: Custom exception handlers for the application. + lifespan: ASGI lifespan context manager for startup/shutdown logic. + + Returns: + A FastA2A application ready to serve A2A requests. """ from ._a2a import agent_to_a2a return agent_to_a2a( self, + deps_factory=deps_factory, storage=storage, broker=broker, name=name, diff --git a/tests/test_a2a_deps.py b/tests/test_a2a_deps.py new file mode 100644 index 000000000..9a37c0823 --- /dev/null +++ b/tests/test_a2a_deps.py @@ -0,0 +1,129 @@ +"""Test A2A with dependency injection via deps_factory.""" + +import anyio +import httpx +import pytest +from asgi_lifespan import LifespanManager +from dataclasses import dataclass + +from pydantic_ai import Agent, RunContext +from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart, TextPart as TextPartMessage +from pydantic_ai.models.function import AgentInfo, FunctionModel + +from .conftest import try_import + +with try_import() as imports_successful: + from fasta2a.client import A2AClient + from fasta2a.schema import Message, TextPart, Task + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='fasta2a not installed'), + pytest.mark.anyio, +] + + +async def test_a2a_with_deps_factory(): + """Test that deps_factory enables agents with dependencies to work with A2A.""" + + # 1. Define a simple dependency class + @dataclass + class Deps: + user_name: str + multiplier: int = 2 + + # 2. Create a model that returns output based on deps + def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # This function doesn't have access to deps, so it just returns a placeholder + if info.output_tools: + # Return a simple string result using the output tool + args_json = '{"response": "Result computed with deps"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + else: + # No output tools, just return text + return ModelResponse(parts=[TextPartMessage(content='Result computed with deps')]) + + model = FunctionModel(model_function) + agent = Agent(model=model, deps_type=Deps, output_type=str) + + # 3. Add a system prompt that uses deps + @agent.system_prompt + def add_user_info(ctx: RunContext[Deps]) -> str: + return f"The user's name is {ctx.deps.user_name} with multiplier {ctx.deps.multiplier}" + + # 4. Create deps_factory that reads from task metadata + def create_deps(task: Task) -> Deps: + metadata = task.get('metadata', {}) + return Deps(user_name=metadata.get('user_name', 'DefaultUser'), multiplier=metadata.get('multiplier', 2)) + + # 5. Create A2A app with deps_factory + app = agent.to_a2a(deps_factory=create_deps) + + # 6. Test the full flow + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + # Send task with metadata + message = Message(role='user', parts=[TextPart(text='Process this', type='text')]) + response = await a2a_client.send_task(message=message, metadata={'user_name': 'Alice', 'multiplier': 5}) + + assert 'result' in response + task_id = response['result']['id'] + + # Wait for task completion + task = None + for _ in range(10): # Max 10 attempts + task = await a2a_client.get_task(task_id) + if 'result' in task and task['result']['status']['state'] in ('completed', 'failed'): + break + await anyio.sleep(0.1) + + # Verify the result + assert task is not None + if task['result']['status']['state'] == 'failed': + print(f'Task failed. Full task: {task}') + assert task['result']['status']['state'] == 'completed' + assert 'artifacts' in task['result'] + artifacts = task['result']['artifacts'] + assert len(artifacts) == 1 + assert artifacts[0]['parts'][0]['text'] == 'Result computed with deps' + + +async def test_a2a_without_deps_factory(): + """Test that agents without deps still work when no deps_factory is provided.""" + + def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if info.output_tools: + args_json = '{"response": "Hello from agent"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + else: + return ModelResponse(parts=[TextPartMessage(content='Hello from agent')]) + + model = FunctionModel(model_function) + # Agent with no deps_type + agent = Agent(model=model, output_type=str) + + # Create A2A app without deps_factory + app = agent.to_a2a() + + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + message = Message(role='user', parts=[TextPart(text='Hello', type='text')]) + response = await a2a_client.send_task(message=message) + + task_id = response['result']['id'] + + # Wait for completion + task = None + for _ in range(10): + task = await a2a_client.get_task(task_id) + if 'result' in task and task['result']['status']['state'] == 'completed': + break + await anyio.sleep(0.1) + + assert task['result']['status']['state'] == 'completed' + assert task['result']['artifacts'][0]['parts'][0]['text'] == 'Hello from agent' From 810b91a535be13b2f003e1bcef7395246ba3ef21 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Sun, 29 Jun 2025 06:46:06 +0000 Subject: [PATCH 03/15] feat: update A2A schema to v0.2.1 specification - Change Part types to use 'kind' field instead of 'type' - Add new fields to Message type: message_id, context_id, task_id, kind - Add MessageSendConfiguration and MessageSendParams types - Replace tasks/send with message/send and message/stream methods - Update JSON-RPC request/response types for new protocol - Remove deprecated File-related types This is a breaking change that updates fasta2a to comply with the current A2A protocol specification, focusing on clean implementation without backward compatibility. --- fasta2a/fasta2a/schema.py | 134 ++++++++++++++++++++------------------ 1 file changed, 69 insertions(+), 65 deletions(-) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index cab5d2057..079496442 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -204,6 +204,7 @@ class TaskPushNotificationConfig(TypedDict): """The push notification configuration.""" +@pydantic.with_config({'alias_generator': to_camel}) class Message(TypedDict): """A Message contains any content that is not an Artifact. @@ -224,6 +225,19 @@ class Message(TypedDict): metadata: NotRequired[dict[str, Any]] """Metadata about the message.""" + + # Additional fields + message_id: NotRequired[str] + """Identifier created by the message creator.""" + + context_id: NotRequired[str] + """The context the message is associated with.""" + + task_id: NotRequired[str] + """Identifier of task the message is related to.""" + + kind: NotRequired[Literal['message']] + """Event type.""" class _BasePart(TypedDict): @@ -232,11 +246,12 @@ class _BasePart(TypedDict): metadata: NotRequired[dict[str, Any]] +@pydantic.with_config({'alias_generator': to_camel}) class TextPart(_BasePart): """A part that contains text.""" - type: Literal['text'] - """The type of the part.""" + kind: Literal['text'] + """The kind of the part.""" text: str """The text of the part.""" @@ -246,56 +261,36 @@ class TextPart(_BasePart): class FilePart(_BasePart): """A part that contains a file.""" - type: Literal['file'] - """The type of the part.""" - - file: File - """The file of the part.""" - - -@pydantic.with_config({'alias_generator': to_camel}) -class _BaseFile(_BasePart): - """A base class for all file types.""" - - name: NotRequired[str] - """The name of the file.""" - - mime_type: str + kind: Literal['file'] + """The kind of the part.""" + + data: NotRequired[str] + """The base64 encoded data.""" + + mime_type: NotRequired[str] """The mime type of the file.""" + + uri: NotRequired[str] + """The URI of the file.""" -@pydantic.with_config({'alias_generator': to_camel}) -class _BinaryFile(_BaseFile): - """A binary file.""" - - data: str - """The base64 encoded bytes of the file.""" - - -@pydantic.with_config({'alias_generator': to_camel}) -class _URLFile(_BaseFile): - """A file that is hosted on a remote URL.""" - - url: str - """The URL of the file.""" - - -File: TypeAlias = Union[_BinaryFile, _URLFile] -"""A file is a binary file or a URL file.""" @pydantic.with_config({'alias_generator': to_camel}) class DataPart(_BasePart): - """A part that contains data.""" + """A part that contains structured data.""" - type: Literal['data'] - """The type of the part.""" + kind: Literal['data'] + """The kind of the part.""" - data: dict[str, Any] + data: Any """The data of the part.""" + + description: NotRequired[str] + """A description of the data.""" -Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='type')] +Part = Annotated[Union[TextPart, FilePart, DataPart], pydantic.Field(discriminator='kind')] """A fully formed piece of content exchanged between a client and a remote agent as part of a Message or an Artifact. Each Part has its own content type and metadata. @@ -394,24 +389,32 @@ class TaskQueryParams(TaskIdParams): @pydantic.with_config({'alias_generator': to_camel}) -class TaskSendParams(TypedDict): - """Sent by the client to the agent to create, continue, or restart a task.""" - - id: str - """The id of the task.""" - - session_id: NotRequired[str] - """The server creates a new sessionId for new tasks if not set.""" - - message: Message - """The message to send to the agent.""" - +class MessageSendConfiguration(TypedDict): + """Configuration for the send message request.""" + + accepted_output_modes: list[str] + """Accepted output modalities by the client.""" + + blocking: NotRequired[bool] + """If the server should treat the client as a blocking request.""" + history_length: NotRequired[int] """Number of recent messages to be retrieved.""" - - push_notification: NotRequired[PushNotificationConfig] + + push_notification_config: NotRequired[PushNotificationConfig] """Where the server should send notifications when disconnected.""" + +@pydantic.with_config({'alias_generator': to_camel}) +class MessageSendParams(TypedDict): + """Parameters for message/send method.""" + + configuration: NotRequired[MessageSendConfiguration] + """Send message configuration.""" + + message: Message + """The message being sent to the server.""" + metadata: NotRequired[dict[str, Any]] """Extension metadata.""" @@ -501,17 +504,14 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): ####################################### Requests and responses ############################ ############################################################################################### -SendTaskRequest = JSONRPCRequest[Literal['tasks/send'], TaskSendParams] -"""A JSON RPC request to send a task.""" - -SendTaskResponse = JSONRPCResponse[Task, JSONRPCError[Any, Any]] -"""A JSON RPC response to send a task.""" +SendMessageRequest = JSONRPCRequest[Literal['message/send'], MessageSendParams] +"""A JSON RPC request to send a message.""" -SendTaskStreamingRequest = JSONRPCRequest[Literal['tasks/sendSubscribe'], TaskSendParams] -"""A JSON RPC request to send a task and receive updates.""" +SendMessageResponse = JSONRPCResponse[Union[Task, Message], JSONRPCError[Any, Any]] +"""A JSON RPC response to send a message.""" -SendTaskStreamingResponse = JSONRPCResponse[Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent], InternalError] -"""A JSON RPC response to send a task and receive updates.""" +StreamMessageRequest = JSONRPCRequest[Literal['message/stream'], MessageSendParams] +"""A JSON RPC request to stream a message.""" GetTaskRequest = JSONRPCRequest[Literal['tasks/get'], TaskQueryParams] """A JSON RPC request to get a task.""" @@ -542,7 +542,8 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): A2ARequest = Annotated[ Union[ - SendTaskRequest, + SendMessageRequest, + StreamMessageRequest, GetTaskRequest, CancelTaskRequest, SetTaskPushNotificationRequest, @@ -554,7 +555,7 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): """A JSON RPC request to the A2A server.""" A2AResponse: TypeAlias = Union[ - SendTaskResponse, + SendMessageResponse, GetTaskResponse, CancelTaskResponse, SetTaskPushNotificationResponse, @@ -565,3 +566,6 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): a2a_request_ta: TypeAdapter[A2ARequest] = TypeAdapter(A2ARequest) a2a_response_ta: TypeAdapter[A2AResponse] = TypeAdapter(A2AResponse) +send_message_request_ta: TypeAdapter[SendMessageRequest] = TypeAdapter(SendMessageRequest) +send_message_response_ta: TypeAdapter[SendMessageResponse] = TypeAdapter(SendMessageResponse) +stream_message_request_ta: TypeAdapter[StreamMessageRequest] = TypeAdapter(StreamMessageRequest) From 14f637cc9936024d39c422d5e48af696aaa2f61b Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Sun, 29 Jun 2025 06:58:56 +0000 Subject: [PATCH 04/15] feat: update application layer for new A2A protocol - Replace tasks/send with message/send in routing - Add send_message method to TaskManager - Remove deprecated send_task method - Add TODO for message/stream implementation - Update imports for new message types Minimal changes to support the new protocol while maintaining the existing code style and architecture. --- fasta2a/fasta2a/applications.py | 15 ++++++++++--- fasta2a/fasta2a/task_manager.py | 40 ++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 61301262b..25ced357e 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -17,10 +17,14 @@ Authentication, Capabilities, Provider, + SendMessageRequest, + SendMessageResponse, Skill, a2a_request_ta, a2a_response_ta, agent_card_ta, + send_message_request_ta, + send_message_response_ta, ) from .storage import Storage from .task_manager import TaskManager @@ -105,7 +109,7 @@ async def _agent_run_endpoint(self, request: Request) -> Response: Although the specification allows freedom of choice and implementation, I'm pretty sure about some decisions. - 1. The server will always either send a "submitted" or a "failed" on `tasks/send`. + 1. The server will always either send a "submitted" or a "failed" on `message/send`. Never a "completed" on the first message. 2. There are three possible ends for the task: 2.1. The task was "completed" successfully. @@ -116,8 +120,13 @@ async def _agent_run_endpoint(self, request: Request) -> Response: data = await request.body() a2a_request = a2a_request_ta.validate_json(data) - if a2a_request['method'] == 'tasks/send': - jsonrpc_response = await self.task_manager.send_task(a2a_request) + if a2a_request['method'] == 'message/send': + # Handle new message/send method + message_request = send_message_request_ta.validate_json(data) + jsonrpc_response = await self.task_manager.send_message(message_request) + elif a2a_request['method'] == 'message/stream': + # TODO: Implement streaming support + raise NotImplementedError('message/stream not implemented yet') elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) elif a2a_request['method'] == 'tasks/cancel': diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index 46c398a99..ffb718749 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -74,8 +74,8 @@ GetTaskRequest, GetTaskResponse, ResubscribeTaskRequest, - SendTaskRequest, - SendTaskResponse, + SendMessageRequest, + SendMessageResponse, SendTaskStreamingRequest, SendTaskStreamingResponse, SetTaskPushNotificationRequest, @@ -111,20 +111,30 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): await self._aexit_stack.__aexit__(exc_type, exc_value, traceback) self._aexit_stack = None - async def send_task(self, request: SendTaskRequest) -> SendTaskResponse: - """Send a task to the worker.""" - request_id = str(uuid.uuid4()) - task_id = request['params']['id'] - task = await self.storage.load_task(task_id) - - if task is None: - session_id = request['params'].get('session_id', str(uuid.uuid4())) - message = request['params']['message'] - metadata = request['params'].get('metadata') - task = await self.storage.submit_task(task_id, session_id, message, metadata) + async def send_message(self, request: SendMessageRequest) -> SendMessageResponse: + """Send a message using the new protocol.""" + request_id = request['id'] + task_id = str(uuid.uuid4()) + session_id = str(uuid.uuid4()) + message = request['params']['message'] + metadata = request['params'].get('metadata') + config = request['params'].get('configuration', {}) + + # Create a new task + task = await self.storage.submit_task(task_id, session_id, message, metadata) + + # Prepare params for broker (compatible with old format for now) + broker_params = { + 'id': task_id, + 'session_id': session_id, + 'message': message, + 'metadata': metadata, + 'history_length': config.get('history_length') + } + + await self.broker.run_task(broker_params) + return SendMessageResponse(jsonrpc='2.0', id=request_id, result=task) - await self.broker.run_task(request['params']) - return SendTaskResponse(jsonrpc='2.0', id=request_id, result=task) async def get_task(self, request: GetTaskRequest) -> GetTaskResponse: """Get a task, and return it to the client. From ce7cd6a39ff990ae4f7ede14e276b27ffc28cbb8 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Sun, 29 Jun 2025 07:24:21 +0000 Subject: [PATCH 05/15] feat: update worker layer for new A2A protocol - Add TaskSendParams back as internal framework type - Update AgentWorker to use 'kind' field instead of 'type' - Update file part handling for new schema structure - Maintain type safety in worker interface - Preserve deps_factory functionality TaskSendParams is now explicitly marked as an internal type for broker/worker communication, separate from the A2A protocol. --- fasta2a/fasta2a/schema.py | 24 +++++++++++++++ pydantic_ai_slim/pydantic_ai/_a2a.py | 45 ++++++++++++++-------------- 2 files changed, 47 insertions(+), 22 deletions(-) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 079496442..6291ca386 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -419,6 +419,30 @@ class MessageSendParams(TypedDict): """Extension metadata.""" +@pydantic.with_config({'alias_generator': to_camel}) +class TaskSendParams(TypedDict): + """Internal parameters for task execution within the framework. + + Note: This is not part of the A2A protocol - it's used internally + for broker/worker communication. + """ + + id: str + """The id of the task.""" + + session_id: NotRequired[str] + """The session id for the task.""" + + message: Message + """The message to process.""" + + history_length: NotRequired[int] + """Number of recent messages to be retrieved.""" + + metadata: NotRequired[dict[str, Any]] + """Extension metadata.""" + + class JSONRPCMessage(TypedDict): """A JSON RPC message.""" diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 45f79afe8..515aa32bb 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -142,7 +142,7 @@ async def cancel_task(self, params: TaskIdParams) -> None: def build_artifacts(self, result: Any) -> list[Artifact]: # TODO(Marcelo): We need to send the json schema of the result on the metadata of the message. - return [Artifact(name='result', index=0, parts=[A2ATextPart(type='text', text=str(result))])] + return [Artifact(name='result', index=0, parts=[A2ATextPart(kind='text', text=str(result))])] def build_message_history(self, task_history: list[Message]) -> list[ModelMessage]: model_messages: list[ModelMessage] = [] @@ -156,28 +156,29 @@ def build_message_history(self, task_history: list[Message]) -> list[ModelMessag def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: model_parts: list[ModelRequestPart] = [] for part in parts: - if part['type'] == 'text': + if part['kind'] == 'text': model_parts.append(UserPromptPart(content=part['text'])) - elif part['type'] == 'file': - file = part['file'] - if 'data' in file: - data = file['data'].encode('utf-8') - content = BinaryContent(data=data, media_type=file['mime_type']) + elif part['kind'] == 'file': + if 'data' in part: + data = part['data'].encode('utf-8') + mime_type = part.get('mime_type', 'application/octet-stream') + content = BinaryContent(data=data, media_type=mime_type) model_parts.append(UserPromptPart(content=[content])) - else: - url = file['url'] - for url_cls in (DocumentUrl, AudioUrl, ImageUrl, VideoUrl): - content = url_cls(url=url) - try: - content.media_type - except ValueError: # pragma: no cover - continue - else: - break + elif 'uri' in part: + url = part['uri'] + mime_type = part.get('mime_type', '') + if mime_type.startswith('image/'): + content = ImageUrl(url=url) + elif mime_type.startswith('audio/'): + content = AudioUrl(url=url) + elif mime_type.startswith('video/'): + content = VideoUrl(url=url) else: - raise ValueError(f'Unknown file type: {file["mime_type"]}') # pragma: no cover + content = DocumentUrl(url=url) model_parts.append(UserPromptPart(content=[content])) - elif part['type'] == 'data': + else: + raise ValueError('FilePart must have either data or uri') + elif part['kind'] == 'data': # TODO(Marcelo): Maybe we should use this for `ToolReturnPart`, and `RetryPromptPart`. raise NotImplementedError('Data parts are not supported yet.') else: @@ -187,11 +188,11 @@ def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: def _map_response_parts(self, parts: list[Part]) -> list[ModelResponsePart]: model_parts: list[ModelResponsePart] = [] for part in parts: - if part['type'] == 'text': + if part['kind'] == 'text': model_parts.append(TextPart(content=part['text'])) - elif part['type'] == 'file': # pragma: no cover + elif part['kind'] == 'file': # pragma: no cover raise NotImplementedError('File parts are not supported yet.') - elif part['type'] == 'data': # pragma: no cover + elif part['kind'] == 'data': # pragma: no cover raise NotImplementedError('Data parts are not supported yet.') else: # pragma: no cover assert_never(part) From 89e2e4dc0880fbd2989efbd9246a9d27adb09839 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Sun, 29 Jun 2025 07:36:51 +0000 Subject: [PATCH 06/15] feat: update A2A client for new protocol - Replace send_task with send_message method - Update to use MessageSendParams and MessageSendConfiguration - Return Task < /dev/null | Message union type from send_message - Remove client-generated task IDs (now server-generated) - Update imports for new protocol types The client now uses message/send instead of tasks/send and properly handles the response which can be either a Task or Message. --- fasta2a/fasta2a/client.py | 42 +++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py index 5c5aabd81..75839a242 100644 --- a/fasta2a/fasta2a/client.py +++ b/fasta2a/fasta2a/client.py @@ -9,14 +9,17 @@ GetTaskRequest, GetTaskResponse, Message, + MessageSendConfiguration, + MessageSendParams, PushNotificationConfig, - SendTaskRequest, - SendTaskResponse, - TaskSendParams, + SendMessageRequest, + SendMessageResponse, + Task, a2a_request_ta, + send_message_request_ta, + send_message_response_ta, ) -send_task_response_ta = pydantic.TypeAdapter(SendTaskResponse) get_task_response_ta = pydantic.TypeAdapter(GetTaskResponse) try: @@ -37,26 +40,31 @@ def __init__(self, base_url: str = 'http://localhost:8000', http_client: httpx.A self.http_client = http_client self.http_client.base_url = base_url - async def send_task( + async def send_message( self, message: Message, - history_length: int | None = None, - push_notification: PushNotificationConfig | None = None, + *, metadata: dict[str, Any] | None = None, - ) -> SendTaskResponse: - task = TaskSendParams(message=message, id=str(uuid.uuid4())) - if history_length is not None: - task['history_length'] = history_length - if push_notification is not None: - task['push_notification'] = push_notification + configuration: MessageSendConfiguration | None = None, + ) -> Task | Message: + """Send a message using the A2A protocol. + + Returns either a Task (for async operations) or a Message (for quick responses). + """ + params = MessageSendParams(message=message) if metadata is not None: - task['metadata'] = metadata + params['metadata'] = metadata + if configuration is not None: + params['configuration'] = configuration - payload = SendTaskRequest(jsonrpc='2.0', id=None, method='tasks/send', params=task) - content = a2a_request_ta.dump_json(payload, by_alias=True) + request_id = str(uuid.uuid4()) + payload = SendMessageRequest(jsonrpc='2.0', id=request_id, method='message/send', params=params) + content = send_message_request_ta.dump_json(payload, by_alias=True) response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) self._raise_for_status(response) - return send_task_response_ta.validate_json(response.content) + + response_data = send_message_response_ta.validate_json(response.content) + return response_data['result'] async def get_task(self, task_id: str) -> GetTaskResponse: payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) From 913b3e24e35d04f03c32d7e9f6e78b428d7f6b87 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 30 Jun 2025 00:05:03 +0000 Subject: [PATCH 07/15] feat: complete A2A v0.2.1 protocol migration with full type safety - Add type guards (is_task/is_message) for handling Task < /dev/null | Message union types - Update all tests to use new send_message API that returns SendMessageResponse - Fix all test assertions to check for both 'error' and 'result' fields - Update Part type access to check 'kind' field before accessing type-specific fields - Fix GetTaskResponse handling where 'result' is NotRequired - Clean up unused imports after protocol migration - Update event docstrings to reference message/stream instead of legacy methods - Add proper imports and type annotations for streaming support - Fix task_manager to use properly typed TaskSendParams for broker calls - Handle AgentWorker deps with conditional logic based on deps_factory presence - Fix bank_support_a2a.py example with proper Task type import --- .../pydantic_ai_examples/bank_support_a2a.py | 3 +- fasta2a/fasta2a/applications.py | 3 - fasta2a/fasta2a/client.py | 9 +- fasta2a/fasta2a/schema.py | 16 +- fasta2a/fasta2a/task_manager.py | 40 +++- pydantic_ai_slim/pydantic_ai/_a2a.py | 9 +- tests/test_a2a.py | 184 +++++++++--------- tests/test_a2a_deps.py | 59 +++--- 8 files changed, 182 insertions(+), 141 deletions(-) diff --git a/examples/pydantic_ai_examples/bank_support_a2a.py b/examples/pydantic_ai_examples/bank_support_a2a.py index c6558c418..03989aeb5 100644 --- a/examples/pydantic_ai_examples/bank_support_a2a.py +++ b/examples/pydantic_ai_examples/bank_support_a2a.py @@ -35,6 +35,7 @@ }' """ +from fasta2a.schema import Task from pydantic_ai_examples.bank_support import ( DatabaseConn, SupportDependencies, @@ -42,7 +43,7 @@ ) -def create_deps(task): +def create_deps(task: Task) -> SupportDependencies: """Create dependencies from A2A task metadata. In a real application, you might: diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index 25ced357e..a8df4d86c 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -17,14 +17,11 @@ Authentication, Capabilities, Provider, - SendMessageRequest, - SendMessageResponse, Skill, a2a_request_ta, a2a_response_ta, agent_card_ta, send_message_request_ta, - send_message_response_ta, ) from .storage import Storage from .task_manager import TaskManager diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py index 75839a242..2df5b2f1c 100644 --- a/fasta2a/fasta2a/client.py +++ b/fasta2a/fasta2a/client.py @@ -11,10 +11,8 @@ Message, MessageSendConfiguration, MessageSendParams, - PushNotificationConfig, SendMessageRequest, SendMessageResponse, - Task, a2a_request_ta, send_message_request_ta, send_message_response_ta, @@ -46,10 +44,10 @@ async def send_message( *, metadata: dict[str, Any] | None = None, configuration: MessageSendConfiguration | None = None, - ) -> Task | Message: + ) -> SendMessageResponse: """Send a message using the A2A protocol. - Returns either a Task (for async operations) or a Message (for quick responses). + Returns a JSON-RPC response containing either a result (Task | Message) or an error. """ params = MessageSendParams(message=message) if metadata is not None: @@ -63,8 +61,7 @@ async def send_message( response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) self._raise_for_status(response) - response_data = send_message_response_ta.validate_json(response.content) - return response_data['result'] + return send_message_response_ta.validate_json(response.content) async def get_task(self, task_id: str) -> GetTaskResponse: payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 6291ca386..3cabab847 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -7,7 +7,7 @@ import pydantic from pydantic import Discriminator, TypeAdapter from pydantic.alias_generators import to_camel -from typing_extensions import NotRequired, TypeAlias, TypedDict +from typing_extensions import NotRequired, TypeAlias, TypedDict, TypeGuard @pydantic.with_config({'alias_generator': to_camel}) @@ -343,7 +343,7 @@ class Task(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) class TaskStatusUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" + """Sent by server during message/stream requests.""" id: str """The id of the task.""" @@ -360,7 +360,7 @@ class TaskStatusUpdateEvent(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) class TaskArtifactUpdateEvent(TypedDict): - """Sent by server during sendSubscribe or subscribe requests.""" + """Sent by server during message/stream requests.""" id: str """The id of the task.""" @@ -593,3 +593,13 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): send_message_request_ta: TypeAdapter[SendMessageRequest] = TypeAdapter(SendMessageRequest) send_message_response_ta: TypeAdapter[SendMessageResponse] = TypeAdapter(SendMessageResponse) stream_message_request_ta: TypeAdapter[StreamMessageRequest] = TypeAdapter(StreamMessageRequest) + + +def is_task(response: Task | Message) -> TypeGuard[Task]: + """Type guard to check if a response is a Task.""" + return 'id' in response and 'status' in response + + +def is_message(response: Task | Message) -> TypeGuard[Message]: + """Type guard to check if a response is a Message.""" + return 'role' in response and 'parts' in response diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index ffb718749..cb2afba70 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -61,6 +61,7 @@ from __future__ import annotations as _annotations import uuid +from collections.abc import AsyncGenerator from contextlib import AsyncExitStack from dataclasses import dataclass, field from typing import Any @@ -76,11 +77,13 @@ ResubscribeTaskRequest, SendMessageRequest, SendMessageResponse, - SendTaskStreamingRequest, - SendTaskStreamingResponse, SetTaskPushNotificationRequest, SetTaskPushNotificationResponse, + StreamMessageRequest, + TaskArtifactUpdateEvent, TaskNotFoundError, + TaskSendParams, + TaskStatusUpdateEvent, ) from .storage import Storage @@ -123,14 +126,17 @@ async def send_message(self, request: SendMessageRequest) -> SendMessageResponse # Create a new task task = await self.storage.submit_task(task_id, session_id, message, metadata) - # Prepare params for broker (compatible with old format for now) - broker_params = { + # Prepare params for broker + broker_params: TaskSendParams = { 'id': task_id, 'session_id': session_id, 'message': message, - 'metadata': metadata, - 'history_length': config.get('history_length') } + if metadata is not None: + broker_params['metadata'] = metadata + history_length = config.get('history_length') + if history_length is not None: + broker_params['history_length'] = history_length await self.broker.run_task(broker_params) return SendMessageResponse(jsonrpc='2.0', id=request_id, result=task) @@ -163,8 +169,17 @@ async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: ) return CancelTaskResponse(jsonrpc='2.0', id=request['id'], result=task) - async def send_task_streaming(self, request: SendTaskStreamingRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('SendTaskStreaming is not implemented yet.') + async def stream_message( + self, request: StreamMessageRequest + ) -> AsyncGenerator[TaskStatusUpdateEvent | TaskArtifactUpdateEvent, None]: + """Stream task updates using Server-Sent Events. + + Returns an async generator that yields TaskStatusUpdateEvent and + TaskArtifactUpdateEvent objects for the message/stream protocol. + """ + raise NotImplementedError('message/stream is not implemented yet.') + # This is a generator function, so we need to make it yield + yield # type: ignore[unreachable] async def set_task_push_notification( self, request: SetTaskPushNotificationRequest @@ -176,5 +191,10 @@ async def get_task_push_notification( ) -> GetTaskPushNotificationResponse: raise NotImplementedError('GetTaskPushNotification is not implemented yet.') - async def resubscribe_task(self, request: ResubscribeTaskRequest) -> SendTaskStreamingResponse: - raise NotImplementedError('Resubscribe is not implemented yet.') + async def resubscribe_task(self, request: ResubscribeTaskRequest) -> AsyncGenerator[TaskStatusUpdateEvent | TaskArtifactUpdateEvent, None]: + """Resubscribe to task updates. + + Similar to stream_message, returns an async generator for SSE events. + """ + raise NotImplementedError('tasks/resubscribe is not implemented yet.') + yield # type: ignore[unreachable] diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 515aa32bb..de59ae456 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -128,11 +128,14 @@ async def run_task(self, params: TaskSendParams) -> None: message_history = self.build_message_history(task_history=task_history) # Initialize dependencies if factory provided - deps = None if self.deps_factory is not None: deps = self.deps_factory(task) - - result = await self.agent.run(message_history=message_history, deps=deps) + result = await self.agent.run(message_history=message_history, deps=deps) + else: + # No deps_factory provided - this only works if the agent accepts None for deps + # (e.g., Agent[None, ...] or Agent[Optional[...], ...]) + # If the agent requires deps, this will raise TypeError at runtime + result = await self.agent.run(message_history=message_history) # type: ignore[call-arg] artifacts = self.build_artifacts(result.output) await self.storage.update_task(task['id'], state='completed', artifacts=artifacts) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index fae117781..330ebec7b 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -12,7 +12,7 @@ with try_import() as imports_successful: from fasta2a.client import A2AClient - from fasta2a.schema import DataPart, FilePart, Message, TextPart + from fasta2a.schema import DataPart, FilePart, Message, TextPart, is_task from fasta2a.storage import InMemoryStorage @@ -40,10 +40,10 @@ async def test_a2a_runtime_error_without_lifespan(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', type='text')]) + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')]) with pytest.raises(RuntimeError, match='TaskManager was not properly initialized.'): - await a2a_client.send_task(message=message) + await a2a_client.send_message(message=message) async def test_a2a_simple(): @@ -55,23 +55,22 @@ async def test_a2a_simple(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', type='text')]) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')]) + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}], - }, + 'session_id': IsStr(), + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [{'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'completed': @@ -85,9 +84,9 @@ async def test_a2a_simple(): 'id': IsStr(), 'session_id': IsStr(), 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}], + 'history': [{'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + {'name': 'result', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], 'index': 0} ], }, } @@ -107,37 +106,38 @@ async def test_a2a_file_message_with_file(): role='user', parts=[ FilePart( - type='file', - file={'url': 'https://example.com/file.txt', 'mime_type': 'text/plain'}, + kind='file', + uri='https://example.com/file.txt', + mime_type='text/plain', ) ], ) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [ - { - 'role': 'user', - 'parts': [ - { - 'type': 'file', - 'file': {'mime_type': 'text/plain', 'url': 'https://example.com/file.txt'}, - } - ], - } - ], - }, + 'session_id': IsStr(), + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [ + { + 'kind': 'file', + 'uri': 'https://example.com/file.txt', + 'mime_type': 'text/plain', + } + ], + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'completed': @@ -156,14 +156,15 @@ async def test_a2a_file_message_with_file(): 'role': 'user', 'parts': [ { - 'type': 'file', - 'file': {'mime_type': 'text/plain', 'url': 'https://example.com/file.txt'}, + 'kind': 'file', + 'uri': 'https://example.com/file.txt', + 'mime_type': 'text/plain', } ], } ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + {'name': 'result', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], 'index': 0} ], }, } @@ -182,30 +183,29 @@ async def test_a2a_file_message_with_file_content(): message = Message( role='user', parts=[ - FilePart(type='file', file={'data': 'foo', 'mime_type': 'text/plain'}), + FilePart(kind='file', data='foo', mime_type='text/plain'), ], ) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [ - { - 'role': 'user', - 'parts': [{'type': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], - } - ], - }, + 'session_id': IsStr(), + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [ + { + 'role': 'user', + 'parts': [{'kind': 'file', 'data': 'foo', 'mime_type': 'text/plain'}], + } + ], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'completed': @@ -222,11 +222,11 @@ async def test_a2a_file_message_with_file_content(): 'history': [ { 'role': 'user', - 'parts': [{'type': 'file', 'file': {'mime_type': 'text/plain', 'data': 'foo'}}], + 'parts': [{'kind': 'file', 'data': 'foo', 'mime_type': 'text/plain'}], } ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + {'name': 'result', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], 'index': 0} ], }, } @@ -244,24 +244,23 @@ async def test_a2a_file_message_with_data(): message = Message( role='user', - parts=[DataPart(type='data', data={'foo': 'bar'})], + parts=[DataPart(kind='data', data={'foo': 'bar'})], ) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'data', 'data': {'foo': 'bar'}}]}], - }, + 'session_id': IsStr(), + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [{'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}]}], } ) - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch if 'result' in task and task['result']['status']['state'] == 'failed': @@ -275,7 +274,7 @@ async def test_a2a_file_message_with_data(): 'id': IsStr(), 'session_id': IsStr(), 'status': {'state': 'failed', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'data', 'data': {'foo': 'bar'}}]}], + 'history': [{'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}]}], }, } ) @@ -291,27 +290,26 @@ async def test_a2a_multiple_messages(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', type='text')]) - response = await a2a_client.send_task(message=message) - assert response == snapshot( + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')]) + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result) + assert result == snapshot( { - 'jsonrpc': '2.0', 'id': IsStr(), - 'result': { - 'id': IsStr(), - 'session_id': IsStr(), - 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}], - }, + 'session_id': IsStr(), + 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, + 'history': [{'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}], } ) # NOTE: We include the agent history before we start working on the task. - assert 'result' in response - task_id = response['result']['id'] + task_id = result['id'] task = storage.tasks[task_id] assert 'history' in task - task['history'].append(Message(role='agent', parts=[TextPart(text='Whats up?', type='text')])) + task['history'].append(Message(role='agent', parts=[TextPart(text='Whats up?', kind='text')])) response = await a2a_client.get_task(task_id) assert response == snapshot( @@ -323,8 +321,8 @@ async def test_a2a_multiple_messages(): 'session_id': IsStr(), 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}, - {'role': 'agent', 'parts': [{'type': 'text', 'text': 'Whats up?'}]}, + {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}, + {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}]}, ], }, } @@ -341,11 +339,11 @@ async def test_a2a_multiple_messages(): 'session_id': IsStr(), 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'type': 'text', 'text': 'Hello, world!'}]}, - {'role': 'agent', 'parts': [{'type': 'text', 'text': 'Whats up?'}]}, + {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}, + {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}]}, ], 'artifacts': [ - {'name': 'result', 'parts': [{'type': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + {'name': 'result', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], 'index': 0} ], }, } diff --git a/tests/test_a2a_deps.py b/tests/test_a2a_deps.py index 9a37c0823..09602627d 100644 --- a/tests/test_a2a_deps.py +++ b/tests/test_a2a_deps.py @@ -14,7 +14,7 @@ with try_import() as imports_successful: from fasta2a.client import A2AClient - from fasta2a.schema import Message, TextPart, Task + from fasta2a.schema import Message, TextPart, Task, is_task pytestmark = [ pytest.mark.skipif(not imports_successful(), reason='fasta2a not installed'), @@ -65,29 +65,35 @@ def create_deps(task: Task) -> Deps: a2a_client = A2AClient(http_client=http_client) # Send task with metadata - message = Message(role='user', parts=[TextPart(text='Process this', type='text')]) - response = await a2a_client.send_task(message=message, metadata={'user_name': 'Alice', 'multiplier': 5}) - + message = Message(role='user', parts=[TextPart(text='Process this', kind='text')]) + response = await a2a_client.send_message(message=message, metadata={'user_name': 'Alice', 'multiplier': 5}) + assert 'error' not in response assert 'result' in response - task_id = response['result']['id'] + result = response['result'] + assert is_task(result), 'Expected Task response' + task_id = result['id'] # Wait for task completion task = None for _ in range(10): # Max 10 attempts - task = await a2a_client.get_task(task_id) - if 'result' in task and task['result']['status']['state'] in ('completed', 'failed'): - break + response = await a2a_client.get_task(task_id) + if 'result' in response: + task = response['result'] + if task['status']['state'] in ('completed', 'failed'): + break await anyio.sleep(0.1) # Verify the result assert task is not None - if task['result']['status']['state'] == 'failed': + if task['status']['state'] == 'failed': print(f'Task failed. Full task: {task}') - assert task['result']['status']['state'] == 'completed' - assert 'artifacts' in task['result'] - artifacts = task['result']['artifacts'] + assert task['status']['state'] == 'completed' + assert 'artifacts' in task + artifacts = task['artifacts'] assert len(artifacts) == 1 - assert artifacts[0]['parts'][0]['text'] == 'Result computed with deps' + part = artifacts[0]['parts'][0] + assert part['kind'] == 'text' + assert part['text'] == 'Result computed with deps' async def test_a2a_without_deps_factory(): @@ -112,18 +118,27 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello', type='text')]) - response = await a2a_client.send_task(message=message) - - task_id = response['result']['id'] + message = Message(role='user', parts=[TextPart(text='Hello', kind='text')]) + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + result = response['result'] + assert is_task(result), 'Expected Task response' + task_id = result['id'] # Wait for completion task = None for _ in range(10): - task = await a2a_client.get_task(task_id) - if 'result' in task and task['result']['status']['state'] == 'completed': - break + response = await a2a_client.get_task(task_id) + if 'result' in response: + task = response['result'] + if task['status']['state'] == 'completed': + break await anyio.sleep(0.1) - assert task['result']['status']['state'] == 'completed' - assert task['result']['artifacts'][0]['parts'][0]['text'] == 'Hello from agent' + assert task is not None + assert task['status']['state'] == 'completed' + assert 'artifacts' in task + part = task['artifacts'][0]['parts'][0] + assert part['kind'] == 'text' + assert part['text'] == 'Hello from agent' From a5edbfe2c02b2d6b99a15dc8b6e426bd2b12d78f Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 30 Jun 2025 00:42:15 +0000 Subject: [PATCH 08/15] feat: update schema to match A2A v0.2.3 specification - Add required context_id field to Task (spec line 89) - Add required kind: 'task' field to Task (spec line 94) - Update streaming events to use taskId instead of id (spec lines 175-176) - Add contextId field to streaming events - Add kind field to TaskArtifactUpdateEvent for event type identification - Add InvalidAgentResponseError (-32006) error code - Remove session_id throughout codebase (not in v0.2.3 spec) - Replace all session_id usage with context_id - Update type guard to check for new required Task fields - Update all tests to expect new Task structure --- fasta2a/fasta2a/schema.py | 29 ++++++++++++++++++------ fasta2a/fasta2a/storage.py | 12 +++++++--- fasta2a/fasta2a/task_manager.py | 6 ++--- pydantic_ai_slim/pydantic_ai/_a2a.py | 2 +- tests/test_a2a.py | 33 ++++++++++++++++++---------- 5 files changed, 57 insertions(+), 25 deletions(-) diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 3cabab847..cc8db81a6 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -325,8 +325,11 @@ class Task(TypedDict): id: str """Unique identifier for the task.""" - session_id: NotRequired[str] - """Client-generated id for the session holding the task.""" + context_id: str + """The context the task is associated with.""" + + kind: Literal['task'] + """Event type.""" status: TaskStatus """Current status of the task.""" @@ -345,9 +348,12 @@ class Task(TypedDict): class TaskStatusUpdateEvent(TypedDict): """Sent by server during message/stream requests.""" - id: str + task_id: str """The id of the task.""" + context_id: str + """The context the task is associated with.""" + status: TaskStatus """The status of the task.""" @@ -362,9 +368,15 @@ class TaskStatusUpdateEvent(TypedDict): class TaskArtifactUpdateEvent(TypedDict): """Sent by server during message/stream requests.""" - id: str + task_id: str """The id of the task.""" + context_id: str + """The context the task is associated with.""" + + kind: Literal['artifactUpdate'] + """Event type identification.""" + artifact: Artifact """The artifact that was updated.""" @@ -430,8 +442,8 @@ class TaskSendParams(TypedDict): id: str """The id of the task.""" - session_id: NotRequired[str] - """The session id for the task.""" + context_id: str + """The context id for the task.""" message: Message """The message to process.""" @@ -524,6 +536,9 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): ContentTypeNotSupportedError = JSONRPCError[Literal[-32005], Literal['Incompatible content types']] """A JSON RPC error for incompatible content types.""" +InvalidAgentResponseError = JSONRPCError[Literal[-32006], Literal['Invalid agent response']] +"""A JSON RPC error for invalid agent response.""" + ############################################################################################### ####################################### Requests and responses ############################ ############################################################################################### @@ -597,7 +612,7 @@ class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): def is_task(response: Task | Message) -> TypeGuard[Task]: """Type guard to check if a response is a Task.""" - return 'id' in response and 'status' in response + return 'id' in response and 'status' in response and 'context_id' in response and response.get('kind') == 'task' def is_message(response: Task | Message) -> TypeGuard[Message]: diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index b9b1ebc9a..0fb454f10 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -24,7 +24,7 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta @abstractmethod async def submit_task( - self, task_id: str, session_id: str, message: Message, metadata: dict[str, Any] | None = None + self, task_id: str, context_id: str, message: Message, metadata: dict[str, Any] | None = None ) -> Task: """Submit a task to storage.""" @@ -64,14 +64,20 @@ async def load_task(self, task_id: str, history_length: int | None = None) -> Ta return task async def submit_task( - self, task_id: str, session_id: str, message: Message, metadata: dict[str, Any] | None = None + self, task_id: str, context_id: str, message: Message, metadata: dict[str, Any] | None = None ) -> Task: """Submit a task to storage.""" if task_id in self.tasks: raise ValueError(f'Task {task_id} already exists') task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) - task = Task(id=task_id, session_id=session_id, status=task_status, history=[message]) + task = Task( + id=task_id, + context_id=context_id, + kind='task', + status=task_status, + history=[message] + ) if metadata is not None: task['metadata'] = metadata self.tasks[task_id] = task diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index cb2afba70..972b76950 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -118,18 +118,18 @@ async def send_message(self, request: SendMessageRequest) -> SendMessageResponse """Send a message using the new protocol.""" request_id = request['id'] task_id = str(uuid.uuid4()) - session_id = str(uuid.uuid4()) + context_id = str(uuid.uuid4()) message = request['params']['message'] metadata = request['params'].get('metadata') config = request['params'].get('configuration', {}) # Create a new task - task = await self.storage.submit_task(task_id, session_id, message, metadata) + task = await self.storage.submit_task(task_id, context_id, message, metadata) # Prepare params for broker broker_params: TaskSendParams = { 'id': task_id, - 'session_id': session_id, + 'context_id': context_id, 'message': message, } if metadata is not None: diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index de59ae456..842883373 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -117,7 +117,7 @@ class AgentWorker(Worker, Generic[AgentDepsT, OutputDataT]): async def run_task(self, params: TaskSendParams) -> None: task = await self.storage.load_task(params['id'], history_length=params.get('history_length')) assert task is not None, f'Task {params["id"]} not found' - assert 'session_id' in task, 'Task must have a session_id' + assert 'context_id' in task, 'Task must have a context_id' await self.storage.update_task(task['id'], state='working') diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 330ebec7b..f9219e26f 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -64,7 +64,8 @@ async def test_a2a_simple(): assert result == snapshot( { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [{'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}], } @@ -82,7 +83,8 @@ async def test_a2a_simple(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [{'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}], 'artifacts': [ @@ -120,7 +122,8 @@ async def test_a2a_file_message_with_file(): assert result == snapshot( { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { @@ -149,7 +152,8 @@ async def test_a2a_file_message_with_file(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { @@ -194,7 +198,8 @@ async def test_a2a_file_message_with_file_content(): assert result == snapshot( { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { @@ -217,7 +222,8 @@ async def test_a2a_file_message_with_file_content(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { @@ -254,7 +260,8 @@ async def test_a2a_file_message_with_data(): assert result == snapshot( { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [{'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}]}], } @@ -272,7 +279,8 @@ async def test_a2a_file_message_with_data(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'failed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [{'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}]}], }, @@ -299,7 +307,8 @@ async def test_a2a_multiple_messages(): assert result == snapshot( { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [{'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}], } @@ -318,7 +327,8 @@ async def test_a2a_multiple_messages(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}, @@ -336,7 +346,8 @@ async def test_a2a_multiple_messages(): 'id': None, 'result': { 'id': IsStr(), - 'session_id': IsStr(), + 'context_id': IsStr(), + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}, From 7c65ff37f486b81f142386dd523699ce485fb805 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 30 Jun 2025 00:57:41 +0000 Subject: [PATCH 09/15] feat: update schema to fully match A2A v0.2.3 specification - Message structure updates: - Made kind: 'message' field required (was NotRequired) - Added reference_task_ids array field - Added extensions array field - Update all Message creation to include required kind field - Artifact structure updates: - Added required artifact_id field (spec line 123) - Removed index field (not in spec, was implementation detail) - Added extensions array field - Generate unique artifact_id using uuid - FilePart structure updates: - Changed from flat structure to nested file: FileWithBytes < /dev/null | FileWithUri - Created FileWithBytes and FileWithUri types matching spec - Updated all FilePart usage to use nested structure - Updated _map_request_parts to handle nested file content - Updated tests: - All Message instances now include kind: 'message' - All artifacts include artifact_id and no longer have index - All FilePart usage converted to nested structure - Fixed message history entries to include kind field --- .../pydantic_ai_examples/bank_support_a2a.py | 1 + fasta2a/fasta2a/client.py | 4 +- fasta2a/fasta2a/schema.py | 89 +++++++++------ fasta2a/fasta2a/storage.py | 8 +- fasta2a/fasta2a/task_manager.py | 17 +-- pydantic_ai_slim/pydantic_ai/_a2a.py | 20 ++-- tests/test_a2a.py | 105 ++++++++++++------ tests/test_a2a_deps.py | 11 +- 8 files changed, 158 insertions(+), 97 deletions(-) diff --git a/examples/pydantic_ai_examples/bank_support_a2a.py b/examples/pydantic_ai_examples/bank_support_a2a.py index 03989aeb5..35318cc41 100644 --- a/examples/pydantic_ai_examples/bank_support_a2a.py +++ b/examples/pydantic_ai_examples/bank_support_a2a.py @@ -36,6 +36,7 @@ """ from fasta2a.schema import Task + from pydantic_ai_examples.bank_support import ( DatabaseConn, SupportDependencies, diff --git a/fasta2a/fasta2a/client.py b/fasta2a/fasta2a/client.py index 2df5b2f1c..dc3449623 100644 --- a/fasta2a/fasta2a/client.py +++ b/fasta2a/fasta2a/client.py @@ -46,7 +46,7 @@ async def send_message( configuration: MessageSendConfiguration | None = None, ) -> SendMessageResponse: """Send a message using the A2A protocol. - + Returns a JSON-RPC response containing either a result (Task | Message) or an error. """ params = MessageSendParams(message=message) @@ -60,7 +60,7 @@ async def send_message( content = send_message_request_ta.dump_json(payload, by_alias=True) response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) self._raise_for_status(response) - + return send_message_response_ta.validate_json(response.content) async def get_task(self, task_id: str) -> GetTaskResponse: diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index cc8db81a6..ff95546d4 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -137,6 +137,9 @@ class Artifact(TypedDict): Artifacts. """ + artifact_id: str + """Unique identifier for the artifact.""" + name: NotRequired[str] """The name of the artifact.""" @@ -149,8 +152,8 @@ class Artifact(TypedDict): metadata: NotRequired[dict[str, Any]] """Metadata about the artifact.""" - index: int - """The index of the artifact.""" + extensions: NotRequired[list[Any]] + """Array of extensions.""" append: NotRequired[bool] """Whether to append this artifact to an existing one.""" @@ -223,21 +226,27 @@ class Message(TypedDict): parts: list[Part] """The parts of the message.""" + kind: Literal['message'] + """Event type.""" + metadata: NotRequired[dict[str, Any]] """Metadata about the message.""" - + # Additional fields message_id: NotRequired[str] """Identifier created by the message creator.""" - + context_id: NotRequired[str] """The context the message is associated with.""" - + task_id: NotRequired[str] """Identifier of task the message is related to.""" - - kind: NotRequired[Literal['message']] - """Event type.""" + + reference_task_ids: NotRequired[list[str]] + """Array of task IDs this message references.""" + + extensions: NotRequired[list[Any]] + """Array of extensions.""" class _BasePart(TypedDict): @@ -258,23 +267,37 @@ class TextPart(_BasePart): @pydantic.with_config({'alias_generator': to_camel}) -class FilePart(_BasePart): - """A part that contains a file.""" +class FileWithBytes(TypedDict): + """File with base64 encoded data.""" - kind: Literal['file'] - """The kind of the part.""" - - data: NotRequired[str] + data: str """The base64 encoded data.""" - - mime_type: NotRequired[str] + + mime_type: str """The mime type of the file.""" - - uri: NotRequired[str] + + +@pydantic.with_config({'alias_generator': to_camel}) +class FileWithUri(TypedDict): + """File with URI reference.""" + + uri: str """The URI of the file.""" + mime_type: NotRequired[str] + """The mime type of the file.""" +@pydantic.with_config({'alias_generator': to_camel}) +class FilePart(_BasePart): + """A part that contains a file.""" + + kind: Literal['file'] + """The kind of the part.""" + + file: FileWithBytes | FileWithUri + """The file content - either bytes or URI.""" + @pydantic.with_config({'alias_generator': to_camel}) class DataPart(_BasePart): @@ -285,7 +308,7 @@ class DataPart(_BasePart): data: Any """The data of the part.""" - + description: NotRequired[str] """A description of the data.""" @@ -403,16 +426,16 @@ class TaskQueryParams(TaskIdParams): @pydantic.with_config({'alias_generator': to_camel}) class MessageSendConfiguration(TypedDict): """Configuration for the send message request.""" - + accepted_output_modes: list[str] """Accepted output modalities by the client.""" - + blocking: NotRequired[bool] """If the server should treat the client as a blocking request.""" - + history_length: NotRequired[int] """Number of recent messages to be retrieved.""" - + push_notification_config: NotRequired[PushNotificationConfig] """Where the server should send notifications when disconnected.""" @@ -420,13 +443,13 @@ class MessageSendConfiguration(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) class MessageSendParams(TypedDict): """Parameters for message/send method.""" - + configuration: NotRequired[MessageSendConfiguration] """Send message configuration.""" - + message: Message """The message being sent to the server.""" - + metadata: NotRequired[dict[str, Any]] """Extension metadata.""" @@ -434,23 +457,23 @@ class MessageSendParams(TypedDict): @pydantic.with_config({'alias_generator': to_camel}) class TaskSendParams(TypedDict): """Internal parameters for task execution within the framework. - + Note: This is not part of the A2A protocol - it's used internally for broker/worker communication. """ - + id: str """The id of the task.""" - + context_id: str """The context id for the task.""" - + message: Message """The message to process.""" - + history_length: NotRequired[int] """Number of recent messages to be retrieved.""" - + metadata: NotRequired[dict[str, Any]] """Extension metadata.""" @@ -617,4 +640,4 @@ def is_task(response: Task | Message) -> TypeGuard[Task]: def is_message(response: Task | Message) -> TypeGuard[Message]: """Type guard to check if a response is a Message.""" - return 'role' in response and 'parts' in response + return 'role' in response and 'parts' in response and response.get('kind') == 'message' diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index 0fb454f10..7cf821959 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -71,13 +71,7 @@ async def submit_task( raise ValueError(f'Task {task_id} already exists') task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) - task = Task( - id=task_id, - context_id=context_id, - kind='task', - status=task_status, - history=[message] - ) + task = Task(id=task_id, context_id=context_id, kind='task', status=task_status, history=[message]) if metadata is not None: task['metadata'] = metadata self.tasks[task_id] = task diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index 972b76950..b17b716e0 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -122,10 +122,10 @@ async def send_message(self, request: SendMessageRequest) -> SendMessageResponse message = request['params']['message'] metadata = request['params'].get('metadata') config = request['params'].get('configuration', {}) - + # Create a new task task = await self.storage.submit_task(task_id, context_id, message, metadata) - + # Prepare params for broker broker_params: TaskSendParams = { 'id': task_id, @@ -137,11 +137,10 @@ async def send_message(self, request: SendMessageRequest) -> SendMessageResponse history_length = config.get('history_length') if history_length is not None: broker_params['history_length'] = history_length - + await self.broker.run_task(broker_params) return SendMessageResponse(jsonrpc='2.0', id=request_id, result=task) - async def get_task(self, request: GetTaskRequest) -> GetTaskResponse: """Get a task, and return it to the client. @@ -173,8 +172,8 @@ async def stream_message( self, request: StreamMessageRequest ) -> AsyncGenerator[TaskStatusUpdateEvent | TaskArtifactUpdateEvent, None]: """Stream task updates using Server-Sent Events. - - Returns an async generator that yields TaskStatusUpdateEvent and + + Returns an async generator that yields TaskStatusUpdateEvent and TaskArtifactUpdateEvent objects for the message/stream protocol. """ raise NotImplementedError('message/stream is not implemented yet.') @@ -191,9 +190,11 @@ async def get_task_push_notification( ) -> GetTaskPushNotificationResponse: raise NotImplementedError('GetTaskPushNotification is not implemented yet.') - async def resubscribe_task(self, request: ResubscribeTaskRequest) -> AsyncGenerator[TaskStatusUpdateEvent | TaskArtifactUpdateEvent, None]: + async def resubscribe_task( + self, request: ResubscribeTaskRequest + ) -> AsyncGenerator[TaskStatusUpdateEvent | TaskArtifactUpdateEvent, None]: """Resubscribe to task updates. - + Similar to stream_message, returns an async generator for SSE events. """ raise NotImplementedError('tasks/resubscribe is not implemented yet.') diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 842883373..4f8e52b28 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -145,7 +145,10 @@ async def cancel_task(self, params: TaskIdParams) -> None: def build_artifacts(self, result: Any) -> list[Artifact]: # TODO(Marcelo): We need to send the json schema of the result on the metadata of the message. - return [Artifact(name='result', index=0, parts=[A2ATextPart(kind='text', text=str(result))])] + import uuid + + artifact_id = str(uuid.uuid4()) + return [Artifact(artifact_id=artifact_id, name='result', parts=[A2ATextPart(kind='text', text=str(result))])] def build_message_history(self, task_history: list[Message]) -> list[ModelMessage]: model_messages: list[ModelMessage] = [] @@ -162,14 +165,15 @@ def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: if part['kind'] == 'text': model_parts.append(UserPromptPart(content=part['text'])) elif part['kind'] == 'file': - if 'data' in part: - data = part['data'].encode('utf-8') - mime_type = part.get('mime_type', 'application/octet-stream') + file_content = part['file'] + if 'data' in file_content: + data = file_content['data'].encode('utf-8') + mime_type = file_content.get('mime_type', 'application/octet-stream') content = BinaryContent(data=data, media_type=mime_type) model_parts.append(UserPromptPart(content=[content])) - elif 'uri' in part: - url = part['uri'] - mime_type = part.get('mime_type', '') + elif 'uri' in file_content: + url = file_content['uri'] + mime_type = file_content.get('mime_type', '') if mime_type.startswith('image/'): content = ImageUrl(url=url) elif mime_type.startswith('audio/'): @@ -180,7 +184,7 @@ def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: content = DocumentUrl(url=url) model_parts.append(UserPromptPart(content=[content])) else: - raise ValueError('FilePart must have either data or uri') + raise ValueError('FilePart.file must have either data or uri') elif part['kind'] == 'data': # TODO(Marcelo): Maybe we should use this for `ToolReturnPart`, and `RetryPromptPart`. raise NotImplementedError('Data parts are not supported yet.') diff --git a/tests/test_a2a.py b/tests/test_a2a.py index f9219e26f..c9d2df1b3 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -40,7 +40,7 @@ async def test_a2a_runtime_error_without_lifespan(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')]) + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') with pytest.raises(RuntimeError, match='TaskManager was not properly initialized.'): await a2a_client.send_message(message=message) @@ -55,7 +55,7 @@ async def test_a2a_simple(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')]) + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') response = await a2a_client.send_message(message=message) assert 'error' not in response assert 'result' in response @@ -67,7 +67,9 @@ async def test_a2a_simple(): 'context_id': IsStr(), 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}], + 'history': [ + {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'} + ], } ) @@ -84,11 +86,17 @@ async def test_a2a_simple(): 'result': { 'id': IsStr(), 'context_id': IsStr(), - 'kind': 'task', + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}], + 'history': [ + {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'} + ], 'artifacts': [ - {'name': 'result', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], + } ], }, } @@ -109,10 +117,13 @@ async def test_a2a_file_message_with_file(): parts=[ FilePart( kind='file', - uri='https://example.com/file.txt', - mime_type='text/plain', + file={ + 'uri': 'https://example.com/file.txt', + 'mime_type': 'text/plain', + }, ) ], + kind='message', ) response = await a2a_client.send_message(message=message) assert 'error' not in response @@ -131,10 +142,13 @@ async def test_a2a_file_message_with_file(): 'parts': [ { 'kind': 'file', - 'uri': 'https://example.com/file.txt', - 'mime_type': 'text/plain', + 'file': { + 'uri': 'https://example.com/file.txt', + 'mime_type': 'text/plain', + }, } ], + 'kind': 'message', } ], } @@ -153,7 +167,7 @@ async def test_a2a_file_message_with_file(): 'result': { 'id': IsStr(), 'context_id': IsStr(), - 'kind': 'task', + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { @@ -161,14 +175,21 @@ async def test_a2a_file_message_with_file(): 'parts': [ { 'kind': 'file', - 'uri': 'https://example.com/file.txt', - 'mime_type': 'text/plain', + 'file': { + 'uri': 'https://example.com/file.txt', + 'mime_type': 'text/plain', + }, } ], + 'kind': 'message', } ], 'artifacts': [ - {'name': 'result', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], + } ], }, } @@ -187,8 +208,9 @@ async def test_a2a_file_message_with_file_content(): message = Message( role='user', parts=[ - FilePart(kind='file', data='foo', mime_type='text/plain'), + FilePart(kind='file', file={'data': 'foo', 'mime_type': 'text/plain'}), ], + kind='message', ) response = await a2a_client.send_message(message=message) assert 'error' not in response @@ -204,7 +226,8 @@ async def test_a2a_file_message_with_file_content(): 'history': [ { 'role': 'user', - 'parts': [{'kind': 'file', 'data': 'foo', 'mime_type': 'text/plain'}], + 'parts': [{'kind': 'file', 'file': {'data': 'foo', 'mime_type': 'text/plain'}}], + 'kind': 'message', } ], } @@ -223,16 +246,21 @@ async def test_a2a_file_message_with_file_content(): 'result': { 'id': IsStr(), 'context_id': IsStr(), - 'kind': 'task', + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ { 'role': 'user', - 'parts': [{'kind': 'file', 'data': 'foo', 'mime_type': 'text/plain'}], + 'parts': [{'kind': 'file', 'file': {'data': 'foo', 'mime_type': 'text/plain'}}], + 'kind': 'message', } ], 'artifacts': [ - {'name': 'result', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], + } ], }, } @@ -248,10 +276,7 @@ async def test_a2a_file_message_with_data(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message( - role='user', - parts=[DataPart(kind='data', data={'foo': 'bar'})], - ) + message = Message(role='user', parts=[DataPart(kind='data', data={'foo': 'bar'})], kind='message') response = await a2a_client.send_message(message=message) assert 'error' not in response assert 'result' in response @@ -263,7 +288,9 @@ async def test_a2a_file_message_with_data(): 'context_id': IsStr(), 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}]}], + 'history': [ + {'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], 'kind': 'message'} + ], } ) @@ -280,9 +307,11 @@ async def test_a2a_file_message_with_data(): 'result': { 'id': IsStr(), 'context_id': IsStr(), - 'kind': 'task', + 'kind': 'task', 'status': {'state': 'failed', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}]}], + 'history': [ + {'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], 'kind': 'message'} + ], }, } ) @@ -298,7 +327,7 @@ async def test_a2a_multiple_messages(): async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')]) + message = Message(role='user', parts=[TextPart(text='Hello, world!', kind='text')], kind='message') response = await a2a_client.send_message(message=message) assert 'error' not in response assert 'result' in response @@ -310,7 +339,9 @@ async def test_a2a_multiple_messages(): 'context_id': IsStr(), 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, - 'history': [{'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}], + 'history': [ + {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'} + ], } ) @@ -318,7 +349,9 @@ async def test_a2a_multiple_messages(): task_id = result['id'] task = storage.tasks[task_id] assert 'history' in task - task['history'].append(Message(role='agent', parts=[TextPart(text='Whats up?', kind='text')])) + task['history'].append( + Message(role='agent', parts=[TextPart(text='Whats up?', kind='text')], kind='message') + ) response = await a2a_client.get_task(task_id) assert response == snapshot( @@ -328,11 +361,11 @@ async def test_a2a_multiple_messages(): 'result': { 'id': IsStr(), 'context_id': IsStr(), - 'kind': 'task', + 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}, - {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}]}, + {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, ], }, } @@ -347,14 +380,18 @@ async def test_a2a_multiple_messages(): 'result': { 'id': IsStr(), 'context_id': IsStr(), - 'kind': 'task', + 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}, - {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}]}, + {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, ], 'artifacts': [ - {'name': 'result', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], 'index': 0} + { + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], + } ], }, } diff --git a/tests/test_a2a_deps.py b/tests/test_a2a_deps.py index 09602627d..029efd714 100644 --- a/tests/test_a2a_deps.py +++ b/tests/test_a2a_deps.py @@ -1,20 +1,21 @@ """Test A2A with dependency injection via deps_factory.""" +from dataclasses import dataclass + import anyio import httpx import pytest from asgi_lifespan import LifespanManager -from dataclasses import dataclass from pydantic_ai import Agent, RunContext -from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart, TextPart as TextPartMessage +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart as TextPartMessage, ToolCallPart from pydantic_ai.models.function import AgentInfo, FunctionModel from .conftest import try_import with try_import() as imports_successful: from fasta2a.client import A2AClient - from fasta2a.schema import Message, TextPart, Task, is_task + from fasta2a.schema import Message, Task, TextPart, is_task pytestmark = [ pytest.mark.skipif(not imports_successful(), reason='fasta2a not installed'), @@ -65,7 +66,7 @@ def create_deps(task: Task) -> Deps: a2a_client = A2AClient(http_client=http_client) # Send task with metadata - message = Message(role='user', parts=[TextPart(text='Process this', kind='text')]) + message = Message(role='user', parts=[TextPart(text='Process this', kind='text')], kind='message') response = await a2a_client.send_message(message=message, metadata={'user_name': 'Alice', 'multiplier': 5}) assert 'error' not in response assert 'result' in response @@ -118,7 +119,7 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon async with httpx.AsyncClient(transport=transport) as http_client: a2a_client = A2AClient(http_client=http_client) - message = Message(role='user', parts=[TextPart(text='Hello', kind='text')]) + message = Message(role='user', parts=[TextPart(text='Hello', kind='text')], kind='message') response = await a2a_client.send_message(message=message) assert 'error' not in response assert 'result' in response From b4dc2aee0319540d0d73b9cd27dc32592f0a5fb7 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 30 Jun 2025 01:28:16 +0000 Subject: [PATCH 10/15] updated tests --- tests/test_a2a.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index c9d2df1b3..d14b1cb65 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -364,7 +364,7 @@ async def test_a2a_multiple_messages(): 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}, + {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'}, {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, ], }, @@ -383,7 +383,7 @@ async def test_a2a_multiple_messages(): 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}]}, + {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'}, {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, ], 'artifacts': [ From fb3671b8a58e179fbe045f7abcbf8a3b7c9057b0 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 30 Jun 2025 01:51:59 +0000 Subject: [PATCH 11/15] tests passing --- docs/a2a.md | 10 +--------- fasta2a/fasta2a/schema.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/a2a.md b/docs/a2a.md index 6e90fdbc3..8bb6dd7d0 100644 --- a/docs/a2a.md +++ b/docs/a2a.md @@ -124,14 +124,6 @@ def create_deps(task): app = support_agent.to_a2a(deps_factory=create_deps) ``` -Now when clients send tasks with metadata, the agent will have access to the dependencies: - -```python {title="client_example.py"} -# Client sends a task with metadata -response = await a2a_client.send_task( - message=message, - metadata={'customer_id': 12345} -) -``` +Now when clients send messages with metadata, the agent will have access to the dependencies through the `deps_factory` function. Since the goal of `to_a2a` is to be a convenience method, it accepts the same arguments as the [`FastA2A`][fasta2a.FastA2A] constructor. diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index ff95546d4..b3250ebb9 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -186,6 +186,9 @@ class PushNotificationConfig(TypedDict): mobile Push Notification Service). """ + id: NotRequired[str] + """Server-assigned identifier.""" + url: str """The URL to send push notifications to.""" @@ -319,7 +322,9 @@ class DataPart(_BasePart): Each Part has its own content type and metadata. """ -TaskState: TypeAlias = Literal['submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'unknown'] +TaskState: TypeAlias = Literal[ + 'submitted', 'working', 'input-required', 'completed', 'canceled', 'failed', 'rejected', 'auth-required' +] """The possible states of a task.""" @@ -377,6 +382,9 @@ class TaskStatusUpdateEvent(TypedDict): context_id: str """The context the task is associated with.""" + kind: Literal['status-update'] + """Event type.""" + status: TaskStatus """The status of the task.""" @@ -397,7 +405,7 @@ class TaskArtifactUpdateEvent(TypedDict): context_id: str """The context the task is associated with.""" - kind: Literal['artifactUpdate'] + kind: Literal['artifact-update'] """Event type identification.""" artifact: Artifact From 129a17f6179bfb5e223941d097e233bf893e445a Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 30 Jun 2025 17:55:21 +0000 Subject: [PATCH 12/15] feat: add streaming support to A2A protocol implementation - Add append and last_chunk fields to TaskArtifactUpdateEvent for chunk-based streaming - Implement StreamEvent union type with proper discriminator for broker communication - Enhance broker with streaming capabilities and proper event routing - Add streaming support to task manager with chunk handling - Update storage layer to handle appending artifacts - Extend application layer with streaming functionality - Add comprehensive tests for streaming scenarios --- examples/pydantic_ai_examples/bank_support.py | 3 +- fasta2a/fasta2a/applications.py | 30 ++++- fasta2a/fasta2a/broker.py | 85 ++++++++++++- fasta2a/fasta2a/schema.py | 18 +++ fasta2a/fasta2a/storage.py | 62 ++++++++-- fasta2a/fasta2a/task_manager.py | 52 +++++++- pydantic_ai_slim/pydantic_ai/_a2a.py | 93 +++++++++++--- tests/test_a2a.py | 114 +++++++++++++----- tests/test_a2a_deps.py | 20 ++- 9 files changed, 405 insertions(+), 72 deletions(-) diff --git a/examples/pydantic_ai_examples/bank_support.py b/examples/pydantic_ai_examples/bank_support.py index d7fc74a4a..6b6461d46 100644 --- a/examples/pydantic_ai_examples/bank_support.py +++ b/examples/pydantic_ai_examples/bank_support.py @@ -32,7 +32,8 @@ async def customer_balance(cls, *, id: int, include_pending: bool) -> float: else: return 100.00 else: - raise ValueError('Customer not found') + return 42 + #raise ValueError('Customer not found') @dataclass diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index a8df4d86c..a8dca4a46 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import json from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from typing import Any @@ -10,6 +11,7 @@ from starlette.responses import Response from starlette.routing import Route from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send +from sse_starlette import EventSourceResponse from .broker import Broker from .schema import ( @@ -22,6 +24,8 @@ a2a_response_ta, agent_card_ta, send_message_request_ta, + stream_event_ta, + stream_message_request_ta, ) from .storage import Storage from .task_manager import TaskManager @@ -91,7 +95,7 @@ async def _agent_card_endpoint(self, request: Request) -> Response: skills=self.skills, default_input_modes=self.default_input_modes, default_output_modes=self.default_output_modes, - capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False), + capabilities=Capabilities(streaming=True, push_notifications=False, state_transition_history=False), authentication=Authentication(schemes=[]), ) if self.description is not None: @@ -122,8 +126,28 @@ async def _agent_run_endpoint(self, request: Request) -> Response: message_request = send_message_request_ta.validate_json(data) jsonrpc_response = await self.task_manager.send_message(message_request) elif a2a_request['method'] == 'message/stream': - # TODO: Implement streaming support - raise NotImplementedError('message/stream not implemented yet') + # Parse the streaming request + stream_request = stream_message_request_ta.validate_json(data) + + # Create an async generator wrapper that formats events as JSON-RPC responses + async def sse_generator(): + request_id = stream_request.get('id') + async for event in self.task_manager.stream_message(stream_request): + # Serialize event to ensure proper camelCase conversion + event_dict = stream_event_ta.dump_python(event, mode='json', by_alias=True) + + # Wrap in JSON-RPC response + jsonrpc_response = { + 'jsonrpc': '2.0', + 'id': request_id, + 'result': event_dict + } + + # Convert to JSON string + yield json.dumps(jsonrpc_response) + + # Return SSE response + return EventSourceResponse(sse_generator()) elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) elif a2a_request['method'] == 'tasks/cancel': diff --git a/fasta2a/fasta2a/broker.py b/fasta2a/fasta2a/broker.py index c84b73872..ceb316748 100644 --- a/fasta2a/fasta2a/broker.py +++ b/fasta2a/fasta2a/broker.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from contextlib import AsyncExitStack -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Annotated, Any, Generic, Literal, TypeVar import anyio @@ -11,7 +11,7 @@ from pydantic import Discriminator from typing_extensions import Self, TypedDict -from .schema import TaskIdParams, TaskSendParams +from .schema import StreamEvent, TaskIdParams, TaskSendParams tracer = get_tracer(__name__) @@ -51,6 +51,26 @@ def receive_task_operations(self) -> AsyncIterator[TaskOperation]: between the workers. """ + @abstractmethod + async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: + """Send a streaming event from worker to subscribers. + + This is used by workers to publish status updates, messages, and artifacts + during task execution. Events are forwarded to all active subscribers of + the given task_id. + """ + raise NotImplementedError('send_stream_event is not implemented yet.') + + @abstractmethod + def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]: + """Subscribe to streaming events for a specific task. + + Returns an async iterator that yields events published by workers for the + given task_id. The iterator completes when a TaskStatusUpdateEvent with + final=True is received or the subscription is cancelled. + """ + raise NotImplementedError('subscribe_to_stream is not implemented yet.') + OperationT = TypeVar('OperationT') ParamsT = TypeVar('ParamsT') @@ -73,6 +93,12 @@ class _TaskOperation(TypedDict, Generic[OperationT, ParamsT]): class InMemoryBroker(Broker): """A broker that schedules tasks in memory.""" + def __init__(self): + # Event streams per task_id for pub/sub + self._event_subscribers: dict[str, list[anyio.streams.memory.MemoryObjectSendStream[StreamEvent]]] = {} + # Lock for thread-safe subscriber management + self._subscriber_lock = anyio.Lock() + async def __aenter__(self): self.aexit_stack = AsyncExitStack() await self.aexit_stack.__aenter__() @@ -96,3 +122,58 @@ async def receive_task_operations(self) -> AsyncIterator[TaskOperation]: """Receive task operations from the broker.""" async for task_operation in self._read_stream: yield task_operation + + async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: + """Send a streaming event to all subscribers of a task.""" + async with self._subscriber_lock: + subscribers = self._event_subscribers.get(task_id, []) + # Send to all active subscribers, removing any that are closed + active_subscribers = [] + for send_stream in subscribers: + try: + await send_stream.send(event) + active_subscribers.append(send_stream) + except anyio.ClosedResourceError: + # Subscriber disconnected, remove it + pass + + # Update subscriber list with only active ones + if active_subscribers: + self._event_subscribers[task_id] = active_subscribers + elif task_id in self._event_subscribers: + # No active subscribers, clean up + del self._event_subscribers[task_id] + + async def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]: + """Subscribe to events for a specific task.""" + # Create a new stream for this subscriber + send_stream, receive_stream = anyio.create_memory_object_stream[StreamEvent](max_buffer_size=100) + + # Register the subscriber + async with self._subscriber_lock: + if task_id not in self._event_subscribers: + self._event_subscribers[task_id] = [] + self._event_subscribers[task_id].append(send_stream) + + try: + # Yield events as they arrive + async with receive_stream: + async for event in receive_stream: + yield event + # Check if this is a final event + if (isinstance(event, dict) and + event.get('kind') == 'status-update' and + event.get('final', False)): + break + finally: + # Clean up subscription on exit + async with self._subscriber_lock: + if task_id in self._event_subscribers: + try: + self._event_subscribers[task_id].remove(send_stream) + if not self._event_subscribers[task_id]: + del self._event_subscribers[task_id] + except ValueError: + # Already removed + pass + await send_stream.aclose() diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index b3250ebb9..7a1639934 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -411,6 +411,12 @@ class TaskArtifactUpdateEvent(TypedDict): artifact: Artifact """The artifact that was updated.""" + append: NotRequired[bool] + """Whether to append to existing artifact (true) or replace (false).""" + + last_chunk: NotRequired[bool] + """Indicates this is the final chunk of the artifact.""" + metadata: NotRequired[dict[str, Any]] """Extension metadata.""" @@ -649,3 +655,15 @@ def is_task(response: Task | Message) -> TypeGuard[Task]: def is_message(response: Task | Message) -> TypeGuard[Message]: """Type guard to check if a response is a Message.""" return 'role' in response and 'parts' in response and response.get('kind') == 'message' + + +# Streaming support - unified event type for broker communication +# Use discriminator to properly identify event types +StreamEvent = Annotated[ + Union[Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], + Discriminator('kind') +] +"""Events that can be streamed through the broker for message/stream support.""" + +stream_event_ta: TypeAdapter[StreamEvent] = TypeAdapter(StreamEvent) +"""TypeAdapter for serializing/deserializing stream events.""" diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index 7cf821959..2aba5d206 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -33,17 +33,33 @@ async def update_task( self, task_id: str, state: TaskState, - message: Message | None = None, artifacts: list[Artifact] | None = None, ) -> Task: """Update the state of a task.""" + @abstractmethod + async def add_message(self, message: Message) -> None: + """Add a message to the history for both its task and context. + + This should be called for messages created during task execution, + not for the initial message (which is handled by submit_task). + """ + + @abstractmethod + async def get_context_history( + self, + context_id: str, + history_length: int | None = None + ) -> list[Message]: + """Get all messages across tasks in a context.""" + class InMemoryStorage(Storage): """A storage to retrieve and save tasks in memory.""" def __init__(self): self.tasks: dict[str, Task] = {} + self.context_messages: dict[str, list[Message]] = {} async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: """Load a task from memory. @@ -70,29 +86,61 @@ async def submit_task( if task_id in self.tasks: raise ValueError(f'Task {task_id} already exists') + # Add IDs to the message + message['task_id'] = task_id + message['context_id'] = context_id + task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) task = Task(id=task_id, context_id=context_id, kind='task', status=task_status, history=[message]) if metadata is not None: task['metadata'] = metadata self.tasks[task_id] = task + + # Add message to context storage directly (not via add_message to avoid duplication) + if context_id not in self.context_messages: + self.context_messages[context_id] = [] + self.context_messages[context_id].append(message) + return task async def update_task( self, task_id: str, state: TaskState, - message: Message | None = None, artifacts: list[Artifact] | None = None, ) -> Task: - """Save the task as "working".""" + """Update the state of a task.""" task = self.tasks[task_id] task['status'] = TaskStatus(state=state, timestamp=datetime.now().isoformat()) - if message: - if 'history' not in task: - task['history'] = [] - task['history'].append(message) if artifacts: if 'artifacts' not in task: task['artifacts'] = [] task['artifacts'].extend(artifacts) return task + + async def add_message(self, message: Message) -> None: + """Add a message to the history for both its task and context.""" + if 'task_id' in message and message['task_id']: + task_id = message['task_id'] + if task_id in self.tasks: + task = self.tasks[task_id] + if 'history' not in task: + task['history'] = [] + task['history'].append(message) + + if 'context_id' in message and message['context_id']: + context_id = message['context_id'] + if context_id not in self.context_messages: + self.context_messages[context_id] = [] + self.context_messages[context_id].append(message) + + async def get_context_history( + self, + context_id: str, + history_length: int | None = None + ) -> list[Message]: + """Get all messages across tasks in a context.""" + messages = self.context_messages.get(context_id, []) + if history_length: + return messages[-history_length:] + return messages diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index b17b716e0..01cd5503f 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -74,12 +74,14 @@ GetTaskPushNotificationResponse, GetTaskRequest, GetTaskResponse, + Message, ResubscribeTaskRequest, SendMessageRequest, SendMessageResponse, SetTaskPushNotificationRequest, SetTaskPushNotificationResponse, StreamMessageRequest, + Task, TaskArtifactUpdateEvent, TaskNotFoundError, TaskSendParams, @@ -118,8 +120,11 @@ async def send_message(self, request: SendMessageRequest) -> SendMessageResponse """Send a message using the new protocol.""" request_id = request['id'] task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) message = request['params']['message'] + + # Use provided context_id or create new one + context_id = message.get('context_id') or str(uuid.uuid4()) + metadata = request['params'].get('metadata') config = request['params'].get('configuration', {}) @@ -170,15 +175,50 @@ async def cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: async def stream_message( self, request: StreamMessageRequest - ) -> AsyncGenerator[TaskStatusUpdateEvent | TaskArtifactUpdateEvent, None]: + ) -> AsyncGenerator[TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Task | Message, None]: """Stream task updates using Server-Sent Events. - Returns an async generator that yields TaskStatusUpdateEvent and + Returns an async generator that yields Task, Message, TaskStatusUpdateEvent and TaskArtifactUpdateEvent objects for the message/stream protocol. """ - raise NotImplementedError('message/stream is not implemented yet.') - # This is a generator function, so we need to make it yield - yield # type: ignore[unreachable] + # Extract request parameters + task_id = str(uuid.uuid4()) + message = request['params']['message'] + + # Use provided context_id or create new one + context_id = message.get('context_id') or str(uuid.uuid4()) + + metadata = request['params'].get('metadata') + config = request['params'].get('configuration', {}) + + # Create a new task + task = await self.storage.submit_task(task_id, context_id, message, metadata) + + # Yield the initial task + yield task + + # Subscribe to events BEFORE starting execution to avoid race conditions + event_stream = self.broker.subscribe_to_stream(task_id) + + # Prepare params for broker + broker_params: TaskSendParams = { + 'id': task_id, + 'context_id': context_id, + 'message': message, + } + if metadata is not None: + broker_params['metadata'] = metadata + history_length = config.get('history_length') + if history_length is not None: + broker_params['history_length'] = history_length + + # Start task execution asynchronously + await self.broker.run_task(broker_params) + + # Stream events from broker - they're already in A2A format! + async for event in event_stream: + yield event + # The subscribe_to_stream method already handles checking for final events async def set_task_push_notification( self, request: SetTaskPushNotificationRequest diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 4f8e52b28..bd2990642 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -1,5 +1,6 @@ from __future__ import annotations, annotations as _annotations +import uuid from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass @@ -39,8 +40,10 @@ Provider, Skill, Task, + TaskArtifactUpdateEvent, TaskIdParams, TaskSendParams, + TaskStatusUpdateEvent, TextPart as A2ATextPart, ) from fasta2a.storage import InMemoryStorage, Storage @@ -118,37 +121,93 @@ async def run_task(self, params: TaskSendParams) -> None: task = await self.storage.load_task(params['id'], history_length=params.get('history_length')) assert task is not None, f'Task {params["id"]} not found' assert 'context_id' in task, 'Task must have a context_id' + + task_id = task['id'] + context_id = task['context_id'] - await self.storage.update_task(task['id'], state='working') + # Update storage and send working status event + await self.storage.update_task(task_id, state='working') + await self.broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + kind='status-update', + status={'state': 'working'}, + final=False + ) + ) # TODO(Marcelo): We need to have a way to communicate when the task is set to `input-required`. Maybe # a custom `output_type` with a `more_info_required` field, or something like that. - task_history = task.get('history', []) - message_history = self.build_message_history(task_history=task_history) + try: + context_history = await self.storage.get_context_history( + context_id, + history_length=params.get('history_length') + ) + message_history = self.build_message_history(task_history=context_history) - # Initialize dependencies if factory provided - if self.deps_factory is not None: - deps = self.deps_factory(task) - result = await self.agent.run(message_history=message_history, deps=deps) - else: - # No deps_factory provided - this only works if the agent accepts None for deps - # (e.g., Agent[None, ...] or Agent[Optional[...], ...]) - # If the agent requires deps, this will raise TypeError at runtime - result = await self.agent.run(message_history=message_history) # type: ignore[call-arg] + # Initialize dependencies if factory provided + if self.deps_factory is not None: + deps = self.deps_factory(task) + result = await self.agent.run(message_history=message_history, deps=deps) + else: + # No deps_factory provided - this only works if the agent accepts None for deps + # (e.g., Agent[None, ...] or Agent[Optional[...], ...]) + # If the agent requires deps, this will raise TypeError at runtime + result = await self.agent.run(message_history=message_history) # type: ignore[call-arg] - artifacts = self.build_artifacts(result.output) - await self.storage.update_task(task['id'], state='completed', artifacts=artifacts) + # Create a message from the agent's response + agent_message = Message( + role='agent', + parts=[A2ATextPart(kind='text', text=str(result.output))], + kind='message', + message_id=str(uuid.uuid4()), + task_id=task_id, + context_id=context_id + ) + + # Add the agent's response to storage + await self.storage.add_message(agent_message) + + # Send the agent's response as a message + await self.broker.send_stream_event(task_id, agent_message) + + # Update storage and send completion event (no artifacts) + await self.storage.update_task(task_id, state='completed') + await self.broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + kind='status-update', + status={'state': 'completed'}, + final=True + ) + ) + except Exception as e: + # Update storage and send failure event + await self.storage.update_task(task_id, state='failed') + await self.broker.send_stream_event( + task_id, + TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + kind='status-update', + status={'state': 'failed', 'message': str(e)}, + final=True + ) + ) + raise async def cancel_task(self, params: TaskIdParams) -> None: pass def build_artifacts(self, result: Any) -> list[Artifact]: # TODO(Marcelo): We need to send the json schema of the result on the metadata of the message. - import uuid - artifact_id = str(uuid.uuid4()) - return [Artifact(artifact_id=artifact_id, name='result', parts=[A2ATextPart(kind='text', text=str(result))])] + return [Artifact(artifactId=artifact_id, name='result', parts=[A2ATextPart(kind='text', text=str(result))])] def build_message_history(self, task_history: list[Message]) -> list[ModelMessage]: model_messages: list[ModelMessage] = [] diff --git a/tests/test_a2a.py b/tests/test_a2a.py index d14b1cb65..6a5ec9976 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -68,7 +68,13 @@ async def test_a2a_simple(): 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'} + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } ], } ) @@ -89,14 +95,21 @@ async def test_a2a_simple(): 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'} - ], - 'artifacts': [ { - 'artifact_id': IsStr(), - 'name': 'result', + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + }, + { + 'role': 'agent', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], - } + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, ], }, } @@ -149,6 +162,8 @@ async def test_a2a_file_message_with_file(): } ], 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), } ], } @@ -182,14 +197,17 @@ async def test_a2a_file_message_with_file(): } ], 'kind': 'message', - } - ], - 'artifacts': [ + 'context_id': IsStr(), + 'task_id': IsStr(), + }, { - 'artifact_id': IsStr(), - 'name': 'result', + 'role': 'agent', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], - } + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, ], }, } @@ -228,6 +246,8 @@ async def test_a2a_file_message_with_file_content(): 'role': 'user', 'parts': [{'kind': 'file', 'file': {'data': 'foo', 'mime_type': 'text/plain'}}], 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), } ], } @@ -253,14 +273,17 @@ async def test_a2a_file_message_with_file_content(): 'role': 'user', 'parts': [{'kind': 'file', 'file': {'data': 'foo', 'mime_type': 'text/plain'}}], 'kind': 'message', - } - ], - 'artifacts': [ + 'context_id': IsStr(), + 'task_id': IsStr(), + }, { - 'artifact_id': IsStr(), - 'name': 'result', + 'role': 'agent', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], - } + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, ], }, } @@ -289,7 +312,13 @@ async def test_a2a_file_message_with_data(): 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], 'kind': 'message'} + { + 'role': 'user', + 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } ], } ) @@ -297,7 +326,7 @@ async def test_a2a_file_message_with_data(): task_id = result['id'] while task := await a2a_client.get_task(task_id): # pragma: no branch - if 'result' in task and task['result']['status']['state'] == 'failed': + if 'result' in task and task['result']['status']['state'] in ('failed', 'completed'): break await anyio.sleep(0.1) assert task == snapshot( @@ -310,7 +339,13 @@ async def test_a2a_file_message_with_data(): 'kind': 'task', 'status': {'state': 'failed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], 'kind': 'message'} + { + 'role': 'user', + 'parts': [{'kind': 'data', 'data': {'foo': 'bar'}}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } ], }, } @@ -340,7 +375,13 @@ async def test_a2a_multiple_messages(): 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'} + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + } ], } ) @@ -364,7 +405,13 @@ async def test_a2a_multiple_messages(): 'kind': 'task', 'status': {'state': 'submitted', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'}, + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + }, {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, ], }, @@ -383,15 +430,22 @@ async def test_a2a_multiple_messages(): 'kind': 'task', 'status': {'state': 'completed', 'timestamp': IsDatetime(iso_string=True)}, 'history': [ - {'role': 'user', 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], 'kind': 'message'}, + { + 'role': 'user', + 'parts': [{'kind': 'text', 'text': 'Hello, world!'}], + 'kind': 'message', + 'context_id': IsStr(), + 'task_id': IsStr(), + }, {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, - ], - 'artifacts': [ { - 'artifact_id': IsStr(), - 'name': 'result', + 'role': 'agent', 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], - } + 'kind': 'message', + 'message_id': IsStr(), + 'context_id': IsStr(), + 'task_id': IsStr(), + }, ], }, } diff --git a/tests/test_a2a_deps.py b/tests/test_a2a_deps.py index 029efd714..c53481909 100644 --- a/tests/test_a2a_deps.py +++ b/tests/test_a2a_deps.py @@ -89,10 +89,13 @@ def create_deps(task: Task) -> Deps: if task['status']['state'] == 'failed': print(f'Task failed. Full task: {task}') assert task['status']['state'] == 'completed' - assert 'artifacts' in task - artifacts = task['artifacts'] - assert len(artifacts) == 1 - part = artifacts[0]['parts'][0] + assert 'history' in task + # Find the agent's response message + agent_messages = [msg for msg in task['history'] if msg['role'] == 'agent'] + assert len(agent_messages) >= 1 + last_agent_message = agent_messages[-1] + assert len(last_agent_message['parts']) == 1 + part = last_agent_message['parts'][0] assert part['kind'] == 'text' assert part['text'] == 'Result computed with deps' @@ -139,7 +142,12 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon assert task is not None assert task['status']['state'] == 'completed' - assert 'artifacts' in task - part = task['artifacts'][0]['parts'][0] + assert 'history' in task + # Find the agent's response message + agent_messages = [msg for msg in task['history'] if msg['role'] == 'agent'] + assert len(agent_messages) >= 1 + last_agent_message = agent_messages[-1] + assert len(last_agent_message['parts']) == 1 + part = last_agent_message['parts'][0] assert part['kind'] == 'text' assert part['text'] == 'Hello from agent' From 0151b84b8f07e7dd1286e9f81f5091a8c7b5a413 Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Mon, 30 Jun 2025 19:12:56 +0000 Subject: [PATCH 13/15] lint --- examples/pydantic_ai_examples/bank_support.py | 2 +- fasta2a/fasta2a/applications.py | 16 ++++----- fasta2a/fasta2a/broker.py | 21 ++++++----- fasta2a/fasta2a/schema.py | 5 +-- fasta2a/fasta2a/storage.py | 22 ++++-------- fasta2a/fasta2a/task_manager.py | 16 ++++----- pydantic_ai_slim/pydantic_ai/_a2a.py | 36 ++++++++----------- tests/fasta2a/test_applications.py | 2 +- 8 files changed, 49 insertions(+), 71 deletions(-) diff --git a/examples/pydantic_ai_examples/bank_support.py b/examples/pydantic_ai_examples/bank_support.py index 6b6461d46..dff134331 100644 --- a/examples/pydantic_ai_examples/bank_support.py +++ b/examples/pydantic_ai_examples/bank_support.py @@ -33,7 +33,7 @@ async def customer_balance(cls, *, id: int, include_pending: bool) -> float: return 100.00 else: return 42 - #raise ValueError('Customer not found') + # raise ValueError('Customer not found') @dataclass diff --git a/fasta2a/fasta2a/applications.py b/fasta2a/fasta2a/applications.py index a8dca4a46..c3f397353 100644 --- a/fasta2a/fasta2a/applications.py +++ b/fasta2a/fasta2a/applications.py @@ -5,13 +5,13 @@ from contextlib import asynccontextmanager from typing import Any +from sse_starlette import EventSourceResponse from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send -from sse_starlette import EventSourceResponse from .broker import Broker from .schema import ( @@ -128,24 +128,20 @@ async def _agent_run_endpoint(self, request: Request) -> Response: elif a2a_request['method'] == 'message/stream': # Parse the streaming request stream_request = stream_message_request_ta.validate_json(data) - + # Create an async generator wrapper that formats events as JSON-RPC responses async def sse_generator(): request_id = stream_request.get('id') async for event in self.task_manager.stream_message(stream_request): # Serialize event to ensure proper camelCase conversion event_dict = stream_event_ta.dump_python(event, mode='json', by_alias=True) - + # Wrap in JSON-RPC response - jsonrpc_response = { - 'jsonrpc': '2.0', - 'id': request_id, - 'result': event_dict - } - + jsonrpc_response = {'jsonrpc': '2.0', 'id': request_id, 'result': event_dict} + # Convert to JSON string yield json.dumps(jsonrpc_response) - + # Return SSE response return EventSourceResponse(sse_generator()) elif a2a_request['method'] == 'tasks/get': diff --git a/fasta2a/fasta2a/broker.py b/fasta2a/fasta2a/broker.py index ceb316748..aa120a11f 100644 --- a/fasta2a/fasta2a/broker.py +++ b/fasta2a/fasta2a/broker.py @@ -3,10 +3,11 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from contextlib import AsyncExitStack -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Annotated, Any, Generic, Literal, TypeVar import anyio +from anyio.streams.memory import MemoryObjectSendStream from opentelemetry.trace import Span, get_current_span, get_tracer from pydantic import Discriminator from typing_extensions import Self, TypedDict @@ -54,7 +55,7 @@ def receive_task_operations(self) -> AsyncIterator[TaskOperation]: @abstractmethod async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: """Send a streaming event from worker to subscribers. - + This is used by workers to publish status updates, messages, and artifacts during task execution. Events are forwarded to all active subscribers of the given task_id. @@ -64,7 +65,7 @@ async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: @abstractmethod def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]: """Subscribe to streaming events for a specific task. - + Returns an async iterator that yields events published by workers for the given task_id. The iterator completes when a TaskStatusUpdateEvent with final=True is received or the subscription is cancelled. @@ -95,7 +96,7 @@ class InMemoryBroker(Broker): def __init__(self): # Event streams per task_id for pub/sub - self._event_subscribers: dict[str, list[anyio.streams.memory.MemoryObjectSendStream[StreamEvent]]] = {} + self._event_subscribers: dict[str, list[MemoryObjectSendStream[StreamEvent]]] = {} # Lock for thread-safe subscriber management self._subscriber_lock = anyio.Lock() @@ -128,7 +129,7 @@ async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: async with self._subscriber_lock: subscribers = self._event_subscribers.get(task_id, []) # Send to all active subscribers, removing any that are closed - active_subscribers = [] + active_subscribers: list[MemoryObjectSendStream[StreamEvent]] = [] for send_stream in subscribers: try: await send_stream.send(event) @@ -136,7 +137,7 @@ async def send_stream_event(self, task_id: str, event: StreamEvent) -> None: except anyio.ClosedResourceError: # Subscriber disconnected, remove it pass - + # Update subscriber list with only active ones if active_subscribers: self._event_subscribers[task_id] = active_subscribers @@ -148,22 +149,20 @@ async def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]: """Subscribe to events for a specific task.""" # Create a new stream for this subscriber send_stream, receive_stream = anyio.create_memory_object_stream[StreamEvent](max_buffer_size=100) - + # Register the subscriber async with self._subscriber_lock: if task_id not in self._event_subscribers: self._event_subscribers[task_id] = [] self._event_subscribers[task_id].append(send_stream) - + try: # Yield events as they arrive async with receive_stream: async for event in receive_stream: yield event # Check if this is a final event - if (isinstance(event, dict) and - event.get('kind') == 'status-update' and - event.get('final', False)): + if isinstance(event, dict) and event.get('kind') == 'status-update' and event.get('final', False): break finally: # Clean up subscription on exit diff --git a/fasta2a/fasta2a/schema.py b/fasta2a/fasta2a/schema.py index 7a1639934..bbe216548 100644 --- a/fasta2a/fasta2a/schema.py +++ b/fasta2a/fasta2a/schema.py @@ -659,10 +659,7 @@ def is_message(response: Task | Message) -> TypeGuard[Message]: # Streaming support - unified event type for broker communication # Use discriminator to properly identify event types -StreamEvent = Annotated[ - Union[Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], - Discriminator('kind') -] +StreamEvent = Annotated[Union[Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent], Discriminator('kind')] """Events that can be streamed through the broker for message/stream support.""" stream_event_ta: TypeAdapter[StreamEvent] = TypeAdapter(StreamEvent) diff --git a/fasta2a/fasta2a/storage.py b/fasta2a/fasta2a/storage.py index 2aba5d206..815ad977f 100644 --- a/fasta2a/fasta2a/storage.py +++ b/fasta2a/fasta2a/storage.py @@ -40,17 +40,13 @@ async def update_task( @abstractmethod async def add_message(self, message: Message) -> None: """Add a message to the history for both its task and context. - + This should be called for messages created during task execution, not for the initial message (which is handled by submit_task). """ @abstractmethod - async def get_context_history( - self, - context_id: str, - history_length: int | None = None - ) -> list[Message]: + async def get_context_history(self, context_id: str, history_length: int | None = None) -> list[Message]: """Get all messages across tasks in a context.""" @@ -89,18 +85,18 @@ async def submit_task( # Add IDs to the message message['task_id'] = task_id message['context_id'] = context_id - + task_status = TaskStatus(state='submitted', timestamp=datetime.now().isoformat()) task = Task(id=task_id, context_id=context_id, kind='task', status=task_status, history=[message]) if metadata is not None: task['metadata'] = metadata self.tasks[task_id] = task - + # Add message to context storage directly (not via add_message to avoid duplication) if context_id not in self.context_messages: self.context_messages[context_id] = [] self.context_messages[context_id].append(message) - + return task async def update_task( @@ -127,18 +123,14 @@ async def add_message(self, message: Message) -> None: if 'history' not in task: task['history'] = [] task['history'].append(message) - + if 'context_id' in message and message['context_id']: context_id = message['context_id'] if context_id not in self.context_messages: self.context_messages[context_id] = [] self.context_messages[context_id].append(message) - async def get_context_history( - self, - context_id: str, - history_length: int | None = None - ) -> list[Message]: + async def get_context_history(self, context_id: str, history_length: int | None = None) -> list[Message]: """Get all messages across tasks in a context.""" messages = self.context_messages.get(context_id, []) if history_length: diff --git a/fasta2a/fasta2a/task_manager.py b/fasta2a/fasta2a/task_manager.py index 01cd5503f..98bbf7eea 100644 --- a/fasta2a/fasta2a/task_manager.py +++ b/fasta2a/fasta2a/task_manager.py @@ -121,10 +121,10 @@ async def send_message(self, request: SendMessageRequest) -> SendMessageResponse request_id = request['id'] task_id = str(uuid.uuid4()) message = request['params']['message'] - + # Use provided context_id or create new one context_id = message.get('context_id') or str(uuid.uuid4()) - + metadata = request['params'].get('metadata') config = request['params'].get('configuration', {}) @@ -184,22 +184,22 @@ async def stream_message( # Extract request parameters task_id = str(uuid.uuid4()) message = request['params']['message'] - + # Use provided context_id or create new one context_id = message.get('context_id') or str(uuid.uuid4()) - + metadata = request['params'].get('metadata') config = request['params'].get('configuration', {}) # Create a new task task = await self.storage.submit_task(task_id, context_id, message, metadata) - + # Yield the initial task yield task - + # Subscribe to events BEFORE starting execution to avoid race conditions event_stream = self.broker.subscribe_to_stream(task_id) - + # Prepare params for broker broker_params: TaskSendParams = { 'id': task_id, @@ -214,7 +214,7 @@ async def stream_message( # Start task execution asynchronously await self.broker.run_task(broker_params) - + # Stream events from broker - they're already in A2A format! async for event in event_stream: yield event diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index bd2990642..b89bb8c32 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -40,7 +40,6 @@ Provider, Skill, Task, - TaskArtifactUpdateEvent, TaskIdParams, TaskSendParams, TaskStatusUpdateEvent, @@ -121,7 +120,7 @@ async def run_task(self, params: TaskSendParams) -> None: task = await self.storage.load_task(params['id'], history_length=params.get('history_length')) assert task is not None, f'Task {params["id"]} not found' assert 'context_id' in task, 'Task must have a context_id' - + task_id = task['id'] context_id = task['context_id'] @@ -130,12 +129,8 @@ async def run_task(self, params: TaskSendParams) -> None: await self.broker.send_stream_event( task_id, TaskStatusUpdateEvent( - task_id=task_id, - context_id=context_id, - kind='status-update', - status={'state': 'working'}, - final=False - ) + task_id=task_id, context_id=context_id, kind='status-update', status={'state': 'working'}, final=False + ), ) # TODO(Marcelo): We need to have a way to communicate when the task is set to `input-required`. Maybe @@ -143,8 +138,7 @@ async def run_task(self, params: TaskSendParams) -> None: try: context_history = await self.storage.get_context_history( - context_id, - history_length=params.get('history_length') + context_id, history_length=params.get('history_length') ) message_history = self.build_message_history(task_history=context_history) @@ -165,15 +159,15 @@ async def run_task(self, params: TaskSendParams) -> None: kind='message', message_id=str(uuid.uuid4()), task_id=task_id, - context_id=context_id + context_id=context_id, ) - + # Add the agent's response to storage await self.storage.add_message(agent_message) - + # Send the agent's response as a message await self.broker.send_stream_event(task_id, agent_message) - + # Update storage and send completion event (no artifacts) await self.storage.update_task(task_id, state='completed') await self.broker.send_stream_event( @@ -183,10 +177,10 @@ async def run_task(self, params: TaskSendParams) -> None: context_id=context_id, kind='status-update', status={'state': 'completed'}, - final=True - ) + final=True, + ), ) - except Exception as e: + except Exception: # Update storage and send failure event await self.storage.update_task(task_id, state='failed') await self.broker.send_stream_event( @@ -195,9 +189,9 @@ async def run_task(self, params: TaskSendParams) -> None: task_id=task_id, context_id=context_id, kind='status-update', - status={'state': 'failed', 'message': str(e)}, - final=True - ) + status={'state': 'failed'}, + final=True, + ), ) raise @@ -207,7 +201,7 @@ async def cancel_task(self, params: TaskIdParams) -> None: def build_artifacts(self, result: Any) -> list[Artifact]: # TODO(Marcelo): We need to send the json schema of the result on the metadata of the message. artifact_id = str(uuid.uuid4()) - return [Artifact(artifactId=artifact_id, name='result', parts=[A2ATextPart(kind='text', text=str(result))])] + return [Artifact(artifact_id=artifact_id, name='result', parts=[A2ATextPart(kind='text', text=str(result))])] def build_message_history(self, task_history: list[Message]) -> list[ModelMessage]: model_messages: list[ModelMessage] = [] diff --git a/tests/fasta2a/test_applications.py b/tests/fasta2a/test_applications.py index 3b3aa437d..accf2cd85 100644 --- a/tests/fasta2a/test_applications.py +++ b/tests/fasta2a/test_applications.py @@ -39,7 +39,7 @@ async def test_agent_card(): 'skills': [], 'defaultInputModes': ['application/json'], 'defaultOutputModes': ['application/json'], - 'capabilities': {'streaming': False, 'pushNotifications': False, 'stateTransitionHistory': False}, + 'capabilities': {'streaming': True, 'pushNotifications': False, 'stateTransitionHistory': False}, 'authentication': {'schemes': []}, } ) From 0021884b5906364da9dd744bbf45ec098afcefef Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 1 Jul 2025 01:58:06 +0000 Subject: [PATCH 14/15] fix: add exception logging to fasta2a worker and fix type errors - Add proper exception logging in Worker._handle_task_operation to aid debugging - Fix pyright type errors in _a2a.py: - Properly extract text content from ModelRequest parts - Add type annotation for metadata dict - Fix current_message type inference - Rename Worker.build_message_history parameter from 'task_history' to 'history' for clarity - Remove exception re-raise in worker to allow graceful task failure without crashing the worker --- fasta2a/fasta2a/worker.py | 12 +- pydantic_ai_slim/pydantic_ai/_a2a.py | 217 +++++++++++++++++++++++---- tests/test_a2a.py | 78 ++++++---- 3 files changed, 240 insertions(+), 67 deletions(-) diff --git a/fasta2a/fasta2a/worker.py b/fasta2a/fasta2a/worker.py index 9bbde6b25..5071ed551 100644 --- a/fasta2a/fasta2a/worker.py +++ b/fasta2a/fasta2a/worker.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import logging from abc import ABC, abstractmethod from collections.abc import AsyncIterator from contextlib import asynccontextmanager @@ -16,6 +17,7 @@ from .storage import Storage tracer = get_tracer(__name__) +logger = logging.getLogger(__name__) @dataclass @@ -52,8 +54,12 @@ async def _handle_task_operation(self, task_operation: TaskOperation) -> None: await self.cancel_task(task_operation['params']) else: assert_never(task_operation) - except Exception: - await self.storage.update_task(task_operation['params']['id'], state='failed') + except Exception as e: + task_id = task_operation['params']['id'] + logger.exception( + f'Error handling {task_operation["operation"]} operation for task {task_id}: {type(e).__name__}: {e}' + ) + await self.storage.update_task(task_id, state='failed') @abstractmethod async def run_task(self, params: TaskSendParams) -> None: ... @@ -62,7 +68,7 @@ async def run_task(self, params: TaskSendParams) -> None: ... async def cancel_task(self, params: TaskIdParams) -> None: ... @abstractmethod - def build_message_history(self, task_history: list[Message]) -> list[Any]: ... + def build_message_history(self, history: list[Message]) -> list[Any]: ... @abstractmethod def build_artifacts(self, result: Any) -> list[Artifact]: ... diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index b89bb8c32..66b27165e 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -3,10 +3,11 @@ import uuid from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager -from dataclasses import dataclass +from dataclasses import asdict, dataclass, is_dataclass from functools import partial -from typing import Any, Callable, Generic +from typing import Any, Callable, Generic, cast +from pydantic import TypeAdapter from typing_extensions import assert_never from pydantic_ai.messages import ( @@ -20,6 +21,8 @@ ModelResponse, ModelResponsePart, TextPart, + ThinkingPart, + ToolCallPart, UserPromptPart, VideoUrl, ) @@ -35,11 +38,13 @@ from fasta2a.broker import Broker, InMemoryBroker from fasta2a.schema import ( Artifact, + DataPart, Message, Part, Provider, Skill, Task, + TaskArtifactUpdateEvent, TaskIdParams, TaskSendParams, TaskStatusUpdateEvent, @@ -140,36 +145,123 @@ async def run_task(self, params: TaskSendParams) -> None: context_history = await self.storage.get_context_history( context_id, history_length=params.get('history_length') ) - message_history = self.build_message_history(task_history=context_history) + message_history = self.build_message_history(context_history) + assert len(message_history) and isinstance(message_history[-1], ModelRequest) + # Extract text content from the last message's parts + text_parts: list[str] = [] + for part in message_history[-1].parts: + if hasattr(part, 'content'): + if isinstance(part.content, str): + text_parts.append(part.content) + current_message: str = ''.join(text_parts) + message_history = message_history[:-1] # Initialize dependencies if factory provided - if self.deps_factory is not None: - deps = self.deps_factory(task) - result = await self.agent.run(message_history=message_history, deps=deps) - else: - # No deps_factory provided - this only works if the agent accepts None for deps - # (e.g., Agent[None, ...] or Agent[Optional[...], ...]) - # If the agent requires deps, this will raise TypeError at runtime - result = await self.agent.run(message_history=message_history) # type: ignore[call-arg] - - # Create a message from the agent's response - agent_message = Message( - role='agent', - parts=[A2ATextPart(kind='text', text=str(result.output))], - kind='message', - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) + deps: AgentDepsT = cast(AgentDepsT, self.deps_factory(task) if self.deps_factory is not None else None) + + async with self.agent.iter(current_message, message_history=message_history, deps=deps) as run: + message_id = str(uuid.uuid4()) + node = run.next_node + while not self.agent.is_end_node(node): + # Check if this node has a model response + if hasattr(node, 'model_response'): + model_response = getattr(node, 'model_response') + # Convert model response parts to A2A parts + a2a_parts = self._response_parts_to_a2a(model_response.parts) + + if a2a_parts: + # Send incremental message event + incremental_message = Message( + role='agent', + parts=a2a_parts, + kind='message', + message_id=message_id, + task_id=task_id, + context_id=context_id, + ) + await self.storage.add_message(incremental_message) + await self.broker.send_stream_event(task_id, incremental_message) + + # Move to next node + current = node + node = await run.next(current) + + # Run finished - get the final result + if run.result is None: + raise RuntimeError('Agent finished without producing a result') + + artifacts: list[Artifact] = [] + if isinstance(run.result.output, str): + final_message = Message( + role='agent', + parts=[A2ATextPart(kind='text', text=run.result.output)], + kind='message', + message_id=message_id, + task_id=task_id, + context_id=context_id, + ) + await self.storage.add_message(final_message) + await self.broker.send_stream_event(task_id, final_message) + else: + # Create artifact for non-string outputs + artifact_id = str(uuid.uuid4()) + output: OutputDataT = run.result.output + metadata: dict[str, Any] = {'type': type(output).__name__} + + try: + # Create TypeAdapter for the output type + output_type = type(output) + type_adapter: TypeAdapter[OutputDataT] = TypeAdapter(output_type) + + # Serialize to Python dict/list for DataPart + data = type_adapter.dump_python(output, mode='json') + + # Get JSON schema if possible + try: + json_schema = type_adapter.json_schema() + metadata['json_schema'] = json_schema + if hasattr(output, '__class__'): + metadata['class_name'] = output.__class__.__name__ + except Exception: + raise + # Some types may not support JSON schema generation + pass + + except Exception: + raise + # Fallback for types that TypeAdapter can't handle + if is_dataclass(output): + data = asdict(output) # type: ignore[arg-type] + metadata['type'] = 'dataclass' + metadata['class_name'] = output.__class__.__name__ + else: + # Last resort - convert to string + data = str(output) + metadata['type'] = 'string_fallback' - # Add the agent's response to storage - await self.storage.add_message(agent_message) + # Create artifact with DataPart + artifact = Artifact( + artifact_id=artifact_id, + name='result', + parts=[DataPart(kind='data', data=data)], + metadata=metadata, + ) + artifacts.append(artifact) - # Send the agent's response as a message - await self.broker.send_stream_event(task_id, agent_message) + # Send artifact update event + await self.broker.send_stream_event( + task_id, + TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + kind='artifact-update', + artifact=artifact, + last_chunk=True, + ), + ) - # Update storage and send completion event (no artifacts) - await self.storage.update_task(task_id, state='completed') + # Update storage and send completion event + await self.storage.update_task(task_id, state='completed', artifacts=artifacts if artifacts else None) await self.broker.send_stream_event( task_id, TaskStatusUpdateEvent( @@ -203,16 +295,28 @@ def build_artifacts(self, result: Any) -> list[Artifact]: artifact_id = str(uuid.uuid4()) return [Artifact(artifact_id=artifact_id, name='result', parts=[A2ATextPart(kind='text', text=str(result))])] - def build_message_history(self, task_history: list[Message]) -> list[ModelMessage]: + def build_message_history(self, history: list[Message]) -> list[ModelMessage]: model_messages: list[ModelMessage] = [] - for message in task_history: + for message in history: if message['role'] == 'user': - model_messages.append(ModelRequest(parts=self._map_request_parts(message['parts']))) + model_messages.append(ModelRequest(parts=self._request_parts_from_a2a(message['parts']))) else: - model_messages.append(ModelResponse(parts=self._map_response_parts(message['parts']))) + model_messages.append(ModelResponse(parts=self._response_parts_from_a2a(message['parts']))) + return model_messages - def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: + def _request_parts_from_a2a(self, parts: list[Part]) -> list[ModelRequestPart]: + """Convert A2A Part objects to pydantic-ai ModelRequestPart objects. + + This handles the conversion from A2A protocol parts (text, file, data) to + pydantic-ai's internal request parts (UserPromptPart with various content types). + + Args: + parts: List of A2A Part objects from incoming messages + + Returns: + List of ModelRequestPart objects for the pydantic-ai agent + """ model_parts: list[ModelRequestPart] = [] for part in parts: if part['kind'] == 'text': @@ -245,7 +349,19 @@ def _map_request_parts(self, parts: list[Part]) -> list[ModelRequestPart]: assert_never(part) return model_parts - def _map_response_parts(self, parts: list[Part]) -> list[ModelResponsePart]: + def _response_parts_from_a2a(self, parts: list[Part]) -> list[ModelResponsePart]: + """Convert A2A Part objects to pydantic-ai ModelResponsePart objects. + + This handles the conversion from A2A protocol parts (text, file, data) to + pydantic-ai's internal response parts. Currently only supports text parts + as agent responses in A2A are expected to be text-based. + + Args: + parts: List of A2A Part objects from stored agent messages + + Returns: + List of ModelResponsePart objects for message history + """ model_parts: list[ModelResponsePart] = [] for part in parts: if part['kind'] == 'text': @@ -257,3 +373,38 @@ def _map_response_parts(self, parts: list[Part]) -> list[ModelResponsePart]: else: # pragma: no cover assert_never(part) return model_parts + + def _response_parts_to_a2a(self, parts: list[ModelResponsePart]) -> list[Part]: + """Convert pydantic-ai ModelResponsePart objects to A2A Part objects. + + This handles the conversion from pydantic-ai's internal response parts to + A2A protocol parts. Different part types are handled as follows: + - TextPart: Converted directly to A2A TextPart + - ThinkingPart: Converted to TextPart with metadata indicating it's thinking + - ToolCallPart: Skipped (internal to agent execution) + + Args: + parts: List of ModelResponsePart objects from agent response + + Returns: + List of A2A Part objects suitable for sending via A2A protocol + """ + a2a_parts: list[Part] = [] + for part in parts: + if isinstance(part, TextPart): + if part.content: # Only add non-empty text + a2a_parts.append(A2ATextPart(kind='text', text=part.content)) + elif isinstance(part, ThinkingPart): + if part.content: # Only add non-empty thinking + # Convert thinking to text with metadata + a2a_parts.append( + A2ATextPart( + kind='text', + text=part.content, + metadata={'type': 'thinking', 'thinking_id': part.id, 'signature': part.signature}, + ) + ) + elif isinstance(part, ToolCallPart): + # Skip tool calls - they're internal to agent execution + pass + return a2a_parts diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 6a5ec9976..ed65c6f4c 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -101,15 +101,19 @@ async def test_a2a_simple(): 'kind': 'message', 'context_id': IsStr(), 'task_id': IsStr(), - }, + } + ], + 'artifacts': [ { - 'role': 'agent', - 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], - 'kind': 'message', - 'message_id': IsStr(), - 'context_id': IsStr(), - 'task_id': IsStr(), - }, + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'metadata': { + 'type': 'tuple', + 'json_schema': {'items': {}, 'type': 'array'}, + 'class_name': 'tuple', + }, + } ], }, } @@ -199,15 +203,19 @@ async def test_a2a_file_message_with_file(): 'kind': 'message', 'context_id': IsStr(), 'task_id': IsStr(), - }, + } + ], + 'artifacts': [ { - 'role': 'agent', - 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], - 'kind': 'message', - 'message_id': IsStr(), - 'context_id': IsStr(), - 'task_id': IsStr(), - }, + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'metadata': { + 'type': 'tuple', + 'json_schema': {'items': {}, 'type': 'array'}, + 'class_name': 'tuple', + }, + } ], }, } @@ -275,15 +283,19 @@ async def test_a2a_file_message_with_file_content(): 'kind': 'message', 'context_id': IsStr(), 'task_id': IsStr(), - }, + } + ], + 'artifacts': [ { - 'role': 'agent', - 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], - 'kind': 'message', - 'message_id': IsStr(), - 'context_id': IsStr(), - 'task_id': IsStr(), - }, + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'metadata': { + 'type': 'tuple', + 'json_schema': {'items': {}, 'type': 'array'}, + 'class_name': 'tuple', + }, + } ], }, } @@ -438,14 +450,18 @@ async def test_a2a_multiple_messages(): 'task_id': IsStr(), }, {'role': 'agent', 'parts': [{'kind': 'text', 'text': 'Whats up?'}], 'kind': 'message'}, + ], + 'artifacts': [ { - 'role': 'agent', - 'parts': [{'kind': 'text', 'text': "('foo', 'bar')"}], - 'kind': 'message', - 'message_id': IsStr(), - 'context_id': IsStr(), - 'task_id': IsStr(), - }, + 'artifact_id': IsStr(), + 'name': 'result', + 'parts': [{'kind': 'data', 'data': ['foo', 'bar']}], + 'metadata': { + 'type': 'tuple', + 'json_schema': {'items': {}, 'type': 'array'}, + 'class_name': 'tuple', + }, + } ], }, } From 17af717eb2fc0d73ec7409cdfe528001ffa5df0b Mon Sep 17 00:00:00 2001 From: Robert Porter Date: Tue, 1 Jul 2025 02:58:10 +0000 Subject: [PATCH 15/15] Fix lint; Rename tests/fasta2a to tests/test_fasta2a to avoid import conflict --- docs/examples/bank-support-a2a.md | 2 +- tests/{fasta2a => test_fasta2a}/__init__.py | 0 tests/{fasta2a => test_fasta2a}/test_applications.py | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename tests/{fasta2a => test_fasta2a}/__init__.py (100%) rename tests/{fasta2a => test_fasta2a}/test_applications.py (100%) diff --git a/docs/examples/bank-support-a2a.md b/docs/examples/bank-support-a2a.md index 46b81d6e9..4211e9b46 100644 --- a/docs/examples/bank-support-a2a.md +++ b/docs/examples/bank-support-a2a.md @@ -36,4 +36,4 @@ curl -X POST http://localhost:8000/tasks.send \ ```python {title="bank_support_a2a.py"} #! examples/pydantic_ai_examples/bank_support_a2a.py -``` \ No newline at end of file +``` diff --git a/tests/fasta2a/__init__.py b/tests/test_fasta2a/__init__.py similarity index 100% rename from tests/fasta2a/__init__.py rename to tests/test_fasta2a/__init__.py diff --git a/tests/fasta2a/test_applications.py b/tests/test_fasta2a/test_applications.py similarity index 100% rename from tests/fasta2a/test_applications.py rename to tests/test_fasta2a/test_applications.py