Skip to content

Commit c6e1e82

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: Minor improvement to session service
- Add missing override. - Add warning to failed actions. - Remove unused import. - Remove unused fields. - Add type checking. PiperOrigin-RevId: 767209697
1 parent 54ed031 commit c6e1e82

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
import copy
1617
import logging
@@ -223,6 +224,7 @@ def _list_sessions_impl(
223224
sessions_without_events.append(copied_session)
224225
return ListSessionsResponse(sessions=sessions_without_events)
225226

227+
@override
226228
async def delete_session(
227229
self, *, app_name: str, user_id: str, session_id: str
228230
) -> None:
@@ -247,7 +249,7 @@ def _delete_session_impl(
247249
)
248250
is None
249251
):
250-
return None
252+
return
251253

252254
self.sessions[app_name][user_id].pop(session_id)
253255

@@ -261,11 +263,20 @@ async def append_event(self, session: Session, event: Event) -> Event:
261263
app_name = session.app_name
262264
user_id = session.user_id
263265
session_id = session.id
266+
267+
def _warning(message: str) -> None:
268+
logger.warning(
269+
f'Failed to append event to session {session_id}: {message}'
270+
)
271+
264272
if app_name not in self.sessions:
273+
_warning(f'app_name {app_name} not in sessions')
265274
return event
266275
if user_id not in self.sessions[app_name]:
276+
_warning(f'user_id {user_id} not in sessions[app_name]')
267277
return event
268278
if session_id not in self.sessions[app_name][user_id]:
279+
_warning(f'session_id {session_id} not in sessions[app_name][user_id]')
269280
return event
270281

271282
if event.actions and event.actions.state_delta:

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import asyncio
1717
import logging
1818
import re
19-
import time
2019
from typing import Any
20+
from typing import Dict
2121
from typing import Optional
2222
import urllib.parse
2323

@@ -50,9 +50,6 @@ def __init__(
5050
self.project = project
5151
self.location = location
5252

53-
client = genai.Client(vertexai=True, project=project, location=location)
54-
self.api_client = client._api_client
55-
5653
@override
5754
async def create_session(
5855
self,
@@ -86,6 +83,7 @@ async def create_session(
8683
operation_id = api_response['name'].split('/')[-1]
8784

8885
max_retry_attempt = 5
86+
lro_response = None
8987
while max_retry_attempt >= 0:
9088
lro_response = await api_client.async_request(
9189
http_method='GET',
@@ -99,6 +97,11 @@ async def create_session(
9997
await asyncio.sleep(1)
10098
max_retry_attempt -= 1
10199

100+
if lro_response is None or not lro_response.get('done', None):
101+
raise TimeoutError(
102+
f'Timeout waiting for operation {operation_id} to complete.'
103+
)
104+
102105
# Get session resource
103106
get_session_api_response = await api_client.async_request(
104107
http_method='GET',
@@ -235,11 +238,15 @@ async def delete_session(
235238
) -> None:
236239
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
237240
api_client = _get_api_client(self.project, self.location)
238-
await api_client.async_request(
239-
http_method='DELETE',
240-
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
241-
request_dict={},
242-
)
241+
try:
242+
await api_client.async_request(
243+
http_method='DELETE',
244+
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
245+
request_dict={},
246+
)
247+
except Exception as e:
248+
logger.error(f'Error deleting session {session_id}: {e}')
249+
raise e
243250

244251
@override
245252
async def append_event(self, session: Session, event: Event) -> Event:
@@ -266,7 +273,7 @@ def _get_api_client(project: str, location: str):
266273
return client._api_client
267274

268275

269-
def _convert_event_to_json(event: Event):
276+
def _convert_event_to_json(event: Event) -> Dict[str, Any]:
270277
metadata_json = {
271278
'partial': event.partial,
272279
'turn_complete': event.turn_complete,
@@ -318,7 +325,7 @@ def _convert_event_to_json(event: Event):
318325
return event_json
319326

320327

321-
def _from_api_event(api_event: dict) -> Event:
328+
def _from_api_event(api_event: Dict[str, Any]) -> Event:
322329
event_actions = EventActions()
323330
if api_event.get('actions', None):
324331
event_actions = EventActions(

0 commit comments

Comments
 (0)