Skip to content

feat: add pagination and state options to list_sessions #1825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/google/adk/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import logging

from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsConfig
from .base_session_service import ListSessionsResponse
from .in_memory_session_service import InMemorySessionService
from .session import Session
from .state import State
Expand All @@ -24,7 +27,10 @@

__all__ = [
'BaseSessionService',
'GetSessionConfig',
'InMemorySessionService',
'ListSessionsConfig',
'ListSessionsResponse',
'Session',
'State',
'VertexAiSessionService',
Expand Down
14 changes: 12 additions & 2 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,20 @@ class GetSessionConfig(BaseModel):
after_timestamp: Optional[float] = None


class ListSessionsConfig(BaseModel):
"""The configuration of listing sessions."""

max_sessions: Optional[int] = None
"""Maximum number of sessions to return. If not specified, all sessions are returned."""

include_state: bool = False
"""Whether to include the state data in the response. Default is False for performance."""


class ListSessionsResponse(BaseModel):
"""The response of listing sessions.

The events and states are not set within each Session object.
The events and states are not set within each Session object unless explicitly requested.
"""

sessions: list[Session] = Field(default_factory=list)
Expand Down Expand Up @@ -81,7 +91,7 @@ async def get_session(

@abc.abstractmethod
async def list_sessions(
self, *, app_name: str, user_id: str
self, *, app_name: str, user_id: str, config: Optional[ListSessionsConfig] = None
) -> ListSessionsResponse:
"""Lists all the sessions."""

Expand Down
21 changes: 17 additions & 4 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from ..events.event import Event
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsConfig
from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
Expand Down Expand Up @@ -498,22 +499,34 @@ async def get_session(

@override
async def list_sessions(
self, *, app_name: str, user_id: str
self, *, app_name: str, user_id: str, config: Optional[ListSessionsConfig] = None
) -> ListSessionsResponse:
with self.database_session_factory() as session_factory:
results = (
query = (
session_factory.query(StorageSession)
.filter(StorageSession.app_name == app_name)
.filter(StorageSession.user_id == user_id)
.all()
.order_by(StorageSession.update_time.desc())
)

# Apply pagination if specified
if config and config.max_sessions:
query = query.limit(config.max_sessions)

results = query.all()

sessions = []
for storage_session in results:
# Determine whether to include state
session_state = {}
if config and config.include_state:
session_state = storage_session.state

session = Session(
app_name=app_name,
user_id=user_id,
id=storage_session.id,
state={},
state=session_state,
last_update_time=storage_session.update_timestamp_tz,
)
sessions.append(session)
Expand Down
30 changes: 23 additions & 7 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..events.event import Event
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsConfig
from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
Expand Down Expand Up @@ -201,18 +202,18 @@ def _merge_state(

@override
async def list_sessions(
self, *, app_name: str, user_id: str
self, *, app_name: str, user_id: str, config: Optional[ListSessionsConfig] = None
) -> ListSessionsResponse:
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
return self._list_sessions_impl(app_name=app_name, user_id=user_id, config=config)

def list_sessions_sync(
self, *, app_name: str, user_id: str
self, *, app_name: str, user_id: str, config: Optional[ListSessionsConfig] = None
) -> ListSessionsResponse:
logger.warning('Deprecated. Please migrate to the async method.')
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
return self._list_sessions_impl(app_name=app_name, user_id=user_id, config=config)

def _list_sessions_impl(
self, *, app_name: str, user_id: str
self, *, app_name: str, user_id: str, config: Optional[ListSessionsConfig] = None
) -> ListSessionsResponse:
empty_response = ListSessionsResponse()
if app_name not in self.sessions:
Expand All @@ -221,10 +222,25 @@ def _list_sessions_impl(
return empty_response

sessions_without_events = []
for session in self.sessions[app_name][user_id].values():
all_sessions = list(self.sessions[app_name][user_id].values())

# Sort by last_update_time in descending order to get most recent first
all_sessions.sort(key=lambda s: s.last_update_time, reverse=True)

# Apply pagination if specified
if config and config.max_sessions:
all_sessions = all_sessions[:config.max_sessions]

for session in all_sessions:
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session.state = {}

# Determine whether to include state
if config and config.include_state:
copied_session.state = session.state
else:
copied_session.state = {}

sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events)

Expand Down
24 changes: 21 additions & 3 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ..events.event_actions import EventActions
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsConfig
from .base_session_service import ListSessionsResponse
from .session import Session

Expand Down Expand Up @@ -263,7 +264,7 @@ async def get_session(

@override
async def list_sessions(
self, *, app_name: str, user_id: str
self, *, app_name: str, user_id: str, config: Optional[ListSessionsConfig] = None
) -> ListSessionsResponse:
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
api_client = self._get_api_client()
Expand All @@ -284,13 +285,30 @@ async def list_sessions(
if not api_response or api_response.get('httpHeaders', None):
return ListSessionsResponse()

api_sessions = api_response['sessions']

# Sort by updateTime in descending order to get most recent first
api_sessions.sort(key=lambda s: s.get('updateTime', ''), reverse=True)

# Apply pagination if specified
if config and config.max_sessions:
api_sessions = api_sessions[:config.max_sessions]

sessions = []
for api_session in api_response['sessions']:
for api_session in api_sessions:
session_state = {}
if config and config.include_state:
# For Vertex AI, we'd need to fetch individual session details to get state
# This is a performance trade-off - we can either make individual calls
# or keep the current behavior. For now, we'll keep it empty but document it.
# In a real implementation, you might want to batch these calls or cache them.
pass

session = Session(
app_name=app_name,
user_id=user_id,
id=api_session['name'].split('/')[-1],
state={},
state=session_state,
last_update_time=isoparse(api_session['updateTime']).timestamp(),
)
sessions.append(session)
Expand Down