Skip to content

Commit 3d5199b

Browse files
authored
Bugfix: record instructions properly on agent run span when using structured output (#1740)
1 parent 7ad3be9 commit 3d5199b

File tree

3 files changed

+152
-9
lines changed

3 files changed

+152
-9
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,11 +324,48 @@ def base_url(self) -> str | None:
324324
"""The base URL for the provider API, if available."""
325325
return None
326326

327-
def _get_instructions(self, messages: list[ModelMessage]) -> str | None:
328-
"""Get instructions from the first ModelRequest found when iterating messages in reverse."""
327+
@staticmethod
328+
def _get_instructions(messages: list[ModelMessage]) -> str | None:
329+
"""Get instructions from the first ModelRequest found when iterating messages in reverse.
330+
331+
In the case that a "mock" request was generated to include a tool-return part for a result tool,
332+
we want to use the instructions from the second-to-most-recent request (which should correspond to the
333+
original request that generated the response that resulted in the tool-return part).
334+
"""
335+
last_two_requests: list[ModelRequest] = []
329336
for message in reversed(messages):
330337
if isinstance(message, ModelRequest):
331-
return message.instructions
338+
last_two_requests.append(message)
339+
if len(last_two_requests) == 2:
340+
break
341+
if message.instructions is not None:
342+
return message.instructions
343+
344+
# If we don't have two requests, and we didn't already return instructions, there are definitely not any:
345+
if len(last_two_requests) != 2:
346+
return None
347+
348+
most_recent_request = last_two_requests[0]
349+
second_most_recent_request = last_two_requests[1]
350+
351+
# If we've gotten this far and the most recent request consists of only tool-return parts or retry-prompt parts,
352+
# we use the instructions from the second-to-most-recent request. This is necessary because when handling
353+
# result tools, we generate a "mock" ModelRequest with a tool-return part for it, and that ModelRequest will not
354+
# have the relevant instructions from the agent.
355+
356+
# While it's possible that you could have a message history where the most recent request has only tool returns,
357+
# I believe there is no way to achieve that would _change_ the instructions without manually crafting the most
358+
# recent message. That might make sense in principle for some usage pattern, but it's enough of an edge case
359+
# that I think it's not worth worrying about, since you can work around this by inserting another ModelRequest
360+
# with no parts at all immediately before the request that has the tool calls (that works because we only look
361+
# at the two most recent ModelRequests here).
362+
363+
# If you have a use case where this causes pain, please open a GitHub issue and we can discuss alternatives.
364+
365+
if all(p.part_kind == 'tool-return' or p.part_kind == 'retry-prompt' for p in most_recent_request.parts):
366+
return second_most_recent_request.instructions
367+
368+
return None
332369

333370

334371
@dataclass

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,13 @@ def event_to_dict(event: Event) -> dict[str, Any]:
273273
@staticmethod
274274
def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
275275
events: list[Event] = []
276-
last_model_request: ModelRequest | None = None
276+
instructions = InstrumentedModel._get_instructions(messages)
277+
if instructions is not None:
278+
events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'}))
279+
277280
for message_index, message in enumerate(messages):
278281
message_events: list[Event] = []
279282
if isinstance(message, ModelRequest):
280-
last_model_request = message
281283
for part in message.parts:
282284
if hasattr(part, 'otel_event'):
283285
message_events.append(part.otel_event())
@@ -289,10 +291,7 @@ def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
289291
**(event.attributes or {}),
290292
}
291293
events.extend(message_events)
292-
if last_model_request and last_model_request.instructions:
293-
events.insert(
294-
0, Event('gen_ai.system.message', body={'content': last_model_request.instructions, 'role': 'system'})
295-
)
294+
296295
for event in events:
297296
event.body = InstrumentedModel.serialize_any(event.body)
298297
return events

tests/test_logfire.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,113 @@ async def my_ret(x: int) -> str:
238238
)
239239

240240

241+
@pytest.mark.skipif(not logfire_installed, reason='logfire not installed')
242+
def test_instructions_with_structured_output(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
243+
@dataclass
244+
class MyOutput:
245+
content: str
246+
247+
my_agent = Agent(model=TestModel(), instructions='Here are some instructions', instrument=True)
248+
249+
result = my_agent.run_sync('Hello', output_type=MyOutput)
250+
assert result.output == snapshot(MyOutput(content='a'))
251+
252+
summary = get_logfire_summary()
253+
assert summary.attributes[0] == snapshot(
254+
{
255+
'model_name': 'test',
256+
'agent_name': 'my_agent',
257+
'logfire.msg': 'my_agent run',
258+
'logfire.span_type': 'span',
259+
'gen_ai.usage.input_tokens': 51,
260+
'gen_ai.usage.output_tokens': 5,
261+
'all_messages_events': IsJson(
262+
snapshot(
263+
[
264+
{
265+
'content': 'Here are some instructions',
266+
'role': 'system',
267+
'event.name': 'gen_ai.system.message',
268+
},
269+
{
270+
'content': 'Hello',
271+
'role': 'user',
272+
'gen_ai.message.index': 0,
273+
'event.name': 'gen_ai.user.message',
274+
},
275+
{
276+
'role': 'assistant',
277+
'tool_calls': [
278+
{
279+
'id': IsStr(),
280+
'type': 'function',
281+
'function': {'name': 'final_result', 'arguments': {'content': 'a'}},
282+
}
283+
],
284+
'gen_ai.message.index': 1,
285+
'event.name': 'gen_ai.assistant.message',
286+
},
287+
{
288+
'content': 'Final result processed.',
289+
'role': 'tool',
290+
'id': IsStr(),
291+
'name': 'final_result',
292+
'gen_ai.message.index': 2,
293+
'event.name': 'gen_ai.tool.message',
294+
},
295+
]
296+
)
297+
),
298+
'final_result': '{"content": "a"}',
299+
'logfire.json_schema': IsJson(
300+
snapshot(
301+
{
302+
'type': 'object',
303+
'properties': {'all_messages_events': {'type': 'array'}, 'final_result': {'type': 'object'}},
304+
}
305+
)
306+
),
307+
}
308+
)
309+
chat_span_attributes = summary.attributes[1]
310+
assert chat_span_attributes['events'] == snapshot(
311+
IsJson(
312+
snapshot(
313+
[
314+
{
315+
'content': 'Here are some instructions',
316+
'role': 'system',
317+
'gen_ai.system': 'test',
318+
'event.name': 'gen_ai.system.message',
319+
},
320+
{
321+
'event.name': 'gen_ai.user.message',
322+
'content': 'Hello',
323+
'role': 'user',
324+
'gen_ai.message.index': 0,
325+
'gen_ai.system': 'test',
326+
},
327+
{
328+
'event.name': 'gen_ai.choice',
329+
'index': 0,
330+
'message': {
331+
'role': 'assistant',
332+
'tool_calls': [
333+
{
334+
'id': IsStr(),
335+
'type': 'function',
336+
'function': {'name': 'final_result', 'arguments': {'content': 'a'}},
337+
}
338+
],
339+
},
340+
'gen_ai.system': 'test',
341+
},
342+
]
343+
)
344+
)
345+
)
346+
347+
241348
def test_instrument_all():
242349
model = TestModel()
243350
agent = Agent()

0 commit comments

Comments
 (0)