@@ -29,9 +29,11 @@ use actix_web::*;
29
29
30
30
// pyO3 module
31
31
use log:: { debug, error} ;
32
- use pyo3:: exceptions:: PyValueError ;
32
+ use once_cell:: sync:: OnceCell ;
33
+ use pyo3:: exceptions:: { asyncio, PyValueError } ;
33
34
use pyo3:: prelude:: * ;
34
35
use pyo3:: pycell:: PyRef ;
36
+ use pyo3_async_runtimes:: TaskLocals ;
35
37
36
38
const MAX_PAYLOAD_SIZE : & str = "ROBYN_MAX_PAYLOAD_SIZE" ;
37
39
const DEFAULT_MAX_PAYLOAD_SIZE : usize = 1_000_000 ; // 1Mb
@@ -80,12 +82,14 @@ impl Server {
80
82
81
83
pub fn start (
82
84
& mut self ,
83
- py : Python ,
85
+ _py : Python ,
84
86
socket : PyRef < SocketHeld > ,
85
87
workers : usize ,
86
88
) -> PyResult < ( ) > {
87
89
pyo3_log:: init ( ) ;
88
90
91
+ static TASK_LOCALS : OnceCell < TaskLocals > = OnceCell :: new ( ) ;
92
+
89
93
if STARTED
90
94
. compare_exchange ( false , true , SeqCst , Relaxed )
91
95
. is_err ( )
@@ -104,17 +108,18 @@ impl Server {
104
108
let global_response_headers = self . global_response_headers . clone ( ) ;
105
109
let directories = self . directories . clone ( ) ;
106
110
107
- let asyncio = py . import ( "asyncio" ) ?;
111
+ let asyncio = _py . import ( "asyncio" ) ?;
108
112
let event_loop = asyncio. call_method0 ( "new_event_loop" ) ?;
109
- asyncio. call_method1 ( "set_event_loop" , ( event_loop, ) ) ?;
113
+ asyncio. call_method1 ( "set_event_loop" , ( event_loop. clone ( ) , ) ) ?;
110
114
111
115
let startup_handler = self . startup_handler . clone ( ) ;
112
116
let shutdown_handler = self . shutdown_handler . clone ( ) ;
113
117
114
118
let excluded_response_headers_paths = self . excluded_response_headers_paths . clone ( ) ;
115
119
116
- let task_locals = pyo3_async_runtimes:: TaskLocals :: new ( event_loop) . copy_context ( py) ?;
117
- let task_locals_copy = task_locals. clone_ref ( py) ;
120
+ let _ = TASK_LOCALS . get_or_try_init ( || {
121
+ Python :: with_gil ( |py| pyo3_async_runtimes:: TaskLocals :: new ( event_loop. clone ( ) . into ( ) ) . copy_context ( py) )
122
+ } ) ;
118
123
119
124
let max_payload_size = env:: var ( MAX_PAYLOAD_SIZE )
120
125
. unwrap_or ( DEFAULT_MAX_PAYLOAD_SIZE . to_string ( ) )
@@ -129,14 +134,15 @@ impl Server {
129
134
thread:: spawn ( move || {
130
135
actix_web:: rt:: System :: new ( ) . block_on ( async move {
131
136
debug ! ( "The number of workers is {}" , workers) ;
132
- execute_startup_handler ( startup_handler, & task_locals_copy)
137
+
138
+ let task_locals = Python :: with_gil ( |py| TASK_LOCALS . get ( ) . unwrap ( ) . clone_ref ( py) ) ;
139
+ execute_startup_handler ( startup_handler, & task_locals)
133
140
. await
134
141
. unwrap ( ) ;
135
142
136
143
HttpServer :: new ( move || {
137
144
let mut app = App :: new ( ) ;
138
145
139
- let task_locals = task_locals_copy. clone_ref ( py) ;
140
146
let directories = directories. read ( ) . unwrap ( ) ;
141
147
142
148
// this loop matches three types of directory serving
@@ -174,16 +180,16 @@ impl Server {
174
180
for ( elem, value) in ( web_socket_map. read ( ) ) . iter ( ) {
175
181
let endpoint = elem. clone ( ) ;
176
182
let path_params = value. clone ( ) ;
177
- let task_locals = task_locals. clone_ref ( py) ;
178
183
app = app. route (
179
184
& endpoint. clone ( ) ,
180
185
web:: get ( ) . to ( move |stream : web:: Payload , req : HttpRequest | {
181
186
let endpoint_copy = endpoint. clone ( ) ;
187
+ let task_locals = Python :: with_gil ( |py| TASK_LOCALS . get ( ) . unwrap ( ) . clone_ref ( py) ) ;
182
188
start_web_socket (
183
189
req,
184
190
stream,
185
191
path_params. clone ( ) ,
186
- task_locals. clone_ref ( py ) ,
192
+ task_locals,
187
193
endpoint_copy,
188
194
)
189
195
} ) ,
@@ -202,7 +208,8 @@ impl Server {
202
208
global_response_headers,
203
209
response_headers_exclude_paths,
204
210
req| {
205
- pyo3_async_runtimes:: tokio:: scope_local ( task_locals. clone_ref ( py) , async move {
211
+ let task_locals = Python :: with_gil ( |py| TASK_LOCALS . get ( ) . unwrap ( ) . clone_ref ( py) ) ;
212
+ pyo3_async_runtimes:: tokio:: scope_local ( task_locals, async move {
206
213
index (
207
214
router,
208
215
payload,
@@ -250,11 +257,13 @@ impl Server {
250
257
if function. is_async {
251
258
debug ! ( "Shutdown event handler async" ) ;
252
259
260
+ let task_locals = Python :: with_gil ( |py| TASK_LOCALS . get ( ) . unwrap ( ) . clone_ref ( py) ) ;
261
+
253
262
pyo3_async_runtimes:: tokio:: run_until_complete (
254
- task_locals. event_loop ( py ) ,
263
+ task_locals. event_loop ( _py ) ,
255
264
pyo3_async_runtimes:: into_future_with_locals (
256
- & task_locals. clone_ref ( py ) ,
257
- function. handler . bind ( py ) . call0 ( ) ?,
265
+ & task_locals. clone_ref ( _py ) ,
266
+ function. handler . bind ( _py ) . call0 ( ) ?,
258
267
)
259
268
. unwrap ( ) ,
260
269
)
0 commit comments