Skip to content

Commit 0399ca2

Browse files
sindrigunnarszshehov
authored andcommitted
fix: scenario where a user can access another users events given the same session id
test: refine test_session_state in test_session_state to catch event leakage fix: revert app name to my_app for test_session_state test style: fix pyink style warnings fix: add app_name to filter as per suggestion from rpedela-recurly
1 parent 60ca2e6 commit 0399ca2

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,9 @@ async def get_session(
459459

460460
storage_events = (
461461
session_factory.query(StorageEvent)
462+
.filter(StorageEvent.app_name == app_name)
462463
.filter(StorageEvent.session_id == storage_session.id)
464+
.filter(StorageEvent.user_id == user_id)
463465
.filter(timestamp_filter)
464466
.order_by(StorageEvent.timestamp.desc())
465467
.limit(

tests/unittests/sessions/test_session_service.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ async def test_session_state(service_type):
126126
app_name = 'my_app'
127127
user_id_1 = 'user1'
128128
user_id_2 = 'user2'
129+
user_id_malicious = 'malicious'
129130
session_id_11 = 'session11'
130131
session_id_12 = 'session12'
131132
session_id_2 = 'session2'
@@ -148,6 +149,10 @@ async def test_session_state(service_type):
148149
app_name=app_name, user_id=user_id_2, session_id=session_id_2
149150
)
150151

152+
await session_service.create_session(
153+
app_name=app_name, user_id=user_id_malicious, session_id=session_id_11
154+
)
155+
151156
assert session_11.state.get('key11') == 'value11'
152157

153158
event = Event(
@@ -196,6 +201,13 @@ async def test_session_state(service_type):
196201
assert session_11.state.get('user:key1') == 'value1'
197202
assert not session_11.state.get('temp:key')
198203

204+
# Make sure a malicious user can obtain a session and events not belonging to them
205+
session_mismatch = await session_service.get_session(
206+
app_name=app_name, user_id=user_id_malicious, session_id=session_id_11
207+
)
208+
209+
assert len(session_mismatch.events) == 0
210+
199211

200212
@pytest.mark.asyncio
201213
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)