Skip to content

Commit d212e50

Browse files
DeanChensjcopybara-github
authored andcommitted
feat:Make VertexAiSessionService true async.
PiperOrigin-RevId: 762547133
1 parent 79681e3 commit d212e50

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
import logging
1516
import re
16-
import time
1717
from typing import Any
1818
from typing import Optional
1919

@@ -69,7 +69,8 @@ async def create_session(
6969
if state:
7070
session_json_dict['session_state'] = state
7171

72-
api_response = self.api_client.request(
72+
api_client = _get_api_client(self.project, self.location)
73+
api_response = await api_client.async_request(
7374
http_method='POST',
7475
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
7576
request_dict=session_json_dict,
@@ -81,7 +82,7 @@ async def create_session(
8182

8283
max_retry_attempt = 5
8384
while max_retry_attempt >= 0:
84-
lro_response = self.api_client.request(
85+
lro_response = await api_client.async_request(
8586
http_method='GET',
8687
path=f'operations/{operation_id}',
8788
request_dict={},
@@ -90,11 +91,11 @@ async def create_session(
9091
if lro_response.get('done', None):
9192
break
9293

93-
time.sleep(1)
94+
await asyncio.sleep(1)
9495
max_retry_attempt -= 1
9596

9697
# Get session resource
97-
get_session_api_response = self.api_client.request(
98+
get_session_api_response = await api_client.async_request(
9899
http_method='GET',
99100
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
100101
request_dict={},
@@ -124,7 +125,8 @@ async def get_session(
124125
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
125126

126127
# Get session resource
127-
get_session_api_response = self.api_client.request(
128+
api_client = _get_api_client(self.project, self.location)
129+
get_session_api_response = await api_client.async_request(
128130
http_method='GET',
129131
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
130132
request_dict={},
@@ -142,7 +144,7 @@ async def get_session(
142144
last_update_time=update_timestamp,
143145
)
144146

145-
list_events_api_response = self.api_client.request(
147+
list_events_api_response = await api_client.async_request(
146148
http_method='GET',
147149
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
148150
request_dict={},
@@ -181,7 +183,8 @@ async def list_sessions(
181183
) -> ListSessionsResponse:
182184
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
183185

184-
api_response = self.api_client.request(
186+
api_client = _get_api_client(self.project, self.location)
187+
api_response = await api_client.async_request(
185188
http_method='GET',
186189
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
187190
request_dict={},
@@ -207,7 +210,8 @@ async def delete_session(
207210
self, *, app_name: str, user_id: str, session_id: str
208211
) -> None:
209212
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
210-
self.api_client.request(
213+
api_client = _get_api_client(self.project, self.location)
214+
await api_client.async_request(
211215
http_method='DELETE',
212216
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
213217
request_dict={},
@@ -219,15 +223,25 @@ async def append_event(self, session: Session, event: Event) -> Event:
219223
await super().append_event(session=session, event=event)
220224

221225
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
222-
self.api_client.request(
226+
api_client = _get_api_client(self.project, self.location)
227+
await api_client.async_request(
223228
http_method='POST',
224229
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
225230
request_dict=_convert_event_to_json(event),
226231
)
227-
228232
return event
229233

230234

235+
def _get_api_client(project: str, location: str):
236+
"""Instantiates an API client for the given project and location.
237+
238+
It needs to be instantiated inside each request so that the event loop
239+
management.
240+
"""
241+
client = genai.Client(vertexai=True, project=project, location=location)
242+
return client._api_client
243+
244+
231245
def _convert_event_to_json(event: Event):
232246
metadata_json = {
233247
'partial': event.partial,

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import this
1717
from typing import Any
18+
from unittest import mock
1819

1920
from dateutil.parser import isoparse
2021
from google.adk.events import Event
@@ -123,7 +124,9 @@ def __init__(self) -> None:
123124
this.session_dict: dict[str, Any] = {}
124125
this.event_dict: dict[str, list[Any]] = {}
125126

126-
def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
127+
async def async_request(
128+
self, http_method: str, path: str, request_dict: dict[str, Any]
129+
):
127130
"""Mocks the API Client request method."""
128131
if http_method == 'GET':
129132
if re.match(SESSION_REGEX, path):
@@ -194,22 +197,31 @@ def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
194197

195198
def mock_vertex_ai_session_service():
196199
"""Creates a mock Vertex AI Session service for testing."""
197-
service = VertexAiSessionService(
200+
return VertexAiSessionService(
198201
project='test-project', location='test-location'
199202
)
200-
service.api_client = MockApiClient()
201-
service.api_client.session_dict = {
203+
204+
205+
@pytest.fixture
206+
def mock_get_api_client():
207+
api_client = MockApiClient()
208+
api_client.session_dict = {
202209
'1': MOCK_SESSION_JSON_1,
203210
'2': MOCK_SESSION_JSON_2,
204211
'3': MOCK_SESSION_JSON_3,
205212
}
206-
service.api_client.event_dict = {
213+
api_client.event_dict = {
207214
'1': MOCK_EVENT_JSON,
208215
}
209-
return service
216+
with mock.patch(
217+
"google.adk.sessions.vertex_ai_session_service._get_api_client",
218+
return_value=api_client,
219+
):
220+
yield
210221

211222

212223
@pytest.mark.asyncio
224+
@pytest.mark.usefixtures('mock_get_api_client')
213225
async def test_get_empty_session():
214226
session_service = mock_vertex_ai_session_service()
215227
with pytest.raises(ValueError) as excinfo:
@@ -220,6 +232,7 @@ async def test_get_empty_session():
220232

221233

222234
@pytest.mark.asyncio
235+
@pytest.mark.usefixtures('mock_get_api_client')
223236
async def test_get_and_delete_session():
224237
session_service = mock_vertex_ai_session_service()
225238

@@ -241,6 +254,7 @@ async def test_get_and_delete_session():
241254

242255

243256
@pytest.mark.asyncio
257+
@pytest.mark.usefixtures('mock_get_api_client')
244258
async def test_list_sessions():
245259
session_service = mock_vertex_ai_session_service()
246260
sessions = await session_service.list_sessions(app_name='123', user_id='user')
@@ -250,6 +264,7 @@ async def test_list_sessions():
250264

251265

252266
@pytest.mark.asyncio
267+
@pytest.mark.usefixtures('mock_get_api_client')
253268
async def test_create_session():
254269
session_service = mock_vertex_ai_session_service()
255270

@@ -269,6 +284,7 @@ async def test_create_session():
269284

270285

271286
@pytest.mark.asyncio
287+
@pytest.mark.usefixtures('mock_get_api_client')
272288
async def test_create_session_with_custom_session_id():
273289
session_service = mock_vertex_ai_session_service()
274290

0 commit comments

Comments
 (0)