diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30c8c4584..ca6620e03 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -204,6 +204,7 @@ jobs: enable-cache: true - run: uv sync --package pydantic-ai-slim --only-dev + - run: rm coverage/.coverage.*-py3.9-* # Exclude 3.9 coverage as it gets the wrong line numbers, causing invalid failures. - run: uv run coverage combine coverage - run: uv run coverage html --show-contexts --title "PydanticAI coverage for ${{ github.sha }}" diff --git a/docs/ag-ui.md b/docs/ag-ui.md new file mode 100644 index 000000000..6d0df700b --- /dev/null +++ b/docs/ag-ui.md @@ -0,0 +1,265 @@ +# Agent User Interaction (AG-UI) Protocol + +The [Agent User Interaction (AG-UI) Protocol](https://docs.ag-ui.com/introduction) +is an open standard introduced by the +[CopilotKit](https://webflow.copilotkit.ai/blog/introducing-ag-ui-the-protocol-where-agents-meet-users) +team that standardises how frontend applications connect to AI agents through +an open protocol. Think of it as a universal translator for AI-driven systems +no matter what language an agent speaks: AG-UI ensures fluent communication. + +The team at [Rocket Science](https://www.rocketscience.gg/), contributed the +[AG-UI integration](#ag-ui-adapter) to make it easy to implement the AG-UI +protocol with PydanticAI agents. + +This also includes an [`Agent.to_ag_ui`][pydantic_ai.Agent.to_ag_ui] convenience +method which simplifies the creation of [`FastAGUI`][pydantic_ai.ag_ui.FastAGUI] +for PydanticAI agents, which is built on top of [Starlette](https://www.starlette.io/), +meaning it's fully compatible with any ASGI server. + +## AG-UI Adapter + +The [Adapter][pydantic_ai.ag_ui.Adapter] class is an adapter between +PydanticAI agents and the AG-UI protocol written in Python. It provides support +for all aspects of spec including: + +- [Events](https://docs.ag-ui.com/concepts/events) +- [Messages](https://docs.ag-ui.com/concepts/messages) +- [State Management](https://docs.ag-ui.com/concepts/state) +- [Tools](https://docs.ag-ui.com/concepts/tools) + +### Installation + +The only dependencies are: + +- [ag-ui-protocol](https://docs.ag-ui.com/introduction): to provide the AG-UI + types and encoder. +- [pydantic](https://pydantic.dev): to validate the request/response messages +- [pydantic-ai](https://ai.pydantic.dev/): to provide the agent framework + +To run the examples you'll also need: + +- [uvicorn](https://www.uvicorn.org/) or another ASGI compatible server + +```bash +pip/uv-add 'uvicorn' +``` + +You can install PydanticAI with the `ag-ui` extra to ensure you have all the +required AG-UI dependencies: + +```bash +pip/uv-add 'pydantic-ai-slim[ag-ui]' +``` + +### Quick start + +```py {title="agent_to_ag_ui.py" py="3.10" hl_lines="17-28"} +"""Basic example for AG-UI with FastAPI and Pydantic AI.""" + +from __future__ import annotations + +from pydantic_ai import Agent + +agent = Agent('openai:gpt-4.1', instructions='Be fun!') +app = agent.to_ag_ui() +``` + +You can run the example with: + +```shell +uvicorn agent_to_ag_ui:app --host 0.0.0.0 --port 8000 +``` + +This will expose the agent as an AG-UI server, and you can start sending +requests to it. + +### Design + +The adapter receives messages in the form of a +[`RunAgentInput`](https://docs.ag-ui.com/sdk/js/core/types#runagentinput) +which describes the details of a request being passed to the agent including +messages and state. These are then converted to PydanticAI types, passed to the +agent which then process the request. + +Results from the agent are converted from PydanticAI types to AG-UI events and +streamed back to the caller as Server-Sent Events (SSE). + +A user request may require multiple round trips between client UI and PydanticAI +server, depending on the tools and events needed. + +In addition to the [Adapter][pydantic_ai.ag_ui.Adapter] there is also +[FastAGUI][pydantic_ai.ag_ui.FastAGUI] which is slim wrapper around +[Starlette](https://www.starlette.io/) providing easy access to run a PydanticAI +server with AG-UI support with any ASGI server. + +### Features + +To expose a PydanticAI agent as an AG-UI server including state support, you can +use the [`to_ag_ui`][pydantic_ai.agent.Agent.to_ag_ui] method create an ASGI +compatible server. + +In the example below we have document state which is shared between the UI and +server using the [`StateDeps`][pydantic_ai.ag_ui.StateDeps] which implements the +[`StateHandler`][pydantic_ai.ag_ui.StateHandler] that can be used to automatically +decode state contained in [`RunAgentInput.state`](https://docs.ag-ui.com/sdk/js/core/types#runagentinput) +when processing requests. + +#### State management + +The adapter provides full support for +[AG-UI state management](https://docs.ag-ui.com/concepts/state), which enables +real-time synchronization between agents and frontend applications. + +```python {title="ag_ui_state.py" py="3.10" hl_lines="18-40"} +"""State example for AG-UI with FastAPI and Pydantic AI.""" + +from __future__ import annotations + +from pydantic import BaseModel + +from pydantic_ai import Agent +from pydantic_ai.ag_ui import StateDeps + + +class DocumentState(BaseModel): + """State for the document being written.""" + + document: str = '' + + +agent = Agent( + 'openai:gpt-4.1', + instructions='Be fun!', + deps_type=StateDeps[DocumentState], +) +app = agent.to_ag_ui(deps=StateDeps(DocumentState())) +``` + +Since `app` is an ASGI application, it can be used with any ASGI server e.g. + +```bash +uvicorn agent_to_ag_ui:app --host 0.0.0.0 --port 8000 +``` + +Since the goal of [`to_ag_ui`][pydantic_ai.agent.Agent.to_ag_ui] is to be a +convenience method, it accepts the same a combination of the arguments require +for: + +- [`Adapter`][pydantic_ai.ag_ui.Adapter] constructor +- [`Agent.iter`][pydantic_ai.agent.Agent.iter] method + +If you want more control you can either use +[`agent_to_ag_ui`][pydantic_ai.ag_ui.agent_to_ag_ui] helper method or create +and [`Agent`][pydantic_ai.ag_ui.Agent] directly which also provide +the ability to customise [`Starlette`](https://www.starlette.io/applications/#starlette.applications.Starlette) +options. + +#### Tools + +AG-UI tools are seamlessly provided to the PydanticAI agent, enabling rich +use experiences with frontend user interfaces. + +#### Events + +The adapter provides the ability for PydanticAI tools to send +[AG-UI events](https://docs.ag-ui.com/concepts/events) simply by defining a tool +which returns a type based off +[`BaseEvent`](https://docs.ag-ui.com/sdk/js/core/events#baseevent) this allows +for custom events and state updates. + +```python {title="ag_ui_tool_events.py" py="3.10" hl_lines="34-55"} +"""Tool events example for AG-UI with FastAPI and Pydantic AI.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ag_ui.core import CustomEvent, EventType, StateSnapshotEvent +from pydantic import BaseModel + +from pydantic_ai import Agent, RunContext +from pydantic_ai.ag_ui import StateDeps + +if TYPE_CHECKING: + pass + + +class DocumentState(BaseModel): + """State for the document being written.""" + + document: str = '' + + +agent = Agent( + 'openai:gpt-4.1', + instructions='Be fun!', + deps_type=StateDeps[DocumentState], +) +app = agent.to_ag_ui(deps=StateDeps(DocumentState())) + + +@agent.tool +def update_state(ctx: RunContext[StateDeps[DocumentState]]) -> StateSnapshotEvent: + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot=ctx.deps.state, + ) + + +@agent.tool_plain +def custom_events() -> list[CustomEvent]: + return [ + CustomEvent( + type=EventType.CUSTOM, + name='count', + value=1, + ), + CustomEvent( + type=EventType.CUSTOM, + name='count', + value=2, + ), + ] +``` + +### Examples + +For more examples of how to use [`Adapter`][pydantic_ai.ag_ui.Adapter] see +[`pydantic_ai_ag_ui_examples`](https://github.com/pydantic/pydantic-ai/tree/main/examples/pydantic_ai_ag_ui_examples), +which includes working server for the with the +[AG-UI Dojo](https://docs.ag-ui.com/tutorials/debugging#the-ag-ui-dojo) which +can be run from a clone of the repo or with the `pydantic-ai-examples` package +installed with either of the following: + +```bash +pip/uv-add pydantic-ai-examples +``` + +Direct, which supports command line flags: + +```shell +python -m pydantic_ai_ag_ui_examples.dojo_server --help +usage: dojo_server.py [-h] [--port PORT] [--reload] [--no-reload] [--log-level {critical,error,warning,info,debug,trace}] + +PydanticAI AG-UI Dojo server + +options: + -h, --help show this help message and exit + --port PORT, -p PORT Port to run the server on (default: 9000) + --reload Enable auto-reload (default: True) + --no-reload Disable auto-reload + --log-level {critical,error,warning,info,debug,trace} + Agent log level (default: info) +``` + +Run with adapter debug logging: + +```shell +python -m pydantic_ai_ag_ui_examples.dojo_server --log-level debug +``` + +Using uvicorn: + +```shell +uvicorn pydantic_ai_ag_ui_examples.dojo_server:app --port 9000 +``` diff --git a/docs/api/ag_ui.md b/docs/api/ag_ui.md new file mode 100644 index 000000000..bb0ffd429 --- /dev/null +++ b/docs/api/ag_ui.md @@ -0,0 +1,3 @@ +# `pydantic_ai.ag_ui` + +::: pydantic_ai.ag_ui diff --git a/docs/install.md b/docs/install.md index 6d621ada5..3d803c729 100644 --- a/docs/install.md +++ b/docs/install.md @@ -56,6 +56,7 @@ pip/uv-add "pydantic-ai-slim[openai]" * `cohere` - installs `cohere` [PyPI ↗](https://pypi.org/project/cohere){:target="_blank"} * `duckduckgo` - installs `duckduckgo-search` [PyPI ↗](https://pypi.org/project/duckduckgo-search){:target="_blank"} * `tavily` - installs `tavily-python` [PyPI ↗](https://pypi.org/project/tavily-python){:target="_blank"} +* `ag-ui` - installs `ag-ui-protocol` [PyPI ↗](https://pypi.org/project/ag-ui-protocol){:target="_blank"} See the [models](models/index.md) documentation for information on which optional dependencies are required for each model. diff --git a/examples/pydantic_ai_ag_ui_examples/README.md b/examples/pydantic_ai_ag_ui_examples/README.md new file mode 100644 index 000000000..5a6c0b12a --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/README.md @@ -0,0 +1,156 @@ +# PydanticAI + +Implementation of the AG-UI protocol for PydanticAI. + +## Prerequisites + +This example uses a PydanticAI agent using an OpenAI model and the AG-UI dojo. + +1. An [OpenAI API key](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key) +2. A clone of the [AG-UI protocol repository](https://github.com/ag-ui-protocol/ag-ui) + +## Running + +To run this integration you need to: + +1. Make a copy of `.env-sample` as `.env` in the `typescript-sdk/integrations/pydantic-ai` directory +2. Open it in your editor and set `OPENAI_API_KEY` to a valid OpenAI key +3. Open terminal in the `typescript-sdk/integrations/pydantic-ai` of the `ag-ui` repo +4. Install the `pydantic-ai-examples` package + + ```shell + pip/uv-add pydantic-ai-examples + ``` + +5. Run the example dojo server + + ```shell + python -m pydantic_ai_ag_ui_examples.dojo_server + ``` + +6. Open another terminal in root directory of the `ag-ui` repository clone +7. Start the integration ag-ui dojo: + + ```shell + cd typescript-sdk + pnpm install && pnpm run dev + ``` + +8. Finally visit [http://localhost:3000/pydantic-ai](http://localhost:3000/pydantic-ai) + +## Feature Demos + +### [Agentic Chat](http://localhost:3000/pydantic-ai/feature/agentic_chat) + +This demonstrates a basic agent interaction including PydanticAI server side +tools and AG-UI client side tools. + +#### Agent Tools + +- `time` - PydanticAI tool to check the current time for a time zone +- `background` - AG-UI tool to set the background color of the client window + +#### Agent Prompts + +```text +What is the time in New York? +``` + +```text +Change the background to blue +``` + +A complex example which mixes both AG-UI and PydanticAI tools: + +```text +Perform the following steps, waiting for the response of each step before continuing: +1. Get the time +2. Set the background to red +3. Get the time +4. Report how long the background set took by diffing the two times +``` + +### [Agentic Generative UI](http://localhost:3000/pydantic-ai/feature/agentic_generative_ui) + +Demonstrates a long running task where the agent sends updates to the frontend +to let the user know what's happening. + +#### Plan Prompts + +```text +Create a plan for breakfast and execute it +``` + +### [Human in the Loop](http://localhost:3000/pydantic-ai/feature/human_in_the_loop) + +Demonstrates simple human in the loop workflow where the agent comes up with a +plan and the user can approve it using checkboxes. + +#### Task Planning Tools + +- `generate_task_steps` - AG-UI tool to generate and confirm steps + +#### Task Planning Prompt + +```text +Generate a list of steps for cleaning a car for me to review +``` + +### [Predictive State Updates](http://localhost:3000/pydantic-ai/feature/predictive_state_updates) + +Demonstrates how to use the predictive state updates feature to update the state +of the UI based on agent responses, including user interaction via user +confirmation. + +#### Story Tools + +- `write_document` - AG-UI tool to write the document to a window +- `document_predict_state` - PydanticAI tool that enables document state + prediction for the `write_document` tool + +This also shows how to use custom instructions based on shared state information. + +#### Story Example + +Starting document text + +```markdown +Bruce was a good dog, +``` + +Agent prompt + +```text +Help me complete my story about bruce the dog, is should be no longer than a sentence. +``` + +### [Shared State](http://localhost:3000/pydantic-ai/feature/shared_state) + +Demonstrates how to use the shared state between the UI and the agent. + +State sent to the agent is detected by a function based instruction. This then +validates the data using a custom pydantic model before using to create the +instructions for the agent to follow and send to the client using a AG-UI tool. + +#### Recipe Tools + +- `display_recipe` - AG-UI tool to display the recipe in a graphical format + +#### Recipe Example + +1. Customise the basic settings of your recipe +2. Click `Improve with AI` + +### [Tool Based Generative UI](http://localhost:3000/pydantic-ai/feature/tool_based_generative_ui) + +Demonstrates customised rendering for tool output with used confirmation. + +#### Haiku Tools + +- `generate_haiku` - AG-UI tool to display a haiku in English and Japanese + +#### Haiku Prompt + +```text +Generate a haiku about formula 1 +``` diff --git a/examples/pydantic_ai_ag_ui_examples/__init__.py b/examples/pydantic_ai_ag_ui_examples/__init__.py new file mode 100644 index 000000000..2652b3500 --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/__init__.py @@ -0,0 +1 @@ +"""Example API Server for a AG-UI compatible Pydantic AI Agent UI.""" diff --git a/examples/pydantic_ai_ag_ui_examples/api/__init__.py b/examples/pydantic_ai_ag_ui_examples/api/__init__.py new file mode 100644 index 000000000..d17cab009 --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/api/__init__.py @@ -0,0 +1,19 @@ +"""Example API for a AG-UI compatible Pydantic AI Agent UI.""" + +from __future__ import annotations + +from .agentic_chat import app as agentic_chat_app +from .agentic_generative_ui import app as agentic_generative_ui_app +from .human_in_the_loop import app as human_in_the_loop_app +from .predictive_state_updates import app as predictive_state_updates_app +from .shared_state import app as shared_state_app +from .tool_based_generative_ui import app as tool_based_generative_ui_app + +__all__: list[str] = [ + 'agentic_chat_app', + 'agentic_generative_ui_app', + 'human_in_the_loop_app', + 'predictive_state_updates_app', + 'shared_state_app', + 'tool_based_generative_ui_app', +] diff --git a/examples/pydantic_ai_ag_ui_examples/api/agent.py b/examples/pydantic_ai_ag_ui_examples/api/agent.py new file mode 100644 index 000000000..ddc5a29d2 --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/api/agent.py @@ -0,0 +1,35 @@ +"""Create a Pydantic AI agent and AG-UI adapter.""" + +from __future__ import annotations + +from dotenv import load_dotenv + +from pydantic_ai import Agent +from pydantic_ai.ag_ui import FastAGUI +from pydantic_ai.tools import AgentDepsT + + +def agent( + model: str = 'openai:gpt-4o-mini', + deps: AgentDepsT = None, + instructions: str | None = None, +) -> FastAGUI[AgentDepsT, str]: + """Create a Pydantic AI agent with AG-UI adapter. + + Args: + model: The model to use for the agent. + deps: Optional dependencies for the agent. + instructions: Optional instructions for the agent. + + Returns: + An instance of FastAGUI with the agent and adapter. + """ + # Ensure environment variables are loaded. + load_dotenv() + + return Agent( + model, + output_type=str, + instructions=instructions, + deps_type=type(deps), + ).to_ag_ui(deps=deps) diff --git a/examples/pydantic_ai_ag_ui_examples/api/agentic_chat.py b/examples/pydantic_ai_ag_ui_examples/api/agentic_chat.py new file mode 100644 index 000000000..a369d1ef7 --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/api/agentic_chat.py @@ -0,0 +1,26 @@ +"""Agentic Chat feature.""" + +from __future__ import annotations + +from datetime import datetime +from zoneinfo import ZoneInfo + +from pydantic_ai.ag_ui import FastAGUI + +from .agent import agent + +app: FastAGUI = agent() + + +@app.adapter.agent.tool_plain +async def current_time(timezone: str = 'UTC') -> str: + """Get the current time in ISO format. + + Args: + timezone: The timezone to use. + + Returns: + The current time in ISO format string. + """ + tz: ZoneInfo = ZoneInfo(timezone) + return datetime.now(tz=tz).isoformat() diff --git a/examples/pydantic_ai_ag_ui_examples/api/agentic_generative_ui.py b/examples/pydantic_ai_ag_ui_examples/api/agentic_generative_ui.py new file mode 100644 index 000000000..cdc131095 --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/api/agentic_generative_ui.py @@ -0,0 +1,121 @@ +"""Agentic Generative UI feature.""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any, Literal + +from ag_ui.core import EventType, StateDeltaEvent, StateSnapshotEvent +from pydantic import BaseModel, Field + +from pydantic_ai.ag_ui import FastAGUI + +from .agent import agent + +app: FastAGUI = agent( + instructions="""When planning use tools only, without any other messages. +IMPORTANT: +- Use the `create_plan` tool to set the initial state of the steps +- Use the `update_plan_step` tool to update the status of each step +- Do NOT repeat the plan or summarise it in a message +- Do NOT confirm the creation or updates in a message +- Do NOT ask the user for additional information or next steps + +Only one plan can be active at a time, so do not call the `create_plan` tool +again until all the steps in current plan are completed. +""" +) + + +class StepStatus(StrEnum): + """The status of a step in a plan.""" + + PENDING = 'pending' + COMPLETED = 'completed' + + +class Step(BaseModel): + """Represents a step in a plan.""" + + description: str = Field(description='The description of the step') + status: StepStatus = Field( + default=StepStatus.PENDING, + description='The status of the step (e.g., pending, completed)', + ) + + +class Plan(BaseModel): + """Represents a plan with multiple steps.""" + + steps: list[Step] = Field( + default_factory=lambda: list[Step](), description='The steps in the plan' + ) + + +class JSONPatchOp(BaseModel): + """A class representing a JSON Patch operation (RFC 6902).""" + + op: Literal['add', 'remove', 'replace', 'move', 'copy', 'test'] = Field( + ..., + description='The operation to perform: add, remove, replace, move, copy, or test', + ) + path: str = Field(..., description='JSON Pointer (RFC 6901) to the target location') + value: Any = Field( + default=None, + description='The value to apply (for add, replace operations)', + ) + from_: str | None = Field( + default=None, + alias='from', + description='Source path (for move, copy operations)', + ) + + +@app.adapter.agent.tool_plain +def create_plan(steps: list[str]) -> StateSnapshotEvent: + """Create a plan with multiple steps. + + Args: + steps: List of step descriptions to create the plan. + + Returns: + StateSnapshotEvent containing the initial state of the steps. + """ + plan: Plan = Plan( + steps=[Step(description=step) for step in steps], + ) + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot=plan.model_dump(), + ) + + +@app.adapter.agent.tool_plain +def update_plan_step( + index: int, description: str | None = None, status: StepStatus | None = None +) -> StateDeltaEvent: + """Update the plan with new steps or changes. + + Args: + index: The index of the step to update. + description: The new description for the step. + status: The new status for the step. + + Returns: + StateDeltaEvent containing the changes made to the plan. + """ + changes: list[JSONPatchOp] = [] + if description is not None: + changes.append( + JSONPatchOp( + op='replace', path=f'/steps/{index}/description', value=description + ) + ) + if status is not None: + changes.append( + JSONPatchOp(op='replace', path=f'/steps/{index}/status', value=status.value) + ) + return StateDeltaEvent( + type=EventType.STATE_DELTA, + delta=changes, + ) diff --git a/examples/pydantic_ai_ag_ui_examples/api/human_in_the_loop.py b/examples/pydantic_ai_ag_ui_examples/api/human_in_the_loop.py new file mode 100644 index 000000000..e27b85c30 --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/api/human_in_the_loop.py @@ -0,0 +1,20 @@ +"""Human in the Loop Feature. + +No special handling is required for this feature. +""" + +from __future__ import annotations + +from pydantic_ai.ag_ui import FastAGUI + +from .agent import agent + +app: FastAGUI = agent( + instructions="""When planning tasks use tools only, without any other messages. +IMPORTANT: +- Use the `generate_task_steps` tool to display the suggested steps to the user +- Never repeat the plan, or send a message detailing steps +- If accepted, confirm the creation of the plan and the number of selected (enabled) steps only +- If not accepted, ask the user for more information, DO NOT use the `generate_task_steps` tool again +""" +) diff --git a/examples/pydantic_ai_ag_ui_examples/api/predictive_state_updates.py b/examples/pydantic_ai_ag_ui_examples/api/predictive_state_updates.py new file mode 100644 index 000000000..c82d3647c --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/api/predictive_state_updates.py @@ -0,0 +1,83 @@ +"""Predictive State feature.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ag_ui.core import CustomEvent, EventType +from pydantic import BaseModel + +from pydantic_ai.ag_ui import FastAGUI, StateDeps + +from .agent import agent + +if TYPE_CHECKING: # pragma: no cover + from pydantic_ai import RunContext + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class DocumentState(BaseModel): + """State for the document being written.""" + + document: str = '' + + +app: FastAGUI = agent(deps=StateDeps(DocumentState())) + + +# Tools which return AG-UI events will be sent to the client as part of the +# event stream, single events and iterables of events are supported. +@app.adapter.agent.tool_plain +def document_predict_state() -> list[CustomEvent]: + """Enable document state prediction. + + Returns: + CustomEvent containing the event to enable state prediction. + """ + _LOGGER.info('enabling document state state prediction') + return [ + CustomEvent( + type=EventType.CUSTOM, + name='PredictState', + value=[ + { + 'state_key': 'document', + 'tool': 'write_document', + 'tool_argument': 'document', + }, + ], + ), + ] + + +@app.adapter.agent.instructions() +def story_instructions(ctx: RunContext[StateDeps[DocumentState]]) -> str: + """Provide instructions for writing document if present. + + Args: + ctx: The run context containing document state information. + + Returns: + Instructions string for the document writing agent. + """ + _LOGGER.info('story instructions document=%s', ctx.deps.state.document) + + return f"""You are a helpful assistant for writing documents. + +Before you start writing, you MUST call the `document_predict_state` +tool to enable state prediction. + +To present the document to the user for review, you MUST use the +`write_document` tool. + +When you have written the document, DO NOT repeat it as a message. +If accepted briefly summarize the changes you made, 2 sentences +max, otherwise ask the user to clarify what they want to change. + +This is the current document: + +{ctx.deps.state.document} +""" diff --git a/examples/pydantic_ai_ag_ui_examples/api/shared_state.py b/examples/pydantic_ai_ag_ui_examples/api/shared_state.py new file mode 100644 index 000000000..3ab8bcd0a --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/api/shared_state.py @@ -0,0 +1,145 @@ +"""Shared State feature.""" + +from __future__ import annotations + +import json +import logging +from enum import StrEnum +from typing import TYPE_CHECKING + +from ag_ui.core import EventType, StateSnapshotEvent +from pydantic import BaseModel, Field + +from pydantic_ai.ag_ui import FastAGUI, StateDeps + +from .agent import agent + +if TYPE_CHECKING: # pragma: no cover + from pydantic_ai import RunContext + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class SkillLevel(StrEnum): + """The level of skill required for the recipe.""" + + BEGINNER = 'Beginner' + INTERMEDIATE = 'Intermediate' + ADVANCED = 'Advanced' + + +class SpecialPreferences(StrEnum): + """Special preferences for the recipe.""" + + HIGH_PROTEIN = 'High Protein' + LOW_CARB = 'Low Carb' + SPICY = 'Spicy' + BUDGET_FRIENDLY = 'Budget-Friendly' + ONE_POT_MEAL = 'One-Pot Meal' + VEGETARIAN = 'Vegetarian' + VEGAN = 'Vegan' + + +class CookingTime(StrEnum): + """The cooking time of the recipe.""" + + FIVE_MIN = '5 min' + FIFTEEN_MIN = '15 min' + THIRTY_MIN = '30 min' + FORTY_FIVE_MIN = '45 min' + SIXTY_PLUS_MIN = '60+ min' + + +class Ingredient(BaseModel): + """A class representing an ingredient in a recipe.""" + + icon: str = Field( + default='ingredient', + description="The icon emoji (not emoji code like '\x1f35e', but the actual emoji like 🥕) of the ingredient", + ) + name: str + amount: str + + +class Recipe(BaseModel): + """A class representing a recipe.""" + + skill_level: SkillLevel = Field( + default=SkillLevel.BEGINNER, + description='The skill level required for the recipe', + ) + special_preferences: list[SpecialPreferences] = Field( + default_factory=lambda: list[SpecialPreferences](), + description='Any special preferences for the recipe', + ) + cooking_time: CookingTime = Field( + default=CookingTime.FIVE_MIN, description='The cooking time of the recipe' + ) + ingredients: list[Ingredient] = Field( + default_factory=lambda: list[Ingredient](), + description='Ingredients for the recipe', + ) + instructions: list[str] = Field( + default_factory=lambda: list[str](), description='Instructions for the recipe' + ) + + +class RecipeSnapshot(BaseModel): + """A class representing the state of the recipe.""" + + recipe: Recipe = Field( + default_factory=Recipe, description='The current state of the recipe' + ) + + +app: FastAGUI = agent(deps=StateDeps(RecipeSnapshot())) + + +@app.adapter.agent.tool_plain +def display_recipe(recipe: Recipe) -> StateSnapshotEvent: + """Display the recipe to the user. + + Args: + recipe: The recipe to display. + + Returns: + StateSnapshotEvent containing the recipe snapshot. + """ + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot={'recipe': recipe}, + ) + + +@app.adapter.agent.instructions +def recipe_instructions(ctx: RunContext[StateDeps[RecipeSnapshot]]) -> str: + """Instructions for the recipe generation agent. + + Args: + ctx: The run context containing recipe state information. + + Returns: + Instructions string for the recipe generation agent. + """ + _LOGGER.info('recipe instructions recipe=%s', ctx.deps.state.recipe) + + return f"""You are a helpful assistant for creating recipes. + +IMPORTANT: +- Create a complete recipe using the existing ingredients +- Append new ingredients to the existing ones +- Use the `display_recipe` tool to present the recipe to the user +- Do NOT repeat the recipe in the message, use the tool instead + +Once you have created the updated recipe and displayed it to the user, +summarise the changes in one sentence, don't describe the recipe in +detail or send it as a message to the user. + +The structure of a recipe is as follows: + +{json.dumps(Recipe.model_json_schema(), indent=2)} + +The current state of the recipe is: + +{ctx.deps.state.recipe.model_dump_json(indent=2)} +""" diff --git a/examples/pydantic_ai_ag_ui_examples/api/tool_based_generative_ui.py b/examples/pydantic_ai_ag_ui_examples/api/tool_based_generative_ui.py new file mode 100644 index 000000000..9d04040f5 --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/api/tool_based_generative_ui.py @@ -0,0 +1,12 @@ +"""Tool Based Generative UI feature. + +No special handling is required for this feature. +""" + +from __future__ import annotations + +from pydantic_ai.ag_ui import FastAGUI + +from .agent import agent + +app: FastAGUI = agent() diff --git a/examples/pydantic_ai_ag_ui_examples/basic.py b/examples/pydantic_ai_ag_ui_examples/basic.py new file mode 100644 index 000000000..48828535c --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/basic.py @@ -0,0 +1,26 @@ +"""Basic example of using Agent.to_ag_ui with FastAPI.""" + +from __future__ import annotations + +from pydantic_ai import Agent + +agent: Agent[None, str] = Agent( + 'openai:gpt-4o-mini', + instructions='You are a helpful assistant.', +) +app = agent.to_ag_ui() + +if __name__ == '__main__': + import uvicorn + + from .cli import Args, parse_args + + args: Args = parse_args() + + uvicorn.run( + 'pydantic_ai_ag_ui_examples.dojo_server:app', + port=args.port, + reload=args.reload, + log_level=args.log_level, + log_config=args.log_config(), + ) diff --git a/examples/pydantic_ai_ag_ui_examples/cli/__init__.py b/examples/pydantic_ai_ag_ui_examples/cli/__init__.py new file mode 100644 index 000000000..e4a3ba3cb --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/cli/__init__.py @@ -0,0 +1,10 @@ +"""Command line interface for the PydanticAI AG-UI servers.""" + +from __future__ import annotations + +from .args import Args, parse_args + +__all__ = [ + 'Args', + 'parse_args', +] diff --git a/examples/pydantic_ai_ag_ui_examples/cli/args.py b/examples/pydantic_ai_ag_ui_examples/cli/args.py new file mode 100644 index 000000000..ceb3476fc --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/cli/args.py @@ -0,0 +1,75 @@ +"""CLI argument parser for the PydanticAI AG-UI servers.""" + +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from typing import Any + +from uvicorn.config import LOGGING_CONFIG + + +@dataclass +class Args: + """Custom namespace for command line arguments.""" + + port: int + reload: bool + log_level: str + loggers: list[str] + + def log_config(self) -> dict[str, Any]: + """Return the logging configuration based on the log level.""" + log_config: dict[str, Any] = LOGGING_CONFIG.copy() + for logger in self.loggers: + log_config['loggers'][logger] = { + 'handlers': ['default'], + 'level': self.log_level.upper(), + 'propagate': False, + } + + return log_config + + +def parse_args() -> Args: + """Parse command line arguments for the PydanticAI AG-UI servers. + + Returns: + Args: A dataclass containing the parsed command line arguments. + """ + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description='PydanticAI AG-UI Dojo server' + ) + parser.add_argument( + '--port', + '-p', + type=int, + default=9000, + help='Port to run the server on (default: 9000)', + ) + parser.add_argument( + '--reload', + action='store_true', + default=True, + help='Enable auto-reload (default: True)', + ) + parser.add_argument( + '--no-reload', dest='reload', action='store_false', help='Disable auto-reload' + ) + parser.add_argument( + '--log-level', + choices=['critical', 'error', 'warning', 'info', 'debug', 'trace'], + default='info', + help='Adapter log level (default: info)', + ) + parser.add_argument( + '--loggers', + nargs='*', + default=[ + 'pydantic_ai.ag_ui', + ], + help='Logger names to configure (default: adapter and model loggers)', + ) + + args: argparse.Namespace = parser.parse_args() + return Args(**vars(args)) diff --git a/examples/pydantic_ai_ag_ui_examples/dojo_server.py b/examples/pydantic_ai_ag_ui_examples/dojo_server.py new file mode 100644 index 000000000..597c034e6 --- /dev/null +++ b/examples/pydantic_ai_ag_ui_examples/dojo_server.py @@ -0,0 +1,56 @@ +"""Example usage of the AG-UI adapter for PydanticAI. + +This provides a FastAPI application that demonstrates how to use the +PydanticAI agent with the AG-UI protocol. It includes examples for +each of the AG-UI dojo features: +- Agentic Chat +- Human in the Loop +- Agentic Generative UI +- Tool Based Generative UI +- Shared State +- Predictive State Updates +""" + +from __future__ import annotations + +from fastapi import FastAPI + +from .api import ( + agentic_chat_app, + agentic_generative_ui_app, + human_in_the_loop_app, + predictive_state_updates_app, + shared_state_app, + tool_based_generative_ui_app, +) + +app = FastAPI(title='PydanticAI AG-UI server') +app.mount('/agentic_chat', agentic_chat_app, 'Agentic Chat') +app.mount('/agentic_generative_ui', agentic_generative_ui_app, 'Agentic Generative UI') +app.mount('/human_in_the_loop', human_in_the_loop_app, 'Human in the Loop') +app.mount( + '/predictive_state_updates', + predictive_state_updates_app, + 'Predictive State Updates', +) +app.mount('/shared_state', shared_state_app, 'Shared State') +app.mount( + '/tool_based_generative_ui', + tool_based_generative_ui_app, + 'Tool Based Generative UI', +) + + +if __name__ == '__main__': + import uvicorn + + from .cli import Args, parse_args + + args: Args = parse_args() + + uvicorn.run( + 'pydantic_ai_ag_ui_examples.dojo_server:app', + port=args.port, + reload=args.reload, + log_config=args.log_config(), + ) diff --git a/examples/pydantic_ai_ag_ui_examples/py.typed b/examples/pydantic_ai_ag_ui_examples/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/examples/pyproject.toml b/examples/pyproject.toml index bb5dcd9ef..04d6240bf 100644 --- a/examples/pyproject.toml +++ b/examples/pyproject.toml @@ -42,7 +42,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,groq,anthropic]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,groq,anthropic,ag-ui]=={{ version }}", "pydantic-evals=={{ version }}", "asyncpg>=0.30.0", "fastapi>=0.115.4", @@ -57,7 +57,10 @@ dependencies = [ ] [tool.hatch.build.targets.wheel] -packages = ["pydantic_ai_examples"] +packages = [ + "pydantic_ai_ag_ui_examples", + "pydantic_ai_examples", +] [tool.uv.sources] pydantic-ai-slim = { workspace = true } diff --git a/fasta2a/pyproject.toml b/fasta2a/pyproject.toml index 2abe809aa..93495d93f 100644 --- a/fasta2a/pyproject.toml +++ b/fasta2a/pyproject.toml @@ -54,7 +54,7 @@ logfire = ["logfire>=2.3"] [project.urls] Homepage = "https://ai.pydantic.dev/a2a/fasta2a" -Source = "https://github.com/pydantic/fasta2a" +Source = "https://github.com/pydantic/pydantic-ai" Documentation = "https://ai.pydantic.dev/a2a" Changelog = "https://github.com/pydantic/pydantic-ai/releases" diff --git a/mkdocs.yml b/mkdocs.yml index 44b1548f1..d86854e64 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -47,6 +47,7 @@ nav: - mcp/server.md - mcp/run-python.md - A2A: a2a.md + - AG-UI: ag-ui.md - cli.md - Examples: - examples/index.md @@ -62,6 +63,7 @@ nav: - examples/question-graph.md - examples/slack-lead-qualifier.md - API Reference: + - api/ag_ui.md - api/agent.md - api/tools.md - api/common_tools.md diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 3c9c8f0dc..3378bc345 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -621,7 +621,7 @@ async def process_function_tools( # noqa: C901 result_data = await toolset.call_tool(call, run_context) except exceptions.UnexpectedModelBehavior as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) - raise e + raise # pragma: no cover except ToolRetryError as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) yield _messages.FunctionToolCallEvent(call) diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py new file mode 100644 index 000000000..eecd8e6c8 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -0,0 +1,871 @@ +"""Provides an AG-UI protocol adapter for the PydanticAI agent. + +This package provides seamless integration between pydantic-ai agents and ag-ui +for building interactive AI applications with streaming event-based communication. +""" + +from __future__ import annotations + +import json +import logging +import uuid +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Final, + Generic, + Protocol, + TypeVar, + runtime_checkable, +) + +try: + from ag_ui.core import ( + AssistantMessage, + BaseEvent, + DeveloperMessage, + EventType, + Message, + RunAgentInput, + RunErrorEvent, + RunFinishedEvent, + RunStartedEvent, + State, + SystemMessage, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ThinkingTextMessageContentEvent, + ThinkingTextMessageEndEvent, + ThinkingTextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, + ToolMessage, + UserMessage, + ) + from ag_ui.encoder import EventEncoder +except ImportError as e: # pragma: no cover + raise ImportError( + 'Please install the `ag-ui-protocol` package to use `Agent.to_ag_ui()` method, ' + 'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`' + ) from e + +try: + from starlette.applications import Starlette + from starlette.middleware import Middleware + from starlette.requests import Request + from starlette.responses import Response, StreamingResponse + from starlette.routing import BaseRoute + from starlette.types import ExceptionHandler, Lifespan +except ImportError as e: # pragma: no cover + raise ImportError( + 'Please install the `starlette` package to use `Agent.to_ag_ui()` method, ' + 'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`' + ) from e + +from pydantic import BaseModel, ValidationError + +from ._agent_graph import CallToolsNode, ModelRequestNode +from .agent import Agent, RunOutputDataT +from .messages import ( + AgentStreamEvent, + FinalResultEvent, + FunctionToolResultEvent, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponsePart, + PartDeltaEvent, + PartStartEvent, + SystemPromptPart, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, + ToolCallPart, + ToolCallPartDelta, + ToolReturnPart, + UserPromptPart, +) +from .models import KnownModelName, Model +from .output import DeferredToolCalls, OutputDataT, OutputSpec +from .result import AgentStream +from .settings import ModelSettings +from .tools import AgentDepsT, ToolDefinition +from .toolsets import AbstractToolset +from .toolsets.deferred import DeferredToolset +from .usage import Usage, UsageLimits + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from ag_ui.encoder import EventEncoder + + from pydantic_graph.nodes import End + + from ._agent_graph import AgentNode + from .agent import AgentRun + from .result import FinalResult + +# Variables. +_LOGGER: logging.Logger = logging.getLogger(__name__) + +# Constants. +SSE_CONTENT_TYPE: Final[str] = 'text/event-stream' +"""Content type header value for Server-Sent Events (SSE).""" + + +class FastAGUI(Generic[AgentDepsT, OutputDataT], Starlette): + """ASGI application for running PydanticAI agents with AG-UI protocol support.""" + + def __init__( + self, + *, + # Adapter for the agent. + adapter: Adapter[AgentDepsT, OutputDataT], + # Agent.iter parameters. + output_type: OutputSpec[OutputDataT] | None = None, + model: Model | KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + # Starlette + debug: bool = False, + routes: Sequence[BaseRoute] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: Mapping[Any, ExceptionHandler] | None = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, + lifespan: Lifespan[FastAGUI[AgentDepsT, OutputDataT]] | None = None, + ) -> None: + """Initialize the FastAGUI application. + + Args: + adapter: The adapter to use for running the agent. + + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has + no output validators since output validators would expect an argument that matches the agent's + output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + + debug: Boolean indicating if debug tracebacks should be returned on errors. + routes: A list of routes to serve incoming HTTP and WebSocket requests. + middleware: A list of middleware to run for every request. A starlette application will always + automatically include two middleware classes. `ServerErrorMiddleware` is added as the very + outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled + exception cases occurring in the routing or endpoints. + exception_handlers: A mapping of either integer status codes, or exception class types onto + callables which handle the exceptions. Exception handler callables should be of the form + `handler(request, exc) -> response` and may be either standard functions, or async functions. + on_startup: A list of callables to run on application startup. Startup handler callables do not + take any arguments, and may be either standard functions, or async functions. + on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do + not take any arguments, and may be either standard functions, or async functions. + lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. + This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or + the other, not both. + """ + super().__init__( + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + on_startup=on_startup, + on_shutdown=on_shutdown, + lifespan=lifespan, + ) + self.adapter: Adapter[AgentDepsT, OutputDataT] = adapter + + async def endpoint(request: Request) -> Response | StreamingResponse: + """Endpoint to run the agent with the provided input data.""" + accept: str = request.headers.get('accept', SSE_CONTENT_TYPE) + try: + input_data: RunAgentInput = RunAgentInput.model_validate_json(await request.body()) + except ValidationError as e: # pragma: no cover + _LOGGER.error('invalid request: %s', e) + return Response( + content=json.dumps(e.json()), + media_type='application/json', + status_code=400, + ) + + return StreamingResponse( + adapter.run( + input_data, + accept, + output_type=output_type, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ), + media_type=SSE_CONTENT_TYPE, + ) + + self.router.add_route('/', endpoint, methods=['POST'], name='run_agent') + + +def agent_to_ag_ui( + *, + # Adapter parameters. + agent: Agent[AgentDepsT, OutputDataT], + # Agent.iter parameters. + output_type: OutputSpec[OutputDataT] | None = None, + model: Model | KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + # Starlette parameters. + debug: bool = False, + routes: Sequence[BaseRoute] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: Mapping[Any, ExceptionHandler] | None = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, + lifespan: Lifespan[FastAGUI[AgentDepsT, OutputDataT]] | None = None, +) -> FastAGUI[AgentDepsT, OutputDataT]: + """Create a FastAGUI server from an agent. + + Args: + agent: The PydanticAI agent to adapt for AG-UI protocol. + + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has + no output validators since output validators would expect an argument that matches the agent's + output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + + debug: Boolean indicating if debug tracebacks should be returned on errors. + routes: A list of routes to serve incoming HTTP and WebSocket requests. + middleware: A list of middleware to run for every request. A starlette application will always + automatically include two middleware classes. `ServerErrorMiddleware` is added as the very + outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled + exception cases occurring in the routing or endpoints. + exception_handlers: A mapping of either integer status codes, or exception class types onto + callables which handle the exceptions. Exception handler callables should be of the form + `handler(request, exc) -> response` and may be either standard functions, or async functions. + on_startup: A list of callables to run on application startup. Startup handler callables do not + take any arguments, and may be either standard functions, or async functions. + on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do + not take any arguments, and may be either standard functions, or async functions. + lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. + This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or + the other, not both. + """ + adapter: Adapter[AgentDepsT, OutputDataT] = Adapter(agent=agent) + + return FastAGUI( + adapter=adapter, + # Agent.iter parameters + output_type=output_type, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + # Starlette + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + on_startup=on_startup, + on_shutdown=on_shutdown, + lifespan=lifespan, + ) + + +@dataclass(repr=False) +class Adapter(Generic[AgentDepsT, OutputDataT]): + """An agent adapter providing AG-UI protocol support for PydanticAI agents. + + This class manages the agent runs, tool calls, state storage and providing + an adapter for running agents with Server-Sent Event (SSE) streaming + responses using the AG-UI protocol. + + Examples: + This is an example of basic usage with FastAGUI. + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4.1', instructions='Be fun!') + app = agent.to_ag_ui() + ``` + + PydanticAI tools which return AG-UI events will be sent to the client + as part of the event stream, single events and event iterables are + supported. + ```python + from ag_ui.core import CustomEvent, EventType, StateSnapshotEvent + from pydantic import BaseModel + + from pydantic_ai import Agent, RunContext + from pydantic_ai.ag_ui import StateDeps + + + class DocumentState(BaseModel): + document: str + + + agent = Agent( + 'openai:gpt-4.1', instructions='Be fun!', deps_type=StateDeps[DocumentState] + ) + + + @agent.tool + def update_state(ctx: RunContext[StateDeps[DocumentState]]) -> StateSnapshotEvent: + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot=ctx.deps.state, + ) + + + @agent.tool_plain + def custom_events() -> list[CustomEvent]: + return [ + CustomEvent( + type=EventType.CUSTOM, + name='count', + value=1, + ), + CustomEvent( + type=EventType.CUSTOM, + name='count', + value=2, + ), + ] + ``` + Args: + agent: The PydanticAI `Agent` to adapt. + """ + + agent: Agent[AgentDepsT, OutputDataT] = field(repr=False) + _logger: logging.Logger = field(default=_LOGGER, repr=False, init=False) + + async def run( + self, + run_input: RunAgentInput, + accept: str = SSE_CONTENT_TYPE, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + model: Model | KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AsyncGenerator[str, None]: + """Run the agent with streaming response using AG-UI protocol events. + + The first two arguments are specific to `Adapter` the rest map directly to the `Agent.iter` method. + + Args: + run_input: The AG-UI run input containing thread_id, run_id, messages, etc. + accept: The accept header value for the run. + + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + + Yields: + Streaming SSE-formatted event chunks. + """ + self._logger.debug('starting run: %s', json.dumps(run_input.model_dump(), indent=2)) + + encoder: EventEncoder = EventEncoder(accept=accept) + if run_input.tools: + # AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the + # PydanticAI events and actual AG-UI tool names, preventing the tool from being called. If any + # conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets. + toolset: AbstractToolset[AgentDepsT] = DeferredToolset[AgentDepsT]( + [ + ToolDefinition( + name=tool.name, + description=tool.description, + parameters_json_schema=tool.parameters, + ) + for tool in run_input.tools + ] + ) + toolsets = [toolset] if toolsets is None else [toolset] + list(toolsets) + + try: + yield encoder.encode( + RunStartedEvent( + type=EventType.RUN_STARTED, + thread_id=run_input.thread_id, + run_id=run_input.run_id, + ), + ) + + if not run_input.messages: + raise _NoMessagesError + + if isinstance(deps, StateHandler): + deps.state = run_input.state + + history: _History = _convert_history(run_input.messages) + + run: AgentRun[AgentDepsT, Any] + async with self.agent.iter( + user_prompt=None, + output_type=[output_type or self.agent.output_type, DeferredToolCalls], + message_history=history.messages, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) as run: + async for event in self._agent_stream(run, history): + yield encoder.encode(event) + except _RunError as e: + self._logger.exception('agent run') + yield encoder.encode( + RunErrorEvent(type=EventType.RUN_ERROR, message=e.message, code=e.code), + ) + except Exception as e: # pragma: no cover + self._logger.exception('unexpected error in agent run') + yield encoder.encode( + RunErrorEvent(type=EventType.RUN_ERROR, message=str(e), code='run_error'), + ) + else: + yield encoder.encode( + RunFinishedEvent( + type=EventType.RUN_FINISHED, + thread_id=run_input.thread_id, + run_id=run_input.run_id, + ), + ) + + self._logger.debug('done thread_id=%s run_id=%s', run_input.thread_id, run_input.run_id) + + async def _tool_result_event( + self, + result: ToolReturnPart, + prompt_message_id: str, + ) -> AsyncGenerator[BaseEvent, None]: + """Convert a tool call result to AG-UI events. + + Args: + result: The tool call result to process. + prompt_message_id: The message ID of the prompt that initiated the tool call. + + Yields: + AG-UI Server-Sent Events (SSE). + """ + yield ToolCallResultEvent( + message_id=prompt_message_id, + type=EventType.TOOL_CALL_RESULT, + role=Role.TOOL.value, + tool_call_id=result.tool_call_id, + content=result.model_response_str(), + ) + + # Now check for AG-UI events returned by the tool calls. + content: Any = result.content + if isinstance(content, BaseEvent): + self._logger.debug('ag-ui event: %s', content) + yield content + elif isinstance(content, (str, bytes)): # pragma: no branch + # Avoid iterable check for strings and bytes. + pass + elif isinstance(content, Iterable): # pragma: no branch + item: Any + for item in content: # type: ignore[reportUnknownMemberType] + if isinstance(item, BaseEvent): # pragma: no branch + self._logger.debug('ag-ui event: %s', item) + yield item + + async def _agent_stream( + self, + run: AgentRun[AgentDepsT, Any], + history: _History, + ) -> AsyncGenerator[BaseEvent, None]: + """Run the agent streaming responses using AG-UI protocol events. + + Args: + run: The agent run to process. + history: The history of messages and tool calls to use for the run. + + Yields: + AG-UI Server-Sent Events (SSE). + """ + node: AgentNode[AgentDepsT, Any] | End[FinalResult[Any]] + msg: BaseEvent + async for node in run: + self._logger.debug('processing node=%r', node) + if isinstance(node, CallToolsNode): + # Handle tool results. + async with node.stream(run.ctx) as handle_stream: + async for event in handle_stream: + if isinstance(event, FunctionToolResultEvent) and isinstance(event.result, ToolReturnPart): + async for msg in self._tool_result_event(event.result, history.prompt_message_id): + yield msg + elif isinstance(node, ModelRequestNode): + # Handle model requests. + stream_ctx: _RequestStreamContext = _RequestStreamContext() + request_stream: AgentStream[AgentDepsT] + async with node.stream(run.ctx) as request_stream: + agent_event: AgentStreamEvent + async for agent_event in request_stream: + async for msg in self._agent_event(stream_ctx, agent_event): + yield msg + + if stream_ctx.part_end: # pragma: no branch + yield stream_ctx.part_end + stream_ctx.part_end = None + + async def _agent_event( + self, + stream_ctx: _RequestStreamContext, + agent_event: AgentStreamEvent, + ) -> AsyncGenerator[BaseEvent, None]: + """Handle an agent event and yield AG-UI protocol events. + + Args: + stream_ctx: The request stream context to manage state. + agent_event: The agent event to process. + + Yields: + AG-UI Server-Sent Events (SSE) based on the agent event. + """ + self._logger.debug('agent_event: %s', agent_event) + if isinstance(agent_event, PartStartEvent): + if stream_ctx.part_end: + # End the previous part. + yield stream_ctx.part_end + stream_ctx.part_end = None + + part: ModelResponsePart = agent_event.part + if isinstance(part, TextPart): + message_id: str = stream_ctx.new_message_id() + yield TextMessageStartEvent( + type=EventType.TEXT_MESSAGE_START, + message_id=message_id, + role=Role.ASSISTANT.value, + ) + stream_ctx.part_end = TextMessageEndEvent( + type=EventType.TEXT_MESSAGE_END, + message_id=message_id, + ) + if part.content: + yield TextMessageContentEvent( # pragma: no cover + type=EventType.TEXT_MESSAGE_CONTENT, + message_id=message_id, + delta=part.content, + ) + elif isinstance(part, ToolCallPart): # pragma: no branch + stream_ctx.last_tool_call_id = part.tool_call_id + yield ToolCallStartEvent( + type=EventType.TOOL_CALL_START, + tool_call_id=part.tool_call_id, + tool_call_name=part.tool_name, + ) + stream_ctx.part_end = ToolCallEndEvent( + type=EventType.TOOL_CALL_END, + tool_call_id=part.tool_call_id, + ) + + elif isinstance(part, ThinkingPart): # pragma: no branch + yield ThinkingTextMessageStartEvent( + type=EventType.THINKING_TEXT_MESSAGE_START, + ) + if part.content: # pragma: no branch + yield ThinkingTextMessageContentEvent( + type=EventType.THINKING_TEXT_MESSAGE_CONTENT, + delta=part.content, + ) + stream_ctx.part_end = ThinkingTextMessageEndEvent( + type=EventType.THINKING_TEXT_MESSAGE_END, + ) + + elif isinstance(agent_event, PartDeltaEvent): + if isinstance(agent_event.delta, TextPartDelta): + yield TextMessageContentEvent( + type=EventType.TEXT_MESSAGE_CONTENT, + message_id=stream_ctx.message_id, + delta=agent_event.delta.content_delta, + ) + elif isinstance(agent_event.delta, ToolCallPartDelta): # pragma: no branch + yield ToolCallArgsEvent( + type=EventType.TOOL_CALL_ARGS, + tool_call_id=agent_event.delta.tool_call_id + or stream_ctx.last_tool_call_id + or 'unknown', # Should never be unknown, but just in case. + delta=agent_event.delta.args_delta + if isinstance(agent_event.delta.args_delta, str) + else json.dumps(agent_event.delta.args_delta), + ) + elif isinstance(agent_event.delta, ThinkingPartDelta): # pragma: no cover + yield ThinkingTextMessageContentEvent( + type=EventType.THINKING_TEXT_MESSAGE_CONTENT, + delta=agent_event.delta.content_delta or '', + ) + elif isinstance(agent_event, FinalResultEvent): + # No equivalent AG-UI event yet. + pass + + +@dataclass +class _History: + """A simple history representation for AG-UI protocol.""" + + prompt_message_id: str # The ID of the last user message. + messages: list[ModelMessage] + + +def _convert_history(messages: list[Message]) -> _History: + """Convert a AG-UI history to a PydanticAI one. + + Args: + messages: List of AG-UI messages to convert. + + Returns: + List of PydanticAI model messages. + """ + msg: Message + prompt_message_id: str = '' + result: list[ModelMessage] = [] + tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping. + for msg in messages: + if isinstance(msg, UserMessage): + prompt_message_id = msg.id + result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)])) + elif isinstance(msg, AssistantMessage): + if msg.tool_calls: + for tool_call in msg.tool_calls: + tool_calls[tool_call.id] = tool_call.function.name + + result.append( + ModelResponse( + parts=[ + ToolCallPart( + tool_name=tool_call.function.name, + tool_call_id=tool_call.id, + args=tool_call.function.arguments, + ) + for tool_call in msg.tool_calls + ] + ) + ) + + if msg.content: + result.append(ModelResponse(parts=[TextPart(content=msg.content)])) + elif isinstance(msg, SystemMessage): + result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)])) + elif isinstance(msg, ToolMessage): + result.append( + ModelRequest( + parts=[ + ToolReturnPart( + tool_name=tool_calls.get(msg.tool_call_id, 'unknown'), + content=msg.content, + tool_call_id=msg.tool_call_id, + ) + ] + ) + ) + elif isinstance(msg, DeveloperMessage): # pragma: no branch + result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)])) + + return _History( + prompt_message_id=prompt_message_id, + messages=result, + ) + + +__all__ = [ + 'Adapter', + 'SSE_CONTENT_TYPE', + 'StateDeps', + 'StateHandler', + 'FastAGUI', + 'agent_to_ag_ui', +] + + +# Enums. +# TODO(steve): Remove this and all uses once https://github.com/ag-ui-protocol/ag-ui/pull/49 is merged. +class Role(str, Enum): + """Enum for message roles in AG-UI protocol.""" + + ASSISTANT = 'assistant' + USER = 'user' + DEVELOPER = 'developer' + SYSTEM = 'system' + TOOL = 'tool' + + +# Exceptions. +@dataclass +class _RunError(Exception): + """Exception raised for errors during agent runs.""" + + message: str + code: str + + def __str__(self) -> str: + return self.message + + +@dataclass +class _NoMessagesError(_RunError): + """Exception raised when no messages are found in the input.""" + + message: str = 'no messages found in the input' + code: str = 'no_messages' + + +@dataclass +class StateNotSetError(_RunError, AttributeError): + """Exception raised when the state has not been set.""" + + message: str = 'state is not set' + code: str = 'state_not_set' + + +@dataclass +class InvalidStateError(_RunError, ValidationError): + """Exception raised when an invalid state is provided.""" + + message: str = 'invalid state provided' + code: str = 'invalid_state' + + +# Protocols. +@runtime_checkable +class StateHandler(Protocol): + """Protocol for state handlers in agent runs.""" + + @property + def state(self) -> State: + """Get the current state of the agent run.""" + ... + + @state.setter + def state(self, state: State) -> None: + """Set the state of the agent run. + + This method is called to update the state of the agent run with the + provided state. + + Args: + state: The run state. + + Raises: + InvalidStateError: If `state` does not match the expected model. + """ + ... + + +StateT = TypeVar('StateT', bound=BaseModel) +"""Type variable for the state type, which must be a subclass of `BaseModel`.""" + + +class StateDeps(Generic[StateT]): + """Provides AG-UI state management. + + This class is used to manage the state of an agent run. It allows setting + the state of the agent run with a specific type of state model, which must + be a subclass of `BaseModel`. + + The state is set using the `state` setter by the `Adapter` when the run starts. + + Implements the `StateHandler` protocol. + """ + + def __init__(self, default: StateT) -> None: + """Initialize the state with the provided state type.""" + self._state = default + + @property + def state(self) -> StateT: + """Get the current state of the agent run. + + Returns: + The current run state. + """ + return self._state + + @state.setter + def state(self, state: State) -> None: + """Set the state of the agent run. + + This method is called to update the state of the agent run with the + provided state. + + Implements the `StateHandler` protocol. + + Args: + state: The run state, which must be `None` or model validate for the state type. + + Raises: + InvalidStateError: If `state` does not validate. + """ + if state is None: + # If state is None, we keep the current state, which will be the default state. + return + + try: + self._state = type(self._state).model_validate(state) + except ValidationError as e: # pragma: no cover + raise InvalidStateError from e + + +@dataclass(repr=False) +class _RequestStreamContext: + """Data class to hold request stream context.""" + + message_id: str = '' + last_tool_call_id: str | None = None + part_end: BaseEvent | None = None + + def new_message_id(self) -> str: + """Generate a new message ID for the request stream. + + Assigns a new UUID to the `message_id` and returns it. + + Returns: + A new message ID. + """ + self.message_id = str(uuid.uuid4()) + return self.message_id diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index c46699fd0..c8a00d619 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -5,7 +5,7 @@ import json import warnings from asyncio import Lock -from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar from copy import deepcopy @@ -54,6 +54,7 @@ from .toolsets.combined import CombinedToolset from .toolsets.function import FunctionToolset from .toolsets.prepared import PreparedToolset +from .usage import Usage, UsageLimits # Re-exporting like this improves auto-import behavior in PyCharm capture_run_messages = _agent_graph.capture_run_messages @@ -64,7 +65,7 @@ if TYPE_CHECKING: from starlette.middleware import Middleware - from starlette.routing import Route + from starlette.routing import BaseRoute, Route from starlette.types import ExceptionHandler, Lifespan from fasta2a.applications import FastA2A @@ -73,6 +74,7 @@ from fasta2a.storage import Storage from pydantic_ai.mcp import MCPServer + from .ag_ui import FastAGUI __all__ = ( 'Agent', @@ -1853,6 +1855,110 @@ async def run_mcp_servers( async with self: yield + def to_ag_ui( + self, + *, + # Agent.iter parameters + output_type: OutputSpec[OutputDataT] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + # Starlette + debug: bool = False, + routes: Sequence[BaseRoute] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: Mapping[Any, ExceptionHandler] | None = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, + lifespan: Lifespan[FastAGUI[AgentDepsT, OutputDataT]] | None = None, + ) -> FastAGUI[AgentDepsT, OutputDataT]: + """Convert the agent to an Adapter instance. + + This allows you to use the agent with a compatible AG-UI frontend. + + The first two arguments are specific to `Adapter` the rest map directly to the `Agent.iter` method. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + app = agent.to_ag_ui() + ``` + + The `app` is an ASGI application that can be used with any ASGI server. + + To run the application, you can use the following command: + + ```bash + uvicorn app:app --host 0.0.0.0 --port 8000 + ``` + Args: + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has + no output validators since output validators would expect an argument that matches the agent's + output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + + debug: Boolean indicating if debug tracebacks should be returned on errors. + routes: A list of routes to serve incoming HTTP and WebSocket requests. + middleware: A list of middleware to run for every request. A starlette application will always + automatically include two middleware classes. `ServerErrorMiddleware` is added as the very + outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled + exception cases occurring in the routing or endpoints. + exception_handlers: A mapping of either integer status codes, or exception class types onto + callables which handle the exceptions. Exception handler callables should be of the form + `handler(request, exc) -> response` and may be either standard functions, or async functions. + on_startup: A list of callables to run on application startup. Startup handler callables do not + take any arguments, and may be either standard functions, or async functions. + on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do + not take any arguments, and may be either standard functions, or async functions. + lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. + This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or + the other, not both. + + Returns: + An adapter that converts between AG-UI protocol and PydanticAI. + """ + try: + from .ag_ui import agent_to_ag_ui + except ImportError as e: # pragma: no cover + raise ImportError( + 'Please install the `ag-ui-protocol` and `starlette` packages to use `Agent.to_ag_ui()` method, ' + 'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`' + ) from e + + return agent_to_ag_ui( + agent=self, + # Agent.iter parameters + output_type=output_type, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + # Starlette + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + on_startup=on_startup, + on_shutdown=on_shutdown, + lifespan=lifespan, + ) + def to_a2a( self, *, diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index a8de70274..0fd15df07 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -423,7 +423,7 @@ def _get_python_type(cls, value: dict[str, Any]) -> str: if value_type == 'object': additional_properties = value.get('additionalProperties', {}) if isinstance(additional_properties, bool): - return 'bool' # pragma: no cover + return 'bool' additional_properties_type = additional_properties.get('type') if ( additional_properties_type in SIMPLE_JSON_TYPE_MAPPING diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 87a0c79c0..4d69841df 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -6,10 +6,10 @@ from contextlib import asynccontextmanager from dataclasses import InitVar, dataclass, field from datetime import date, datetime, timedelta -from typing import Any, Literal +from typing import Any, Literal, Union import pydantic_core -from typing_extensions import assert_never +from typing_extensions import TypeAlias, assert_never from .. import _utils from ..messages import ( @@ -45,6 +45,55 @@ class _WrappedToolOutput: value: Any | None +@dataclass +class TestToolCallPart: + """Represents a tool call in the test model.""" + + # NOTE: Avoid test discovery by pytest. + __test__ = False + + call_tools: list[str] | Literal['all'] = 'all' + deltas: bool = False + + +@dataclass +class TestTextPart: + """Represents a text part in the test model.""" + + # NOTE: Avoid test discovery by pytest. + __test__ = False + + text: str + + +@dataclass +class TestThinkingPart: + """Represents a thinking part in the test model. + + This is used to simulate the model thinking about the response. + """ + + # NOTE: Avoid test discovery by pytest. + __test__ = False + + content: str = 'Thinking...' + + +TestPart: TypeAlias = Union[TestTextPart, TestToolCallPart, TestThinkingPart] +"""A part of the test model response.""" + + +@dataclass +class TestNode: + """A node in the test model.""" + + # NOTE: Avoid test discovery by pytest. + __test__ = False + + parts: list[TestPart] + id: str = field(default_factory=_utils.generate_tool_call_id) + + @dataclass class TestModel(Model): """A model specifically for testing purposes. @@ -63,6 +112,10 @@ class TestModel(Model): call_tools: list[str] | Literal['all'] = 'all' """List of tools to call. If `'all'`, all tools will be called.""" + tool_call_deltas: set[str] = field(default_factory=set) + """A set of tool call names which should result in tool call part deltas.""" + custom_response_nodes: list[TestNode] | None = None + """A list of nodes which defines a custom model response.""" custom_output_text: str | None = None """If set, this text is returned as the final output.""" custom_output_args: Any | None = None @@ -102,7 +155,10 @@ async def request_stream( model_response = self._request(messages, model_settings, model_request_parameters) yield TestStreamedResponse( - _model_name=self._model_name, _structured_response=model_response, _messages=messages + _model_name=self._model_name, + _structured_response=model_response, + _messages=messages, + _tool_call_deltas=self.tool_call_deltas, ) @property @@ -141,14 +197,65 @@ def _get_output(self, model_request_parameters: ModelRequestParameters) -> _Wrap if k := output_tool.outer_typed_dict_key: return _WrappedToolOutput({k: self.custom_output_args}) - else: - return _WrappedToolOutput(self.custom_output_args) + + return _WrappedToolOutput(self.custom_output_args) elif model_request_parameters.allow_text_output: return _WrappedTextOutput(None) - elif model_request_parameters.output_tools: + elif model_request_parameters.output_tools: # pragma: no branch return _WrappedToolOutput(None) else: - return _WrappedTextOutput(None) + return _WrappedTextOutput(None) # pragma: no cover + + def _node_response( + self, + messages: list[ModelMessage], + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse | None: + """Returns a ModelResponse based on configured nodes. + + Args: + messages: The messages sent to the model. + model_request_parameters: The parameters for the model request. + + Returns: + The response from the model, or `None` if no nodes configured or + all nodes have already been processed. + """ + if not self.custom_response_nodes: + # No nodes configured, follow the default behaviour. + return None + + # Pick up where we left off by counting the number of ModelResponse messages in the stream. + # This allows us to stream the response in chunks, simulating a real model response. + node: TestNode + count: int = sum(isinstance(m, ModelResponse) for m in messages) + if count < len(self.custom_response_nodes): + node: TestNode = self.custom_response_nodes[count] + assert node.parts, 'Node parts should not be empty.' + + parts: list[ModelResponsePart] = [] + part: TestPart + for part in node.parts: + if isinstance(part, TestTextPart): # pragma: no branch + assert model_request_parameters.allow_text_output, ( # pragma: no cover + 'Plain response not allowed, but `part` is a `TestText`.' + ) + parts.append(TextPart(part.text)) # pragma: no cover + elif isinstance(part, TestToolCallPart): # pragma: no branch + tool_calls = self._get_tool_calls(model_request_parameters) + if part.call_tools == 'all': # pragma: no branch + parts.extend( + ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls + ) # pragma: no cover + else: + parts.extend( + ToolCallPart(name, self.gen_tool_args(args)) + for name, args in tool_calls + if name in part.call_tools + ) + elif isinstance(part, TestThinkingPart): # pragma: no branch + parts.append(ThinkingPart(content=part.content)) + return ModelResponse(vendor_id=node.id, parts=parts, model_name=self._model_name) def _request( self, @@ -156,17 +263,18 @@ def _request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: - tool_calls = self._get_tool_calls(model_request_parameters) - output_wrapper = self._get_output(model_request_parameters) - output_tools = model_request_parameters.output_tools + if (response := self._node_response(messages, model_request_parameters)) is not None: + return response - # if there are tools, the first thing we want to do is call all of them + tool_calls = self._get_tool_calls(model_request_parameters) if tool_calls and not any(isinstance(m, ModelResponse) for m in messages): return ModelResponse( parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls], model_name=self._model_name, ) + output_wrapper = self._get_output(model_request_parameters) + output_tools = model_request_parameters.output_tools if messages: # pragma: no branch last_message = messages[-1] assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.' @@ -232,6 +340,7 @@ class TestStreamedResponse(StreamedResponse): _model_name: str _structured_response: ModelResponse _messages: InitVar[Iterable[ModelMessage]] + _tool_call_deltas: set[str] _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) def __post_init__(self, _messages: Iterable[ModelMessage]): @@ -253,12 +362,33 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: self._usage += _get_string_usage(word) yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) elif isinstance(part, ToolCallPart): - yield self._parts_manager.handle_tool_call_part( - vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id - ) - elif isinstance(part, ThinkingPart): # pragma: no cover - # NOTE: There's no way to reach this part of the code, since we don't generate ThinkingPart on TestModel. - assert False, "This should be unreachable — we don't generate ThinkingPart on TestModel." + if part.tool_name in self._tool_call_deltas: + # Start with empty tool call delta. + event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=i, tool_name=part.tool_name, args='', tool_call_id=part.tool_call_id + ) + if event is not None: # pragma: no branch + yield event + + # Stream the args as JSON string in chunks. + args_json = pydantic_core.to_json(part.args).decode() + *chunks, last_chunk = args_json.split(',') if ',' in args_json else [args_json] + chunks = [f'{chunk},' for chunk in chunks] if chunks else [] + if last_chunk: # pragma: no branch + chunks.append(last_chunk) + + for chunk in chunks: + event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=i, tool_name=None, args=chunk, tool_call_id=part.tool_call_id + ) + if event is not None: # pragma: no branch + yield event + else: + yield self._parts_manager.handle_tool_call_part( + vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id + ) + elif isinstance(part, ThinkingPart): + yield self._parts_manager.handle_thinking_delta(vendor_part_id=i, content=part.content) else: assert_never(part) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 99706c2c6..b82198289 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -80,6 +80,8 @@ mcp = ["mcp>=1.9.4; python_version >= '3.10'"] evals = ["pydantic-evals=={{ version }}"] # A2A a2a = ["fasta2a=={{ version }}"] +# AG UI Adapter +ag-ui = ["ag-ui-protocol>=0.1.7", "starlette>=0.45.3"] [dependency-groups] dev = [ diff --git a/pyproject.toml b/pyproject.toml index ea96212e3..ad3fc4fed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,cli,mcp,evals,a2a]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,cli,mcp,evals,a2a,ag-ui]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] @@ -198,6 +198,8 @@ filterwarnings = [ "error", # Issue with python-multipart - we don't want to bump the minimum version of starlette. "ignore::PendingDeprecationWarning:starlette", + # mistralai accesses model_fields on the instance, which is deprecated in Pydantic 2.11. + "ignore:Accessing the 'model_fields' attribute", # boto3 "ignore::DeprecationWarning:botocore.*", "ignore::RuntimeWarning:pydantic_ai.mcp", @@ -226,6 +228,15 @@ omit = [ "pydantic_ai_slim/pydantic_ai/ext/aci.py", # aci-sdk requires Python 3.10+ so cannot be added as an (optional) dependency ] branch = true +disable_warnings = ["include-ignored"] + +[tool.coverage.paths] +# Allow CI run assets to be downloaded an replicated locally. +source = [ + ".", + "/home/runner/work/pydantic-ai/pydantic-ai", + "/System/Volumes/Data/home/runner/work/pydantic-ai/pydantic-ai" +] # https://coverage.readthedocs.io/en/latest/config.html#report [tool.coverage.report] diff --git a/tests/conftest.py b/tests/conftest.py index ce95301d3..475fd4f23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import asyncio import importlib.util +import logging import os import re import secrets @@ -28,6 +29,11 @@ __all__ = 'IsDatetime', 'IsFloat', 'IsNow', 'IsStr', 'IsInt', 'IsInstance', 'TestEnv', 'ClientWithHandler', 'try_import' +# Configure VCR logger to WARNING as it is too verbose by default +# specifically, it logs every request and response including binary +# content in Cassette.append, which is causing log downloads from +# GitHub action to fail. +logging.getLogger('vcr.cassette').setLevel(logging.WARNING) pydantic_ai.models.ALLOW_MODEL_REQUESTS = False diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py new file mode 100644 index 000000000..f58886e0f --- /dev/null +++ b/tests/test_ag_ui.py @@ -0,0 +1,896 @@ +"""Tests for AG-UI implementation.""" + +# pyright: reportPossiblyUnboundVariable=none +from __future__ import annotations + +import asyncio +import contextlib +import re +import uuid +from collections.abc import Callable +from dataclasses import dataclass, field +from itertools import count +from typing import Any, Final, Literal + +import httpx +import pytest +from asgi_lifespan import LifespanManager +from pydantic import BaseModel + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel, TestNode, TestThinkingPart, TestToolCallPart + +has_ag_ui: bool = False +with contextlib.suppress(ImportError): + from ag_ui.core import ( + AssistantMessage, + CustomEvent, + DeveloperMessage, + EventType, + FunctionCall, + Message, + RunAgentInput, + StateSnapshotEvent, + SystemMessage, + Tool, + ToolCall, + ToolMessage, + UserMessage, + ) + + from pydantic_ai.ag_ui import ( + SSE_CONTENT_TYPE, + Adapter, + Role, + StateDeps, + ) + + has_ag_ui = True + + +pytestmark = [ + pytest.mark.anyio, + pytest.mark.skipif(not has_ag_ui, reason='ag-ui-protocol not installed'), +] + + +# Type aliases. +_MockUUID = Callable[[], str] + +# Constants. +THREAD_ID_PREFIX: Final[str] = 'thread_' +RUN_ID_PREFIX: Final[str] = 'run_' +EXPECTED_EVENTS: Final[list[str]] = [ + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000003","role":"assistant"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000003","delta":"success "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000003","delta":"(no "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000003","delta":"tool "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000003","delta":"calls)"}', + '{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000003"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', +] +UUID_PATTERN: Final[re.Pattern[str]] = re.compile(r'\d{8}-\d{4}-\d{4}-\d{4}-\d{12}') + + +class StateInt(BaseModel): + """Example state class for testing purposes.""" + + value: int = 0 + + +def get_weather(name: str = 'get_weather') -> Tool: + return Tool( + name=name, + description='Get the weather for a given location', + parameters={ + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'The location to get the weather for', + }, + }, + 'required': ['location'], + }, + ) + + +@pytest.fixture +async def adapter() -> Adapter[StateDeps[StateInt], str]: + """Fixture to create an Adapter instance for testing. + + Returns: + An Adapter instance configured for testing. + """ + return await create_adapter([]) + + +async def create_adapter( + call_tools: list[str] | Literal['all'], +) -> Adapter[StateDeps[StateInt], str]: + """Create an Adapter instance for testing. + + Args: + call_tools: List of tool names to enable, or 'all' for all tools. + + Returns: + An Adapter instance configured with the specified tools. + """ + return Adapter( + agent=Agent( + model=TestModel( + call_tools=call_tools, + tool_call_deltas={'get_weather_parts', 'current_time'}, + ), + deps_type=StateDeps[StateInt], # type: ignore[reportUnknownArgumentType] + tools=[send_snapshot, send_custom, current_time], + ), + ) + + +@pytest.fixture +def mock_uuid(monkeypatch: pytest.MonkeyPatch) -> _MockUUID: + """Mock UUID generation for consistent test results. + + This fixture replaces the uuid.uuid4 function with a mock that generates + sequential UUIDs for testing purposes. This ensures that UUIDs are + predictable and consistent across test runs. + + Args: + monkeypatch: The pytest monkeypatch fixture to modify uuid.uuid4. + + Returns: + A function that generates a mock UUID. + """ + counter = count(1) + + def _fake_uuid() -> str: + """Generate a fake UUID string with sequential numbering. + + Returns: + A fake UUID string in the format '00000000-0000-0000-0000-{counter:012d}'. + """ + return f'00000000-0000-0000-0000-{next(counter):012d}' + + def _fake_uuid4() -> uuid.UUID: + """Generate a fake UUID object using the fake UUID string. + + Returns: + A UUID object created from the fake UUID string. + """ + return uuid.UUID(_fake_uuid()) + + # Due to how ToolCallPart uses generate_tool_call_id with field default_factory, + # we have to patch uuid.uuid4 directly instead of the generate function. This + # also covers how we generate messages IDs. + monkeypatch.setattr('uuid.uuid4', _fake_uuid4) + + return _fake_uuid + + +def assert_events(events: list[str], expected_events: list[str], *, loose: bool = False) -> None: + expected: str + event: str + for event, expected in zip(events, expected_events): + if loose: + expected = normalize_uuids(expected) + event = normalize_uuids(event) + assert event == f'data: {expected}\n\n' + assert len(events) == len(expected_events) + + +def normalize_uuids(text: str) -> str: + """Normalize UUIDs in the given text to a fixed format. + + Args: + text: The input text containing UUIDs. + + Returns: + The text with UUIDs replaced by a fixed UUID. + """ + return UUID_PATTERN.sub('00000000-0000-0000-0000-000000000001', text) + + +def current_time() -> str: + """Get the current time in ISO format. + + Returns: + The current UTC time in ISO format string. + """ + return '21T12:08:45.485981+00:00' + + +async def send_snapshot() -> StateSnapshotEvent: + """Display the recipe to the user. + + Returns: + StateSnapshotEvent. + """ + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot={'key': 'value'}, + ) + + +async def send_custom() -> list[CustomEvent]: + """Display the recipe to the user. + + Returns: + StateSnapshotEvent. + """ + return [ + CustomEvent( + type=EventType.CUSTOM, + name='custom_event1', + value={'key1': 'value1'}, + ), + CustomEvent( + type=EventType.CUSTOM, + name='custom_event2', + value={'key2': 'value2'}, + ), + ] + + +@dataclass(frozen=True) +class Run: + """Test parameter class for Adapter.run method tests. + + Args: + messages: List of messages for the run input. + state: State object for the run input. + context: Context list for the run input. + tools: List of tools for the run input. + forwarded_props: Forwarded properties for the run input. + nodes: List of TestNode instances for the run input. + """ + + messages: list[Message] + state: Any = None + context: list[Any] = field(default_factory=lambda: list[Any]()) + tools: list[Tool] = field(default_factory=lambda: list[Tool]()) + nodes: list[TestNode] | None = None + forwarded_props: Any = None + + def run_input(self, *, thread_id: str, run_id: str) -> RunAgentInput: + """Create a RunAgentInput instance for the test case. + + Args: + thread_id: The thread ID for the run. + run_id: The run ID for the run. + + Returns: + A RunAgentInput instance with the test case parameters. + """ + return RunAgentInput( + thread_id=thread_id, + run_id=run_id, + messages=self.messages, + state=self.state, + context=self.context, + tools=self.tools, + forwarded_props=self.forwarded_props, + ) + + +@dataclass(frozen=True) +class AdapterRunTest: + """Test parameter class for Adapter.run method tests. + + Args: + id: Name of the test case. + runs: List of Run instances for the test case. + """ + + id: str + runs: list[Run] + call_tools: list[str] = field(default_factory=lambda: list[str]()) + expected_events: list[str] = field(default_factory=lambda: list(EXPECTED_EVENTS)) + expected_state: int | None = None + + +# Test parameter data +def tc_parameters() -> list[AdapterRunTest]: + if not has_ag_ui: # pragma: no branch + return [AdapterRunTest(id='skipped', runs=[])] + + return [ + AdapterRunTest( + id='basic_user_message', + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Hello, how are you?', + ), + ], + ), + ], + ), + AdapterRunTest( + id='empty_messages', + runs=[ + Run(messages=[]), + ], + expected_events=[ + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"RUN_ERROR","message":"no messages found in the input","code":"no_messages"}', + ], + ), + AdapterRunTest( + id='multiple_messages', + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='First message', + ), + AssistantMessage( + id='msg_2', + role=Role.ASSISTANT.value, + content='Assistant response', + ), + SystemMessage( + id='msg_3', + role=Role.SYSTEM.value, + content='System message', + ), + DeveloperMessage( + id='msg_4', + role=Role.DEVELOPER.value, + content='Developer note', + ), + UserMessage( + id='msg_5', + role=Role.USER.value, + content='Second message', + ), + ], + ), + ], + ), + AdapterRunTest( + id='messages_with_history', + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='First message', + ), + UserMessage( + id='msg_2', + role=Role.USER.value, + content='Second message', + ), + ], + ), + ], + ), + AdapterRunTest( + id='tool_ag_ui', + call_tools=['get_weather'], + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call get_weather for Paris', + ), + ], + tools=[get_weather()], + ), + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call get_weather for Paris', + ), + AssistantMessage( + id='msg_2', + role=Role.ASSISTANT.value, + tool_calls=[ + ToolCall( + id='pyd_ai_00000000000000000000000000000003', + type='function', + function=FunctionCall( + name='get_weather', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_3', + role=Role.TOOL.value, + content='Tool result', + tool_call_id='pyd_ai_00000000000000000000000000000003', + ), + ], + tools=[get_weather()], + ), + ], + expected_events=[ + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000003","toolCallName":"get_weather"}', + '{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000003"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000004"}', + '{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000005","role":"assistant"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000005","delta":"{\\"get_weather\\":\\"Tool "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000005","delta":"result\\"}"}', + '{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000005"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000004"}', + ], + ), + AdapterRunTest( + id='tool_ag_ui_multiple', + call_tools=['get_weather', 'get_weather_parts'], + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call get_weather and get_weather_parts for Paris', + ), + ], + tools=[get_weather(), get_weather('get_weather_parts')], + ), + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call get_weather for Paris', + ), + AssistantMessage( + id='msg_2', + role=Role.ASSISTANT.value, + tool_calls=[ + ToolCall( + id='pyd_ai_00000000000000000000000000000003', + type='function', + function=FunctionCall( + name='get_weather', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_3', + role=Role.TOOL.value, + content='Tool result', + tool_call_id='pyd_ai_00000000000000000000000000000003', + ), + AssistantMessage( + id='msg_4', + role=Role.ASSISTANT.value, + tool_calls=[ + ToolCall( + id='pyd_ai_00000000000000000000000000000003', + type='function', + function=FunctionCall( + name='get_weather_parts', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_5', + role=Role.TOOL.value, + content='Tool result', + tool_call_id='pyd_ai_00000000000000000000000000000003', + ), + ], + tools=[get_weather(), get_weather('get_weather_parts')], + ), + ], + expected_events=[ + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000003","toolCallName":"get_weather"}', + '{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000003"}', + '{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000004","toolCallName":"get_weather_parts"}', + '{"type":"TOOL_CALL_ARGS","toolCallId":"pyd_ai_00000000000000000000000000000004","delta":"{\\"location\\":\\"a\\"}"}', + '{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000004"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000005"}', + '{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000006","role":"assistant"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000006","delta":"{\\"get_weather\\":\\"Tool "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000006","delta":"result\\",\\"get_weather_parts\\":\\"Tool "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000006","delta":"result\\"}"}', + '{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000006"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000005"}', + ], + ), + AdapterRunTest( + id='tool_ag_ui_parts', + call_tools=['get_weather_parts'], + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call get_weather_parts for Paris', + ), + ], + tools=[get_weather('get_weather_parts')], + ), + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call get_weather_parts for Paris', + ), + AssistantMessage( + id='msg_2', + role=Role.ASSISTANT.value, + tool_calls=[ + ToolCall( + id='pyd_ai_00000000000000000000000000000003', + type='function', + function=FunctionCall( + name='get_weather_parts', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_3', + role=Role.TOOL.value, + content='Tool result', + tool_call_id='pyd_ai_00000000000000000000000000000003', + ), + ], + tools=[get_weather('get_weather_parts')], + ), + ], + expected_events=[ + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000003","toolCallName":"get_weather_parts"}', + '{"type":"TOOL_CALL_ARGS","toolCallId":"pyd_ai_00000000000000000000000000000003","delta":"{\\"location\\":\\"a\\"}"}', + '{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000003"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000004"}', + '{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000005","role":"assistant"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000005","delta":"{\\"get_weather_parts\\":\\"Tool "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000005","delta":"result\\"}"}', + '{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000005"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000004"}', + ], + ), + AdapterRunTest( + id='tool_local_single_event', + call_tools=['send_snapshot'], + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call send_snapshot', + ), + ], + ), + ], + expected_events=[ + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000003","toolCallName":"send_snapshot"}', + '{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000003"}', + '{"type":"TOOL_CALL_RESULT","messageId":"msg_1","toolCallId":"pyd_ai_00000000000000000000000000000003","content":"{\\"type\\":\\"STATE_SNAPSHOT\\",\\"timestamp\\":null,\\"raw_event\\":null,\\"snapshot\\":{\\"key\\":\\"value\\"}}","role":"tool"}', + '{"type":"STATE_SNAPSHOT","snapshot":{"key":"value"}}', + '{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000004","role":"assistant"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000004","delta":"{\\"send_snapshot\\":{\\"type\\":\\"STATE_SNAPSHOT\\",\\"timestam"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000004","delta":"p\\":null,\\"rawEvent\\":null,\\"snapshot\\":{\\"key\\":\\"value\\"}}}"}', + '{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000004"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + ], + ), + AdapterRunTest( + id='tool_local_multiple_events', + call_tools=['send_custom'], + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call send_custom', + ), + ], + ), + ], + expected_events=[ + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000003","toolCallName":"send_custom"}', + '{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000003"}', + '{"type":"TOOL_CALL_RESULT","messageId":"msg_1","toolCallId":"pyd_ai_00000000000000000000000000000003","content":"[{\\"type\\":\\"CUSTOM\\",\\"timestamp\\":null,\\"raw_event\\":null,\\"name\\":\\"custom_event1\\",\\"value\\":{\\"key1\\":\\"value1\\"}},{\\"type\\":\\"CUSTOM\\",\\"timestamp\\":null,\\"raw_event\\":null,\\"name\\":\\"custom_event2\\",\\"value\\":{\\"key2\\":\\"value2\\"}}]","role":"tool"}', + '{"type":"CUSTOM","name":"custom_event1","value":{"key1":"value1"}}', + '{"type":"CUSTOM","name":"custom_event2","value":{"key2":"value2"}}', + '{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000004","role":"assistant"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000004","delta":"{\\"send_custom\\":[{\\"type\\":\\"CUSTOM\\",\\"timestamp\\":null,\\"rawEvent\\":null,\\"name\\":\\"custom_event1\\",\\"value\\":{\\"key1\\":\\"va"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000004","delta":"lue1\\"}},{\\"type\\":\\"CUSTOM\\",\\"timestamp\\":null,\\"rawEvent\\":null,\\"name\\":\\"custom_event2\\",\\"value\\":{\\"key2\\":\\"value2\\"}}]}"}', + '{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000004"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + ], + ), + AdapterRunTest( + id='tool_local_parts', + call_tools=['current_time'], + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call current_time', + ), + ], + ), + ], + expected_events=[ + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000003","toolCallName":"current_time"}', + '{"type":"TOOL_CALL_ARGS","toolCallId":"pyd_ai_00000000000000000000000000000003","delta":"{}"}', + '{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000003"}', + '{"type":"TOOL_CALL_RESULT","messageId":"msg_1","toolCallId":"pyd_ai_00000000000000000000000000000003","content":"21T12:08:45.485981+00:00","role":"tool"}', + '{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000004","role":"assistant"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000004","delta":"{\\"current_time\\":\\"21T1"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000004","delta":"2:08:45.485981+00:00\\"}"}', + '{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000004"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + ], + ), + AdapterRunTest( + id='tool_local_then_ag_ui', + call_tools=['current_time', 'get_weather'], + runs=[ + Run( + nodes=[ + TestNode( + parts=[ + TestToolCallPart(call_tools=['current_time']), + TestThinkingPart(content='Thinking about the weather'), + ], + ), + TestNode( + parts=[TestToolCallPart(call_tools=['get_weather'])], + ), + ], + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please tell me the time and then call get_weather for Paris', + ), + ], + tools=[get_weather()], + ), + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Please call get_weather for Paris', + ), + AssistantMessage( + id='msg_2', + role=Role.ASSISTANT.value, + tool_calls=[ + ToolCall( + id='pyd_ai_00000000000000000000000000000003', + type='function', + function=FunctionCall( + name='current_time', + arguments='{}', + ), + ), + ], + ), + ToolMessage( + id='msg_3', + role=Role.TOOL.value, + content='Tool result', + tool_call_id='pyd_ai_00000000000000000000000000000003', + ), + AssistantMessage( + id='msg_4', + role=Role.ASSISTANT.value, + tool_calls=[ + ToolCall( + id='pyd_ai_00000000000000000000000000000004', + type='function', + function=FunctionCall( + name='get_weather', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_5', + role=Role.TOOL.value, + content='Tool result', + tool_call_id='pyd_ai_00000000000000000000000000000004', + ), + ], + tools=[get_weather()], + ), + ], + expected_events=[ + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000003","toolCallName":"current_time"}', + '{"type":"TOOL_CALL_ARGS","toolCallId":"pyd_ai_00000000000000000000000000000003","delta":"{}"}', + '{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000003"}', + '{"type":"THINKING_TEXT_MESSAGE_START"}', + '{"type":"THINKING_TEXT_MESSAGE_CONTENT","delta":"Thinking about the weather"}', + '{"type":"THINKING_TEXT_MESSAGE_END"}', + '{"type":"TOOL_CALL_RESULT","messageId":"msg_1","toolCallId":"pyd_ai_00000000000000000000000000000003","content":"21T12:08:45.485981+00:00","role":"tool"}', + '{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000004","toolCallName":"get_weather"}', + '{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000004"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}', + '{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000005"}', + '{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000006","role":"assistant"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000006","delta":"{\\"current_time\\":\\"Tool "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000006","delta":"result\\",\\"get_weather\\":\\"Tool "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000006","delta":"result\\"}"}', + '{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000006"}', + '{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000005"}', + ], + ), + AdapterRunTest( + id='request_with_state', + runs=[ + Run( + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Hello, how are you?', + ), + ], + state={'value': 42}, + ), + ], + expected_state=42, + ), + ] + + +@pytest.mark.parametrize('tc', tc_parameters(), ids=lambda tc: tc.id) +async def test_run_method(mock_uuid: _MockUUID, tc: AdapterRunTest) -> None: + """Test the Adapter.run method with various scenarios. + + Args: + mock_uuid: The mock UUID generator fixture. + tc: The test case parameters. + """ + + run: Run + events: list[str] = [] + thread_id: str = f'{THREAD_ID_PREFIX}{mock_uuid()}' + adapter: Adapter[StateDeps[StateInt], str] = await create_adapter(tc.call_tools) + deps: StateDeps[StateInt] = StateDeps(StateInt()) + for run in tc.runs: + if run.nodes is not None: + assert isinstance(adapter.agent.model, TestModel), ( + 'Agent model is not TestModel' + 'data: {"type":"TOOL_CALL_RESULT","messageId":"msg_1","toolCallId":"pyd_ai_00000000000000000000000000000003","content":"21T12:08:45.485981+00:00","role":"tool"}\n\n' + ) + adapter.agent.model.custom_response_nodes = run.nodes + + run_input: RunAgentInput = run.run_input( + thread_id=thread_id, + run_id=f'{RUN_ID_PREFIX}{mock_uuid()}', + ) + + events.extend([event async for event in adapter.run(run_input, deps=deps)]) + + assert_events(events, tc.expected_events) + if tc.expected_state is not None: + assert deps.state.value == tc.expected_state + + +async def test_concurrent_runs(mock_uuid: _MockUUID, adapter: Adapter[None, str]) -> None: + """Test concurrent execution of multiple runs.""" + + async def collect_events(run_input: RunAgentInput) -> list[str]: + """Collect all events from an adapter run. + + Args: + run_input: The input configuration for the adapter run. + + Returns: + List of all events generated by the adapter run. + """ + return [event async for event in adapter.run(run_input)] + + concurrent_tasks: list[asyncio.Task[list[str]]] = [] + + for i in range(20): + run_input: RunAgentInput = RunAgentInput( + thread_id=f'{THREAD_ID_PREFIX}{mock_uuid()}', + run_id=f'{RUN_ID_PREFIX}{mock_uuid()}', + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id=f'msg_{i}', + role=Role.USER.value, + content=f'Message {i}', + ), + ], + state=None, + context=[], + tools=[], + forwarded_props=None, + ) + + task = asyncio.create_task(collect_events(run_input)) + concurrent_tasks.append(task) + + results = await asyncio.gather(*concurrent_tasks) + + for events in results: + assert_events(events, EXPECTED_EVENTS, loose=True) + assert len(events) == len(EXPECTED_EVENTS) + + +@pytest.mark.anyio +async def test_to_ag_ui(mock_uuid: _MockUUID) -> None: + """Test the agent.to_ag_ui method.""" + + agent: Agent[None, str] = Agent(model=TestModel()) + app = agent.to_ag_ui() + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as client: + client.base_url = 'http://localhost:8000' + run_input: RunAgentInput = RunAgentInput( + state=None, + thread_id=f'{THREAD_ID_PREFIX}test_thread', + run_id=f'{RUN_ID_PREFIX}test_run', + messages=[ # pyright: ignore[reportArgumentType] + UserMessage( + id='msg_1', + role=Role.USER.value, + content='Hello, world!', + ), + ], + tools=[], + context=[], + forwarded_props=None, + ) + events: list[str] + async with client.stream( + 'POST', + '/', + content=run_input.model_dump_json(), + headers={'Content-Type': 'application/json', 'Accept': SSE_CONTENT_TYPE}, + ) as response: + assert response.status_code == 200, f'Unexpected status code: {response.status_code}' + events = [line + '\n\n' async for line in response.aiter_lines() if line.startswith('data: ')] + + assert events, 'No parts received from the server' + expected: list[str] = [ + '{"type":"RUN_STARTED","threadId":"thread_test_thread","runId":"run_test_run"}', + '{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000001","role":"assistant"}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000001","delta":"success "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000001","delta":"(no "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000001","delta":"tool "}', + '{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000001","delta":"calls)"}', + '{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000001"}', + '{"type":"RUN_FINISHED","threadId":"thread_test_thread","runId":"run_test_run"}', + ] + assert_events(events, expected) diff --git a/tests/test_agent.py b/tests/test_agent.py index fd8f5e538..45bbd1b39 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3606,7 +3606,7 @@ async def only_if_plan_presented( async def test_context_manager(): try: from pydantic_ai.mcp import MCPServerStdio - except ImportError: + except ImportError: # pragma: no cover return server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) @@ -3626,7 +3626,7 @@ async def test_context_manager(): def test_set_mcp_sampling_model(): try: from pydantic_ai.mcp import MCPServerStdio - except ImportError: + except ImportError: # pragma: no cover return test_model = TestModel() diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index ba2ec479c..623f2502f 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -482,7 +482,7 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef async def test_context_manager(): try: from pydantic_ai.mcp import MCPServerStdio - except ImportError: + except ImportError: # pragma: no cover pytest.skip('mcp is not installed') server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) diff --git a/uv.lock b/uv.lock index f715314a2..628954d84 100644 --- a/uv.lock +++ b/uv.lock @@ -28,6 +28,18 @@ members = [ "pydantic-graph", ] +[[package]] +name = "ag-ui-protocol" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/c0/f2d24d92be950dd6b12f66dbde5fb839dd01e8af34d3a0305b2309a68907/ag_ui_protocol-0.1.7.tar.gz", hash = "sha256:0e93fd9f7c74d52afbd824d6e9738bd3422e859503905ba7582481cbc3c67ab2", size = 4446, upload-time = "2025-06-26T09:37:08.895Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cd/c3/c216f5ad4d78f4030a63fec23f00a71f984f10275ccfc7d3902c3c34b7cd/ag_ui_protocol-0.1.7-py3-none-any.whl", hash = "sha256:8c821662ca6e9852569022f449b9f7aeb3f16aa75390fa8c28ceae2cce642baa", size = 6165, upload-time = "2025-06-26T09:37:07.755Z" }, +] + [[package]] name = "aiofiles" version = "23.2.1" @@ -2962,7 +2974,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"] }, ] [package.optional-dependencies] @@ -2996,7 +3008,7 @@ lint = [ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["examples", "logfire"] @@ -3030,7 +3042,7 @@ dependencies = [ { name = "logfire", extra = ["asyncpg", "fastapi", "httpx", "sqlite3"] }, { name = "mcp", extra = ["cli"], marker = "python_full_version >= '3.10'" }, { name = "modal" }, - { name = "pydantic-ai-slim", extra = ["anthropic", "groq", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "groq", "openai", "vertexai"] }, { name = "pydantic-evals" }, { name = "python-multipart" }, { name = "rich" }, @@ -3046,7 +3058,7 @@ requires-dist = [ { name = "logfire", extras = ["asyncpg", "fastapi", "httpx", "sqlite3"], specifier = ">=2.6" }, { name = "mcp", extras = ["cli"], marker = "python_full_version >= '3.10'", specifier = ">=1.4.1" }, { name = "modal", specifier = ">=1.0.4" }, - { name = "pydantic-ai-slim", extras = ["anthropic", "groq", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "groq", "openai", "vertexai"], editable = "pydantic_ai_slim" }, { name = "pydantic-evals", editable = "pydantic_evals" }, { name = "python-multipart", specifier = ">=0.0.17" }, { name = "rich", specifier = ">=13.9.2" }, @@ -3071,6 +3083,10 @@ dependencies = [ a2a = [ { name = "fasta2a" }, ] +ag-ui = [ + { name = "ag-ui-protocol" }, + { name = "starlette" }, +] anthropic = [ { name = "anthropic" }, ] @@ -3139,6 +3155,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "ag-ui-protocol", marker = "extra == 'ag-ui'", specifier = ">=0.1.7" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.52.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.37.24" }, @@ -3163,10 +3180,11 @@ requires-dist = [ { name = "pydantic-graph", editable = "pydantic_graph" }, { name = "requests", marker = "extra == 'vertexai'", specifier = ">=2.32.2" }, { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, + { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] [package.metadata.requires-dev] dev = [