From 2f6f37796e8803635462d953f7995189c13924ef Mon Sep 17 00:00:00 2001 From: soundTricker Date: Thu, 3 Jul 2025 17:36:10 +0900 Subject: [PATCH] fix: Handle HttpResponse objects in VertexAI session service API calls --- .../adk/sessions/vertex_ai_session_service.py | 1 + .../test_vertex_ai_session_service.py | 86 ++++++++++++------- 2 files changed, 58 insertions(+), 29 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index eac13367b..31186693d 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -236,6 +236,7 @@ async def get_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}', request_dict={}, ) + list_events_api_response = _convert_api_response(list_events_api_response) session.events += [ _from_api_event(event) for event in list_events_api_response['sessionEvents'] diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 52fa42c91..0473bf4f2 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import re import this from typing import Any @@ -155,7 +156,6 @@ ], ) - SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$' SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$' @@ -169,10 +169,25 @@ class MockApiClient: """Mocks the API Client.""" - def __init__(self) -> None: - """Initializes MockClient.""" - this.session_dict: dict[str, Any] = {} - this.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {} + def __init__(self, return_as_http_response: bool = False) -> None: + """Initializes MockClient. + + Args: + return_as_http_response: If True, the mock client will return + `types.HttpResponse` objects. Otherwise, it will return raw dicts. + """ + self.session_dict: dict[str, Any] = {} + self.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {} + self.return_as_http_response = return_as_http_response + + def _maybe_wrap_in_response(self, data: dict[str, Any]) -> Any: + """Wraps the data in an HttpResponse if configured to do so.""" + if self.return_as_http_response: + return types.HttpResponse( + headers={}, + body=json.dumps(data).encode('utf-8'), + ) + return data async def async_request( self, http_method: str, path: str, request_dict: dict[str, Any] @@ -184,36 +199,40 @@ async def async_request( if match: session_id = match.group(2) if session_id in self.session_dict: - return self.session_dict[session_id] + return self._maybe_wrap_in_response(self.session_dict[session_id]) else: raise ValueError(f'Session not found: {session_id}') elif re.match(SESSIONS_REGEX, path): match = re.match(SESSIONS_REGEX, path) - return { + response_data = { 'sessions': [ session for session in self.session_dict.values() if session['userId'] == match.group(2) ], } + return self._maybe_wrap_in_response(response_data) elif re.match(EVENTS_REGEX, path): match = re.match(EVENTS_REGEX, path) if match: session_id = match.group(2) - if match.group(3): - return {'sessionEvents': MOCK_EVENT_JSON_3} + if match.group(3): # pageToken is present + return self._maybe_wrap_in_response( + {'sessionEvents': MOCK_EVENT_JSON_3} + ) events_tuple = self.event_dict.get(session_id, ([], None)) response = {'sessionEvents': events_tuple[0]} if events_tuple[1]: response['nextPageToken'] = events_tuple[1] - return response + return self._maybe_wrap_in_response(response) elif re.match(LRO_REGEX, path): # Mock long-running operation as completed - return { + response_data = { 'name': path, 'done': True, 'response': self.session_dict['4'], # Return the created session } + return self._maybe_wrap_in_response(response_data) else: raise ValueError(f'Unsupported path: {path}') elif http_method == 'POST': @@ -228,7 +247,7 @@ async def async_request( 'sessionState': request_dict.get('session_state', {}), 'updateTime': '2024-12-12T12:12:12.123456Z', } - return { + response_data = { 'name': ( 'projects/test_project/locations/test_location/' 'reasoningEngines/123/sessions/' @@ -237,6 +256,7 @@ async def async_request( ), 'done': False, } + return self._maybe_wrap_in_response(response_data) elif http_method == 'DELETE': match = re.match(SESSION_REGEX, path) if match: @@ -259,8 +279,17 @@ def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None): @pytest.fixture -def mock_get_api_client(): - api_client = MockApiClient() +def mock_get_api_client(request): + """Patches _get_api_client to return a mock client. + + This fixture is parameterized indirectly. The parameter determines whether the + mock client returns raw dicts (False) or HttpResponse objects (True). + + Args: + request: The pytest request object, used for indirect parameterization. + """ + return_as_http_response = getattr(request, 'param', False) + api_client = MockApiClient(return_as_http_response=return_as_http_response) api_client.session_dict = { '1': MOCK_SESSION_JSON_1, '2': MOCK_SESSION_JSON_2, @@ -273,14 +302,14 @@ def mock_get_api_client(): with mock.patch( 'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client', return_value=api_client, - ): - yield + ) as mock_patch: + yield mock_patch @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') +@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True) @pytest.mark.parametrize('agent_engine_id', [None, '123']) -async def test_get_empty_session(agent_engine_id): +async def test_get_empty_session(agent_engine_id, mock_get_api_client): if agent_engine_id: session_service = mock_vertex_ai_session_service(agent_engine_id) else: @@ -293,9 +322,9 @@ async def test_get_empty_session(agent_engine_id): @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') +@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True) @pytest.mark.parametrize('agent_engine_id', [None, '123']) -async def test_get_another_user_session(agent_engine_id): +async def test_get_another_user_session(agent_engine_id, mock_get_api_client): if agent_engine_id: session_service = mock_vertex_ai_session_service(agent_engine_id) else: @@ -308,8 +337,8 @@ async def test_get_another_user_session(agent_engine_id): @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') -async def test_get_and_delete_session(): +@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True) +async def test_get_and_delete_session(mock_get_api_client): session_service = mock_vertex_ai_session_service() assert ( @@ -330,8 +359,8 @@ async def test_get_and_delete_session(): @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') -async def test_get_session_with_page_token(): +@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True) +async def test_get_session_with_page_token(mock_get_api_client): session_service = mock_vertex_ai_session_service() assert ( @@ -343,8 +372,8 @@ async def test_get_session_with_page_token(): @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') -async def test_list_sessions(): +@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True) +async def test_list_sessions(mock_get_api_client): session_service = mock_vertex_ai_session_service() sessions = await session_service.list_sessions(app_name='123', user_id='user') assert len(sessions.sessions) == 2 @@ -353,8 +382,8 @@ async def test_list_sessions(): @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') -async def test_create_session(): +@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True) +async def test_create_session(mock_get_api_client): session_service = mock_vertex_ai_session_service() state = {'key': 'value'} @@ -373,7 +402,6 @@ async def test_create_session(): @pytest.mark.asyncio -@pytest.mark.usefixtures('mock_get_api_client') async def test_create_session_with_custom_session_id(): session_service = mock_vertex_ai_session_service()