Skip to content

Commit 1282f32

Browse files
zaidhaantboserDouweM
authored
Add support for non-string enums in Gemini (#1564)
Co-authored-by: Thomas Boser <thomasboser@gmail.com> Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent 240b012 commit 1282f32

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,12 @@ def transform(self, schema: JsonSchema) -> JsonSchema:
831831
schema.pop('exclusiveMaximum', None)
832832
schema.pop('exclusiveMinimum', None)
833833

834+
# Gemini only supports string enums, so we need to convert any enum values to strings.
835+
# Pydantic will take care of transforming the transformed string values to the correct type.
836+
if enum := schema.get('enum'):
837+
schema['type'] = 'string'
838+
schema['enum'] = [str(val) for val in enum]
839+
834840
type_ = schema.get('type')
835841
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
836842
# This gets hit when we have a discriminated union

tests/models/test_gemini.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import AsyncIterator, Callable, Sequence
77
from dataclasses import dataclass
88
from datetime import timezone
9+
from enum import IntEnum
910
from typing import Annotated
1011

1112
import httpx
@@ -228,6 +229,64 @@ class Locations(BaseModel):
228229
)
229230

230231

232+
async def test_json_def_enum(allow_model_requests: None):
233+
class ProgressEnum(IntEnum):
234+
DONE = 100
235+
ALMOST_DONE = 80
236+
IN_PROGRESS = 60
237+
BARELY_STARTED = 40
238+
NOT_STARTED = 20
239+
240+
class QueryDetails(BaseModel):
241+
progress: list[ProgressEnum] | None = None
242+
243+
json_schema = QueryDetails.model_json_schema()
244+
assert json_schema == snapshot(
245+
{
246+
'$defs': {'ProgressEnum': {'enum': [100, 80, 60, 40, 20], 'title': 'ProgressEnum', 'type': 'integer'}},
247+
'properties': {
248+
'progress': {
249+
'anyOf': [{'items': {'$ref': '#/$defs/ProgressEnum'}, 'type': 'array'}, {'type': 'null'}],
250+
'default': None,
251+
'title': 'Progress',
252+
}
253+
},
254+
'title': 'QueryDetails',
255+
'type': 'object',
256+
}
257+
)
258+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
259+
output_tool = ToolDefinition(
260+
'result',
261+
'This is the tool for the final Result',
262+
json_schema,
263+
)
264+
mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool])
265+
mrp = m.customize_request_parameters(mrp)
266+
267+
# This tests that the enum values are properly converted to strings for Gemini
268+
assert m._get_tools(mrp) == snapshot(
269+
_GeminiTools(
270+
function_declarations=[
271+
_GeminiFunction(
272+
name='result',
273+
description='This is the tool for the final Result',
274+
parameters={
275+
'properties': {
276+
'progress': {
277+
'items': {'enum': ['100', '80', '60', '40', '20'], 'type': 'string'},
278+
'type': 'array',
279+
'nullable': True,
280+
}
281+
},
282+
'type': 'object',
283+
},
284+
)
285+
]
286+
)
287+
)
288+
289+
231290
async def test_json_def_replaced_any_of(allow_model_requests: None):
232291
class Location(BaseModel):
233292
lat: float

0 commit comments

Comments
 (0)