|
1 | 1 | from __future__ import annotations as _annotations
|
2 | 2 |
|
3 |
| -from collections.abc import AsyncIterator, Sequence |
4 |
| -from contextlib import asynccontextmanager |
5 | 3 | from typing import Any
|
6 | 4 |
|
7 |
| -from starlette.applications import Starlette |
| 5 | +import httpx |
| 6 | +from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication |
| 7 | +from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler |
| 8 | +from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier |
| 9 | +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore |
| 10 | +from a2a.types import AgentCapabilities, AgentCard, AgentProvider |
8 | 11 | from starlette.middleware import Middleware
|
9 |
| -from starlette.requests import Request |
10 |
| -from starlette.responses import Response |
11 | 12 | from starlette.routing import Route
|
12 | 13 | from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send
|
13 | 14 |
|
14 |
| -from .broker import Broker |
15 |
| -from .schema import ( |
16 |
| - AgentCard, |
17 |
| - Authentication, |
18 |
| - Capabilities, |
19 |
| - Provider, |
20 |
| - Skill, |
21 |
| - a2a_request_ta, |
22 |
| - a2a_response_ta, |
23 |
| - agent_card_ta, |
24 |
| -) |
| 15 | +from .schema import Skill |
25 | 16 | from .storage import Storage
|
26 |
| -from .task_manager import TaskManager |
| 17 | +from .worker import Worker |
27 | 18 |
|
28 | 19 |
|
29 |
| -class FastA2A(Starlette): |
30 |
| - """The main class for the FastA2A library.""" |
| 20 | +class FastA2A: |
| 21 | + """ |
| 22 | + The main class for the FastA2A library. It provides a simple way to create |
| 23 | + an A2A server by wrapping the Google A2A SDK. |
| 24 | + """ |
31 | 25 |
|
32 | 26 | def __init__(
|
33 | 27 | self,
|
34 | 28 | *,
|
35 |
| - storage: Storage, |
36 |
| - broker: Broker, |
| 29 | + worker: Worker, |
| 30 | + storage: Storage | None = None, |
37 | 31 | # Agent card
|
38 |
| - name: str | None = None, |
39 |
| - url: str = 'http://localhost:8000', |
40 |
| - version: str = '1.0.0', |
| 32 | + name: str = "Agent", |
| 33 | + url: str = "http://localhost:8000", |
| 34 | + version: str = "1.0.0", |
41 | 35 | description: str | None = None,
|
42 |
| - provider: Provider | None = None, |
| 36 | + provider: AgentProvider | None = None, |
43 | 37 | skills: list[Skill] | None = None,
|
44 | 38 | # Starlette
|
45 | 39 | debug: bool = False,
|
46 |
| - routes: Sequence[Route] | None = None, |
47 |
| - middleware: Sequence[Middleware] | None = None, |
| 40 | + routes: list[Route] | None = None, |
| 41 | + middleware: list[Middleware] | None = None, |
48 | 42 | exception_handlers: dict[Any, ExceptionHandler] | None = None,
|
49 |
| - lifespan: Lifespan[FastA2A] | None = None, |
| 43 | + lifespan: Lifespan | None = None, |
50 | 44 | ):
|
51 |
| - if lifespan is None: |
52 |
| - lifespan = _default_lifespan |
| 45 | + """ |
| 46 | + Initializes the FastA2A application. |
| 47 | +
|
| 48 | + Args: |
| 49 | + worker: An implementation of `fasta2a.Worker` (which is an `a2a.server.agent_execution.AgentExecutor`). |
| 50 | + storage: An implementation of `fasta2a.Storage` (which is an `a2a.server.tasks.TaskStore`). |
| 51 | + Defaults to `InMemoryTaskStore`. |
| 52 | + name: The human-readable name of the agent. |
| 53 | + url: The URL where the agent is hosted. |
| 54 | + version: The version of the agent. |
| 55 | + description: A human-readable description of the agent. |
| 56 | + provider: The service provider of the agent. |
| 57 | + skills: A list of skills the agent can perform. |
| 58 | + debug: Starlette's debug flag. |
| 59 | + routes: A list of additional Starlette routes. |
| 60 | + middleware: A list of Starlette middleware. |
| 61 | + exception_handlers: A dictionary of Starlette exception handlers. |
| 62 | + lifespan: A Starlette lifespan context manager. |
| 63 | + """ |
| 64 | + self.agent_card = AgentCard( |
| 65 | + name=name, |
| 66 | + url=url, |
| 67 | + version=version, |
| 68 | + description=description or "A FastA2A Agent", |
| 69 | + provider=provider, |
| 70 | + skills=skills or [], |
| 71 | + capabilities=AgentCapabilities( |
| 72 | + streaming=True, pushNotifications=True, stateTransitionHistory=True |
| 73 | + ), |
| 74 | + defaultInputModes=["application/json"], |
| 75 | + defaultOutputModes=["application/json"], |
| 76 | + securitySchemes={}, |
| 77 | + ) |
53 | 78 |
|
54 |
| - super().__init__( |
| 79 | + self.storage = storage or InMemoryTaskStore() |
| 80 | + self.worker = worker |
| 81 | + |
| 82 | + # The SDK's DefaultRequestHandler uses httpx to send push notifications |
| 83 | + http_client = httpx.AsyncClient() |
| 84 | + push_notifier = InMemoryPushNotifier(httpx_client) |
| 85 | + |
| 86 | + request_handler = DefaultRequestHandler( |
| 87 | + agent_executor=self.worker, |
| 88 | + task_store=self.storage, |
| 89 | + push_notifier=push_notifier, |
| 90 | + ) |
| 91 | + |
| 92 | + a2a_app = A2AStarletteApplication( |
| 93 | + agent_card=self.agent_card, |
| 94 | + http_handler=request_handler, |
| 95 | + ) |
| 96 | + |
| 97 | + self.app = a2a_app.build( |
55 | 98 | debug=debug,
|
56 | 99 | routes=routes,
|
57 | 100 | middleware=middleware,
|
58 | 101 | exception_handlers=exception_handlers,
|
59 | 102 | lifespan=lifespan,
|
60 | 103 | )
|
61 | 104 |
|
62 |
| - self.name = name or 'Agent' |
63 |
| - self.url = url |
64 |
| - self.version = version |
65 |
| - self.description = description |
66 |
| - self.provider = provider |
67 |
| - self.skills = skills or [] |
68 |
| - # NOTE: For now, I don't think there's any reason to support any other input/output modes. |
69 |
| - self.default_input_modes = ['application/json'] |
70 |
| - self.default_output_modes = ['application/json'] |
71 |
| - |
72 |
| - self.task_manager = TaskManager(broker=broker, storage=storage) |
73 |
| - |
74 |
| - # Setup |
75 |
| - self._agent_card_json_schema: bytes | None = None |
76 |
| - self.router.add_route('/.well-known/agent.json', self._agent_card_endpoint, methods=['HEAD', 'GET', 'OPTIONS']) |
77 |
| - self.router.add_route('/', self._agent_run_endpoint, methods=['POST']) |
78 |
| - |
79 | 105 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
80 |
| - if scope['type'] == 'http' and not self.task_manager.is_running: |
81 |
| - raise RuntimeError('TaskManager was not properly initialized.') |
82 |
| - await super().__call__(scope, receive, send) |
83 |
| - |
84 |
| - async def _agent_card_endpoint(self, request: Request) -> Response: |
85 |
| - if self._agent_card_json_schema is None: |
86 |
| - agent_card = AgentCard( |
87 |
| - name=self.name, |
88 |
| - url=self.url, |
89 |
| - version=self.version, |
90 |
| - skills=self.skills, |
91 |
| - default_input_modes=self.default_input_modes, |
92 |
| - default_output_modes=self.default_output_modes, |
93 |
| - capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False), |
94 |
| - authentication=Authentication(schemes=[]), |
95 |
| - ) |
96 |
| - if self.description is not None: |
97 |
| - agent_card['description'] = self.description |
98 |
| - if self.provider is not None: |
99 |
| - agent_card['provider'] = self.provider |
100 |
| - self._agent_card_json_schema = agent_card_ta.dump_json(agent_card, by_alias=True) |
101 |
| - return Response(content=self._agent_card_json_schema, media_type='application/json') |
102 |
| - |
103 |
| - async def _agent_run_endpoint(self, request: Request) -> Response: |
104 |
| - """This is the main endpoint for the A2A server. |
105 |
| -
|
106 |
| - Although the specification allows freedom of choice and implementation, I'm pretty sure about some decisions. |
107 |
| -
|
108 |
| - 1. The server will always either send a "submitted" or a "failed" on `tasks/send`. |
109 |
| - Never a "completed" on the first message. |
110 |
| - 2. There are three possible ends for the task: |
111 |
| - 2.1. The task was "completed" successfully. |
112 |
| - 2.2. The task was "canceled". |
113 |
| - 2.3. The task "failed". |
114 |
| - 3. The server will send a "working" on the first chunk on `tasks/pushNotification/get`. |
115 |
| - """ |
116 |
| - data = await request.body() |
117 |
| - a2a_request = a2a_request_ta.validate_json(data) |
118 |
| - |
119 |
| - if a2a_request['method'] == 'tasks/send': |
120 |
| - jsonrpc_response = await self.task_manager.send_task(a2a_request) |
121 |
| - elif a2a_request['method'] == 'tasks/get': |
122 |
| - jsonrpc_response = await self.task_manager.get_task(a2a_request) |
123 |
| - elif a2a_request['method'] == 'tasks/cancel': |
124 |
| - jsonrpc_response = await self.task_manager.cancel_task(a2a_request) |
125 |
| - else: |
126 |
| - raise NotImplementedError(f'Method {a2a_request["method"]} not implemented.') |
127 |
| - return Response( |
128 |
| - content=a2a_response_ta.dump_json(jsonrpc_response, by_alias=True), media_type='application/json' |
129 |
| - ) |
130 |
| - |
131 |
| - |
132 |
| -@asynccontextmanager |
133 |
| -async def _default_lifespan(app: FastA2A) -> AsyncIterator[None]: |
134 |
| - async with app.task_manager: |
135 |
| - yield |
| 106 | + await self.app(scope, receive, send) |
0 commit comments