|
6 | 6 | from collections.abc import AsyncIterator, Callable, Sequence
|
7 | 7 | from dataclasses import dataclass
|
8 | 8 | from datetime import timezone
|
| 9 | +from enum import IntEnum |
9 | 10 | from typing import Annotated
|
10 | 11 |
|
11 | 12 | import httpx
|
@@ -228,6 +229,64 @@ class Locations(BaseModel):
|
228 | 229 | )
|
229 | 230 |
|
230 | 231 |
|
| 232 | +async def test_json_def_enum(allow_model_requests: None): |
| 233 | + class ProgressEnum(IntEnum): |
| 234 | + DONE = 100 |
| 235 | + ALMOST_DONE = 80 |
| 236 | + IN_PROGRESS = 60 |
| 237 | + BARELY_STARTED = 40 |
| 238 | + NOT_STARTED = 20 |
| 239 | + |
| 240 | + class QueryDetails(BaseModel): |
| 241 | + progress: list[ProgressEnum] | None = None |
| 242 | + |
| 243 | + json_schema = QueryDetails.model_json_schema() |
| 244 | + assert json_schema == snapshot( |
| 245 | + { |
| 246 | + '$defs': {'ProgressEnum': {'enum': [100, 80, 60, 40, 20], 'title': 'ProgressEnum', 'type': 'integer'}}, |
| 247 | + 'properties': { |
| 248 | + 'progress': { |
| 249 | + 'anyOf': [{'items': {'$ref': '#/$defs/ProgressEnum'}, 'type': 'array'}, {'type': 'null'}], |
| 250 | + 'default': None, |
| 251 | + 'title': 'Progress', |
| 252 | + } |
| 253 | + }, |
| 254 | + 'title': 'QueryDetails', |
| 255 | + 'type': 'object', |
| 256 | + } |
| 257 | + ) |
| 258 | + m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg')) |
| 259 | + output_tool = ToolDefinition( |
| 260 | + 'result', |
| 261 | + 'This is the tool for the final Result', |
| 262 | + json_schema, |
| 263 | + ) |
| 264 | + mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) |
| 265 | + mrp = m.customize_request_parameters(mrp) |
| 266 | + |
| 267 | + # This tests that the enum values are properly converted to strings for Gemini |
| 268 | + assert m._get_tools(mrp) == snapshot( |
| 269 | + _GeminiTools( |
| 270 | + function_declarations=[ |
| 271 | + _GeminiFunction( |
| 272 | + name='result', |
| 273 | + description='This is the tool for the final Result', |
| 274 | + parameters={ |
| 275 | + 'properties': { |
| 276 | + 'progress': { |
| 277 | + 'items': {'enum': ['100', '80', '60', '40', '20'], 'type': 'string'}, |
| 278 | + 'type': 'array', |
| 279 | + 'nullable': True, |
| 280 | + } |
| 281 | + }, |
| 282 | + 'type': 'object', |
| 283 | + }, |
| 284 | + ) |
| 285 | + ] |
| 286 | + ) |
| 287 | + ) |
| 288 | + |
| 289 | + |
231 | 290 | async def test_json_def_replaced_any_of(allow_model_requests: None):
|
232 | 291 | class Location(BaseModel):
|
233 | 292 | lat: float
|
|
0 commit comments