Skip to content

Commit 7681ae5

Browse files
committed
fixes task_locals clone issue on server.rs ;___;
1 parent db9fc48 commit 7681ae5

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/server.rs

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ use actix_web::*;
2929

3030
// pyO3 module
3131
use log::{debug, error};
32-
use pyo3::exceptions::PyValueError;
32+
use once_cell::sync::OnceCell;
33+
use pyo3::exceptions::{asyncio, PyValueError};
3334
use pyo3::prelude::*;
3435
use pyo3::pycell::PyRef;
36+
use pyo3_async_runtimes::TaskLocals;
3537

3638
const MAX_PAYLOAD_SIZE: &str = "ROBYN_MAX_PAYLOAD_SIZE";
3739
const DEFAULT_MAX_PAYLOAD_SIZE: usize = 1_000_000; // 1Mb
@@ -80,12 +82,14 @@ impl Server {
8082

8183
pub fn start(
8284
&mut self,
83-
py: Python,
85+
_py: Python,
8486
socket: PyRef<SocketHeld>,
8587
workers: usize,
8688
) -> PyResult<()> {
8789
pyo3_log::init();
8890

91+
static TASK_LOCALS: OnceCell<TaskLocals> = OnceCell::new();
92+
8993
if STARTED
9094
.compare_exchange(false, true, SeqCst, Relaxed)
9195
.is_err()
@@ -104,17 +108,18 @@ impl Server {
104108
let global_response_headers = self.global_response_headers.clone();
105109
let directories = self.directories.clone();
106110

107-
let asyncio = py.import("asyncio")?;
111+
let asyncio = _py.import("asyncio")?;
108112
let event_loop = asyncio.call_method0("new_event_loop")?;
109-
asyncio.call_method1("set_event_loop", (event_loop,))?;
113+
asyncio.call_method1("set_event_loop", (event_loop.clone(),))?;
110114

111115
let startup_handler = self.startup_handler.clone();
112116
let shutdown_handler = self.shutdown_handler.clone();
113117

114118
let excluded_response_headers_paths = self.excluded_response_headers_paths.clone();
115119

116-
let task_locals = pyo3_async_runtimes::TaskLocals::new(event_loop).copy_context(py)?;
117-
let task_locals_copy = task_locals.clone_ref(py);
120+
let _ = TASK_LOCALS.get_or_try_init(|| {
121+
Python::with_gil(|py| pyo3_async_runtimes::TaskLocals::new(event_loop.clone().into()).copy_context(py))
122+
});
118123

119124
let max_payload_size = env::var(MAX_PAYLOAD_SIZE)
120125
.unwrap_or(DEFAULT_MAX_PAYLOAD_SIZE.to_string())
@@ -129,14 +134,15 @@ impl Server {
129134
thread::spawn(move || {
130135
actix_web::rt::System::new().block_on(async move {
131136
debug!("The number of workers is {}", workers);
132-
execute_startup_handler(startup_handler, &task_locals_copy)
137+
138+
let task_locals = Python::with_gil(|py| TASK_LOCALS.get().unwrap().clone_ref(py));
139+
execute_startup_handler(startup_handler, &task_locals)
133140
.await
134141
.unwrap();
135142

136143
HttpServer::new(move || {
137144
let mut app = App::new();
138145

139-
let task_locals = task_locals_copy.clone_ref(py);
140146
let directories = directories.read().unwrap();
141147

142148
// this loop matches three types of directory serving
@@ -174,16 +180,16 @@ impl Server {
174180
for (elem, value) in (web_socket_map.read()).iter() {
175181
let endpoint = elem.clone();
176182
let path_params = value.clone();
177-
let task_locals = task_locals.clone_ref(py);
178183
app = app.route(
179184
&endpoint.clone(),
180185
web::get().to(move |stream: web::Payload, req: HttpRequest| {
181186
let endpoint_copy = endpoint.clone();
187+
let task_locals = Python::with_gil(|py| TASK_LOCALS.get().unwrap().clone_ref(py));
182188
start_web_socket(
183189
req,
184190
stream,
185191
path_params.clone(),
186-
task_locals.clone_ref(py),
192+
task_locals,
187193
endpoint_copy,
188194
)
189195
}),
@@ -202,7 +208,8 @@ impl Server {
202208
global_response_headers,
203209
response_headers_exclude_paths,
204210
req| {
205-
pyo3_async_runtimes::tokio::scope_local(task_locals.clone_ref(py), async move {
211+
let task_locals = Python::with_gil(|py| TASK_LOCALS.get().unwrap().clone_ref(py));
212+
pyo3_async_runtimes::tokio::scope_local(task_locals, async move {
206213
index(
207214
router,
208215
payload,
@@ -250,11 +257,13 @@ impl Server {
250257
if function.is_async {
251258
debug!("Shutdown event handler async");
252259

260+
let task_locals = Python::with_gil(|py| TASK_LOCALS.get().unwrap().clone_ref(py));
261+
253262
pyo3_async_runtimes::tokio::run_until_complete(
254-
task_locals.event_loop(py),
263+
task_locals.event_loop(_py),
255264
pyo3_async_runtimes::into_future_with_locals(
256-
&task_locals.clone_ref(py),
257-
function.handler.bind(py).call0()?,
265+
&task_locals.clone_ref(_py),
266+
function.handler.bind(_py).call0()?,
258267
)
259268
.unwrap(),
260269
)

0 commit comments

Comments
 (0)