Skip to content

Commit 3bc5219

Browse files
Add to_cli() method to Agent (#1642)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent cb37b74 commit 3bc5219

File tree

5 files changed

+230
-11
lines changed

5 files changed

+230
-11
lines changed

clai/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Either way, running `clai` will start an interactive session where you can chat
5353
## Help
5454

5555
```
56-
usage: clai [-h] [-m [MODEL]] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt]
56+
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt]
5757
5858
PydanticAI CLI v...
5959
@@ -69,6 +69,8 @@ options:
6969
-h, --help show this help message and exit
7070
-m [MODEL], --model [MODEL]
7171
Model to use, in format "<provider>:<model>" e.g. "openai:gpt-4o" or "anthropic:claude-3-7-sonnet-latest". Defaults to "openai:gpt-4o".
72+
-a AGENT, --agent AGENT
73+
Custom Agent to use, in format "module:variable", e.g. "mymodule.submodule:my_agent"
7274
-l, --list-models List all available models and exit
7375
-t [CODE_THEME], --code-theme [CODE_THEME]
7476
Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.

docs/cli.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,47 @@ uvx clai --model anthropic:claude-3-7-sonnet-latest
6464
```
6565

6666
(a full list of models available can be printed with `uvx clai --list-models`)
67+
68+
### Custom Agents
69+
70+
You can specify a custom agent using the `--agent` flag with a module path and variable name:
71+
72+
```python {title="custom_agent.py" test="skip"}
73+
from pydantic_ai import Agent
74+
75+
agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.')
76+
```
77+
78+
Then run:
79+
80+
```bash
81+
uvx clai --agent custom_agent:agent "What's the weather today?"
82+
```
83+
84+
The format must be `module:variable` where:
85+
86+
- `module` is the importable Python module path
87+
- `variable` is the name of the Agent instance in that module
88+
89+
90+
Additionally, you can directly launch CLI mode from an `Agent` instance using `Agent.to_cli_sync()`:
91+
92+
```python {title="agent_to_cli_sync.py" test="skip" hl_lines=4}
93+
from pydantic_ai import Agent
94+
95+
agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.')
96+
agent.to_cli_sync()
97+
```
98+
99+
You can also use the async interface with `Agent.to_cli()`:
100+
101+
```python {title="agent_to_cli.py" test="skip" hl_lines=6}
102+
from pydantic_ai import Agent
103+
104+
agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.')
105+
106+
async def main():
107+
await agent.to_cli()
108+
```
109+
110+
_(You'll need to add `asyncio.run(main())` to run `main`)_

pydantic_ai_slim/pydantic_ai/_cli.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import asyncio
5+
import importlib
56
import sys
67
from asyncio import CancelledError
78
from collections.abc import Sequence
@@ -12,6 +13,9 @@
1213

1314
from typing_inspection.introspection import get_literal_values
1415

16+
from pydantic_ai.result import OutputDataT
17+
from pydantic_ai.tools import AgentDepsT
18+
1519
from . import __version__
1620
from .agent import Agent
1721
from .exceptions import UserError
@@ -123,6 +127,11 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in
123127
# e.g. we want to show `openai:gpt-4o` but not `gpt-4o`
124128
qualified_model_names = [n for n in get_literal_values(KnownModelName.__value__) if ':' in n]
125129
arg.completer = argcomplete.ChoicesCompleter(qualified_model_names) # type: ignore[reportPrivateUsage]
130+
parser.add_argument(
131+
'-a',
132+
'--agent',
133+
help='Custom Agent to use, in format "module:variable", e.g. "mymodule.submodule:my_agent"',
134+
)
126135
parser.add_argument(
127136
'-l',
128137
'--list-models',
@@ -155,8 +164,22 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in
155164
console.print(f' {model}', highlight=False)
156165
return 0
157166

167+
agent: Agent[None, str] = cli_agent
168+
if args.agent:
169+
try:
170+
module_path, variable_name = args.agent.split(':')
171+
module = importlib.import_module(module_path)
172+
agent = getattr(module, variable_name)
173+
if not isinstance(agent, Agent):
174+
console.print(f'[red]Error: {args.agent} is not an Agent instance[/red]')
175+
return 1
176+
console.print(f'[green]Using custom agent:[/green] [magenta]{args.agent}[/magenta]', highlight=False)
177+
except ValueError:
178+
console.print('[red]Error: Agent must be specified in "module:variable" format[/red]')
179+
return 1
180+
158181
try:
159-
cli_agent.model = infer_model(args.model)
182+
agent.model = infer_model(args.model)
160183
except UserError as e:
161184
console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]')
162185
return 1
@@ -171,21 +194,27 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in
171194

172195
if prompt := cast(str, args.prompt):
173196
try:
174-
asyncio.run(ask_agent(cli_agent, prompt, stream, console, code_theme))
197+
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
175198
except KeyboardInterrupt:
176199
pass
177200
return 0
178201

179202
# doing this instead of `PromptSession[Any](history=` allows mocking of PromptSession in tests
180203
session: PromptSession[Any] = PromptSession(history=FileHistory(str(PROMPT_HISTORY_PATH)))
181204
try:
182-
return asyncio.run(run_chat(session, stream, cli_agent, console, code_theme, prog_name))
205+
return asyncio.run(run_chat(session, stream, agent, console, code_theme, prog_name))
183206
except KeyboardInterrupt: # pragma: no cover
184207
return 0
185208

186209

187210
async def run_chat(
188-
session: PromptSession[Any], stream: bool, agent: Agent, console: Console, code_theme: str, prog_name: str
211+
session: PromptSession[Any],
212+
stream: bool,
213+
agent: Agent[AgentDepsT, OutputDataT],
214+
console: Console,
215+
code_theme: str,
216+
prog_name: str,
217+
deps: AgentDepsT = None,
189218
) -> int:
190219
multiline = False
191220
messages: list[ModelMessage] = []
@@ -207,30 +236,31 @@ async def run_chat(
207236
return exit_value
208237
else:
209238
try:
210-
messages = await ask_agent(agent, text, stream, console, code_theme, messages)
239+
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages)
211240
except CancelledError: # pragma: no cover
212241
console.print('[dim]Interrupted[/dim]')
213242

214243

215244
async def ask_agent(
216-
agent: Agent,
245+
agent: Agent[AgentDepsT, OutputDataT],
217246
prompt: str,
218247
stream: bool,
219248
console: Console,
220249
code_theme: str,
250+
deps: AgentDepsT = None,
221251
messages: list[ModelMessage] | None = None,
222252
) -> list[ModelMessage]:
223253
status = Status('[dim]Working on it…[/dim]', console=console)
224254

225255
if not stream:
226256
with status:
227-
result = await agent.run(prompt, message_history=messages)
228-
content = result.output
257+
result = await agent.run(prompt, message_history=messages, deps=deps)
258+
content = str(result.output)
229259
console.print(Markdown(content, code_theme=code_theme))
230260
return result.all_messages()
231261

232262
with status, ExitStack() as stack:
233-
async with agent.iter(prompt, message_history=messages) as agent_run:
263+
async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run:
234264
live = Live('', refresh_per_second=15, console=console, vertical_overflow='ellipsis')
235265
async for node in agent_run:
236266
if Agent.is_model_request_node(node):

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from opentelemetry.trace import NoOpTracer, use_span
1414
from pydantic.json_schema import GenerateJsonSchema
15-
from typing_extensions import Literal, Never, TypeIs, TypeVar, deprecated
15+
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
1616

1717
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
1818
from pydantic_graph._utils import get_event_loop
@@ -1688,6 +1688,51 @@ async def run_mcp_servers(self) -> AsyncIterator[None]:
16881688
finally:
16891689
await exit_stack.aclose()
16901690

1691+
async def to_cli(self: Self, deps: AgentDepsT = None) -> None:
1692+
"""Run the agent in a CLI chat interface.
1693+
1694+
Example:
1695+
```python {title="agent_to_cli.py" test="skip"}
1696+
from pydantic_ai import Agent
1697+
1698+
agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.')
1699+
1700+
async def main():
1701+
await agent.to_cli()
1702+
```
1703+
"""
1704+
from prompt_toolkit import PromptSession
1705+
from prompt_toolkit.history import FileHistory
1706+
from rich.console import Console
1707+
1708+
from pydantic_ai._cli import PROMPT_HISTORY_PATH, run_chat
1709+
1710+
# TODO(Marcelo): We need to refactor the CLI code to be able to be able to just pass `agent`, `deps` and
1711+
# `prog_name` from here.
1712+
1713+
session: PromptSession[Any] = PromptSession(history=FileHistory(str(PROMPT_HISTORY_PATH)))
1714+
await run_chat(
1715+
session=session,
1716+
stream=True,
1717+
agent=self,
1718+
deps=deps,
1719+
console=Console(),
1720+
code_theme='monokai',
1721+
prog_name='pydantic-ai',
1722+
)
1723+
1724+
def to_cli_sync(self: Self, deps: AgentDepsT = None) -> None:
1725+
"""Run the agent in a CLI chat interface with the non-async interface.
1726+
1727+
```python {title="agent_to_cli_sync.py" test="skip"}
1728+
from pydantic_ai import Agent
1729+
1730+
agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.')
1731+
agent.to_cli_sync()
1732+
```
1733+
"""
1734+
return get_event_loop().run_until_complete(self.to_cli(deps=deps))
1735+
16911736

16921737
@dataclasses.dataclass(repr=False)
16931738
class AgentRun(Generic[AgentDepsT, OutputDataT]):

tests/test_cli.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from io import StringIO
23
from typing import Any
34

@@ -36,6 +37,72 @@ def test_invalid_model(capfd: CaptureFixture[str]):
3637
)
3738

3839

40+
def test_agent_flag(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
41+
env.set('OPENAI_API_KEY', 'test')
42+
43+
# Create a dynamic module using types.ModuleType
44+
import types
45+
46+
test_module = types.ModuleType('test_module')
47+
48+
# Create and add agent to the module
49+
test_agent = Agent()
50+
test_agent.model = TestModel(custom_output_text='Hello from custom agent')
51+
setattr(test_module, 'custom_agent', test_agent)
52+
53+
# Register the module in sys.modules
54+
sys.modules['test_module'] = test_module
55+
56+
try:
57+
# Mock ask_agent to avoid actual execution but capture the agent
58+
mock_ask = mocker.patch('pydantic_ai._cli.ask_agent')
59+
60+
# Test CLI with custom agent
61+
assert cli(['--agent', 'test_module:custom_agent', 'hello']) == 0
62+
63+
# Verify the output contains the custom agent message
64+
assert 'Using custom agent: test_module:custom_agent' in capfd.readouterr().out
65+
66+
# Verify ask_agent was called with our custom agent
67+
mock_ask.assert_called_once()
68+
assert mock_ask.call_args[0][0] is test_agent
69+
70+
finally:
71+
# Clean up by removing the module from sys.modules
72+
if 'test_module' in sys.modules:
73+
del sys.modules['test_module']
74+
75+
76+
def test_agent_flag_non_agent(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
77+
env.set('OPENAI_API_KEY', 'test')
78+
79+
# Create a dynamic module using types.ModuleType
80+
import types
81+
82+
test_module = types.ModuleType('test_module')
83+
84+
# Create and add agent to the module
85+
test_agent = 'Not an Agent object'
86+
setattr(test_module, 'custom_agent', test_agent)
87+
88+
# Register the module in sys.modules
89+
sys.modules['test_module'] = test_module
90+
91+
try:
92+
assert cli(['--agent', 'test_module:custom_agent', 'hello']) == 1
93+
assert 'is not an Agent' in capfd.readouterr().out
94+
95+
finally:
96+
# Clean up by removing the module from sys.modules
97+
if 'test_module' in sys.modules:
98+
del sys.modules['test_module']
99+
100+
101+
def test_agent_flag_bad_module_variable_path(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
102+
assert cli(['--agent', 'bad_path', 'hello']) == 1
103+
assert 'Agent must be specified in "module:variable" format' in capfd.readouterr().out
104+
105+
39106
def test_list_models(capfd: CaptureFixture[str]):
40107
assert cli(['--list-models']) == 0
41108
output = capfd.readouterr().out.splitlines()
@@ -153,3 +220,34 @@ def test_code_theme_dark(mocker: MockerFixture, env: TestEnv):
153220
mock_run_chat.assert_awaited_once_with(
154221
IsInstance(PromptSession), True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai'
155222
)
223+
224+
225+
def test_agent_to_cli_sync(mocker: MockerFixture, env: TestEnv):
226+
env.set('OPENAI_API_KEY', 'test')
227+
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
228+
cli_agent.to_cli_sync()
229+
mock_run_chat.assert_awaited_once_with(
230+
session=IsInstance(PromptSession),
231+
stream=True,
232+
agent=IsInstance(Agent),
233+
console=IsInstance(Console),
234+
code_theme='monokai',
235+
prog_name='pydantic-ai',
236+
deps=None,
237+
)
238+
239+
240+
@pytest.mark.anyio
241+
async def test_agent_to_cli_async(mocker: MockerFixture, env: TestEnv):
242+
env.set('OPENAI_API_KEY', 'test')
243+
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
244+
await cli_agent.to_cli()
245+
mock_run_chat.assert_awaited_once_with(
246+
session=IsInstance(PromptSession),
247+
stream=True,
248+
agent=IsInstance(Agent),
249+
console=IsInstance(Console),
250+
code_theme='monokai',
251+
prog_name='pydantic-ai',
252+
deps=None,
253+
)

0 commit comments

Comments
 (0)