Skip to content

Commit 2aed548

Browse files
committed
Fix scoped async destruction of partially polled futures
1 parent 6a77b5f commit 2aed548

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

src/lua.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1787,9 +1787,11 @@ impl Lua {
17871787
})?,
17881788
)?;
17891789

1790+
// We set `poll` variable in the env table to be able to destroy upvalues
17901791
self.load(
17911792
r#"
1792-
local poll = get_poll(...)
1793+
poll = get_poll(...)
1794+
local poll = poll
17931795
while true do
17941796
ready, res = poll()
17951797
if ready then

src/scope.rs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti, Value};
2323
#[cfg(feature = "async")]
2424
use {
2525
crate::types::AsyncCallback,
26-
futures_core::future::Future,
26+
futures_core::future::{Future, LocalBoxFuture},
2727
futures_util::future::{self, TryFutureExt},
28-
std::os::raw::c_char,
2928
};
3029

3130
/// Constructed by the [`Lua::scope`] method, allows temporarily creating Lua userdata and
@@ -420,12 +419,11 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
420419
#[cfg(any(feature = "lua51", feature = "luajit"))]
421420
ffi::lua_getfenv(state, -1);
422421

423-
// Then, get the get_poll() closure using the corresponding key
424-
let key = "get_poll";
425-
ffi::lua_pushlstring(state, key.as_ptr() as *const c_char, key.len());
422+
// Second, get the `get_poll()` closure using the corresponding key
423+
ffi::lua_pushstring(state, cstr!("get_poll"));
426424
ffi::lua_rawget(state, -2);
427425

428-
// Finally, destroy all upvalues
426+
// Destroy all upvalues
429427
ffi::lua_getupvalue(state, -1, 1);
430428
let ud1 = take_userdata::<AsyncCallback>(state);
431429
ffi::lua_pushnil(state);
@@ -437,8 +435,25 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
437435
ffi::lua_setupvalue(state, -2, 2);
438436

439437
ffi::lua_pop(state, 1);
438+
let mut data: Vec<Box<dyn Any>> = vec![Box::new(ud1), Box::new(ud2)];
439+
440+
// Finally, get polled future and destroy it
441+
ffi::lua_pushstring(state, cstr!("poll"));
442+
if ffi::lua_rawget(state, -2) == ffi::LUA_TFUNCTION {
443+
ffi::lua_getupvalue(state, -1, 1);
444+
let ud3 = take_userdata::<LocalBoxFuture<Result<MultiValue>>>(state);
445+
ffi::lua_pushnil(state);
446+
ffi::lua_setupvalue(state, -2, 1);
447+
data.push(Box::new(ud3));
448+
449+
ffi::lua_getupvalue(state, -1, 2);
450+
let ud4 = take_userdata::<Lua>(state);
451+
ffi::lua_pushnil(state);
452+
ffi::lua_setupvalue(state, -2, 2);
453+
data.push(Box::new(ud4));
454+
}
440455

441-
vec![Box::new(ud1), Box::new(ud2)]
456+
data
442457
}));
443458

444459
Ok(f)

tests/async.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ use std::time::Duration;
2222
use futures_timer::Delay;
2323
use futures_util::stream::TryStreamExt;
2424

25-
use mlua::{Error, Function, Lua, Result, Table, TableExt, UserData, UserDataMethods};
25+
use mlua::{
26+
Error, Function, Lua, Result, Table, TableExt, Thread, UserData, UserDataMethods, Value,
27+
};
2628

2729
#[tokio::test]
2830
async fn test_async_function() -> Result<()> {
@@ -332,11 +334,18 @@ async fn test_async_scope() -> Result<()> {
332334
let _ = f.call_async::<u64, ()>(10).await?;
333335
assert_eq!(Rc::strong_count(rc), 1);
334336

337+
// Create future in partialy polled state (Poll::Pending)
338+
let g = lua.create_thread(f)?;
339+
g.resume::<u64, ()>(10)?;
340+
lua.globals().set("g", g)?;
341+
assert_eq!(Rc::strong_count(rc), 2);
342+
335343
Ok(())
336344
});
337345

338346
assert_eq!(Rc::strong_count(rc), 1);
339347
let _ = fut.await?;
348+
assert_eq!(Rc::strong_count(rc), 1);
340349

341350
match lua
342351
.globals()
@@ -351,6 +360,14 @@ async fn test_async_scope() -> Result<()> {
351360
r => panic!("improper return for destructed function: {:?}", r),
352361
};
353362

363+
match lua.globals().get::<_, Thread>("g")?.resume::<_, Value>(()) {
364+
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
365+
Error::CallbackDestructed => {}
366+
e => panic!("expected `CallbackDestructed` error cause, got {:?}", e),
367+
},
368+
r => panic!("improper return for destructed function: {:?}", r),
369+
};
370+
354371
Ok(())
355372
}
356373

0 commit comments

Comments
 (0)