Skip to content

Commit 044ca13

Browse files
Switch gemini request to camelCase as required by API (#1456)
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
1 parent cc18937 commit 044ca13

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,11 @@ async def _make_request(
203203

204204
request_data = _GeminiRequest(contents=contents)
205205
if sys_prompt_parts:
206-
request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
206+
request_data['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
207207
if tools is not None:
208208
request_data['tools'] = tools
209209
if tool_config is not None:
210-
request_data['tool_config'] = tool_config
210+
request_data['toolConfig'] = tool_config
211211

212212
generation_config: _GeminiGenerationConfig = {}
213213
if model_settings:
@@ -222,9 +222,9 @@ async def _make_request(
222222
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
223223
generation_config['frequency_penalty'] = frequency_penalty
224224
if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
225-
request_data['safety_settings'] = gemini_safety_settings
225+
request_data['safetySettings'] = gemini_safety_settings
226226
if generation_config:
227-
request_data['generation_config'] = generation_config
227+
request_data['generationConfig'] = generation_config
228228

229229
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
230230
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
@@ -450,17 +450,19 @@ class _GeminiRequest(TypedDict):
450450
See <https://ai.google.dev/api/generate-content#request-body> for API docs.
451451
"""
452452

453+
# Note: Even though Google supposedly supports camelCase and snake_case, we've had user report misbehavior
454+
# when using snake_case, which is why this typeddict now uses camelCase. And anyway, the plan is to replace this
455+
# with an official google SDK in the near future anyway.
453456
contents: list[_GeminiContent]
454457
tools: NotRequired[_GeminiTools]
455-
tool_config: NotRequired[_GeminiToolConfig]
456-
safety_settings: NotRequired[list[GeminiSafetySettings]]
457-
# we don't implement `generationConfig`, instead we use a named tool for the response
458-
system_instruction: NotRequired[_GeminiTextContent]
458+
toolConfig: NotRequired[_GeminiToolConfig]
459+
safetySettings: NotRequired[list[GeminiSafetySettings]]
460+
systemInstruction: NotRequired[_GeminiTextContent]
459461
"""
460462
Developer generated system instructions, see
461463
<https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
462464
"""
463-
generation_config: NotRequired[_GeminiGenerationConfig]
465+
generationConfig: NotRequired[_GeminiGenerationConfig]
464466

465467

466468
class GeminiSafetySettings(TypedDict):

tests/models/test_gemini.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ async def test_empty_text_ignored():
837837

838838
async def test_model_settings(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None) -> None:
839839
def handler(request: httpx.Request) -> httpx.Response:
840-
generation_config = json.loads(request.content)['generation_config']
840+
generation_config = json.loads(request.content)['generationConfig']
841841
assert generation_config == {
842842
'max_output_tokens': 1,
843843
'temperature': 0.1,
@@ -886,7 +886,7 @@ async def test_safety_settings_unsafe(
886886
try:
887887

888888
def handler(request: httpx.Request) -> httpx.Response:
889-
safety_settings = json.loads(request.content)['safety_settings']
889+
safety_settings = json.loads(request.content)['safetySettings']
890890
assert safety_settings == [
891891
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
892892
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
@@ -928,7 +928,7 @@ async def test_safety_settings_safe(
928928
client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
929929
) -> None:
930930
def handler(request: httpx.Request) -> httpx.Response:
931-
safety_settings = json.loads(request.content)['safety_settings']
931+
safety_settings = json.loads(request.content)['safetySettings']
932932
assert safety_settings == [
933933
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
934934
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},

0 commit comments

Comments
 (0)