Skip to content

Commit cc18937

Browse files
authored
Generalize the JSON schema transformations (#1481)
1 parent 7260360 commit cc18937

File tree

5 files changed

+271
-216
lines changed

5 files changed

+271
-216
lines changed

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from __future__ import annotations as _annotations
22

33
import base64
4-
import re
54
from collections.abc import AsyncIterator, Sequence
65
from contextlib import asynccontextmanager
7-
from copy import deepcopy
86
from dataclasses import dataclass, field, replace
97
from datetime import datetime
108
from typing import Annotated, Any, Literal, Protocol, Union, cast
@@ -46,6 +44,7 @@
4644
check_allow_model_requests,
4745
get_user_agent,
4846
)
47+
from .json_schema import JsonSchema, WalkJsonSchema
4948

5049
LatestGeminiModelNames = Literal[
5150
'gemini-1.5-flash',
@@ -156,7 +155,7 @@ async def request_stream(
156155

157156
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
158157
def _customize_tool_def(t: ToolDefinition):
159-
return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).simplify())
158+
return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
160159

161160
return ModelRequestParameters(
162161
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
@@ -760,7 +759,7 @@ class _GeminiPromptFeedback(TypedDict):
760759
_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
761760

762761

763-
class _GeminiJsonSchema:
762+
class _GeminiJsonSchema(WalkJsonSchema):
764763
"""Transforms the JSON Schema from Pydantic to be suitable for Gemini.
765764
766765
Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
@@ -771,72 +770,58 @@ class _GeminiJsonSchema:
771770
* gemini doesn't allow `$defs` — we need to inline the definitions where possible
772771
"""
773772

774-
def __init__(self, schema: _utils.ObjectJsonSchema):
775-
self.schema = deepcopy(schema)
776-
self.defs = self.schema.pop('$defs', {})
773+
def __init__(self, schema: JsonSchema):
774+
super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True)
777775

778-
def simplify(self) -> dict[str, Any]:
779-
self._simplify(self.schema, refs_stack=())
780-
return self.schema
781-
782-
def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
776+
def transform(self, schema: JsonSchema) -> JsonSchema:
783777
schema.pop('title', None)
784778
schema.pop('default', None)
785779
schema.pop('$schema', None)
780+
if (const := schema.pop('const', None)) is not None: # pragma: no cover
781+
# Gemini doesn't support const, but it does support enum with a single value
782+
schema['enum'] = [const]
783+
schema.pop('discriminator', None)
784+
schema.pop('examples', None)
785+
786+
# TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
787+
# where we add notes about these properties to the field description?
786788
schema.pop('exclusiveMaximum', None)
787789
schema.pop('exclusiveMinimum', None)
788-
if ref := schema.pop('$ref', None):
789-
# noinspection PyTypeChecker
790-
key = re.sub(r'^#/\$defs/', '', ref)
791-
if key in refs_stack:
792-
raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
793-
refs_stack += (key,)
794-
schema_def = self.defs[key]
795-
self._simplify(schema_def, refs_stack)
796-
schema.update(schema_def)
797-
return
798-
799-
if any_of := schema.get('anyOf'):
800-
for item_schema in any_of:
801-
self._simplify(item_schema, refs_stack)
802-
if len(any_of) == 2 and {'type': 'null'} in any_of:
803-
for item_schema in any_of:
804-
if item_schema != {'type': 'null'}:
805-
schema.clear()
806-
schema.update(item_schema)
807-
schema['nullable'] = True
808-
return
809790

810791
type_ = schema.get('type')
792+
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
793+
# This gets hit when we have a discriminated union
794+
# Gemini returns an API error in this case even though it says in its error message it shouldn't...
795+
# Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
796+
schema['anyOf'] = schema.pop('oneOf')
811797

812-
if type_ == 'object':
813-
self._object(schema, refs_stack)
814-
elif type_ == 'array':
815-
return self._array(schema, refs_stack)
816-
elif type_ == 'string' and (fmt := schema.pop('format', None)):
798+
if type_ == 'string' and (fmt := schema.pop('format', None)):
817799
description = schema.get('description')
818800
if description:
819801
schema['description'] = f'{description} (format: {fmt})'
820802
else:
821803
schema['description'] = f'Format: {fmt}'
822804

823-
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
824-
ad_props = schema.pop('additionalProperties', None)
825-
if ad_props:
826-
raise UserError('Additional properties in JSON Schema are not supported by Gemini')
827-
828-
if properties := schema.get('properties'): # pragma: no branch
829-
for value in properties.values():
830-
self._simplify(value, refs_stack)
831-
832-
def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
833-
if prefix_items := schema.get('prefixItems'):
834-
# TODO I think this not is supported by Gemini, maybe we should raise an error?
835-
for prefix_item in prefix_items:
836-
self._simplify(prefix_item, refs_stack)
837-
838-
if items_schema := schema.get('items'): # pragma: no branch
839-
self._simplify(items_schema, refs_stack)
805+
if '$ref' in schema:
806+
raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}')
807+
808+
if 'prefixItems' in schema:
809+
# prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
810+
prefix_items = schema.pop('prefixItems')
811+
items = schema.get('items')
812+
unique_items = [items] if items is not None else []
813+
for item in prefix_items:
814+
if item not in unique_items:
815+
unique_items.append(item)
816+
if len(unique_items) > 1: # pragma: no cover
817+
schema['items'] = {'anyOf': unique_items}
818+
elif len(unique_items) == 1:
819+
schema['items'] = unique_items[0]
820+
schema.setdefault('minItems', len(prefix_items))
821+
if items is None:
822+
schema.setdefault('maxItems', len(prefix_items))
823+
824+
return schema
840825

841826

842827
def _ensure_decodeable(content: bytearray) -> bytearray:
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import re
2+
from abc import ABC, abstractmethod
3+
from copy import deepcopy
4+
from dataclasses import dataclass
5+
from typing import Any, Literal
6+
7+
from pydantic_ai.exceptions import UserError
8+
9+
JsonSchema = dict[str, Any]
10+
11+
12+
@dataclass(init=False)
13+
class WalkJsonSchema(ABC):
14+
"""Walks a JSON schema, applying transformations to it at each level."""
15+
16+
def __init__(
17+
self, schema: JsonSchema, *, prefer_inlined_defs: bool = False, simplify_nullable_unions: bool = False
18+
):
19+
self.schema = deepcopy(schema)
20+
self.prefer_inlined_defs = prefer_inlined_defs
21+
self.simplify_nullable_unions = simplify_nullable_unions
22+
23+
self.defs: dict[str, JsonSchema] = self.schema.pop('$defs', {})
24+
self.refs_stack = tuple[str, ...]()
25+
self.recursive_refs = set[str]()
26+
27+
@abstractmethod
28+
def transform(self, schema: JsonSchema) -> JsonSchema:
29+
"""Make changes to the schema."""
30+
return schema
31+
32+
def walk(self) -> JsonSchema:
33+
handled = self._handle(deepcopy(self.schema))
34+
35+
if not self.prefer_inlined_defs and self.defs:
36+
handled['$defs'] = {k: self._handle(v) for k, v in self.defs.items()}
37+
38+
elif self.recursive_refs: # pragma: no cover
39+
# If we are preferring inlined defs and there are recursive refs, we _have_ to use a $defs+$ref structure
40+
# We try to use whatever the original root key was, but if it is already in use,
41+
# we modify it to avoid collisions.
42+
defs = {key: self.defs[key] for key in self.recursive_refs}
43+
root_ref = self.schema.get('$ref')
44+
root_key = None if root_ref is None else re.sub(r'^#/\$defs/', '', root_ref)
45+
if root_key is None:
46+
root_key = self.schema.get('title', 'root')
47+
while root_key in defs:
48+
# Modify the root key until it is not already in use
49+
root_key = f'{root_key}_root'
50+
51+
defs[root_key] = handled
52+
return {'$defs': defs, '$ref': f'#/$defs/{root_key}'}
53+
54+
return handled
55+
56+
def _handle(self, schema: JsonSchema) -> JsonSchema:
57+
if self.prefer_inlined_defs:
58+
while ref := schema.get('$ref'):
59+
key = re.sub(r'^#/\$defs/', '', ref)
60+
if key in self.refs_stack:
61+
self.recursive_refs.add(key)
62+
break # recursive ref can't be unpacked
63+
self.refs_stack += (key,)
64+
def_schema = self.defs.get(key)
65+
if def_schema is None: # pragma: no cover
66+
raise UserError(f'Could not find $ref definition for {key}')
67+
schema = def_schema
68+
69+
# Handle the schema based on its type / structure
70+
type_ = schema.get('type')
71+
if type_ == 'object':
72+
schema = self._handle_object(schema)
73+
elif type_ == 'array':
74+
schema = self._handle_array(schema)
75+
elif type_ is None:
76+
schema = self._handle_union(schema, 'anyOf')
77+
schema = self._handle_union(schema, 'oneOf')
78+
79+
# Apply the base transform
80+
schema = self.transform(schema)
81+
82+
return schema
83+
84+
def _handle_object(self, schema: JsonSchema) -> JsonSchema:
85+
if properties := schema.get('properties'):
86+
handled_properties = {}
87+
for key, value in properties.items():
88+
handled_properties[key] = self._handle(value)
89+
schema['properties'] = handled_properties
90+
91+
if (additional_properties := schema.get('additionalProperties')) is not None:
92+
if isinstance(additional_properties, bool):
93+
schema['additionalProperties'] = additional_properties
94+
else: # pragma: no cover
95+
schema['additionalProperties'] = self._handle(additional_properties)
96+
97+
if (pattern_properties := schema.get('patternProperties')) is not None:
98+
handled_pattern_properties = {}
99+
for key, value in pattern_properties.items():
100+
handled_pattern_properties[key] = self._handle(value)
101+
schema['patternProperties'] = handled_pattern_properties
102+
103+
return schema
104+
105+
def _handle_array(self, schema: JsonSchema) -> JsonSchema:
106+
if prefix_items := schema.get('prefixItems'):
107+
schema['prefixItems'] = [self._handle(item) for item in prefix_items]
108+
109+
if items := schema.get('items'):
110+
schema['items'] = self._handle(items)
111+
112+
return schema
113+
114+
def _handle_union(self, schema: JsonSchema, union_kind: Literal['anyOf', 'oneOf']) -> JsonSchema:
115+
members = schema.get(union_kind)
116+
if not members:
117+
return schema
118+
119+
handled = [self._handle(member) for member in members]
120+
121+
# convert nullable unions to nullable types
122+
if self.simplify_nullable_unions:
123+
handled = self._simplify_nullable_union(handled)
124+
125+
if len(handled) == 1:
126+
# In this case, no need to retain the union
127+
return handled[0]
128+
129+
# If we have keys besides the union kind (such as title or discriminator), keep them without modifications
130+
schema = schema.copy()
131+
schema[union_kind] = handled
132+
return schema
133+
134+
@staticmethod
135+
def _simplify_nullable_union(cases: list[JsonSchema]) -> list[JsonSchema]:
136+
# TODO: Should we move this to relevant subclasses? Or is it worth keeping here to make reuse easier?
137+
if len(cases) == 2 and {'type': 'null'} in cases:
138+
# Find the non-null schema
139+
non_null_schema = next(
140+
(item for item in cases if item != {'type': 'null'}),
141+
None,
142+
)
143+
if non_null_schema:
144+
# Create a new schema based on the non-null part, mark as nullable
145+
new_schema = deepcopy(non_null_schema)
146+
new_schema['nullable'] = True
147+
return [new_schema]
148+
else: # pragma: no cover
149+
# they are both null, so just return one of them
150+
return [cases[0]]
151+
152+
return cases # pragma: no cover

0 commit comments

Comments
 (0)