Skip to content

fix: Handle HttpResponse objects in VertexAI session service API calls (fix #1772) #1773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ async def get_session(
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}',
request_dict={},
)
list_events_api_response = _convert_api_response(list_events_api_response)
session.events += [
_from_api_event(event)
for event in list_events_api_response['sessionEvents']
Expand Down
86 changes: 57 additions & 29 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re
import this
from typing import Any
Expand Down Expand Up @@ -155,7 +156,6 @@
],
)


SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string
r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$'
Expand All @@ -169,10 +169,25 @@
class MockApiClient:
"""Mocks the API Client."""

def __init__(self) -> None:
"""Initializes MockClient."""
this.session_dict: dict[str, Any] = {}
this.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {}
def __init__(self, return_as_http_response: bool = False) -> None:
"""Initializes MockClient.

Args:
return_as_http_response: If True, the mock client will return
`types.HttpResponse` objects. Otherwise, it will return raw dicts.
"""
self.session_dict: dict[str, Any] = {}
self.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {}
self.return_as_http_response = return_as_http_response

def _maybe_wrap_in_response(self, data: dict[str, Any]) -> Any:
"""Wraps the data in an HttpResponse if configured to do so."""
if self.return_as_http_response:
return types.HttpResponse(
headers={},
body=json.dumps(data).encode('utf-8'),
)
return data

async def async_request(
self, http_method: str, path: str, request_dict: dict[str, Any]
Expand All @@ -184,36 +199,40 @@ async def async_request(
if match:
session_id = match.group(2)
if session_id in self.session_dict:
return self.session_dict[session_id]
return self._maybe_wrap_in_response(self.session_dict[session_id])
else:
raise ValueError(f'Session not found: {session_id}')
elif re.match(SESSIONS_REGEX, path):
match = re.match(SESSIONS_REGEX, path)
return {
response_data = {
'sessions': [
session
for session in self.session_dict.values()
if session['userId'] == match.group(2)
],
}
return self._maybe_wrap_in_response(response_data)
elif re.match(EVENTS_REGEX, path):
match = re.match(EVENTS_REGEX, path)
if match:
session_id = match.group(2)
if match.group(3):
return {'sessionEvents': MOCK_EVENT_JSON_3}
if match.group(3): # pageToken is present
return self._maybe_wrap_in_response(
{'sessionEvents': MOCK_EVENT_JSON_3}
)
events_tuple = self.event_dict.get(session_id, ([], None))
response = {'sessionEvents': events_tuple[0]}
if events_tuple[1]:
response['nextPageToken'] = events_tuple[1]
return response
return self._maybe_wrap_in_response(response)
elif re.match(LRO_REGEX, path):
# Mock long-running operation as completed
return {
response_data = {
'name': path,
'done': True,
'response': self.session_dict['4'], # Return the created session
}
return self._maybe_wrap_in_response(response_data)
else:
raise ValueError(f'Unsupported path: {path}')
elif http_method == 'POST':
Expand All @@ -228,7 +247,7 @@ async def async_request(
'sessionState': request_dict.get('session_state', {}),
'updateTime': '2024-12-12T12:12:12.123456Z',
}
return {
response_data = {
'name': (
'projects/test_project/locations/test_location/'
'reasoningEngines/123/sessions/'
Expand All @@ -237,6 +256,7 @@ async def async_request(
),
'done': False,
}
return self._maybe_wrap_in_response(response_data)
elif http_method == 'DELETE':
match = re.match(SESSION_REGEX, path)
if match:
Expand All @@ -259,8 +279,17 @@ def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None):


@pytest.fixture
def mock_get_api_client():
api_client = MockApiClient()
def mock_get_api_client(request):
"""Patches _get_api_client to return a mock client.

This fixture is parameterized indirectly. The parameter determines whether the
mock client returns raw dicts (False) or HttpResponse objects (True).

Args:
request: The pytest request object, used for indirect parameterization.
"""
return_as_http_response = getattr(request, 'param', False)
api_client = MockApiClient(return_as_http_response=return_as_http_response)
api_client.session_dict = {
'1': MOCK_SESSION_JSON_1,
'2': MOCK_SESSION_JSON_2,
Expand All @@ -273,14 +302,14 @@ def mock_get_api_client():
with mock.patch(
'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client',
return_value=api_client,
):
yield
) as mock_patch:
yield mock_patch


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True)
@pytest.mark.parametrize('agent_engine_id', [None, '123'])
async def test_get_empty_session(agent_engine_id):
async def test_get_empty_session(agent_engine_id, mock_get_api_client):
if agent_engine_id:
session_service = mock_vertex_ai_session_service(agent_engine_id)
else:
Expand All @@ -293,9 +322,9 @@ async def test_get_empty_session(agent_engine_id):


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True)
@pytest.mark.parametrize('agent_engine_id', [None, '123'])
async def test_get_another_user_session(agent_engine_id):
async def test_get_another_user_session(agent_engine_id, mock_get_api_client):
if agent_engine_id:
session_service = mock_vertex_ai_session_service(agent_engine_id)
else:
Expand All @@ -308,8 +337,8 @@ async def test_get_another_user_session(agent_engine_id):


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_get_and_delete_session():
@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True)
async def test_get_and_delete_session(mock_get_api_client):
session_service = mock_vertex_ai_session_service()

assert (
Expand All @@ -330,8 +359,8 @@ async def test_get_and_delete_session():


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_get_session_with_page_token():
@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True)
async def test_get_session_with_page_token(mock_get_api_client):
session_service = mock_vertex_ai_session_service()

assert (
Expand All @@ -343,8 +372,8 @@ async def test_get_session_with_page_token():


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_list_sessions():
@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True)
async def test_list_sessions(mock_get_api_client):
session_service = mock_vertex_ai_session_service()
sessions = await session_service.list_sessions(app_name='123', user_id='user')
assert len(sessions.sessions) == 2
Expand All @@ -353,8 +382,8 @@ async def test_list_sessions():


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_create_session():
@pytest.mark.parametrize('mock_get_api_client', [False, True], indirect=True)
async def test_create_session(mock_get_api_client):
session_service = mock_vertex_ai_session_service()

state = {'key': 'value'}
Expand All @@ -373,7 +402,6 @@ async def test_create_session():


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_create_session_with_custom_session_id():
session_service = mock_vertex_ai_session_service()

Expand Down