Skip to content

Commit b6e73e8

Browse files
author
SMKRV
committed
fix(api_client): correct Google Gemini API integration
- Change JSON field names from camelCase to snake_case as required by Gemini API (generation_config, max_output_tokens, system_instruction) - Improve message handling to ensure proper role alternation (user/model) - Add safety checks for empty contents and ensure first message is always from user - Implement robust error handling and response parsing - Handle edge cases where candidatesTokenCount might be returned as a list Fixes #6
1 parent 440c734 commit b6e73e8

File tree

1 file changed

+88
-21
lines changed

1 file changed

+88
-21
lines changed

custom_components/ha_text_ai/api_client.py

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -254,45 +254,112 @@ async def _create_gemini_completion(
254254
# Extract API key from headers (Bearer token)
255255
api_key = self.headers.get("Authorization", "").replace("Bearer ", "")
256256
url = f"{self.endpoint}/models/{model}:generateContent?key={api_key}"
257-
257+
258258
# Convert messages to Gemini format
259259
contents = []
260260
system_instruction = ""
261-
261+
262+
# Process messages and ensure proper role alternation
263+
current_role = None
264+
current_content = ""
265+
262266
for msg in messages:
263267
if msg['role'] == 'system':
264268
system_instruction += msg['content'] + "\n"
265269
else:
266-
contents.append({
267-
"role": "user" if msg['role'] == 'user' else "model",
268-
"parts": [{"text": msg['content']}]
269-
})
270+
role = "user" if msg['role'] == 'user' else "model"
271+
272+
# If same role as previous, combine content
273+
if role == current_role:
274+
current_content += "\n" + msg['content']
275+
else:
276+
# Add previous message if exists
277+
if current_role is not None:
278+
contents.append({
279+
"role": current_role,
280+
"parts": [{"text": current_content}]
281+
})
282+
# Start new message
283+
current_role = role
284+
current_content = msg['content']
285+
286+
# Add the last message if exists
287+
if current_role is not None:
288+
contents.append({
289+
"role": current_role,
290+
"parts": [{"text": current_content}]
291+
})
270292

293+
# Ensure contents starts with user message if not empty
294+
if contents and contents[0]["role"] != "user":
295+
# Add a placeholder user message if needed
296+
contents.insert(0, {
297+
"role": "user",
298+
"parts": [{"text": "I need your assistance."}]
299+
})
300+
301+
# Ensure contents is not empty
302+
if not contents:
303+
contents.append({
304+
"role": "user",
305+
"parts": [{"text": "I need your assistance."}]
306+
})
307+
308+
# Create payload with snake_case keys as required by Gemini API
271309
payload = {
272310
"contents": contents,
273-
"generationConfig": {
311+
"generation_config": { # Changed from camelCase to snake_case
274312
"temperature": temperature,
275-
"maxOutputTokens": max_tokens
313+
"max_output_tokens": max_tokens # Changed from camelCase to snake_case
276314
}
277315
}
316+
278317
if system_instruction:
279-
payload["systemInstruction"] = {
280-
"parts": [{"text": system_instruction}]
318+
payload["system_instruction"] = { # Changed from camelCase to snake_case
319+
"parts": [{"text": system_instruction.strip()}]
281320
}
282321

283-
data = await self._make_request(url, payload)
284-
return {
285-
"choices": [{
286-
"message": {
287-
"content": data["candidates"][0]["content"]["parts"][0]["text"]
322+
try:
323+
data = await self._make_request(url, payload)
324+
325+
# Safely extract response data with fallbacks
326+
candidates = data.get("candidates", [])
327+
if not candidates:
328+
raise HomeAssistantError("Gemini API returned no candidates")
329+
330+
content = candidates[0].get("content", {})
331+
parts = content.get("parts", [])
332+
if not parts:
333+
raise HomeAssistantError("Gemini API response contains no content parts")
334+
335+
answer_text = parts[0].get("text", "")
336+
337+
# Safely extract usage data
338+
usage = data.get("usageMetadata", {})
339+
prompt_tokens = usage.get("promptTokenCount", 0)
340+
completion_tokens = usage.get("candidatesTokenCount", 0)
341+
342+
# Handle case where candidatesTokenCount might be a list
343+
if isinstance(completion_tokens, list):
344+
completion_tokens = sum(completion_tokens)
345+
346+
total_tokens = usage.get("totalTokenCount", prompt_tokens + completion_tokens)
347+
348+
return {
349+
"choices": [{
350+
"message": {
351+
"content": answer_text
352+
}
353+
}],
354+
"usage": {
355+
"prompt_tokens": prompt_tokens,
356+
"completion_tokens": completion_tokens,
357+
"total_tokens": total_tokens
288358
}
289-
}],
290-
"usage": {
291-
"prompt_tokens": data["usageMetadata"]["promptTokenCount"],
292-
"completion_tokens": data["usageMetadata"]["candidatesTokenCount"],
293-
"total_tokens": data["usageMetadata"]["totalTokenCount"]
294359
}
295-
}
360+
except Exception as e:
361+
_LOGGER.error(f"Gemini API error: {str(e)}")
362+
raise HomeAssistantError(f"Gemini API error: {str(e)}")
296363

297364
async def shutdown(self) -> None:
298365
"""Shutdown API client."""

0 commit comments

Comments
 (0)