Skip to content

Commit 4d72d31

Browse files
DeanChensjcopybara-github
authored andcommitted
fix: Add type checking to handle different response type of genai API client
Fixes #1514 PiperOrigin-RevId: 773838035
1 parent 742478f commit 4d72d31

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import asyncio
17+
import json
1718
import logging
1819
import re
1920
from typing import Any
@@ -87,6 +88,7 @@ async def create_session(
8788
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
8889
request_dict=session_json_dict,
8990
)
91+
api_response = _convert_api_response(api_response)
9092
logger.info(f'Create Session response {api_response}')
9193

9294
session_id = api_response['name'].split('/')[-3]
@@ -100,6 +102,7 @@ async def create_session(
100102
path=f'operations/{operation_id}',
101103
request_dict={},
102104
)
105+
lro_response = _convert_api_response(lro_response)
103106

104107
if lro_response.get('done', None):
105108
break
@@ -118,6 +121,7 @@ async def create_session(
118121
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
119122
request_dict={},
120123
)
124+
get_session_api_response = _convert_api_response(get_session_api_response)
121125

122126
update_timestamp = isoparse(
123127
get_session_api_response['updateTime']
@@ -149,6 +153,7 @@ async def get_session(
149153
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
150154
request_dict={},
151155
)
156+
get_session_api_response = _convert_api_response(get_session_api_response)
152157

153158
session_id = get_session_api_response['name'].split('/')[-1]
154159
update_timestamp = isoparse(
@@ -167,9 +172,12 @@ async def get_session(
167172
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
168173
request_dict={},
169174
)
175+
list_events_api_response = _convert_api_response(list_events_api_response)
170176

171177
# Handles empty response case
172-
if list_events_api_response.get('httpHeaders', None):
178+
if not list_events_api_response or list_events_api_response.get(
179+
'httpHeaders', None
180+
):
173181
return session
174182

175183
session.events += [
@@ -226,9 +234,10 @@ async def list_sessions(
226234
path=path,
227235
request_dict={},
228236
)
237+
api_response = _convert_api_response(api_response)
229238

230239
# Handles empty response case
231-
if api_response.get('httpHeaders', None):
240+
if not api_response or api_response.get('httpHeaders', None):
232241
return ListSessionsResponse()
233242

234243
sessions = []
@@ -303,6 +312,13 @@ def _get_api_client(self):
303312
return client._api_client
304313

305314

315+
def _convert_api_response(api_response):
316+
"""Converts the API response to a JSON object based on the type."""
317+
if hasattr(api_response, 'body'):
318+
return json.loads(api_response.body)
319+
return api_response
320+
321+
306322
def _convert_event_to_json(event: Event) -> Dict[str, Any]:
307323
metadata_json = {
308324
'partial': event.partial,

0 commit comments

Comments
 (0)