-
Notifications
You must be signed in to change notification settings - Fork 167
feat: Adding stand-alone support for RESTful API serving #297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
03a0fbe
661db10
f7bbfc4
2482059
45b148e
97e9d1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,5 +70,6 @@ sse | |
tagwords | ||
taskupdate | ||
testuuid | ||
Tful | ||
typeerror | ||
vulnz |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""A2A REST Applications.""" | ||
|
||
from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication | ||
from a2a.server.apps.rest.rest_app import RESTApplication | ||
|
||
__all__ = [ | ||
'A2ARESTFastAPIApplication', | ||
'RESTApplication', | ||
] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,83 @@ | ||||||
import logging | ||||||
|
||||||
from typing import Any | ||||||
|
||||||
from fastapi import FastAPI, Request, Response, APIRouter | ||||||
|
||||||
from a2a.server.apps.jsonrpc.jsonrpc_app import ( | ||||||
CallContextBuilder, | ||||||
) | ||||||
from a2a.server.apps.rest.rest_app import ( | ||||||
RESTApplication, | ||||||
) | ||||||
from a2a.server.request_handlers.request_handler import RequestHandler | ||||||
from a2a.types import AgentCard | ||||||
|
||||||
|
||||||
logger = logging.getLogger(__name__) | ||||||
|
||||||
|
||||||
class A2ARESTFastAPIApplication: | ||||||
"""A FastAPI application implementing the A2A protocol server REST endpoints. | ||||||
Handles incoming REST requests, routes them to the appropriate | ||||||
handler methods, and manages response generation including Server-Sent Events | ||||||
(SSE). | ||||||
""" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
agent_card: AgentCard, | ||||||
http_handler: RequestHandler, | ||||||
context_builder: CallContextBuilder | None = None, | ||||||
): | ||||||
"""Initializes the A2ARESTFastAPIApplication. | ||||||
Args: | ||||||
agent_card: The AgentCard describing the agent's capabilities. | ||||||
http_handler: The handler instance responsible for processing A2A | ||||||
requests via http. | ||||||
extended_agent_card: An optional, distinct AgentCard to be served | ||||||
at the authenticated extended card endpoint. | ||||||
context_builder: The CallContextBuilder used to construct the | ||||||
ServerCallContext passed to the http_handler. If None, no | ||||||
ServerCallContext is passed. | ||||||
""" | ||||||
self._handler = RESTApplication( | ||||||
agent_card=agent_card, | ||||||
http_handler=http_handler, | ||||||
context_builder=context_builder, | ||||||
) | ||||||
|
||||||
def build( | ||||||
self, | ||||||
agent_card_url: str = '/.well-known/agent.json', | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Also add import: from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH |
||||||
rpc_url: str = '', | ||||||
**kwargs: Any, | ||||||
) -> FastAPI: | ||||||
"""Builds and returns the FastAPI application instance. | ||||||
Args: | ||||||
agent_card_url: The URL for the agent card endpoint. | ||||||
rpc_url: The URL for the A2A JSON-RPC endpoint. | ||||||
extended_agent_card_url: The URL for the authenticated extended agent card endpoint. | ||||||
**kwargs: Additional keyword arguments to pass to the FastAPI constructor. | ||||||
Returns: | ||||||
A configured FastAPI application instance. | ||||||
""" | ||||||
app = FastAPI(**kwargs) | ||||||
router = APIRouter() | ||||||
for route, callback in self._handler.routes().items(): | ||||||
router.add_api_route( | ||||||
f'{rpc_url}{route[0]}', | ||||||
callback, | ||||||
methods=[route[1]] | ||||||
) | ||||||
|
||||||
@router.get(f'{rpc_url}{agent_card_url}') | ||||||
async def get_agent_card(request: Request) -> Response: | ||||||
return await self._handler._handle_get_agent_card(request) | ||||||
|
||||||
app.include_router(router) | ||||||
return app |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
import contextlib | ||
import json | ||
import logging | ||
import traceback | ||
import functools | ||
|
||
from abc import ABC, abstractmethod | ||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable | ||
from typing import Any, Tuple, Callable | ||
from fastapi import FastAPI | ||
from pydantic import BaseModel, ValidationError | ||
|
||
from sse_starlette.sse import EventSourceResponse | ||
from starlette.applications import Starlette | ||
from starlette.authentication import BaseUser | ||
from starlette.requests import Request | ||
from starlette.responses import JSONResponse, Response | ||
|
||
from a2a.auth.user import UnauthenticatedUser | ||
from a2a.auth.user import User as A2AUser | ||
from a2a.server.context import ServerCallContext | ||
from a2a.server.request_handlers.rest_handler import ( | ||
RESTHandler, | ||
) | ||
from a2a.server.request_handlers.request_handler import RequestHandler | ||
from a2a.types import ( | ||
A2AError, | ||
AgentCard, | ||
JSONParseError, | ||
UnsupportedOperationError, | ||
InternalError, | ||
InvalidRequestError, | ||
) | ||
from a2a.utils.errors import MethodNotImplementedError | ||
from a2a.server.apps.jsonrpc import ( | ||
CallContextBuilder, | ||
StarletteUserProxy, | ||
DefaultCallContextBuilder | ||
) | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class RESTApplication: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be called |
||
"""Base class for A2A REST applications. | ||
Defines REST requests processors and the routes to attach them too, as well as | ||
manages response generation including Server-Sent Events (SSE). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
agent_card: AgentCard, | ||
http_handler: RequestHandler, | ||
context_builder: CallContextBuilder | None = None, | ||
): | ||
"""Initializes the RESTApplication. | ||
Args: | ||
agent_card: The AgentCard describing the agent's capabilities. | ||
http_handler: The handler instance responsible for processing A2A | ||
requests via http. | ||
context_builder: The CallContextBuilder used to construct the | ||
ServerCallContext passed to the http_handler. If None, no | ||
ServerCallContext is passed. | ||
""" | ||
self.agent_card = agent_card | ||
self.handler = RESTHandler( | ||
agent_card=agent_card, request_handler=http_handler | ||
) | ||
self._context_builder = context_builder or DefaultCallContextBuilder() | ||
|
||
def _generate_error_response(self, error) -> JSONResponse: | ||
"""Creates a JSONResponse for a errors. | ||
Logs the error based on its type. | ||
Args: | ||
error: The Error object. | ||
Returns: | ||
A `JSONResponse` object formatted as a JSON error response. | ||
""" | ||
log_level = ( | ||
logging.ERROR | ||
if isinstance(error, InternalError) | ||
else logging.WARNING | ||
) | ||
logger.log( | ||
log_level, | ||
'Request Error: ' | ||
f"Code={error.code}, Message='{error.message}'" | ||
f'{", Data=" + str(error.data) if error.data else ""}', | ||
) | ||
return JSONResponse( | ||
'{"message": ' + error.message + '}', | ||
status_code=404, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would not be 404. Default value would be 500. But we should check for error and attach 4XX or 5XX HTTP error codes. |
||
) | ||
|
||
def _handle_error(self, error: Exception) -> JSONResponse: | ||
traceback.print_exc() | ||
if isinstance(error, MethodNotImplementedError): | ||
return self._generate_error_response(UnsupportedOperationError()) | ||
elif isinstance(error, json.decoder.JSONDecodeError): | ||
return self._generate_error_response( | ||
JSONParseError(message=str(error)) | ||
) | ||
elif isinstance(error, ValidationError): | ||
return self._generate_error_response( | ||
InvalidRequestError(data=json.loads(error.json())), | ||
) | ||
pstephengoogle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.error(f'Unhandled exception: {error}') | ||
return self._generate_error_response( | ||
InternalError(message=str(error)) | ||
) | ||
|
||
async def _handle_request( | ||
self, | ||
method: Callable[[Request, ServerCallContext], Awaitable[str]], | ||
request: Request | ||
) -> JSONResponse: | ||
try: | ||
call_context = self._context_builder.build(request) | ||
response = await method(request, call_context) | ||
return JSONResponse(content=response) | ||
except Exception as e: | ||
return self._handle_error(e) | ||
|
||
async def _handle_streaming_request( | ||
self, | ||
method: Callable[[Request, ServerCallContext], AsyncIterator[str]], | ||
request: Request | ||
) -> EventSourceResponse: | ||
try: | ||
call_context = self._context_builder.build(request) | ||
async def event_generator( | ||
stream: AsyncGenerator[str], | ||
) -> AsyncGenerator[dict[str, str]]: | ||
async for item in stream: | ||
yield {'data': item} | ||
return EventSourceResponse(event_generator(method(request, call_context))) | ||
except Exception as e: | ||
# Since the stream has started, we can't return a JSONResponse. | ||
# Instead, we runt the error handling logic (provides logging) | ||
# and reraise the error and let server framework manage | ||
self._handle_error(e) | ||
raise e | ||
|
||
|
||
async def _handle_get_agent_card(self, request: Request) -> JSONResponse: | ||
"""Handles GET requests for the agent card endpoint. | ||
Args: | ||
request: The incoming Starlette Request object. | ||
Returns: | ||
A JSONResponse containing the agent card data. | ||
""" | ||
# The public agent card is a direct serialization of the agent_card | ||
# provided at initialization. | ||
return JSONResponse( | ||
self.agent_card.model_dump(mode='json', exclude_none=True) | ||
) | ||
|
||
async def handle_authenticated_agent_card(self, request: Request) -> JSONResponse: | ||
"""Hook for per credential agent card response. | ||
If a dynamic card is needed based on the credentials provided in the request | ||
override this method and return the customized content. | ||
Args: | ||
request: The incoming Starlette Request object. | ||
Returns: | ||
A JSONResponse containing the authenticated card. | ||
""" | ||
if not self.agent_card.supportsAuthenticatedExtendedCard: | ||
return JSONResponse( | ||
'{"detail": "Authenticated card not supported"}', status_code=404 | ||
) | ||
return JSONResponse( | ||
self.agent_card.model_dump(mode='json', exclude_none=True) | ||
) | ||
|
||
def routes(self) -> dict[Tuple[str, str], Callable[[Request],Any]]: | ||
routes = { | ||
('/v1/message:send', 'POST'): ( | ||
swapydapy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
functools.partial( | ||
self._handle_request, | ||
self.handler.on_message_send), | ||
), | ||
('/v1/message:stream', 'POST'): ( | ||
functools.partial( | ||
self._handle_streaming_request, | ||
self.handler.on_message_send_stream), | ||
), | ||
('/v1/tasks/{id}:subscribe', 'POST'): ( | ||
functools.partial( | ||
self._handle_streaming_request, | ||
self.handler.on_resubscribe_to_task), | ||
), | ||
('/v1/tasks/{id}', 'GET'): ( | ||
functools.partial( | ||
self._handle_request, | ||
self.handler.on_get_task), | ||
), | ||
('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): ( | ||
functools.partial( | ||
self._handle_request, | ||
self.handler.get_push_notification), | ||
), | ||
('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): ( | ||
functools.partial( | ||
self._handle_request, | ||
self.handler.set_push_notification), | ||
), | ||
('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): ( | ||
functools.partial( | ||
self._handle_request, | ||
self.handler.list_push_notifications), | ||
), | ||
('/v1/tasks', 'GET'): ( | ||
functools.partial( | ||
self._handle_request, | ||
self.handler.list_tasks), | ||
), | ||
} | ||
if self.agent_card.supportsAuthenticatedExtendedCard: | ||
routes['/v1/card'] = ( | ||
self.handle_authenticated_agent_card, | ||
'GET') | ||
return routes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this param should be called request_handler, right? As it is the base request handler, there's nothing related to HTTP.