Skip to content

Commit 392c2f3

Browse files
committed
wip fix gil issue on server.rs
1 parent db9fc48 commit 392c2f3

File tree

1 file changed

+55
-43
lines changed

1 file changed

+55
-43
lines changed

src/server.rs

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::types::HttpMethod;
1616
use crate::types::MiddlewareReturn;
1717
use crate::websockets::start_web_socket;
1818

19+
use core::task;
1920
use std::sync::atomic::AtomicBool;
2021
use std::sync::atomic::Ordering::{Relaxed, SeqCst};
2122
use std::sync::{Arc, RwLock};
@@ -29,7 +30,7 @@ use actix_web::*;
2930

3031
// pyO3 module
3132
use log::{debug, error};
32-
use pyo3::exceptions::PyValueError;
33+
use pyo3::exceptions::{asyncio, PyValueError};
3334
use pyo3::prelude::*;
3435
use pyo3::pycell::PyRef;
3536

@@ -78,12 +79,7 @@ impl Server {
7879
}
7980
}
8081

81-
pub fn start(
82-
&mut self,
83-
py: Python,
84-
socket: PyRef<SocketHeld>,
85-
workers: usize,
86-
) -> PyResult<()> {
82+
pub fn start(&mut self, py: Python, socket: PyRef<SocketHeld>, workers: usize) -> PyResult<()> {
8783
pyo3_log::init();
8884

8985
if STARTED
@@ -113,8 +109,7 @@ impl Server {
113109

114110
let excluded_response_headers_paths = self.excluded_response_headers_paths.clone();
115111

116-
let task_locals = pyo3_async_runtimes::TaskLocals::new(event_loop).copy_context(py)?;
117-
let task_locals_copy = task_locals.clone_ref(py);
112+
let task_locals_new = pyo3_async_runtimes::TaskLocals::new(event_loop).copy_context(py)?;
118113

119114
let max_payload_size = env::var(MAX_PAYLOAD_SIZE)
120115
.unwrap_or(DEFAULT_MAX_PAYLOAD_SIZE.to_string())
@@ -129,14 +124,15 @@ impl Server {
129124
thread::spawn(move || {
130125
actix_web::rt::System::new().block_on(async move {
131126
debug!("The number of workers is {}", workers);
132-
execute_startup_handler(startup_handler, &task_locals_copy)
127+
execute_startup_handler(startup_handler, &task_locals_new)
133128
.await
134129
.unwrap();
135130

136131
HttpServer::new(move || {
137132
let mut app = App::new();
138133

139-
let task_locals = task_locals_copy.clone_ref(py);
134+
// let task_locals = pyo3_async_runtimes::tokio::get_current_locals(py).expect("Failed to get Python task locals");
135+
// let task_locals: pyo3_async_runtimes::TaskLocals = asyncio.call_method0("get_running_loop");
140136
let directories = directories.read().unwrap();
141137

142138
// this loop matches three types of directory serving
@@ -170,25 +166,32 @@ impl Server {
170166
.app_data(web::Data::new(global_response_headers.clone()))
171167
.app_data(web::Data::new(excluded_response_headers_paths.clone()));
172168

173-
let web_socket_map = web_socket_router.get_web_socket_map();
174-
for (elem, value) in (web_socket_map.read()).iter() {
175-
let endpoint = elem.clone();
176-
let path_params = value.clone();
177-
let task_locals = task_locals.clone_ref(py);
178-
app = app.route(
179-
&endpoint.clone(),
180-
web::get().to(move |stream: web::Payload, req: HttpRequest| {
181-
let endpoint_copy = endpoint.clone();
182-
start_web_socket(
183-
req,
184-
stream,
185-
path_params.clone(),
186-
task_locals.clone_ref(py),
187-
endpoint_copy,
188-
)
189-
}),
190-
);
191-
}
169+
Python::with_gil(|py| {
170+
let task_locals = pyo3_async_runtimes::tokio::get_current_locals(py)
171+
.expect("Failed to get Python task locals");
172+
173+
let web_socket_map = web_socket_router.get_web_socket_map();
174+
for (elem, value) in (web_socket_map.read()).iter() {
175+
let endpoint = elem.clone();
176+
let path_params = value.clone();
177+
let task_locals = task_locals.clone_ref(py);
178+
app = app.route(
179+
&endpoint.clone(),
180+
web::get().to(move |stream: web::Payload, req: HttpRequest| {
181+
let endpoint_copy = endpoint.clone();
182+
start_web_socket(
183+
req,
184+
stream,
185+
path_params.clone(),
186+
task_locals.clone_ref(py),
187+
endpoint_copy,
188+
)
189+
}),
190+
);
191+
}
192+
193+
Ok(())
194+
});
192195

193196
debug!("Max payload size is {}", max_payload_size);
194197

@@ -202,19 +205,25 @@ impl Server {
202205
global_response_headers,
203206
response_headers_exclude_paths,
204207
req| {
205-
pyo3_async_runtimes::tokio::scope_local(task_locals.clone_ref(py), async move {
206-
index(
207-
router,
208-
payload,
209-
const_router,
210-
middleware_router,
211-
global_request_headers,
212-
global_response_headers,
213-
response_headers_exclude_paths,
214-
req,
215-
)
216-
.await
217-
})
208+
Python::with_gil(|py| {
209+
pyo3_async_runtimes::tokio::scope_local(
210+
task_locals.clone_ref(py),
211+
async move {
212+
index(
213+
router,
214+
payload,
215+
const_router,
216+
middleware_router,
217+
global_request_headers,
218+
global_response_headers,
219+
response_headers_exclude_paths,
220+
req,
221+
)
222+
.await
223+
},
224+
);
225+
Ok(())
226+
});
218227
},
219228
))
220229
})
@@ -250,6 +259,9 @@ impl Server {
250259
if function.is_async {
251260
debug!("Shutdown event handler async");
252261

262+
let task_locals = pyo3_async_runtimes::tokio::get_current_locals(py)
263+
.expect("Failed to get Python task locals");
264+
253265
pyo3_async_runtimes::tokio::run_until_complete(
254266
task_locals.event_loop(py),
255267
pyo3_async_runtimes::into_future_with_locals(

0 commit comments

Comments
 (0)