1
1
import asyncio
2
- from asyncio import Task
3
2
from contextvars import ContextVar
4
3
from typing import Dict , Optional , Union
5
4
@@ -22,6 +21,9 @@ def create_middleware_and_session_proxy():
22
21
_Session : Optional [async_sessionmaker ] = None
23
22
_session : ContextVar [Optional [AsyncSession ]] = ContextVar ("_session" , default = None )
24
23
_multi_sessions_ctx : ContextVar [bool ] = ContextVar ("_multi_sessions_context" , default = False )
24
+ _task_session_ctx : ContextVar [Optional [AsyncSession ]] = ContextVar (
25
+ "_task_session_ctx" , default = None
26
+ )
25
27
_commit_on_exit_ctx : ContextVar [bool ] = ContextVar ("_commit_on_exit_ctx" , default = False )
26
28
# Usage of context vars inside closures is not recommended, since they are not properly
27
29
# garbage collected, but in our use case context var is created on program startup and
@@ -90,28 +92,26 @@ async def execute_query(query):
90
92
```
91
93
"""
92
94
commit_on_exit = _commit_on_exit_ctx .get ()
93
- task : Task = asyncio .current_task () # type: ignore
94
- if not hasattr (task , "_db_session" ):
95
- task ._db_session = _Session () # type: ignore
96
-
97
- def cleanup (future ):
98
- session = getattr (task , "_db_session" , None )
99
- if session :
100
-
101
- async def do_cleanup ():
102
- try :
103
- if future .exception ():
104
- await session .rollback ()
105
- else :
106
- if commit_on_exit :
107
- await session .commit ()
108
- finally :
109
- await session .close ()
110
-
111
- asyncio .create_task (do_cleanup ())
112
-
113
- task .add_done_callback (cleanup )
114
- return task ._db_session # type: ignore
95
+ session = _task_session_ctx .get ()
96
+ if session is None :
97
+ session = _Session ()
98
+ _task_session_ctx .set (session )
99
+
100
+ async def cleanup ():
101
+ try :
102
+ if commit_on_exit :
103
+ await session .commit ()
104
+ except Exception :
105
+ await session .rollback ()
106
+ raise
107
+ finally :
108
+ await session .close ()
109
+ _task_session_ctx .set (None )
110
+
111
+ task = asyncio .current_task ()
112
+ if task is not None :
113
+ task .add_done_callback (lambda t : asyncio .create_task (cleanup ()))
114
+ return session
115
115
else :
116
116
session = _session .get ()
117
117
if session is None :
@@ -139,23 +139,24 @@ async def __aenter__(self):
139
139
if self .multi_sessions :
140
140
self .multi_sessions_token = _multi_sessions_ctx .set (True )
141
141
self .commit_on_exit_token = _commit_on_exit_ctx .set (self .commit_on_exit )
142
-
143
- self .token = _session .set (_Session (** self .session_args ))
142
+ else :
143
+ self .token = _session .set (_Session (** self .session_args ))
144
144
return type (self )
145
145
146
146
async def __aexit__ (self , exc_type , exc_value , traceback ):
147
- session = _session .get ()
148
- try :
149
- if exc_type is not None :
150
- await session .rollback ()
151
- elif self .commit_on_exit :
152
- await session .commit ()
153
- finally :
154
- await session .close ()
155
- _session .reset (self .token )
156
- if self .multi_sessions_token is not None :
157
- _multi_sessions_ctx .reset (self .multi_sessions_token )
158
- _commit_on_exit_ctx .reset (self .commit_on_exit_token )
147
+ if self .multi_sessions :
148
+ _multi_sessions_ctx .reset (self .multi_sessions_token )
149
+ _commit_on_exit_ctx .reset (self .commit_on_exit_token )
150
+ else :
151
+ session = _session .get ()
152
+ try :
153
+ if exc_type is not None :
154
+ await session .rollback ()
155
+ elif self .commit_on_exit :
156
+ await session .commit ()
157
+ finally :
158
+ await session .close ()
159
+ _session .reset (self .token )
159
160
160
161
return SQLAlchemyMiddleware , DBSession
161
162
0 commit comments