Skip to content

Commit bd89845

Browse files
eliassoaresDouweM
andauthored
Add logprobs to OpenAI model settings and response (#1238)
Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent 96f676c commit bd89845

File tree

4 files changed

+94
-5
lines changed

4 files changed

+94
-5
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,13 @@ class ModelResponse:
562562
kind: Literal['response'] = 'response'
563563
"""Message type identifier, this is available on all parts as a discriminator."""
564564

565+
vendor_details: dict[str, Any] | None = field(default=None, repr=False)
566+
"""Additional vendor-specific details in a serializable format.
567+
568+
This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
569+
For OpenAI models, this may include 'logprobs', 'finish_reason', etc.
570+
"""
571+
565572
def otel_events(self) -> list[Event]:
566573
"""Return OpenTelemetry events for the response."""
567574
result: list[Event] = []

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ class OpenAIModelSettings(ModelSettings, total=False):
104104
result in faster responses and fewer tokens used on reasoning in a response.
105105
"""
106106

107+
openai_logprobs: bool
108+
"""Include log probabilities in the response."""
109+
110+
openai_top_logprobs: int
111+
"""Include log probabilities of the top n tokens in the response."""
112+
107113
openai_user: str
108114
"""A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.
109115
@@ -287,6 +293,8 @@ async def _completions_create(
287293
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
288294
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
289295
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
296+
logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
297+
top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
290298
user=model_settings.get('openai_user', NOT_GIVEN),
291299
extra_headers=extra_headers,
292300
extra_body=model_settings.get('extra_body'),
@@ -301,12 +309,37 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
301309
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
302310
choice = response.choices[0]
303311
items: list[ModelResponsePart] = []
312+
vendor_details: dict[str, Any] | None = None
313+
314+
# Add logprobs to vendor_details if available
315+
if choice.logprobs is not None and choice.logprobs.content:
316+
# Convert logprobs to a serializable format
317+
vendor_details = {
318+
'logprobs': [
319+
{
320+
'token': lp.token,
321+
'bytes': lp.bytes,
322+
'logprob': lp.logprob,
323+
'top_logprobs': [
324+
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
325+
],
326+
}
327+
for lp in choice.logprobs.content
328+
],
329+
}
330+
304331
if choice.message.content is not None:
305332
items.append(TextPart(choice.message.content))
306333
if choice.message.tool_calls is not None:
307334
for c in choice.message.tool_calls:
308335
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
309-
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
336+
return ModelResponse(
337+
items,
338+
usage=_map_usage(response),
339+
model_name=response.model,
340+
timestamp=timestamp,
341+
vendor_details=vendor_details,
342+
)
310343

311344
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
312345
"""Process a streamed response, and prepare a streaming response to return."""

tests/models/test_openai.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
with try_import() as imports_successful:
4141
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI
4242
from openai.types import chat
43-
from openai.types.chat.chat_completion import Choice
43+
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
4444
from openai.types.chat.chat_completion_chunk import (
4545
Choice as ChunkChoice,
4646
ChoiceDelta,
@@ -49,6 +49,7 @@
4949
)
5050
from openai.types.chat.chat_completion_message import ChatCompletionMessage
5151
from openai.types.chat.chat_completion_message_tool_call import Function
52+
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
5253
from openai.types.completion_usage import CompletionUsage, PromptTokensDetails
5354

5455
from pydantic_ai.models.openai import (
@@ -129,10 +130,15 @@ def get_mock_chat_completion_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str
129130
raise RuntimeError('Not a MockOpenAI instance')
130131

131132

132-
def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage | None = None) -> chat.ChatCompletion:
133+
def completion_message(
134+
message: ChatCompletionMessage, *, usage: CompletionUsage | None = None, logprobs: ChoiceLogprobs | None = None
135+
) -> chat.ChatCompletion:
136+
choices = [Choice(finish_reason='stop', index=0, message=message)]
137+
if logprobs:
138+
choices = [Choice(finish_reason='stop', index=0, message=message, logprobs=logprobs)]
133139
return chat.ChatCompletion(
134140
id='123',
135-
choices=[Choice(finish_reason='stop', index=0, message=message)],
141+
choices=choices,
136142
created=1704067200, # 2024-01-01
137143
model='gpt-4o-123',
138144
object='chat.completion',
@@ -141,7 +147,9 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
141147

142148

143149
async def test_request_simple_success(allow_model_requests: None):
144-
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
150+
c = completion_message(
151+
ChatCompletionMessage(content='world', role='assistant'),
152+
)
145153
mock_client = MockOpenAI.create_mock(c)
146154
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
147155
agent = Agent(m)
@@ -1543,3 +1551,43 @@ async def get_temperature(city: str) -> float:
15431551
),
15441552
]
15451553
)
1554+
1555+
1556+
@pytest.mark.vcr()
1557+
async def test_openai_instructions_with_logprobs(allow_model_requests: None):
1558+
# Create a mock response with logprobs
1559+
c = completion_message(
1560+
ChatCompletionMessage(content='world', role='assistant'),
1561+
logprobs=ChoiceLogprobs(
1562+
content=[
1563+
ChatCompletionTokenLogprob(
1564+
token='world', logprob=-0.6931, top_logprobs=[], bytes=[119, 111, 114, 108, 100]
1565+
)
1566+
],
1567+
),
1568+
)
1569+
1570+
mock_client = MockOpenAI.create_mock(c)
1571+
m = OpenAIModel(
1572+
'gpt-4o',
1573+
provider=OpenAIProvider(openai_client=mock_client),
1574+
)
1575+
agent = Agent(
1576+
m,
1577+
instructions='You are a helpful assistant.',
1578+
)
1579+
result = await agent.run(
1580+
'What is the capital of Minas Gerais?',
1581+
model_settings=OpenAIModelSettings(openai_logprobs=True),
1582+
)
1583+
messages = result.all_messages()
1584+
response = cast(Any, messages[1])
1585+
assert response.vendor_details is not None
1586+
assert response.vendor_details['logprobs'] == [
1587+
{
1588+
'token': 'world',
1589+
'logprob': -0.6931,
1590+
'bytes': [119, 111, 114, 108, 100],
1591+
'top_logprobs': [],
1592+
}
1593+
]

tests/test_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,6 +1757,7 @@ def test_binary_content_all_messages_json():
17571757
'model_name': 'test',
17581758
'timestamp': IsStr(),
17591759
'kind': 'response',
1760+
'vendor_details': None,
17601761
},
17611762
]
17621763
)

0 commit comments

Comments
 (0)