@@ -16,6 +16,7 @@ use crate::types::HttpMethod;
16
16
use crate :: types:: MiddlewareReturn ;
17
17
use crate :: websockets:: start_web_socket;
18
18
19
+ use core:: task;
19
20
use std:: sync:: atomic:: AtomicBool ;
20
21
use std:: sync:: atomic:: Ordering :: { Relaxed , SeqCst } ;
21
22
use std:: sync:: { Arc , RwLock } ;
@@ -29,7 +30,7 @@ use actix_web::*;
29
30
30
31
// pyO3 module
31
32
use log:: { debug, error} ;
32
- use pyo3:: exceptions:: PyValueError ;
33
+ use pyo3:: exceptions:: { asyncio , PyValueError } ;
33
34
use pyo3:: prelude:: * ;
34
35
use pyo3:: pycell:: PyRef ;
35
36
@@ -78,12 +79,7 @@ impl Server {
78
79
}
79
80
}
80
81
81
- pub fn start (
82
- & mut self ,
83
- py : Python ,
84
- socket : PyRef < SocketHeld > ,
85
- workers : usize ,
86
- ) -> PyResult < ( ) > {
82
+ pub fn start ( & mut self , py : Python , socket : PyRef < SocketHeld > , workers : usize ) -> PyResult < ( ) > {
87
83
pyo3_log:: init ( ) ;
88
84
89
85
if STARTED
@@ -113,8 +109,7 @@ impl Server {
113
109
114
110
let excluded_response_headers_paths = self . excluded_response_headers_paths . clone ( ) ;
115
111
116
- let task_locals = pyo3_async_runtimes:: TaskLocals :: new ( event_loop) . copy_context ( py) ?;
117
- let task_locals_copy = task_locals. clone_ref ( py) ;
112
+ let task_locals_new = pyo3_async_runtimes:: TaskLocals :: new ( event_loop) . copy_context ( py) ?;
118
113
119
114
let max_payload_size = env:: var ( MAX_PAYLOAD_SIZE )
120
115
. unwrap_or ( DEFAULT_MAX_PAYLOAD_SIZE . to_string ( ) )
@@ -129,14 +124,15 @@ impl Server {
129
124
thread:: spawn ( move || {
130
125
actix_web:: rt:: System :: new ( ) . block_on ( async move {
131
126
debug ! ( "The number of workers is {}" , workers) ;
132
- execute_startup_handler ( startup_handler, & task_locals_copy )
127
+ execute_startup_handler ( startup_handler, & task_locals_new )
133
128
. await
134
129
. unwrap ( ) ;
135
130
136
131
HttpServer :: new ( move || {
137
132
let mut app = App :: new ( ) ;
138
133
139
- let task_locals = task_locals_copy. clone_ref ( py) ;
134
+ // let task_locals = pyo3_async_runtimes::tokio::get_current_locals(py).expect("Failed to get Python task locals");
135
+ // let task_locals: pyo3_async_runtimes::TaskLocals = asyncio.call_method0("get_running_loop");
140
136
let directories = directories. read ( ) . unwrap ( ) ;
141
137
142
138
// this loop matches three types of directory serving
@@ -170,25 +166,32 @@ impl Server {
170
166
. app_data ( web:: Data :: new ( global_response_headers. clone ( ) ) )
171
167
. app_data ( web:: Data :: new ( excluded_response_headers_paths. clone ( ) ) ) ;
172
168
173
- let web_socket_map = web_socket_router. get_web_socket_map ( ) ;
174
- for ( elem, value) in ( web_socket_map. read ( ) ) . iter ( ) {
175
- let endpoint = elem. clone ( ) ;
176
- let path_params = value. clone ( ) ;
177
- let task_locals = task_locals. clone_ref ( py) ;
178
- app = app. route (
179
- & endpoint. clone ( ) ,
180
- web:: get ( ) . to ( move |stream : web:: Payload , req : HttpRequest | {
181
- let endpoint_copy = endpoint. clone ( ) ;
182
- start_web_socket (
183
- req,
184
- stream,
185
- path_params. clone ( ) ,
186
- task_locals. clone_ref ( py) ,
187
- endpoint_copy,
188
- )
189
- } ) ,
190
- ) ;
191
- }
169
+ Python :: with_gil ( |py| {
170
+ let task_locals = pyo3_async_runtimes:: tokio:: get_current_locals ( py)
171
+ . expect ( "Failed to get Python task locals" ) ;
172
+
173
+ let web_socket_map = web_socket_router. get_web_socket_map ( ) ;
174
+ for ( elem, value) in ( web_socket_map. read ( ) ) . iter ( ) {
175
+ let endpoint = elem. clone ( ) ;
176
+ let path_params = value. clone ( ) ;
177
+ let task_locals = task_locals. clone_ref ( py) ;
178
+ app = app. route (
179
+ & endpoint. clone ( ) ,
180
+ web:: get ( ) . to ( move |stream : web:: Payload , req : HttpRequest | {
181
+ let endpoint_copy = endpoint. clone ( ) ;
182
+ start_web_socket (
183
+ req,
184
+ stream,
185
+ path_params. clone ( ) ,
186
+ task_locals. clone_ref ( py) ,
187
+ endpoint_copy,
188
+ )
189
+ } ) ,
190
+ ) ;
191
+ }
192
+
193
+ Ok ( ( ) )
194
+ } ) ;
192
195
193
196
debug ! ( "Max payload size is {}" , max_payload_size) ;
194
197
@@ -202,19 +205,25 @@ impl Server {
202
205
global_response_headers,
203
206
response_headers_exclude_paths,
204
207
req| {
205
- pyo3_async_runtimes:: tokio:: scope_local ( task_locals. clone_ref ( py) , async move {
206
- index (
207
- router,
208
- payload,
209
- const_router,
210
- middleware_router,
211
- global_request_headers,
212
- global_response_headers,
213
- response_headers_exclude_paths,
214
- req,
215
- )
216
- . await
217
- } )
208
+ Python :: with_gil ( |py| {
209
+ pyo3_async_runtimes:: tokio:: scope_local (
210
+ task_locals. clone_ref ( py) ,
211
+ async move {
212
+ index (
213
+ router,
214
+ payload,
215
+ const_router,
216
+ middleware_router,
217
+ global_request_headers,
218
+ global_response_headers,
219
+ response_headers_exclude_paths,
220
+ req,
221
+ )
222
+ . await
223
+ } ,
224
+ ) ;
225
+ Ok ( ( ) )
226
+ } ) ;
218
227
} ,
219
228
) )
220
229
} )
@@ -250,6 +259,9 @@ impl Server {
250
259
if function. is_async {
251
260
debug ! ( "Shutdown event handler async" ) ;
252
261
262
+ let task_locals = pyo3_async_runtimes:: tokio:: get_current_locals ( py)
263
+ . expect ( "Failed to get Python task locals" ) ;
264
+
253
265
pyo3_async_runtimes:: tokio:: run_until_complete (
254
266
task_locals. event_loop ( py) ,
255
267
pyo3_async_runtimes:: into_future_with_locals (
0 commit comments