Skip to content

Commit babdf82

Browse files
authored
Do a better job of inferring openai strict mode (#1511)
1 parent 9733f32 commit babdf82

File tree

2 files changed

+124
-26
lines changed

2 files changed

+124
-26
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import base64
4+
import re
45
import warnings
56
from collections.abc import AsyncIterable, AsyncIterator, Sequence
67
from contextlib import asynccontextmanager
@@ -932,6 +933,31 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
932933
)
933934

934935

936+
_STRICT_INCOMPATIBLE_KEYS = [
937+
'minLength',
938+
'maxLength',
939+
'pattern',
940+
'format',
941+
'minimum',
942+
'maximum',
943+
'multipleOf',
944+
'patternProperties',
945+
'unevaluatedProperties',
946+
'propertyNames',
947+
'minProperties',
948+
'maxProperties',
949+
'unevaluatedItems',
950+
'contains',
951+
'minContains',
952+
'maxContains',
953+
'minItems',
954+
'maxItems',
955+
'uniqueItems',
956+
]
957+
958+
_sentinel = object()
959+
960+
935961
@dataclass
936962
class _OpenAIJsonSchema(WalkJsonSchema):
937963
"""Recursively handle the schema to make it compatible with OpenAI strict mode.
@@ -946,28 +972,64 @@ def __init__(self, schema: JsonSchema, strict: bool | None):
946972
super().__init__(schema)
947973
self.strict = strict
948974
self.is_strict_compatible = True
975+
self.root_ref = schema.get('$ref')
976+
977+
def walk(self) -> JsonSchema:
978+
# Note: OpenAI does not support anyOf at the root in strict mode
979+
# However, we don't need to check for it here because we ensure in pydantic_ai._utils.check_object_json_schema
980+
# that the root schema either has type 'object' or is recursive.
981+
result = super().walk()
982+
983+
# For recursive models, we need to tweak the schema to make it compatible with strict mode.
984+
# Because the following should never change the semantics of the schema we apply it unconditionally.
985+
if self.root_ref is not None:
986+
result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method
987+
root_key = re.sub(r'^#/\$defs/', '', self.root_ref)
988+
result.update(self.defs.get(root_key) or {})
989+
990+
return result
949991

950-
def transform(self, schema: JsonSchema) -> JsonSchema:
992+
def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901
951993
# Remove unnecessary keys
952994
schema.pop('title', None)
953995
schema.pop('default', None)
954996
schema.pop('$schema', None)
955997
schema.pop('discriminator', None)
956998

957-
# Remove incompatible keys, but note their impact in the description provided to the LLM
999+
if schema_ref := schema.get('$ref'):
1000+
if schema_ref == self.root_ref:
1001+
schema['$ref'] = '#'
1002+
if len(schema) > 1:
1003+
# OpenAI Strict mode doesn't support siblings to "$ref", but _does_ allow siblings to "anyOf".
1004+
# So if there is a "description" field or any other extra info, we move the "$ref" into an "anyOf":
1005+
schema['anyOf'] = [{'$ref': schema.pop('$ref')}]
1006+
1007+
# Track strict-incompatible keys
1008+
incompatible_values: dict[str, Any] = {}
1009+
for key in _STRICT_INCOMPATIBLE_KEYS:
1010+
value = schema.get(key, _sentinel)
1011+
if value is not _sentinel:
1012+
incompatible_values[key] = value
9581013
description = schema.get('description')
959-
min_length = schema.pop('minLength', None)
960-
max_length = schema.pop('maxLength', None)
961-
if description is not None:
962-
notes = list[str]()
963-
if min_length is not None: # pragma: no cover
964-
notes.append(f'min_length={min_length}')
965-
if max_length is not None: # pragma: no cover
966-
notes.append(f'max_length={max_length}')
967-
if notes: # pragma: no cover
968-
schema['description'] = f'{description} ({", ".join(notes)})'
1014+
if incompatible_values:
1015+
if self.strict is True:
1016+
notes: list[str] = []
1017+
for key, value in incompatible_values.items():
1018+
schema.pop(key)
1019+
notes.append(f'{key}={value}')
1020+
notes_string = ', '.join(notes)
1021+
schema['description'] = notes_string if not description else f'{description} ({notes_string})'
1022+
elif self.strict is None:
1023+
self.is_strict_compatible = False
9691024

9701025
schema_type = schema.get('type')
1026+
if 'oneOf' in schema:
1027+
# OpenAI does not support oneOf in strict mode
1028+
if self.strict is True:
1029+
schema['anyOf'] = schema.pop('oneOf')
1030+
else:
1031+
self.is_strict_compatible = False
1032+
9711033
if schema_type == 'object':
9721034
if self.strict is True:
9731035
# additional properties are disallowed

tests/models/test_openai.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Sequence
55
from dataclasses import dataclass, field
66
from datetime import datetime, timezone
7+
from enum import Enum
78
from functools import cached_property
89
from typing import Annotated, Any, Callable, Literal, Union, cast
910

@@ -730,9 +731,15 @@ class MyDefaultDc:
730731
x: int = 1
731732

732733

734+
class MyEnum(Enum):
735+
a = 'a'
736+
b = 'b'
737+
738+
733739
@dataclass
734740
class MyRecursiveDc:
735741
field: MyRecursiveDc | None
742+
my_enum: MyEnum = Field(description='my enum')
736743

737744

738745
@dataclass
@@ -826,9 +833,13 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
826833
},
827834
'type': 'object',
828835
},
836+
'MyEnum': {'enum': ['a', 'b'], 'type': 'string'},
829837
'MyRecursiveDc': {
830-
'properties': {'field': {'anyOf': [{'$ref': '#/$defs/MyRecursiveDc'}, {'type': 'null'}]}},
831-
'required': ['field'],
838+
'properties': {
839+
'field': {'anyOf': [{'$ref': '#/$defs/MyRecursiveDc'}, {'type': 'null'}]},
840+
'my_enum': {'description': 'my enum', 'anyOf': [{'$ref': '#/$defs/MyEnum'}]},
841+
},
842+
'required': ['field', 'my_enum'],
832843
'type': 'object',
833844
},
834845
},
@@ -857,11 +868,15 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
857868
'additionalProperties': False,
858869
'required': ['field'],
859870
},
871+
'MyEnum': {'enum': ['a', 'b'], 'type': 'string'},
860872
'MyRecursiveDc': {
861-
'properties': {'field': {'anyOf': [{'$ref': '#/$defs/MyRecursiveDc'}, {'type': 'null'}]}},
873+
'properties': {
874+
'field': {'anyOf': [{'$ref': '#/$defs/MyRecursiveDc'}, {'type': 'null'}]},
875+
'my_enum': {'description': 'my enum', 'anyOf': [{'$ref': '#/$defs/MyEnum'}]},
876+
},
862877
'type': 'object',
863878
'additionalProperties': False,
864-
'required': ['field'],
879+
'required': ['field', 'my_enum'],
865880
},
866881
},
867882
'additionalProperties': False,
@@ -998,7 +1013,7 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
9981013
}
9991014
},
10001015
'additionalProperties': False,
1001-
'properties': {'x': {'oneOf': [{'type': 'integer'}, {'$ref': '#/$defs/MyDefaultDc'}]}},
1016+
'properties': {'x': {'anyOf': [{'type': 'integer'}, {'$ref': '#/$defs/MyDefaultDc'}]}},
10021017
'required': ['x'],
10031018
'type': 'object',
10041019
}
@@ -1079,12 +1094,15 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
10791094
{
10801095
'additionalProperties': False,
10811096
'properties': {
1082-
'x': {'maxItems': 1, 'minItems': 1, 'prefixItems': [{'type': 'integer'}], 'type': 'array'},
1097+
'x': {
1098+
'prefixItems': [{'type': 'integer'}],
1099+
'type': 'array',
1100+
'description': 'minItems=1, maxItems=1',
1101+
},
10831102
'y': {
1084-
'maxItems': 1,
1085-
'minItems': 1,
10861103
'prefixItems': [{'type': 'string'}],
10871104
'type': 'array',
1105+
'description': 'minItems=1, maxItems=1',
10881106
},
10891107
},
10901108
'required': ['x', 'y'],
@@ -1160,28 +1178,46 @@ class MyModel(BaseModel):
11601178
'MyModel': {
11611179
'additionalProperties': False,
11621180
'properties': {
1163-
'my_discriminated_union': {'oneOf': [{'$ref': '#/$defs/Apple'}, {'$ref': '#/$defs/Banana'}]},
1181+
'my_discriminated_union': {'anyOf': [{'$ref': '#/$defs/Apple'}, {'$ref': '#/$defs/Banana'}]},
11641182
'my_list': {'items': {'type': 'number'}, 'type': 'array'},
11651183
'my_patterns': {
11661184
'additionalProperties': False,
1167-
'patternProperties': {'^my-pattern$': {'type': 'string'}},
1185+
'description': "patternProperties={'^my-pattern$': {'type': 'string'}}",
11681186
'type': 'object',
11691187
'properties': {},
11701188
'required': [],
11711189
},
1172-
'my_recursive': {'anyOf': [{'$ref': '#/$defs/MyModel'}, {'type': 'null'}]},
1190+
'my_recursive': {'anyOf': [{'$ref': '#'}, {'type': 'null'}]},
11731191
'my_tuple': {
1174-
'maxItems': 1,
1175-
'minItems': 1,
11761192
'prefixItems': [{'type': 'integer'}],
11771193
'type': 'array',
1194+
'description': 'minItems=1, maxItems=1',
11781195
},
11791196
},
11801197
'required': ['my_recursive', 'my_patterns', 'my_tuple', 'my_list', 'my_discriminated_union'],
11811198
'type': 'object',
11821199
},
11831200
},
1184-
'$ref': '#/$defs/MyModel',
1201+
'properties': {
1202+
'my_recursive': {'anyOf': [{'$ref': '#'}, {'type': 'null'}]},
1203+
'my_patterns': {
1204+
'type': 'object',
1205+
'description': "patternProperties={'^my-pattern$': {'type': 'string'}}",
1206+
'additionalProperties': False,
1207+
'properties': {},
1208+
'required': [],
1209+
},
1210+
'my_tuple': {
1211+
'prefixItems': [{'type': 'integer'}],
1212+
'type': 'array',
1213+
'description': 'minItems=1, maxItems=1',
1214+
},
1215+
'my_list': {'items': {'type': 'number'}, 'type': 'array'},
1216+
'my_discriminated_union': {'anyOf': [{'$ref': '#/$defs/Apple'}, {'$ref': '#/$defs/Banana'}]},
1217+
},
1218+
'required': ['my_recursive', 'my_patterns', 'my_tuple', 'my_list', 'my_discriminated_union'],
1219+
'type': 'object',
1220+
'additionalProperties': False,
11851221
}
11861222
)
11871223

0 commit comments

Comments
 (0)