Skip to content

Commit 0f3a20d

Browse files
authored
fix: add instructions on further requests besides UserPromptNode (#1503)
1 parent 4d261d5 commit 0f3a20d

File tree

4 files changed

+268
-14
lines changed

4 files changed

+268
-14
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import asyncio
44
import dataclasses
55
import json
6-
from collections.abc import AsyncIterator, Iterator, Sequence
6+
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
77
from contextlib import asynccontextmanager, contextmanager
88
from contextvars import ContextVar
99
from dataclasses import field
10-
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
10+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
1111

1212
from opentelemetry.trace import Span, Tracer
1313
from typing_extensions import TypeGuard, TypeVar, assert_never
@@ -87,6 +87,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
8787
usage_limits: _usage.UsageLimits
8888
max_result_retries: int
8989
end_strategy: EndStrategy
90+
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]
9091

9192
output_schema: _output.OutputSchema[OutputDataT] | None
9293
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
@@ -141,7 +142,9 @@ async def _get_first_message(
141142
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
142143
) -> _messages.ModelRequest:
143144
run_context = build_run_context(ctx)
144-
history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
145+
history, next_message = await self._prepare_messages(
146+
self.user_prompt, ctx.state.message_history, ctx.deps.get_instructions, run_context
147+
)
145148
ctx.state.message_history = history
146149
run_context.messages = history
147150

@@ -155,6 +158,7 @@ async def _prepare_messages(
155158
self,
156159
user_prompt: str | Sequence[_messages.UserContent] | None,
157160
message_history: list[_messages.ModelMessage] | None,
161+
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]],
158162
run_context: RunContext[DepsT],
159163
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
160164
try:
@@ -169,7 +173,7 @@ async def _prepare_messages(
169173
ctx_messages.used = True
170174

171175
parts: list[_messages.ModelRequestPart] = []
172-
instructions = await self._instructions(run_context)
176+
instructions = await get_instructions(run_context)
173177
if message_history:
174178
# Shallow copy messages
175179
messages.extend(message_history)
@@ -210,15 +214,6 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod
210214
messages.append(_messages.SystemPromptPart(prompt))
211215
return messages
212216

213-
async def _instructions(self, run_context: RunContext[DepsT]) -> str | None:
214-
if self.instructions is None and not self.instructions_functions:
215-
return None
216-
217-
instructions = self.instructions or ''
218-
for instructions_runner in self.instructions_functions:
219-
instructions += await instructions_runner.run(run_context)
220-
return instructions
221-
222217

223218
async def _prepare_request_parameters(
224219
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
@@ -479,7 +474,11 @@ async def _handle_tool_calls(
479474
else:
480475
if tool_responses:
481476
parts.extend(tool_responses)
482-
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
477+
run_context = build_run_context(ctx)
478+
instructions = await ctx.deps.get_instructions(run_context)
479+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
480+
_messages.ModelRequest(parts=parts, instructions=instructions)
481+
)
483482

484483
def _handle_final_result(
485484
self,

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,15 @@ async def main():
620620
},
621621
)
622622

623+
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
624+
if self._instructions is None and not self._instructions_functions:
625+
return None
626+
627+
instructions = self._instructions or ''
628+
for instructions_runner in self._instructions_functions:
629+
instructions += await instructions_runner.run(run_context)
630+
return instructions
631+
623632
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
624633
user_deps=deps,
625634
prompt=user_prompt,
@@ -635,6 +644,7 @@ async def main():
635644
mcp_servers=self._mcp_servers,
636645
run_span=run_span,
637646
tracer=tracer,
647+
get_instructions=get_instructions,
638648
)
639649
start_node = _agent_graph.UserPromptNode[AgentDepsT](
640650
user_prompt=user_prompt,
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '419'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.openai.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: You are a helpful assistant.
20+
role: system
21+
- content: What is the temperature in Tokyo?
22+
role: user
23+
model: gpt-4.1-mini
24+
n: 1
25+
stream: false
26+
tool_choice: auto
27+
tools:
28+
- function:
29+
description: ''
30+
name: get_temperature
31+
parameters:
32+
additionalProperties: false
33+
properties:
34+
city:
35+
type: string
36+
required:
37+
- city
38+
type: object
39+
strict: true
40+
type: function
41+
uri: https://api.openai.com/v1/chat/completions
42+
response:
43+
headers:
44+
access-control-expose-headers:
45+
- X-Request-ID
46+
alt-svc:
47+
- h3=":443"; ma=86400
48+
connection:
49+
- keep-alive
50+
content-length:
51+
- '1089'
52+
content-type:
53+
- application/json
54+
openai-organization:
55+
- pydantic-28gund
56+
openai-processing-ms:
57+
- '490'
58+
openai-version:
59+
- '2020-10-01'
60+
strict-transport-security:
61+
- max-age=31536000; includeSubDomains; preload
62+
transfer-encoding:
63+
- chunked
64+
parsed_body:
65+
choices:
66+
- finish_reason: tool_calls
67+
index: 0
68+
logprobs: null
69+
message:
70+
annotations: []
71+
content: null
72+
refusal: null
73+
role: assistant
74+
tool_calls:
75+
- function:
76+
arguments: '{"city":"Tokyo"}'
77+
name: get_temperature
78+
id: call_bhZkmIKKItNGJ41whHUHB7p9
79+
type: function
80+
created: 1744810634
81+
id: chatcmpl-BMxEwRA0p0gJ52oKS7806KAlfMhqq
82+
model: gpt-4.1-mini-2025-04-14
83+
object: chat.completion
84+
service_tier: default
85+
system_fingerprint: fp_38647f5e19
86+
usage:
87+
completion_tokens: 15
88+
completion_tokens_details:
89+
accepted_prediction_tokens: 0
90+
audio_tokens: 0
91+
reasoning_tokens: 0
92+
rejected_prediction_tokens: 0
93+
prompt_tokens: 50
94+
prompt_tokens_details:
95+
audio_tokens: 0
96+
cached_tokens: 0
97+
total_tokens: 65
98+
status:
99+
code: 200
100+
message: OK
101+
- request:
102+
headers:
103+
accept:
104+
- application/json
105+
accept-encoding:
106+
- gzip, deflate
107+
connection:
108+
- keep-alive
109+
content-length:
110+
- '665'
111+
content-type:
112+
- application/json
113+
cookie:
114+
- __cf_bm=x.H2GlMeh.t_Q.gVlCXrh3.ggn9lKjhmUeG_ToNThLs-1744810635-1.0.1.1-tiHwqGvBw3eEy_y9_q5nx7B.7YCbLb9cXdDj6DklLmtFllOFe708mKwYvGd8fY2y5bO2NOagULipA7MxfwW9P0hlnRSiJZbZBO9tjrUweFc;
115+
_cfuvid=VlHcJdsIsxGEt2lddKu_5Am_lfyYndl9JB2Ezy.aygo-1744810635187-0.0.1.1-604800000
116+
host:
117+
- api.openai.com
118+
method: POST
119+
parsed_body:
120+
messages:
121+
- content: You are a helpful assistant.
122+
role: system
123+
- content: What is the temperature in Tokyo?
124+
role: user
125+
- role: assistant
126+
tool_calls:
127+
- function:
128+
arguments: '{"city":"Tokyo"}'
129+
name: get_temperature
130+
id: call_bhZkmIKKItNGJ41whHUHB7p9
131+
type: function
132+
- content: '20.0'
133+
role: tool
134+
tool_call_id: call_bhZkmIKKItNGJ41whHUHB7p9
135+
model: gpt-4.1-mini
136+
n: 1
137+
stream: false
138+
tool_choice: auto
139+
tools:
140+
- function:
141+
description: ''
142+
name: get_temperature
143+
parameters:
144+
additionalProperties: false
145+
properties:
146+
city:
147+
type: string
148+
required:
149+
- city
150+
type: object
151+
strict: true
152+
type: function
153+
uri: https://api.openai.com/v1/chat/completions
154+
response:
155+
headers:
156+
access-control-expose-headers:
157+
- X-Request-ID
158+
alt-svc:
159+
- h3=":443"; ma=86400
160+
connection:
161+
- keep-alive
162+
content-length:
163+
- '867'
164+
content-type:
165+
- application/json
166+
openai-organization:
167+
- pydantic-28gund
168+
openai-processing-ms:
169+
- '949'
170+
openai-version:
171+
- '2020-10-01'
172+
strict-transport-security:
173+
- max-age=31536000; includeSubDomains; preload
174+
transfer-encoding:
175+
- chunked
176+
parsed_body:
177+
choices:
178+
- finish_reason: stop
179+
index: 0
180+
logprobs: null
181+
message:
182+
annotations: []
183+
content: The temperature in Tokyo is currently 20.0 degrees Celsius.
184+
refusal: null
185+
role: assistant
186+
created: 1744810635
187+
id: chatcmpl-BMxEx6B8JEj6oDC45MOWKp0phg8UP
188+
model: gpt-4.1-mini-2025-04-14
189+
object: chat.completion
190+
service_tier: default
191+
system_fingerprint: fp_38647f5e19
192+
usage:
193+
completion_tokens: 15
194+
completion_tokens_details:
195+
accepted_prediction_tokens: 0
196+
audio_tokens: 0
197+
reasoning_tokens: 0
198+
rejected_prediction_tokens: 0
199+
prompt_tokens: 75
200+
prompt_tokens_details:
201+
audio_tokens: 0
202+
cached_tokens: 0
203+
total_tokens: 90
204+
status:
205+
code: 200
206+
message: OK
207+
version: 1

tests/models/test_openai.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,3 +1215,41 @@ async def test_openai_model_without_system_prompt(allow_model_requests: None, op
12151215
assert result.output == snapshot(
12161216
"That's right—I am a potato! A spud of many talents, here to help you out. How can this humble potato be of service today?"
12171217
)
1218+
1219+
1220+
@pytest.mark.vcr()
1221+
async def test_openai_instructions_with_tool_calls_keep_instructions(allow_model_requests: None, openai_api_key: str):
1222+
m = OpenAIModel('gpt-4.1-mini', provider=OpenAIProvider(api_key=openai_api_key))
1223+
agent = Agent(m, instructions='You are a helpful assistant.')
1224+
1225+
@agent.tool_plain
1226+
async def get_temperature(city: str) -> float:
1227+
return 20.0
1228+
1229+
result = await agent.run('What is the temperature in Tokyo?')
1230+
assert result.all_messages() == snapshot(
1231+
[
1232+
ModelRequest(
1233+
parts=[UserPromptPart(content='What is the temperature in Tokyo?', timestamp=IsDatetime())],
1234+
instructions='You are a helpful assistant.',
1235+
),
1236+
ModelResponse(
1237+
parts=[ToolCallPart(tool_name='get_temperature', args='{"city":"Tokyo"}', tool_call_id=IsStr())],
1238+
model_name='gpt-4.1-mini-2025-04-14',
1239+
timestamp=IsDatetime(),
1240+
),
1241+
ModelRequest(
1242+
parts=[
1243+
ToolReturnPart(
1244+
tool_name='get_temperature', content=20.0, tool_call_id=IsStr(), timestamp=IsDatetime()
1245+
)
1246+
],
1247+
instructions='You are a helpful assistant.',
1248+
),
1249+
ModelResponse(
1250+
parts=[TextPart(content='The temperature in Tokyo is currently 20.0 degrees Celsius.')],
1251+
model_name='gpt-4.1-mini-2025-04-14',
1252+
timestamp=IsDatetime(),
1253+
),
1254+
]
1255+
)

0 commit comments

Comments
 (0)