1
1
from __future__ import annotations as _annotations
2
2
3
3
import base64
4
- import re
5
4
from collections .abc import AsyncIterator , Sequence
6
5
from contextlib import asynccontextmanager
7
- from copy import deepcopy
8
6
from dataclasses import dataclass , field , replace
9
7
from datetime import datetime
10
8
from typing import Annotated , Any , Literal , Protocol , Union , cast
46
44
check_allow_model_requests ,
47
45
get_user_agent ,
48
46
)
47
+ from .json_schema import JsonSchema , WalkJsonSchema
49
48
50
49
LatestGeminiModelNames = Literal [
51
50
'gemini-1.5-flash' ,
@@ -156,7 +155,7 @@ async def request_stream(
156
155
157
156
def customize_request_parameters (self , model_request_parameters : ModelRequestParameters ) -> ModelRequestParameters :
158
157
def _customize_tool_def (t : ToolDefinition ):
159
- return replace (t , parameters_json_schema = _GeminiJsonSchema (t .parameters_json_schema ).simplify ())
158
+ return replace (t , parameters_json_schema = _GeminiJsonSchema (t .parameters_json_schema ).walk ())
160
159
161
160
return ModelRequestParameters (
162
161
function_tools = [_customize_tool_def (tool ) for tool in model_request_parameters .function_tools ],
@@ -760,7 +759,7 @@ class _GeminiPromptFeedback(TypedDict):
760
759
_gemini_streamed_response_ta = pydantic .TypeAdapter (list [_GeminiResponse ], config = pydantic .ConfigDict (defer_build = True ))
761
760
762
761
763
- class _GeminiJsonSchema :
762
+ class _GeminiJsonSchema ( WalkJsonSchema ) :
764
763
"""Transforms the JSON Schema from Pydantic to be suitable for Gemini.
765
764
766
765
Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
@@ -771,72 +770,58 @@ class _GeminiJsonSchema:
771
770
* gemini doesn't allow `$defs` — we need to inline the definitions where possible
772
771
"""
773
772
774
- def __init__ (self , schema : _utils .ObjectJsonSchema ):
775
- self .schema = deepcopy (schema )
776
- self .defs = self .schema .pop ('$defs' , {})
773
+ def __init__ (self , schema : JsonSchema ):
774
+ super ().__init__ (schema , prefer_inlined_defs = True , simplify_nullable_unions = True )
777
775
778
- def simplify (self ) -> dict [str , Any ]:
779
- self ._simplify (self .schema , refs_stack = ())
780
- return self .schema
781
-
782
- def _simplify (self , schema : dict [str , Any ], refs_stack : tuple [str , ...]) -> None :
776
+ def transform (self , schema : JsonSchema ) -> JsonSchema :
783
777
schema .pop ('title' , None )
784
778
schema .pop ('default' , None )
785
779
schema .pop ('$schema' , None )
780
+ if (const := schema .pop ('const' , None )) is not None : # pragma: no cover
781
+ # Gemini doesn't support const, but it does support enum with a single value
782
+ schema ['enum' ] = [const ]
783
+ schema .pop ('discriminator' , None )
784
+ schema .pop ('examples' , None )
785
+
786
+ # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
787
+ # where we add notes about these properties to the field description?
786
788
schema .pop ('exclusiveMaximum' , None )
787
789
schema .pop ('exclusiveMinimum' , None )
788
- if ref := schema .pop ('$ref' , None ):
789
- # noinspection PyTypeChecker
790
- key = re .sub (r'^#/\$defs/' , '' , ref )
791
- if key in refs_stack :
792
- raise UserError ('Recursive `$ref`s in JSON Schema are not supported by Gemini' )
793
- refs_stack += (key ,)
794
- schema_def = self .defs [key ]
795
- self ._simplify (schema_def , refs_stack )
796
- schema .update (schema_def )
797
- return
798
-
799
- if any_of := schema .get ('anyOf' ):
800
- for item_schema in any_of :
801
- self ._simplify (item_schema , refs_stack )
802
- if len (any_of ) == 2 and {'type' : 'null' } in any_of :
803
- for item_schema in any_of :
804
- if item_schema != {'type' : 'null' }:
805
- schema .clear ()
806
- schema .update (item_schema )
807
- schema ['nullable' ] = True
808
- return
809
790
810
791
type_ = schema .get ('type' )
792
+ if 'oneOf' in schema and 'type' not in schema : # pragma: no cover
793
+ # This gets hit when we have a discriminated union
794
+ # Gemini returns an API error in this case even though it says in its error message it shouldn't...
795
+ # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
796
+ schema ['anyOf' ] = schema .pop ('oneOf' )
811
797
812
- if type_ == 'object' :
813
- self ._object (schema , refs_stack )
814
- elif type_ == 'array' :
815
- return self ._array (schema , refs_stack )
816
- elif type_ == 'string' and (fmt := schema .pop ('format' , None )):
798
+ if type_ == 'string' and (fmt := schema .pop ('format' , None )):
817
799
description = schema .get ('description' )
818
800
if description :
819
801
schema ['description' ] = f'{ description } (format: { fmt } )'
820
802
else :
821
803
schema ['description' ] = f'Format: { fmt } '
822
804
823
- def _object (self , schema : dict [str , Any ], refs_stack : tuple [str , ...]) -> None :
824
- ad_props = schema .pop ('additionalProperties' , None )
825
- if ad_props :
826
- raise UserError ('Additional properties in JSON Schema are not supported by Gemini' )
827
-
828
- if properties := schema .get ('properties' ): # pragma: no branch
829
- for value in properties .values ():
830
- self ._simplify (value , refs_stack )
831
-
832
- def _array (self , schema : dict [str , Any ], refs_stack : tuple [str , ...]) -> None :
833
- if prefix_items := schema .get ('prefixItems' ):
834
- # TODO I think this not is supported by Gemini, maybe we should raise an error?
835
- for prefix_item in prefix_items :
836
- self ._simplify (prefix_item , refs_stack )
837
-
838
- if items_schema := schema .get ('items' ): # pragma: no branch
839
- self ._simplify (items_schema , refs_stack )
805
+ if '$ref' in schema :
806
+ raise UserError (f'Recursive `$ref`s in JSON Schema are not supported by Gemini: { schema ["$ref" ]} ' )
807
+
808
+ if 'prefixItems' in schema :
809
+ # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
810
+ prefix_items = schema .pop ('prefixItems' )
811
+ items = schema .get ('items' )
812
+ unique_items = [items ] if items is not None else []
813
+ for item in prefix_items :
814
+ if item not in unique_items :
815
+ unique_items .append (item )
816
+ if len (unique_items ) > 1 : # pragma: no cover
817
+ schema ['items' ] = {'anyOf' : unique_items }
818
+ elif len (unique_items ) == 1 :
819
+ schema ['items' ] = unique_items [0 ]
820
+ schema .setdefault ('minItems' , len (prefix_items ))
821
+ if items is None :
822
+ schema .setdefault ('maxItems' , len (prefix_items ))
823
+
824
+ return schema
840
825
841
826
842
827
def _ensure_decodeable (content : bytearray ) -> bytearray :
0 commit comments