Skip to content

Commit 1b0b5a5

Browse files
authored
Merge pull request #23 from h0rn3t/multi_sessions_wip
WIP: multi_sessions
2 parents 083197a + 0c74aaf commit 1b0b5a5

File tree

3 files changed

+64
-38
lines changed

3 files changed

+64
-38
lines changed

fastapi_async_sqlalchemy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
__all__ = ["db", "SQLAlchemyMiddleware"]
44

5-
__version__ = "0.7.0.dev1"
5+
__version__ = "0.7.0.dev2"

fastapi_async_sqlalchemy/middleware.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
from asyncio import Task
32
from contextvars import ContextVar
43
from typing import Dict, Optional, Union
54

@@ -22,6 +21,9 @@ def create_middleware_and_session_proxy():
2221
_Session: Optional[async_sessionmaker] = None
2322
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
2423
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
24+
_task_session_ctx: ContextVar[Optional[AsyncSession]] = ContextVar(
25+
"_task_session_ctx", default=None
26+
)
2527
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
2628
# Usage of context vars inside closures is not recommended, since they are not properly
2729
# garbage collected, but in our use case context var is created on program startup and
@@ -90,28 +92,26 @@ async def execute_query(query):
9092
```
9193
"""
9294
commit_on_exit = _commit_on_exit_ctx.get()
93-
task: Task = asyncio.current_task() # type: ignore
94-
if not hasattr(task, "_db_session"):
95-
task._db_session = _Session() # type: ignore
96-
97-
def cleanup(future):
98-
session = getattr(task, "_db_session", None)
99-
if session:
100-
101-
async def do_cleanup():
102-
try:
103-
if future.exception():
104-
await session.rollback()
105-
else:
106-
if commit_on_exit:
107-
await session.commit()
108-
finally:
109-
await session.close()
110-
111-
asyncio.create_task(do_cleanup())
112-
113-
task.add_done_callback(cleanup)
114-
return task._db_session # type: ignore
95+
session = _task_session_ctx.get()
96+
if session is None:
97+
session = _Session()
98+
_task_session_ctx.set(session)
99+
100+
async def cleanup():
101+
try:
102+
if commit_on_exit:
103+
await session.commit()
104+
except Exception:
105+
await session.rollback()
106+
raise
107+
finally:
108+
await session.close()
109+
_task_session_ctx.set(None)
110+
111+
task = asyncio.current_task()
112+
if task is not None:
113+
task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
114+
return session
115115
else:
116116
session = _session.get()
117117
if session is None:
@@ -139,23 +139,24 @@ async def __aenter__(self):
139139
if self.multi_sessions:
140140
self.multi_sessions_token = _multi_sessions_ctx.set(True)
141141
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)
142-
143-
self.token = _session.set(_Session(**self.session_args))
142+
else:
143+
self.token = _session.set(_Session(**self.session_args))
144144
return type(self)
145145

146146
async def __aexit__(self, exc_type, exc_value, traceback):
147-
session = _session.get()
148-
try:
149-
if exc_type is not None:
150-
await session.rollback()
151-
elif self.commit_on_exit:
152-
await session.commit()
153-
finally:
154-
await session.close()
155-
_session.reset(self.token)
156-
if self.multi_sessions_token is not None:
157-
_multi_sessions_ctx.reset(self.multi_sessions_token)
158-
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
147+
if self.multi_sessions:
148+
_multi_sessions_ctx.reset(self.multi_sessions_token)
149+
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
150+
else:
151+
session = _session.get()
152+
try:
153+
if exc_type is not None:
154+
await session.rollback()
155+
elif self.commit_on_exit:
156+
await session.commit()
157+
finally:
158+
await session.close()
159+
_session.reset(self.token)
159160

160161
return SQLAlchemyMiddleware, DBSession
161162

tests/test_session.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,28 @@ async def execute_query(query):
173173

174174
res = await asyncio.gather(*tasks)
175175
assert len(res) == 6
176+
177+
178+
@pytest.mark.asyncio
179+
async def test_concurrent_inserts(app, db, SQLAlchemyMiddleware):
180+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
181+
182+
async with db(multi_sessions=True, commit_on_exit=True):
183+
await db.session.execute(
184+
text("CREATE TABLE IF NOT EXISTS my_model (id INTEGER PRIMARY KEY, value TEXT)")
185+
)
186+
187+
async def insert_data(value):
188+
await db.session.execute(
189+
text("INSERT INTO my_model (value) VALUES (:value)"), {"value": value}
190+
)
191+
await db.session.flush()
192+
193+
tasks = [asyncio.create_task(insert_data(f"value_{i}")) for i in range(10)]
194+
195+
result_ids = await asyncio.gather(*tasks)
196+
assert len(result_ids) == 10
197+
198+
records = await db.session.execute(text("SELECT * FROM my_model"))
199+
records = records.scalars().all()
200+
assert len(records) == 10

0 commit comments

Comments
 (0)