From 6cfa2267c42cf24f59570ddbad692842a9ff8852 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 04:22:09 -0400 Subject: [PATCH 01/46] mlua continuation yield --- mlua-sys/src/luau/lua.rs | 5 ++ src/error.rs | 2 +- src/state.rs | 56 +++++++++++++- src/state/extra.rs | 5 ++ src/state/raw.rs | 67 +++++++++++++++- src/state/util.rs | 161 +++++++++++++++++++++++++++++++++++++++ src/thread.rs | 19 +++++ src/types.rs | 6 ++ src/util/mod.rs | 1 + src/util/types.rs | 12 +++ tests/luau.rs | 3 + tests/luau/cont.rs | 59 ++++++++++++++ 12 files changed, 390 insertions(+), 6 deletions(-) create mode 100644 tests/luau/cont.rs diff --git a/mlua-sys/src/luau/lua.rs b/mlua-sys/src/luau/lua.rs index 8a55eef1..d898534c 100644 --- a/mlua-sys/src/luau/lua.rs +++ b/mlua-sys/src/luau/lua.rs @@ -426,6 +426,11 @@ pub unsafe fn lua_pushcclosure(L: *mut lua_State, f: lua_CFunction, nup: c_int) lua_pushcclosurek(L, f, ptr::null(), nup, None) } +#[inline(always)] +pub unsafe fn lua_pushcclosurec(L: *mut lua_State, f: lua_CFunction, cont: lua_Continuation, nup: c_int) { + lua_pushcclosurek(L, f, ptr::null(), nup, Some(cont)) +} + #[inline(always)] pub unsafe fn lua_pushcclosured(L: *mut lua_State, f: lua_CFunction, debugname: *const c_char, nup: c_int) { lua_pushcclosurek(L, f, debugname, nup, None) diff --git a/src/error.rs b/src/error.rs index 1f243967..483171e5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -321,7 +321,7 @@ impl fmt::Display for Error { Error::WithContext { context, cause } => { writeln!(fmt, "{context}")?; write!(fmt, "{cause}") - } + }, } } } diff --git a/src/state.rs b/src/state.rs index 4be9ba4c..a01f5332 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,5 +1,6 @@ use std::any::TypeId; use std::cell::{BorrowError, BorrowMutError, RefCell}; +use std::convert::Infallible; use std::marker::PhantomData; use std::ops::Deref; use std::os::raw::{c_char, c_int}; @@ -17,7 +18,7 @@ use crate::scope::Scope; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; -use crate::thread::Thread; +use crate::thread::{Thread, LuauContinuationStatus}; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex, @@ -1288,6 +1289,32 @@ impl Lua { })) } + /// Same as ``create_function`` but with support for Luau continuations + /// + /// Note that yieldable luau continuations are not currently supported at this time + #[cfg(feature = "luau")] + pub fn create_function_with_luau_continuation(&self, func: F, cont: FC) -> Result + where + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + FC: Fn(&Lua, LuauContinuationStatus, AC) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + AC: FromLuaMulti, + R: IntoLuaMulti, + RC: IntoLuaMulti, + { + (self.lock()).create_callback_with_luau_continuation( + Box::new(move |rawlua, nargs| unsafe { + let args = A::from_stack_args(nargs, 1, None, rawlua)?; + func(rawlua.lua(), args)?.push_into_stack_multi(rawlua) + }), + Box::new(move |rawlua, nargs, status| unsafe { + let args = AC::from_stack_args(nargs, 1, None, rawlua)?; + let status = LuauContinuationStatus::from_status(status); + cont(rawlua.lua(), status, args)?.push_into_stack_multi(rawlua) + }), + ) + } + /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. /// /// This is a version of [`Lua::create_function`] that accepts a `FnMut` argument. @@ -2103,6 +2130,33 @@ impl Lua { pub(crate) unsafe fn raw_lua(&self) -> &RawLua { &*self.raw.data_ptr() } + + /// Yields arguments + /// + /// If this function cannot yield, it will raise a runtime error. + /// + /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield + /// or not until it reaches the Lua state. + /// + /// Unsafe and should only be used in a function with a luau continuation for now + pub unsafe fn yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { + let raw = self.lock(); + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] + if !raw.is_yieldable() { + return Err(Error::runtime("cannot yield across Rust/Lua boundary.")) + } + unsafe { + raw.extra.get().as_mut().unwrap_unchecked().yielded_values = args.into_lua_multi(self)?; + } + Ok(()) + } + + /// Checks if Lua is currently allowed to yield. + #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[inline] + pub(crate) fn is_yieldable(&self) -> bool { + self.lock().is_yieldable() + } } impl WeakLua { diff --git a/src/state/extra.rs b/src/state/extra.rs index 12133639..445c232c 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -18,6 +18,7 @@ use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, Wrapp #[cfg(any(feature = "luau", doc))] use crate::chunk::Compiler; +use crate::MultiValue; #[cfg(feature = "async")] use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; @@ -93,6 +94,9 @@ pub(crate) struct ExtraData { pub(super) compiler: Option, #[cfg(feature = "luau-jit")] pub(super) enable_jit: bool, + + // Values currently being yielded from Lua.yield() + pub(super) yielded_values: MultiValue, } impl Drop for ExtraData { @@ -194,6 +198,7 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, + yielded_values: MultiValue::with_capacity(0), })); // Store it in the registry diff --git a/src/state/raw.rs b/src/state/raw.rs index cc4ce1cf..341520ec 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,15 +12,15 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, ref_stack_pop}; +use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; use crate::traits::IntoLua; use crate::types::{ - AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, - MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, + AppDataRef, AppDataRefMut, Callback, LuauContinuation, CallbackUpvalue, LuauContinuationUpvalue, DestructedUserdata, + Integer, LightUserData, MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; use crate::userdata::{ init_userdata_metatable, AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, @@ -198,6 +198,8 @@ impl RawLua { init_internal_metatable::>>(state, None)?; init_internal_metatable::(state, None)?; init_internal_metatable::(state, None)?; + #[cfg(feature = "luau")] + init_internal_metatable::(state, None)?; #[cfg(not(feature = "luau"))] init_internal_metatable::(state, None)?; #[cfg(feature = "async")] @@ -1156,10 +1158,11 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); + #[cfg(feature = "luau")] match (*upvalue).data { Some(ref func) => func(rawlua, nargs), None => Err(Error::CallbackDestructed), @@ -1188,6 +1191,56 @@ impl RawLua { } } + // Creates a Function out of a Callback and a continuation containing a 'static Fn. + #[cfg(feature = "luau")] + pub(crate) fn create_callback_with_luau_continuation(&self, func: Callback, cont: LuauContinuation) -> Result { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.0)(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }) + } + + unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }) + } + + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, LuauContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + })?; + } else { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + } + + Ok(Function(self.pop_ref())) + } + } + #[cfg(feature = "async")] pub(crate) fn create_async_callback(&self, func: AsyncCallback) -> Result { // Ensure that the coroutine library is loaded @@ -1348,6 +1401,12 @@ impl RawLua { pub(crate) unsafe fn set_waker(&self, waker: NonNull) -> NonNull { mem::replace(&mut (*self.extra.get()).waker, waker) } + + #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[inline] + pub(crate) fn is_yieldable(&self) -> bool { + unsafe { ffi::lua_isyieldable(self.state()) != 0 } + } } // Uses 3 stack spaces diff --git a/src/state/util.rs b/src/state/util.rs index c3c79302..97cb2f60 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,7 +1,9 @@ +use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; use std::sync::Arc; +use crate::IntoLuaMulti; use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; @@ -38,7 +40,140 @@ where } let nargs = ffi::lua_gettop(state); + + enum PreallocatedFailure { + New(*mut WrappedFailure), + Reserved, + } + + impl PreallocatedFailure { + unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { + if (*extra).wrapped_failure_top > 0 { + (*extra).wrapped_failure_top -= 1; + return PreallocatedFailure::Reserved; + } + + // We need to check stack for Luau in case when callback is called from interrupt + // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + // Place it to the beginning of the stack + let ud = WrappedFailure::new_userdata(state); + ffi::lua_insert(state, 1); + PreallocatedFailure::New(ud) + } + + #[cold] + unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { + let ref_thread = (*extra).ref_thread; + match *self { + PreallocatedFailure::New(ud) => { + ffi::lua_settop(state, 1); + ud + } + PreallocatedFailure::Reserved => { + let index = (*extra).wrapped_failure_pool.pop().unwrap(); + ffi::lua_settop(state, 0); + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + ffi::lua_xpush(ref_thread, state, index); + ffi::lua_pushnil(ref_thread); + ffi::lua_replace(ref_thread, index); + (*extra).ref_free.push(index); + ffi::lua_touserdata(state, -1) as *mut WrappedFailure + } + } + } + unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { + let ref_thread = (*extra).ref_thread; + match self { + PreallocatedFailure::New(_) => { + ffi::lua_rotate(state, 1, -1); + ffi::lua_xmove(state, ref_thread, 1); + let index = ref_stack_pop(extra); + (*extra).wrapped_failure_pool.push(index); + (*extra).wrapped_failure_top += 1; + } + PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, + } + } + } + + // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory + // to store a wrapped failure (error or panic) *before* we proceed. + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + + match catch_unwind(AssertUnwindSafe(|| { + let rawlua = (*extra).raw_lua(); + let _guard = StateGuard::new(rawlua, state); + f(extra, nargs) + })) { + Ok(Ok(r)) => { + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + + // Return unused `WrappedFailure` to the pool + prealloc_failure.release(state, extra); + r + } + Ok(Err(err)) => { + let wrapped_error = prealloc_failure.r#use(state, extra); + + if !wrap_error { + ptr::write(wrapped_error, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) + } + + // Build `CallbackError` with traceback + let traceback = if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { + ffi::luaL_traceback(state, state, ptr::null(), 0); + let traceback = util::to_string(state, -1); + ffi::lua_pop(state, 1); + traceback + } else { + "".to_string() + }; + let cause = Arc::new(err); + ptr::write( + wrapped_error, + WrappedFailure::Error(Error::CallbackError { traceback, cause }), + ); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + + ffi::lua_error(state) + } + Err(p) => { + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Panic(Some(p))); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) + } + } +} + +// An yieldable version of `callback_error_ext` +pub(crate) unsafe fn callback_error_ext_yieldable( + state: *mut ffi::lua_State, + mut extra: *mut ExtraData, + wrap_error: bool, + f: F, +) -> c_int +where + F: FnOnce(*mut ExtraData, c_int) -> Result, +{ + println!("callback_error_ext_yieldable"); + + if extra.is_null() { + extra = ExtraData::get(state); + } + + let nargs = ffi::lua_gettop(state); + enum PreallocatedFailure { New(*mut WrappedFailure), Reserved, @@ -108,6 +243,32 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + + if !values.is_empty() { + println!("YIELD {:?}", values); + match values.push_into_stack_multi(raw) { + Ok(nargs) => { + println!("YIELDARGS {}", nargs); + return ffi::lua_yield(state, nargs); + }, + Err(err) => { + let wrapped_error = prealloc_failure.r#use(state, extra); + ptr::write( + wrapped_error, + WrappedFailure::Error(Error::external(err.to_string())), + ); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + + ffi::lua_error(state) + + } + } + } + + // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r diff --git a/src/thread.rs b/src/thread.rs index da7962b0..0657b3e7 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -26,6 +26,25 @@ use { }, }; +/// Luau continuation final status +#[cfg(feature = "luau")] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum LuauContinuationStatus { + Ok, + Yielded, + Error, +} + +impl LuauContinuationStatus { + pub(crate) fn from_status(status: c_int) -> Self { + match status { + ffi::LUA_YIELD => Self::Yielded, + ffi::LUA_OK => Self::Ok, + _ => Self::Error, + } + } +} + /// Status of a Lua thread (coroutine). #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum ThreadStatus { diff --git a/src/types.rs b/src/types.rs index 2589ea6e..c1efed0e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -40,6 +40,11 @@ pub(crate) type Callback = Box Result + Send + #[cfg(not(feature = "send"))] pub(crate) type Callback = Box Result + 'static>; +#[cfg(all(feature = "send", feature = "luau"))] +pub(crate) type LuauContinuation = Box Result + Send + 'static>; +#[cfg(all(not(feature = "send"), feature = "luau"))] +pub(crate) type LuauContinuation = Box Result + 'static>; + pub(crate) type ScopedCallback<'s> = Box Result + 's>; pub(crate) struct Upvalue { @@ -48,6 +53,7 @@ pub(crate) struct Upvalue { } pub(crate) type CallbackUpvalue = Upvalue>; +pub(crate) type LuauContinuationUpvalue = Upvalue>; #[cfg(all(feature = "async", feature = "send"))] pub(crate) type AsyncCallback = diff --git a/src/util/mod.rs b/src/util/mod.rs index f5fbae52..5dd6b19c 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -75,6 +75,7 @@ impl Drop for StackGuard { unsafe { let top = ffi::lua_gettop(self.state); if top < self.top { + println!("top={}, self.top={}", top, self.top); mlua_panic!("{} too many stack values popped", self.top - top) } if top > self.top { diff --git a/src/util/types.rs b/src/util/types.rs index 8bc9d8b2..16945009 100644 --- a/src/util/types.rs +++ b/src/util/types.rs @@ -3,6 +3,9 @@ use std::os::raw::c_void; use crate::types::{Callback, CallbackUpvalue}; +#[cfg(feature = "luau")] +use crate::types::LuauContinuationUpvalue; + #[cfg(feature = "async")] use crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}; @@ -34,6 +37,15 @@ impl TypeKey for CallbackUpvalue { } } +#[cfg(feature = "luau")] +impl TypeKey for LuauContinuationUpvalue { + #[inline(always)] + fn type_key() -> *const c_void { + static LUAU_CONTINUATION_UPVALUE_TYPE_KEY: u8 = 0; + &LUAU_CONTINUATION_UPVALUE_TYPE_KEY as *const u8 as *const c_void + } +} + #[cfg(not(feature = "luau"))] impl TypeKey for crate::types::HookCallback { #[inline(always)] diff --git a/tests/luau.rs b/tests/luau.rs index 20fbee6a..7b62bae3 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -449,3 +449,6 @@ fn test_typeof_error() -> Result<()> { #[path = "luau/require.rs"] mod require; + +#[path = "luau/cont.rs"] +mod cont; diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs new file mode 100644 index 00000000..9cb7e89f --- /dev/null +++ b/tests/luau/cont.rs @@ -0,0 +1,59 @@ +#[cfg(feature = "luau")] +use mlua::Lua; + +#[test] +fn test_luau_continuation() { + let lua = Lua::new(); + + let cont_func = lua.create_function_with_luau_continuation( + |lua, a: u64| Ok(a + 1), + |lua, status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + } + ).expect("Failed to create cont_func"); + + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func).expect("Failed to call cont_func"), + 2 + ); + + // does not work yet + /*let always_yield = lua.create_function(|lua, ()| { + unsafe { lua.yield_args((42, "69420")) } + }).unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!(thread.resume::<(i32, String)>(()).unwrap(), (42, String::from("69420")));*/ + + // Trigger the continuation + let cont_func = lua.create_function_with_luau_continuation( + |lua, a: u64| { + unsafe { + match lua.yield_args(a) { + Ok(()) => println!("yield_args called"), + Err(e) => println!("{:?}", e) + } + } + Ok(()) + }, + |lua, status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + } + ).expect("Failed to create cont_func"); + + let luau_func = lua.load(" + local cont_func = ... + local res = cont_func(1) + return res + ").into_function().expect("Failed to create function"); + let th = lua.create_thread(luau_func).expect("Failed to create luau thread"); + + let v = th.resume::(cont_func).expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 40); +} \ No newline at end of file From e01e45fe21ae8d7a3d80f53ef080b963036e73c1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 06:15:25 -0400 Subject: [PATCH 02/46] add missing xmove --- src/state/util.rs | 1 + src/thread.rs | 1 + tests/luau/cont.rs | 15 +++++++++------ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index 97cb2f60..957c069f 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -251,6 +251,7 @@ where match values.push_into_stack_multi(raw) { Ok(nargs) => { println!("YIELDARGS {}", nargs); + ffi::lua_xmove(raw.state(), state, nargs); return ffi::lua_yield(state, nargs); }, Err(err) => { diff --git a/src/thread.rs b/src/thread.rs index 0657b3e7..4ad040af 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -234,6 +234,7 @@ impl Thread { let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); #[cfg(feature = "luau")] let ret = ffi::lua_resumex(thread_state, state, nargs, &mut nresults as *mut c_int); + match ret { ffi::LUA_OK => Ok((ThreadStatusInner::Finished, nresults)), ffi::LUA_YIELD => Ok((ThreadStatusInner::Yielded(0), nresults)), diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 9cb7e89f..87c1ca78 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -21,12 +21,15 @@ fn test_luau_continuation() { ); // does not work yet - /*let always_yield = lua.create_function(|lua, ()| { - unsafe { lua.yield_args((42, "69420")) } - }).unwrap(); + let always_yield = lua.create_function( + |lua, ()| { + unsafe { lua.yield_args((42, "69420".to_string()))? } + Ok(()) + }) + .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!(thread.resume::<(i32, String)>(()).unwrap(), (42, String::from("69420")));*/ + assert_eq!(thread.resume::<(i32, String)>(()).unwrap(), (42, String::from("69420"))); // Trigger the continuation let cont_func = lua.create_function_with_luau_continuation( @@ -48,12 +51,12 @@ fn test_luau_continuation() { let luau_func = lua.load(" local cont_func = ... local res = cont_func(1) - return res + return res + 1 ").into_function().expect("Failed to create function"); let th = lua.create_thread(luau_func).expect("Failed to create luau thread"); let v = th.resume::(cont_func).expect("Failed to resume"); let v = th.resume::(v).expect("Failed to load continuation"); - assert_eq!(v, 40); + assert_eq!(v, 41); } \ No newline at end of file From bd6b65b20da638c7a52270b3d75be186f44b2210 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 06:18:11 -0400 Subject: [PATCH 03/46] remove comment --- tests/luau/cont.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 87c1ca78..04122035 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -20,7 +20,6 @@ fn test_luau_continuation() { 2 ); - // does not work yet let always_yield = lua.create_function( |lua, ()| { unsafe { lua.yield_args((42, "69420".to_string()))? } From c6f1ff250a99f0754c2cf0bdf1801f8066f80729 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 06:19:58 -0400 Subject: [PATCH 04/46] add a f32 to test stack --- tests/luau/cont.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 04122035..7e6fc6ee 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -22,13 +22,13 @@ fn test_luau_continuation() { let always_yield = lua.create_function( |lua, ()| { - unsafe { lua.yield_args((42, "69420".to_string()))? } + unsafe { lua.yield_args((42, "69420".to_string(), 45.6))? } Ok(()) }) .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!(thread.resume::<(i32, String)>(()).unwrap(), (42, String::from("69420"))); + assert_eq!(thread.resume::<(i32, String, f32)>(()).unwrap(), (42, String::from("69420"), 45.6)); // Trigger the continuation let cont_func = lua.create_function_with_luau_continuation( From 875d6e8ca6d4310bbb208d95dc161c0038f1ffb2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 06:57:48 -0400 Subject: [PATCH 05/46] cleanup --- src/state.rs | 19 +++++++++++++------ src/state/raw.rs | 25 +++++++++++++++++++++++-- src/state/util.rs | 13 +++++++------ src/thread.rs | 1 + src/types.rs | 1 + tests/luau/cont.rs | 25 ++++++++++++++++++++++--- 6 files changed, 67 insertions(+), 17 deletions(-) diff --git a/src/state.rs b/src/state.rs index a01f5332..8113dd92 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,5 @@ use std::any::TypeId; use std::cell::{BorrowError, BorrowMutError, RefCell}; -use std::convert::Infallible; use std::marker::PhantomData; use std::ops::Deref; use std::os::raw::{c_char, c_int}; @@ -18,7 +17,13 @@ use crate::scope::Scope; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; -use crate::thread::{Thread, LuauContinuationStatus}; +use crate::thread::Thread; + +#[cfg(feature = "luau")] +use crate::thread::LuauContinuationStatus; +#[cfg(feature = "luau")] +use std::convert::Infallible; + use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex, @@ -2131,15 +2136,17 @@ impl Lua { &*self.raw.data_ptr() } - /// Yields arguments + /// Sets the yields arguments. Note that Ok(()) must be returned for the Rust function + /// to actually yield. This method is mostly useful with Luau continuations /// /// If this function cannot yield, it will raise a runtime error. /// /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield /// or not until it reaches the Lua state. /// - /// Unsafe and should only be used in a function with a luau continuation for now - pub unsafe fn yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { + /// Potentially unsafe at this time. Use with caution + #[cfg(feature = "luau")] // todo: support non-luau set_yield_args, the groundwork is here + pub unsafe fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { let raw = self.lock(); #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] if !raw.is_yieldable() { @@ -2154,7 +2161,7 @@ impl Lua { /// Checks if Lua is currently allowed to yield. #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] #[inline] - pub(crate) fn is_yieldable(&self) -> bool { + pub fn is_yieldable(&self) -> bool { self.lock().is_yieldable() } } diff --git a/src/state/raw.rs b/src/state/raw.rs index 341520ec..5d2ba91c 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,16 +12,22 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; +use crate::state::util::{callback_error_ext, ref_stack_pop}; +#[cfg(feature = "luau")] +use crate::state::util::callback_error_ext_yieldable; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; use crate::traits::IntoLua; use crate::types::{ - AppDataRef, AppDataRefMut, Callback, LuauContinuation, CallbackUpvalue, LuauContinuationUpvalue, DestructedUserdata, + AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; + +#[cfg(feature = "luau")] +use crate::types::{LuauContinuation, LuauContinuationUpvalue}; + use crate::userdata::{ init_userdata_metatable, AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, UserDataStorage, @@ -1156,6 +1162,21 @@ impl RawLua { // Creates a Function out of a Callback containing a 'static Fn. pub(crate) fn create_callback(&self, func: Callback) -> Result { + #[cfg(not(feature = "luau"))] + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }) + } + + #[cfg(feature = "luau")] unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { diff --git a/src/state/util.rs b/src/state/util.rs index 957c069f..230a26f9 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -3,6 +3,8 @@ use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; use std::sync::Arc; + +#[cfg(feature = "luau")] use crate::IntoLuaMulti; use crate::error::{Error, Result}; @@ -110,7 +112,7 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + // Ensure yielded values are cleared take(&mut extra.as_mut().unwrap_unchecked().yielded_values); // Return unused `WrappedFailure` to the pool @@ -156,7 +158,10 @@ where } } -// An yieldable version of `callback_error_ext` +/// An yieldable version of `callback_error_ext` +/// +/// Outside of Luau, this does the same thing as ``callback_error_ext`` right now +#[cfg(feature = "luau")] pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, @@ -166,8 +171,6 @@ pub(crate) unsafe fn callback_error_ext_yieldable( where F: FnOnce(*mut ExtraData, c_int) -> Result, { - println!("callback_error_ext_yieldable"); - if extra.is_null() { extra = ExtraData::get(state); } @@ -247,10 +250,8 @@ where let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); if !values.is_empty() { - println!("YIELD {:?}", values); match values.push_into_stack_multi(raw) { Ok(nargs) => { - println!("YIELDARGS {}", nargs); ffi::lua_xmove(raw.state(), state, nargs); return ffi::lua_yield(state, nargs); }, diff --git a/src/thread.rs b/src/thread.rs index 4ad040af..6b834081 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -35,6 +35,7 @@ pub enum LuauContinuationStatus { Error, } +#[cfg(feature = "luau")] impl LuauContinuationStatus { pub(crate) fn from_status(status: c_int) -> Self { match status { diff --git a/src/types.rs b/src/types.rs index c1efed0e..3b6b0d8f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -53,6 +53,7 @@ pub(crate) struct Upvalue { } pub(crate) type CallbackUpvalue = Upvalue>; +#[cfg(feature = "luau")] pub(crate) type LuauContinuationUpvalue = Upvalue>; #[cfg(all(feature = "async", feature = "send"))] diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 7e6fc6ee..80154d03 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -22,7 +22,7 @@ fn test_luau_continuation() { let always_yield = lua.create_function( |lua, ()| { - unsafe { lua.yield_args((42, "69420".to_string(), 45.6))? } + unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } Ok(()) }) .unwrap(); @@ -34,8 +34,8 @@ fn test_luau_continuation() { let cont_func = lua.create_function_with_luau_continuation( |lua, a: u64| { unsafe { - match lua.yield_args(a) { - Ok(()) => println!("yield_args called"), + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), Err(e) => println!("{:?}", e) } } @@ -58,4 +58,23 @@ fn test_luau_continuation() { let v = th.resume::(v).expect("Failed to load continuation"); assert_eq!(v, 41); + + let always_yield = lua.create_function_with_luau_continuation( + |lua, ()| { + unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + Ok(()) + }, + |lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + } + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread.resume::(mv).unwrap_err().to_string().starts_with("a3")); } \ No newline at end of file From 2a67b268f8b4983c7e2664ba6aacc059a3509a90 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 07:04:18 -0400 Subject: [PATCH 06/46] deal with warnings --- src/state.rs | 2 -- tests/luau/cont.rs | 9 +++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/state.rs b/src/state.rs index 8113dd92..cca8a56c 100644 --- a/src/state.rs +++ b/src/state.rs @@ -21,8 +21,6 @@ use crate::thread::Thread; #[cfg(feature = "luau")] use crate::thread::LuauContinuationStatus; -#[cfg(feature = "luau")] -use std::convert::Infallible; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 80154d03..d7d021cc 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -6,8 +6,8 @@ fn test_luau_continuation() { let lua = Lua::new(); let cont_func = lua.create_function_with_luau_continuation( - |lua, a: u64| Ok(a + 1), - |lua, status, a: u64| { + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { println!("Reached cont"); Ok(a + 2) } @@ -20,6 +20,7 @@ fn test_luau_continuation() { 2 ); + // basic yield test before we go any further let always_yield = lua.create_function( |lua, ()| { unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } @@ -41,7 +42,7 @@ fn test_luau_continuation() { } Ok(()) }, - |lua, status, a: u64| { + |_lua, _status, a: u64| { println!("Reached cont"); Ok(a + 39) } @@ -64,7 +65,7 @@ fn test_luau_continuation() { unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } Ok(()) }, - |lua, _, mv: mlua::MultiValue| { + |_lua, _, mv: mlua::MultiValue| { println!("Reached second continuation"); if mv.is_empty() { return Ok(mv); From d23f058549da3f0681b60aa0c78798eee08f7ffa Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 07:10:24 -0400 Subject: [PATCH 07/46] update warning --- src/state.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/state.rs b/src/state.rs index cca8a56c..e6f7f8d6 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2142,7 +2142,10 @@ impl Lua { /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield /// or not until it reaches the Lua state. /// - /// Potentially unsafe at this time. Use with caution + /// Potentially unsafe at this time. Use with caution. + /// + /// This method only supports Luau for now as proper Rust yielding in other Lua variants is + /// more complicated. This limitation may be lifted in the future. #[cfg(feature = "luau")] // todo: support non-luau set_yield_args, the groundwork is here pub unsafe fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { let raw = self.lock(); From 62330e78e09d3d9dbc63042e0f57e5a5f8ea5e5e Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 10:34:08 -0400 Subject: [PATCH 08/46] allow yield on non-luau as well --- src/state.rs | 9 +++------ src/state/raw.rs | 24 ++++-------------------- src/state/util.rs | 5 ----- tests/thread.rs | 15 +++++++++++++++ 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/src/state.rs b/src/state.rs index e6f7f8d6..bfa73907 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2134,8 +2134,9 @@ impl Lua { &*self.raw.data_ptr() } - /// Sets the yields arguments. Note that Ok(()) must be returned for the Rust function - /// to actually yield. This method is mostly useful with Luau continuations + /// Sets the yields arguments. Note that ``Ok(())`` must be returned for the Rust function + /// to actually yield. This method is mostly useful with Luau continuations and Rust-Rust + /// yields /// /// If this function cannot yield, it will raise a runtime error. /// @@ -2143,10 +2144,6 @@ impl Lua { /// or not until it reaches the Lua state. /// /// Potentially unsafe at this time. Use with caution. - /// - /// This method only supports Luau for now as proper Rust yielding in other Lua variants is - /// more complicated. This limitation may be lifted in the future. - #[cfg(feature = "luau")] // todo: support non-luau set_yield_args, the groundwork is here pub unsafe fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { let raw = self.lock(); #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] diff --git a/src/state/raw.rs b/src/state/raw.rs index 5d2ba91c..a9216a95 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,9 +12,9 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, ref_stack_pop}; -#[cfg(feature = "luau")] -use crate::state::util::callback_error_ext_yieldable; +use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; +#[cfg(not(feature = "luau"))] +use crate::state::util::callback_error_ext; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; @@ -1162,28 +1162,12 @@ impl RawLua { // Creates a Function out of a Callback containing a 'static Fn. pub(crate) fn create_callback(&self, func: Callback) -> Result { - #[cfg(not(feature = "luau"))] - unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { - let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => func(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) - } - - #[cfg(feature = "luau")] unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); - #[cfg(feature = "luau")] match (*upvalue).data { Some(ref func) => func(rawlua, nargs), None => Err(Error::CallbackDestructed), @@ -1230,7 +1214,7 @@ impl RawLua { unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); diff --git a/src/state/util.rs b/src/state/util.rs index 230a26f9..702ee6f4 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -3,8 +3,6 @@ use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; use std::sync::Arc; - -#[cfg(feature = "luau")] use crate::IntoLuaMulti; use crate::error::{Error, Result}; @@ -159,9 +157,6 @@ where } /// An yieldable version of `callback_error_ext` -/// -/// Outside of Luau, this does the same thing as ``callback_error_ext`` right now -#[cfg(feature = "luau")] pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, diff --git a/tests/thread.rs b/tests/thread.rs index 4cb6ab10..de5aca2b 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -252,3 +252,18 @@ fn test_thread_resume_error() -> Result<()> { Ok(()) } + +#[test] +fn test_thread_yield_args() -> Result<()> { + let lua = Lua::new(); + let always_yield = lua.create_function( + |lua, ()| { + unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + Ok(()) + })?; + + let thread = lua.create_thread(always_yield)?; + assert_eq!(thread.resume::<(i32, String, f32)>(())?, (42, String::from("69420"), 45.6)); + + Ok(()) +} \ No newline at end of file From 1c7743c9142ce627ba0a53e9c2679ad719c8c50c Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 23:03:50 -0400 Subject: [PATCH 09/46] fix --- src/state.rs | 26 ++++--- src/state/raw.rs | 20 +++-- src/state/util.rs | 15 ++-- tests/luau/cont.rs | 180 +++++++++++++++++++++++++++++++++------------ 4 files changed, 169 insertions(+), 72 deletions(-) diff --git a/src/state.rs b/src/state.rs index bfa73907..7d0c6885 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1292,11 +1292,19 @@ impl Lua { })) } - /// Same as ``create_function`` but with support for Luau continuations + /// Same as ``create_function`` but with support for continuations. /// - /// Note that yieldable luau continuations are not currently supported at this time + /// Currently only luau-style continuations are supported at this time. + /// + /// The values passed to the continuation will either be the yielded arguments + /// from the function or the arguments from resuming the thread. Yielded values + /// from a continuation become the resume args. #[cfg(feature = "luau")] - pub fn create_function_with_luau_continuation(&self, func: F, cont: FC) -> Result + pub fn create_function_with_luau_continuation( + &self, + func: F, + cont: FC, + ) -> Result where F: Fn(&Lua, A) -> Result + MaybeSend + 'static, FC: Fn(&Lua, LuauContinuationStatus, AC) -> Result + MaybeSend + 'static, @@ -1316,7 +1324,7 @@ impl Lua { cont(rawlua.lua(), status, args)?.push_into_stack_multi(rawlua) }), ) - } + } /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. /// @@ -2137,10 +2145,10 @@ impl Lua { /// Sets the yields arguments. Note that ``Ok(())`` must be returned for the Rust function /// to actually yield. This method is mostly useful with Luau continuations and Rust-Rust /// yields - /// + /// /// If this function cannot yield, it will raise a runtime error. - /// - /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield + /// + /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield /// or not until it reaches the Lua state. /// /// Potentially unsafe at this time. Use with caution. @@ -2148,7 +2156,7 @@ impl Lua { let raw = self.lock(); #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] if !raw.is_yieldable() { - return Err(Error::runtime("cannot yield across Rust/Lua boundary.")) + return Err(Error::runtime("cannot yield across Rust/Lua boundary.")); } unsafe { raw.extra.get().as_mut().unwrap_unchecked().yielded_values = args.into_lua_multi(self)?; @@ -2157,7 +2165,7 @@ impl Lua { } /// Checks if Lua is currently allowed to yield. - #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] #[inline] pub fn is_yieldable(&self) -> bool { self.lock().is_yieldable() diff --git a/src/state/raw.rs b/src/state/raw.rs index a9216a95..a151c231 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,17 +12,17 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; #[cfg(not(feature = "luau"))] use crate::state::util::callback_error_ext; +use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; use crate::traits::IntoLua; use crate::types::{ - AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, - Integer, LightUserData, MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, + AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, + MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; #[cfg(feature = "luau")] @@ -1164,7 +1164,7 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1198,10 +1198,14 @@ impl RawLua { // Creates a Function out of a Callback and a continuation containing a 'static Fn. #[cfg(feature = "luau")] - pub(crate) fn create_callback_with_luau_continuation(&self, func: Callback, cont: LuauContinuation) -> Result { + pub(crate) fn create_callback_with_luau_continuation( + &self, + func: Callback, + cont: LuauContinuation, + ) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1214,7 +1218,7 @@ impl RawLua { unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, state, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1407,7 +1411,7 @@ impl RawLua { mem::replace(&mut (*self.extra.get()).waker, waker) } - #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] #[inline] pub(crate) fn is_yieldable(&self) -> bool { unsafe { ffi::lua_isyieldable(self.state()) != 0 } diff --git a/src/state/util.rs b/src/state/util.rs index 702ee6f4..922d36f3 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,9 +1,9 @@ +use crate::IntoLuaMulti; use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; use std::sync::Arc; -use crate::IntoLuaMulti; use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; @@ -40,7 +40,7 @@ where } let nargs = ffi::lua_gettop(state); - + enum PreallocatedFailure { New(*mut WrappedFailure), Reserved, @@ -164,14 +164,14 @@ pub(crate) unsafe fn callback_error_ext_yieldable( f: F, ) -> c_int where - F: FnOnce(*mut ExtraData, c_int) -> Result, + F: FnOnce(*mut ExtraData, *mut ffi::lua_State, c_int) -> Result, { if extra.is_null() { extra = ExtraData::get(state); } let nargs = ffi::lua_gettop(state); - + enum PreallocatedFailure { New(*mut WrappedFailure), Reserved, @@ -238,7 +238,7 @@ where match catch_unwind(AssertUnwindSafe(|| { let rawlua = (*extra).raw_lua(); let _guard = StateGuard::new(rawlua, state); - f(extra, nargs) + f(extra, state, nargs) })) { Ok(Ok(r)) => { let raw = extra.as_ref().unwrap_unchecked().raw_lua(); @@ -247,9 +247,10 @@ where if !values.is_empty() { match values.push_into_stack_multi(raw) { Ok(nargs) => { + ffi::lua_pop(state, -1); ffi::lua_xmove(raw.state(), state, nargs); return ffi::lua_yield(state, nargs); - }, + } Err(err) => { let wrapped_error = prealloc_failure.r#use(state, extra); ptr::write( @@ -260,12 +261,10 @@ where ffi::lua_setmetatable(state, -2); ffi::lua_error(state) - } } } - // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index d7d021cc..0d90b8f4 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -1,81 +1,167 @@ +use mlua::IntoLuaMulti; #[cfg(feature = "luau")] use mlua::Lua; #[test] fn test_luau_continuation() { + // Yielding continuation + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + let lua = Lua::new(); - let cont_func = lua.create_function_with_luau_continuation( - |_lua, a: u64| Ok(a + 1), - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 2) - } - ).expect("Failed to create cont_func"); + let cont_func = lua + .create_function_with_luau_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); // Ensure normal calls work still assert_eq!( lua.load("local cont_func = ...\nreturn cont_func(1)") - .call::(cont_func).expect("Failed to call cont_func"), + .call::(cont_func) + .expect("Failed to call cont_func"), 2 ); // basic yield test before we go any further - let always_yield = lua.create_function( - |lua, ()| { + let always_yield = lua + .create_function(|lua, ()| { unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } Ok(()) }) .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!(thread.resume::<(i32, String, f32)>(()).unwrap(), (42, String::from("69420"), 45.6)); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); // Trigger the continuation - let cont_func = lua.create_function_with_luau_continuation( - |lua, a: u64| { - unsafe { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e) - } - } - Ok(()) - }, - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 39) - } - ).expect("Failed to create cont_func"); - - let luau_func = lua.load(" + let cont_func = lua + .create_function_with_luau_continuation( + |lua, a: u64| { + unsafe { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + } + } + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " local cont_func = ... local res = cont_func(1) return res + 1 - ").into_function().expect("Failed to create function"); - let th = lua.create_thread(luau_func).expect("Failed to create luau thread"); + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); - let v = th.resume::(cont_func).expect("Failed to resume"); + let v = th + .resume::(cont_func) + .expect("Failed to resume"); let v = th.resume::(v).expect("Failed to load continuation"); assert_eq!(v, 41); - let always_yield = lua.create_function_with_luau_continuation( - |lua, ()| { - unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } - Ok(()) - }, - |_lua, _, mv: mlua::MultiValue| { - println!("Reached second continuation"); - if mv.is_empty() { - return Ok(mv); - } - Err(mlua::Error::external(format!("a{}", mv.len()))) - } - ) - .unwrap(); + let always_yield = lua + .create_function_with_luau_continuation( + |lua, ()| { + unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + Ok(()) + }, + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); let mv = thread.resume::(()).unwrap(); - assert!(thread.resume::(mv).unwrap_err().to_string().starts_with("a3")); -} \ No newline at end of file + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_luau_continuation( + |lua, a: u64| { + unsafe { + match lua.set_yield_args((a + 1, 1)) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + } + } + Ok(()) + }, + |lua, _status, args: mlua::MultiValue| { + println!("Reached cont recursive: {:?}", args); + + if args.len() == 5 { + return 6_i32.into_lua_multi(lua); + } + + unsafe { lua.set_yield_args((args.len() + 1, args))? } // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) + (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); +} From 85416fdd3064dc119a719261079329c5f1087bc1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 23:46:21 -0400 Subject: [PATCH 10/46] Document continuations a bit more --- src/state.rs | 30 +++++++++++++++++++++--------- src/state/raw.rs | 2 +- src/state/util.rs | 3 +++ tests/luau/cont.rs | 25 +++++++++++-------------- tests/thread.rs | 16 +++++++++------- 5 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/state.rs b/src/state.rs index 7d0c6885..341adf8d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1292,13 +1292,21 @@ impl Lua { })) } - /// Same as ``create_function`` but with support for continuations. + /// Same as ``create_function`` but with support for luau-style continuations. /// - /// Currently only luau-style continuations are supported at this time. + /// Other Lua versions have completely different continuation semantics and are both + /// not supported at this time, much more complicated to support within mlua and cannot + /// be supported with this API in any case. /// - /// The values passed to the continuation will either be the yielded arguments - /// from the function or the arguments from resuming the thread. Yielded values - /// from a continuation become the resume args. + /// The values passed to the continuation will be the yielded arguments + /// from the function for the initial continuation call. If yielding from a continuation, + /// the yielded results will be returned to the ``Thread::resume`` caller. The arguments + /// passed in the next ``Thread::resume`` call will then be the arguments passed to the yielding + /// continuation upon resumption. + /// + /// Returning a value from a continuation without setting yield + /// arguments will then be returned as the final return value of the Luau function call. + /// Values returned in a function in which there is also yielding will be ignored #[cfg(feature = "luau")] pub fn create_function_with_luau_continuation( &self, @@ -2143,16 +2151,20 @@ impl Lua { } /// Sets the yields arguments. Note that ``Ok(())`` must be returned for the Rust function - /// to actually yield. This method is mostly useful with Luau continuations and Rust-Rust - /// yields + /// to actually yield. Any values returned in a function in which there is also yielding may + /// be ignored. + /// + /// This method is mostly useful with Luau continuations and Rust-Rust yields + /// due to the Rust/Lua boundary /// /// If this function cannot yield, it will raise a runtime error. /// /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield /// or not until it reaches the Lua state. /// - /// Potentially unsafe at this time. Use with caution. - pub unsafe fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { + /// While this method *should be safe*, it is new and may have bugs lurking within. Use + /// with caution + pub fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { let raw = self.lock(); #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] if !raw.is_yieldable() { diff --git a/src/state/raw.rs b/src/state/raw.rs index a151c231..147e6cbe 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1218,7 +1218,7 @@ impl RawLua { unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, state, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); diff --git a/src/state/util.rs b/src/state/util.rs index 922d36f3..5c4d1409 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -157,6 +157,9 @@ where } /// An yieldable version of `callback_error_ext` +/// +/// Unlike ``callback_error_ext``, this method requires a c_int return +/// and not a generic R pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 0d90b8f4..ce8f58f8 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -30,7 +30,7 @@ fn test_luau_continuation() { // basic yield test before we go any further let always_yield = lua .create_function(|lua, ()| { - unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + lua.set_yield_args((42, "69420".to_string(), 45.6))?; Ok(()) }) .unwrap(); @@ -45,12 +45,10 @@ fn test_luau_continuation() { let cont_func = lua .create_function_with_luau_continuation( |lua, a: u64| { - unsafe { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } - } + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; Ok(()) }, |_lua, _status, a: u64| { @@ -70,6 +68,7 @@ fn test_luau_continuation() { ) .into_function() .expect("Failed to create function"); + let th = lua .create_thread(luau_func) .expect("Failed to create luau thread"); @@ -84,7 +83,7 @@ fn test_luau_continuation() { let always_yield = lua .create_function_with_luau_continuation( |lua, ()| { - unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + lua.set_yield_args((42, "69420".to_string(), 45.6))?; Ok(()) }, |_lua, _, mv: mlua::MultiValue| { @@ -108,11 +107,9 @@ fn test_luau_continuation() { let cont_func = lua .create_function_with_luau_continuation( |lua, a: u64| { - unsafe { - match lua.set_yield_args((a + 1, 1)) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } + match lua.set_yield_args((a + 1, 1)) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), } Ok(()) }, @@ -123,7 +120,7 @@ fn test_luau_continuation() { return 6_i32.into_lua_multi(lua); } - unsafe { lua.set_yield_args((args.len() + 1, args))? } // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) + lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored }, ) diff --git a/tests/thread.rs b/tests/thread.rs index de5aca2b..e6825ff3 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -256,14 +256,16 @@ fn test_thread_resume_error() -> Result<()> { #[test] fn test_thread_yield_args() -> Result<()> { let lua = Lua::new(); - let always_yield = lua.create_function( - |lua, ()| { - unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } - Ok(()) - })?; + let always_yield = lua.create_function(|lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + })?; let thread = lua.create_thread(always_yield)?; - assert_eq!(thread.resume::<(i32, String, f32)>(())?, (42, String::from("69420"), 45.6)); + assert_eq!( + thread.resume::<(i32, String, f32)>(())?, + (42, String::from("69420"), 45.6) + ); Ok(()) -} \ No newline at end of file +} From b3a54ac15a4b54f8424cdd842728bff5e5074e91 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 00:18:47 -0400 Subject: [PATCH 11/46] reuse same preallocatedfailure in both callback_error_ext and callback_error_ext_yieldable --- src/state/raw.rs | 6 +- src/state/util.rs | 181 +++++++++++++++------------------------------ tests/luau/cont.rs | 2 + 3 files changed, 66 insertions(+), 123 deletions(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index 147e6cbe..74939b4c 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1164,7 +1164,7 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1205,7 +1205,7 @@ impl RawLua { ) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1218,7 +1218,7 @@ impl RawLua { unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); diff --git a/src/state/util.rs b/src/state/util.rs index 5c4d1409..cbdee8b4 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -24,6 +24,65 @@ impl Drop for StateGuard<'_> { } } +pub(crate) enum PreallocatedFailure { + New(*mut WrappedFailure), + Reserved, +} + +impl PreallocatedFailure { + unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { + if (*extra).wrapped_failure_top > 0 { + (*extra).wrapped_failure_top -= 1; + return PreallocatedFailure::Reserved; + } + + // We need to check stack for Luau in case when callback is called from interrupt + // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + // Place it to the beginning of the stack + let ud = WrappedFailure::new_userdata(state); + ffi::lua_insert(state, 1); + PreallocatedFailure::New(ud) + } + + #[cold] + unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { + let ref_thread = (*extra).ref_thread; + match *self { + PreallocatedFailure::New(ud) => { + ffi::lua_settop(state, 1); + ud + } + PreallocatedFailure::Reserved => { + let index = (*extra).wrapped_failure_pool.pop().unwrap(); + ffi::lua_settop(state, 0); + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + ffi::lua_xpush(ref_thread, state, index); + ffi::lua_pushnil(ref_thread); + ffi::lua_replace(ref_thread, index); + (*extra).ref_free.push(index); + ffi::lua_touserdata(state, -1) as *mut WrappedFailure + } + } + } + + unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { + let ref_thread = (*extra).ref_thread; + match self { + PreallocatedFailure::New(_) => { + ffi::lua_rotate(state, 1, -1); + ffi::lua_xmove(state, ref_thread, 1); + let index = ref_stack_pop(extra); + (*extra).wrapped_failure_pool.push(index); + (*extra).wrapped_failure_top += 1; + } + PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, + } + } +} + // An optimized version of `callback_error` that does not allocate `WrappedFailure` userdata // and instead reuses unused values from previous calls (or allocates new). pub(crate) unsafe fn callback_error_ext( @@ -41,65 +100,6 @@ where let nargs = ffi::lua_gettop(state); - enum PreallocatedFailure { - New(*mut WrappedFailure), - Reserved, - } - - impl PreallocatedFailure { - unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { - if (*extra).wrapped_failure_top > 0 { - (*extra).wrapped_failure_top -= 1; - return PreallocatedFailure::Reserved; - } - - // We need to check stack for Luau in case when callback is called from interrupt - // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - // Place it to the beginning of the stack - let ud = WrappedFailure::new_userdata(state); - ffi::lua_insert(state, 1); - PreallocatedFailure::New(ud) - } - - #[cold] - unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { - let ref_thread = (*extra).ref_thread; - match *self { - PreallocatedFailure::New(ud) => { - ffi::lua_settop(state, 1); - ud - } - PreallocatedFailure::Reserved => { - let index = (*extra).wrapped_failure_pool.pop().unwrap(); - ffi::lua_settop(state, 0); - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - ffi::lua_xpush(ref_thread, state, index); - ffi::lua_pushnil(ref_thread); - ffi::lua_replace(ref_thread, index); - (*extra).ref_free.push(index); - ffi::lua_touserdata(state, -1) as *mut WrappedFailure - } - } - } - - unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { - let ref_thread = (*extra).ref_thread; - match self { - PreallocatedFailure::New(_) => { - ffi::lua_rotate(state, 1, -1); - ffi::lua_xmove(state, ref_thread, 1); - let index = ref_stack_pop(extra); - (*extra).wrapped_failure_pool.push(index); - (*extra).wrapped_failure_top += 1; - } - PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, - } - } - } - // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory // to store a wrapped failure (error or panic) *before* we proceed. let prealloc_failure = PreallocatedFailure::reserve(state, extra); @@ -167,7 +167,7 @@ pub(crate) unsafe fn callback_error_ext_yieldable( f: F, ) -> c_int where - F: FnOnce(*mut ExtraData, *mut ffi::lua_State, c_int) -> Result, + F: FnOnce(*mut ExtraData, c_int) -> Result, { if extra.is_null() { extra = ExtraData::get(state); @@ -175,65 +175,6 @@ where let nargs = ffi::lua_gettop(state); - enum PreallocatedFailure { - New(*mut WrappedFailure), - Reserved, - } - - impl PreallocatedFailure { - unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { - if (*extra).wrapped_failure_top > 0 { - (*extra).wrapped_failure_top -= 1; - return PreallocatedFailure::Reserved; - } - - // We need to check stack for Luau in case when callback is called from interrupt - // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - // Place it to the beginning of the stack - let ud = WrappedFailure::new_userdata(state); - ffi::lua_insert(state, 1); - PreallocatedFailure::New(ud) - } - - #[cold] - unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { - let ref_thread = (*extra).ref_thread; - match *self { - PreallocatedFailure::New(ud) => { - ffi::lua_settop(state, 1); - ud - } - PreallocatedFailure::Reserved => { - let index = (*extra).wrapped_failure_pool.pop().unwrap(); - ffi::lua_settop(state, 0); - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - ffi::lua_xpush(ref_thread, state, index); - ffi::lua_pushnil(ref_thread); - ffi::lua_replace(ref_thread, index); - (*extra).ref_free.push(index); - ffi::lua_touserdata(state, -1) as *mut WrappedFailure - } - } - } - - unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { - let ref_thread = (*extra).ref_thread; - match self { - PreallocatedFailure::New(_) => { - ffi::lua_rotate(state, 1, -1); - ffi::lua_xmove(state, ref_thread, 1); - let index = ref_stack_pop(extra); - (*extra).wrapped_failure_pool.push(index); - (*extra).wrapped_failure_top += 1; - } - PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, - } - } - } - // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory // to store a wrapped failure (error or panic) *before* we proceed. let prealloc_failure = PreallocatedFailure::reserve(state, extra); @@ -241,7 +182,7 @@ where match catch_unwind(AssertUnwindSafe(|| { let rawlua = (*extra).raw_lua(); let _guard = StateGuard::new(rawlua, state); - f(extra, state, nargs) + f(extra, nargs) })) { Ok(Ok(r)) => { let raw = extra.as_ref().unwrap_unchecked().raw_lua(); diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index ce8f58f8..675d90d8 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -148,9 +148,11 @@ fn test_luau_continuation() { let v = th .resume::(v) .expect("Failed to load continuation"); + println!("v={:?}", v); let v = th .resume::(v) .expect("Failed to load continuation"); + println!("v={:?}", v); let v = th .resume::(v) .expect("Failed to load continuation"); From 085c62a68e36c95402b2a5c70d4acd013d3a9fb5 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 04:56:06 -0400 Subject: [PATCH 12/46] handle mainthread edgecase --- src/state/util.rs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index cbdee8b4..517ae487 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -189,10 +189,23 @@ where let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); if !values.is_empty() { + if raw.state() == state { + // Edge case: main thread is being yielded + // + // We need to pop/clear stack early, then push args + ffi::lua_pop(state, -1); + } + match values.push_into_stack_multi(raw) { Ok(nargs) => { - ffi::lua_pop(state, -1); - ffi::lua_xmove(raw.state(), state, nargs); + // If not main thread, then clear and xmove to target thread + if raw.state() != state { + // luau preserves the stack making yieldable continuations ugly and leaky + // + // Even outside of luau, clearing the stack is probably desirable + ffi::lua_pop(state, -1); + ffi::lua_xmove(raw.state(), state, nargs); + } return ffi::lua_yield(state, nargs); } Err(err) => { From 198b8575801d10557834b4c631f513ee3af3b3f7 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 07:06:12 -0400 Subject: [PATCH 13/46] fix import --- src/state/raw.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index 74939b4c..897e693a 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,9 +12,7 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -#[cfg(not(feature = "luau"))] -use crate::state::util::callback_error_ext; -use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; +use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; From 2cd34db321417b0d28a5be3f6a38284caa8b4cc6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 12:04:44 -0400 Subject: [PATCH 14/46] support empty yield args --- src/state.rs | 2 +- src/state/extra.rs | 4 +-- src/state/util.rs | 2 +- tests/luau/cont.rs | 80 ++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 81 insertions(+), 7 deletions(-) diff --git a/src/state.rs b/src/state.rs index 341adf8d..dacf62b9 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2171,7 +2171,7 @@ impl Lua { return Err(Error::runtime("cannot yield across Rust/Lua boundary.")); } unsafe { - raw.extra.get().as_mut().unwrap_unchecked().yielded_values = args.into_lua_multi(self)?; + raw.extra.get().as_mut().unwrap_unchecked().yielded_values = Some(args.into_lua_multi(self)?); } Ok(()) } diff --git a/src/state/extra.rs b/src/state/extra.rs index 445c232c..7c18e9b8 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -96,7 +96,7 @@ pub(crate) struct ExtraData { pub(super) enable_jit: bool, // Values currently being yielded from Lua.yield() - pub(super) yielded_values: MultiValue, + pub(super) yielded_values: Option, } impl Drop for ExtraData { @@ -198,7 +198,7 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, - yielded_values: MultiValue::with_capacity(0), + yielded_values: None, })); // Store it in the registry diff --git a/src/state/util.rs b/src/state/util.rs index 517ae487..6b8d702d 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -188,7 +188,7 @@ where let raw = extra.as_ref().unwrap_unchecked().raw_lua(); let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - if !values.is_empty() { + if let Some(values) = values { if raw.state() == state { // Edge case: main thread is being yielded // diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 675d90d8..9cb2b9a8 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -4,10 +4,84 @@ use mlua::Lua; #[test] fn test_luau_continuation() { - // Yielding continuation - mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); - let lua = Lua::new(); + // No yielding continuation fflag test + let cont_func = lua + .create_function_with_luau_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + // empty yield args test + let cont_func = lua + .create_function_with_luau_continuation( + |lua, _: ()| { + match lua.set_yield_args(()) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res - 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + assert!(v.is_empty()); + let v = th.resume::(v).expect("Failed to load continuation"); + assert_eq!(v, -1); + + // Yielding continuation fflag test + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); let cont_func = lua .create_function_with_luau_continuation( From 23b77056229b316f5b0eea407924f8945fa03abc Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 12:29:52 -0400 Subject: [PATCH 15/46] add more tests for cont --- src/lib.rs | 3 +++ src/prelude.rs | 3 +++ tests/luau/cont.rs | 44 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index de321205..3b27345d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,6 +126,9 @@ pub use crate::value::{Nil, Value}; #[cfg(not(feature = "luau"))] pub use crate::hook::HookTriggers; +#[cfg(feature = "luau")] +pub use crate::thread::LuauContinuationStatus; + #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub use crate::{ diff --git a/src/prelude.rs b/src/prelude.rs index eeeaea26..5becef07 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -28,6 +28,9 @@ pub use crate::{ NavigateError as LuaNavigateError, Require as LuaRequire, Vector as LuaVector, }; +#[cfg(feature = "luau")] +pub use crate::LuauContinuationStatus; + #[cfg(feature = "async")] #[doc(no_inline)] pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn}; diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 9cb2b9a8..58ffdd2c 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -187,10 +187,11 @@ fn test_luau_continuation() { } Ok(()) }, - |lua, _status, args: mlua::MultiValue| { + |lua, status, args: mlua::MultiValue| { println!("Reached cont recursive: {:?}", args); if args.len() == 5 { + assert_eq!(status, mlua::LuauContinuationStatus::Ok); return 6_i32.into_lua_multi(lua); } @@ -237,4 +238,45 @@ fn test_luau_continuation() { let v = th.resume::(v).expect("Failed to load continuation"); assert_eq!(v, 7); + + // test panics + let cont_func = lua + .create_function_with_luau_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local ok, res = pcall(cont_func, 1) + assert(not ok) + return tostring(res) + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); } From dd3ea057e79762eea1379c6a4abdde4353c1dce6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 12:43:43 -0400 Subject: [PATCH 16/46] ensure check_stack of state --- src/state/util.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/state/util.rs b/src/state/util.rs index 6b8d702d..ff39e70b 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; -use crate::util::{self, get_internal_metatable, WrappedFailure}; +use crate::util::{self, check_stack, get_internal_metatable, WrappedFailure}; struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); @@ -204,6 +204,17 @@ where // // Even outside of luau, clearing the stack is probably desirable ffi::lua_pop(state, -1); + if let Err(err) = check_stack(state, nargs) { + let wrapped_error = prealloc_failure.r#use(state, extra); + ptr::write( + wrapped_error, + WrappedFailure::Error(Error::external(err.to_string())), + ); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + + ffi::lua_error(state) + } ffi::lua_xmove(raw.state(), state, nargs); } return ffi::lua_yield(state, nargs); From 52176a8271fb7709ad24636a4d64436c06aa6d87 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 20:39:21 -0400 Subject: [PATCH 17/46] rename to luau continuation status --- src/lib.rs | 5 +---- src/prelude.rs | 21 +++++++++------------ src/state.rs | 16 ++++++++++------ src/thread.rs | 8 +++----- tests/luau/cont.rs | 16 ++++++++-------- 5 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3b27345d..20bf3237 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,7 +110,7 @@ pub use crate::state::{GCMode, Lua, LuaOptions, WeakLua}; pub use crate::stdlib::StdLib; pub use crate::string::{BorrowedBytes, BorrowedStr, String}; pub use crate::table::{Table, TablePairs, TableSequence}; -pub use crate::thread::{Thread, ThreadStatus}; +pub use crate::thread::{ContinuationStatus, Thread, ThreadStatus}; pub use crate::traits::{ FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut, ObjectLike, }; @@ -126,9 +126,6 @@ pub use crate::value::{Nil, Value}; #[cfg(not(feature = "luau"))] pub use crate::hook::HookTriggers; -#[cfg(feature = "luau")] -pub use crate::thread::LuauContinuationStatus; - #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub use crate::{ diff --git a/src/prelude.rs b/src/prelude.rs index 5becef07..d2324c22 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -3,15 +3,15 @@ #[doc(no_inline)] pub use crate::{ AnyUserData as LuaAnyUserData, BorrowedBytes as LuaBorrowedBytes, BorrowedStr as LuaBorrowedStr, - Chunk as LuaChunk, Either as LuaEither, Error as LuaError, ErrorContext as LuaErrorContext, - ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, FromLua, FromLuaMulti, - Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, Integer as LuaInteger, - IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, LuaNativeFnMut, LuaOptions, - MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, - ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, - String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, - Thread as LuaThread, ThreadStatus as LuaThreadStatus, UserData as LuaUserData, - UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, + Chunk as LuaChunk, ContinuationStatus as LuaContinuationStatus, Either as LuaEither, Error as LuaError, + ErrorContext as LuaErrorContext, ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, + FromLua, FromLuaMulti, Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, + Integer as LuaInteger, IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, + LuaNativeFnMut, LuaOptions, MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, + Number as LuaNumber, ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, + StdLib as LuaStdLib, String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, + TableSequence as LuaTableSequence, Thread as LuaThread, ThreadStatus as LuaThreadStatus, + UserData as LuaUserData, UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, Variadic as LuaVariadic, VmState as LuaVmState, WeakLua, @@ -28,9 +28,6 @@ pub use crate::{ NavigateError as LuaNavigateError, Require as LuaRequire, Vector as LuaVector, }; -#[cfg(feature = "luau")] -pub use crate::LuauContinuationStatus; - #[cfg(feature = "async")] #[doc(no_inline)] pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn}; diff --git a/src/state.rs b/src/state.rs index dacf62b9..bb3644e0 100644 --- a/src/state.rs +++ b/src/state.rs @@ -20,7 +20,7 @@ use crate::table::Table; use crate::thread::Thread; #[cfg(feature = "luau")] -use crate::thread::LuauContinuationStatus; +use crate::thread::ContinuationStatus; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ @@ -1295,8 +1295,7 @@ impl Lua { /// Same as ``create_function`` but with support for luau-style continuations. /// /// Other Lua versions have completely different continuation semantics and are both - /// not supported at this time, much more complicated to support within mlua and cannot - /// be supported with this API in any case. + /// not supported at this time. /// /// The values passed to the continuation will be the yielded arguments /// from the function for the initial continuation call. If yielding from a continuation, @@ -1308,19 +1307,24 @@ impl Lua { /// arguments will then be returned as the final return value of the Luau function call. /// Values returned in a function in which there is also yielding will be ignored #[cfg(feature = "luau")] - pub fn create_function_with_luau_continuation( + pub fn create_function_with_continuation( &self, func: F, cont: FC, ) -> Result where F: Fn(&Lua, A) -> Result + MaybeSend + 'static, - FC: Fn(&Lua, LuauContinuationStatus, AC) -> Result + MaybeSend + 'static, + FC: Fn(&Lua, ContinuationStatus, AC) -> Result + MaybeSend + 'static, A: FromLuaMulti, AC: FromLuaMulti, R: IntoLuaMulti, RC: IntoLuaMulti, { + // On luau, use a callback with luau continuation + // + // For other lua versions (in future), this will instead + // make a wrapper function that at the end calls lua_yieldk + #[cfg(feature = "luau")] (self.lock()).create_callback_with_luau_continuation( Box::new(move |rawlua, nargs| unsafe { let args = A::from_stack_args(nargs, 1, None, rawlua)?; @@ -1328,7 +1332,7 @@ impl Lua { }), Box::new(move |rawlua, nargs, status| unsafe { let args = AC::from_stack_args(nargs, 1, None, rawlua)?; - let status = LuauContinuationStatus::from_status(status); + let status = ContinuationStatus::from_status(status); cont(rawlua.lua(), status, args)?.push_into_stack_multi(rawlua) }), ) diff --git a/src/thread.rs b/src/thread.rs index 6b834081..3eb9f9f7 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -26,17 +26,15 @@ use { }, }; -/// Luau continuation final status -#[cfg(feature = "luau")] +/// Continuation thread status. Can either be Ok, Yielded (rare, but can happen) or Error #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum LuauContinuationStatus { +pub enum ContinuationStatus { Ok, Yielded, Error, } -#[cfg(feature = "luau")] -impl LuauContinuationStatus { +impl ContinuationStatus { pub(crate) fn from_status(status: c_int) -> Self { match status { ffi::LUA_YIELD => Self::Yielded, diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 58ffdd2c..a22987a7 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -7,7 +7,7 @@ fn test_luau_continuation() { let lua = Lua::new(); // No yielding continuation fflag test let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, a: u64| { match lua.set_yield_args(a) { Ok(()) => println!("set_yield_args called"), @@ -46,7 +46,7 @@ fn test_luau_continuation() { // empty yield args test let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, _: ()| { match lua.set_yield_args(()) { Ok(()) => println!("set_yield_args called"), @@ -84,7 +84,7 @@ fn test_luau_continuation() { mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |_lua, a: u64| Ok(a + 1), |_lua, _status, a: u64| { println!("Reached cont"); @@ -117,7 +117,7 @@ fn test_luau_continuation() { // Trigger the continuation let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, a: u64| { match lua.set_yield_args(a) { Ok(()) => println!("set_yield_args called"), @@ -155,7 +155,7 @@ fn test_luau_continuation() { assert_eq!(v, 41); let always_yield = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, ()| { lua.set_yield_args((42, "69420".to_string(), 45.6))?; Ok(()) @@ -179,7 +179,7 @@ fn test_luau_continuation() { .starts_with("a3")); let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, a: u64| { match lua.set_yield_args((a + 1, 1)) { Ok(()) => println!("set_yield_args called"), @@ -191,7 +191,7 @@ fn test_luau_continuation() { println!("Reached cont recursive: {:?}", args); if args.len() == 5 { - assert_eq!(status, mlua::LuauContinuationStatus::Ok); + assert_eq!(status, mlua::ContinuationStatus::Ok); return 6_i32.into_lua_multi(lua); } @@ -241,7 +241,7 @@ fn test_luau_continuation() { // test panics let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, a: u64| { match lua.set_yield_args(a) { Ok(()) => println!("set_yield_args called"), From 10d2df40f523663978b6b4b67f74469226650755 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 21:46:58 -0400 Subject: [PATCH 18/46] add support for continuations outside luau --- mlua-sys/src/lua52/lua.rs | 5 + mlua-sys/src/lua53/lua.rs | 5 + mlua-sys/src/lua54/lua.rs | 5 + src/state.rs | 28 ++-- src/state/extra.rs | 6 + src/state/raw.rs | 145 +++++++++++++------ src/state/util.rs | 69 +++++++++ src/thread.rs | 1 + src/types.rs | 12 +- src/util/types.rs | 12 +- tests/luau.rs | 3 - tests/luau/cont.rs | 282 ------------------------------------- tests/thread.rs | 286 ++++++++++++++++++++++++++++++++++++++ 13 files changed, 499 insertions(+), 360 deletions(-) delete mode 100644 tests/luau/cont.rs diff --git a/mlua-sys/src/lua52/lua.rs b/mlua-sys/src/lua52/lua.rs index e5239cee..8de7cdee 100644 --- a/mlua-sys/src/lua52/lua.rs +++ b/mlua-sys/src/lua52/lua.rs @@ -272,6 +272,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_CFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Garbage-collection function and options // diff --git a/mlua-sys/src/lua53/lua.rs b/mlua-sys/src/lua53/lua.rs index 2729fdcd..c3d82e63 100644 --- a/mlua-sys/src/lua53/lua.rs +++ b/mlua-sys/src/lua53/lua.rs @@ -286,6 +286,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_KFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Garbage-collection function and options // diff --git a/mlua-sys/src/lua54/lua.rs b/mlua-sys/src/lua54/lua.rs index 15a30444..c74e1576 100644 --- a/mlua-sys/src/lua54/lua.rs +++ b/mlua-sys/src/lua54/lua.rs @@ -299,6 +299,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_KFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Warning-related functions // diff --git a/src/state.rs b/src/state.rs index bb3644e0..df001893 100644 --- a/src/state.rs +++ b/src/state.rs @@ -19,7 +19,7 @@ use crate::string::String; use crate::table::Table; use crate::thread::Thread; -#[cfg(feature = "luau")] +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] use crate::thread::ContinuationStatus; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; @@ -1292,21 +1292,20 @@ impl Lua { })) } - /// Same as ``create_function`` but with support for luau-style continuations. - /// - /// Other Lua versions have completely different continuation semantics and are both - /// not supported at this time. + /// Same as ``create_function`` but with an added continuation function. /// /// The values passed to the continuation will be the yielded arguments - /// from the function for the initial continuation call. If yielding from a continuation, - /// the yielded results will be returned to the ``Thread::resume`` caller. The arguments - /// passed in the next ``Thread::resume`` call will then be the arguments passed to the yielding - /// continuation upon resumption. + /// from the function for the initial continuation call. On Luau, if yielding from a + /// continuation, the yielded results will be returned to the ``Thread::resume`` caller. The + /// arguments passed in the next ``Thread::resume`` call will then be the arguments passed + /// to the yielding continuation upon resumption. /// /// Returning a value from a continuation without setting yield - /// arguments will then be returned as the final return value of the Luau function call. + /// arguments will then be returned as the final return value of the Lua function call. /// Values returned in a function in which there is also yielding will be ignored - #[cfg(feature = "luau")] + /// + /// Note that yielding in continuations is only supported on Luau + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] pub fn create_function_with_continuation( &self, func: F, @@ -1320,12 +1319,7 @@ impl Lua { R: IntoLuaMulti, RC: IntoLuaMulti, { - // On luau, use a callback with luau continuation - // - // For other lua versions (in future), this will instead - // make a wrapper function that at the end calls lua_yieldk - #[cfg(feature = "luau")] - (self.lock()).create_callback_with_luau_continuation( + (self.lock()).create_callback_with_continuation( Box::new(move |rawlua, nargs| unsafe { let args = A::from_stack_args(nargs, 1, None, rawlua)?; func(rawlua.lua(), args)?.push_into_stack_multi(rawlua) diff --git a/src/state/extra.rs b/src/state/extra.rs index 7c18e9b8..1c49a064 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -13,6 +13,7 @@ use crate::error::Result; use crate::state::RawLua; use crate::stdlib::StdLib; use crate::types::{AppData, ReentrantMutex, XRc}; + use crate::userdata::RawUserDataRegistry; use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, WrappedFailure}; @@ -97,6 +98,9 @@ pub(crate) struct ExtraData { // Values currently being yielded from Lua.yield() pub(super) yielded_values: Option, + + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + pub(super) yield_continuation: bool, } impl Drop for ExtraData { @@ -199,6 +203,8 @@ impl ExtraData { #[cfg(feature = "luau")] running_gc: false, yielded_values: None, + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + yield_continuation: false, })); // Store it in the registry diff --git a/src/state/raw.rs b/src/state/raw.rs index 897e693a..298e88e7 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,7 +12,9 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; +#[cfg(not(feature = "luau"))] +use crate::state::util::callback_error_ext; +use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; @@ -23,8 +25,10 @@ use crate::types::{ MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; -#[cfg(feature = "luau")] -use crate::types::{LuauContinuation, LuauContinuationUpvalue}; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::Continuation; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::ContinuationUpvalue; use crate::userdata::{ init_userdata_metatable, AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, @@ -202,8 +206,8 @@ impl RawLua { init_internal_metatable::>>(state, None)?; init_internal_metatable::(state, None)?; init_internal_metatable::(state, None)?; - #[cfg(feature = "luau")] - init_internal_metatable::(state, None)?; + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + init_internal_metatable::(state, None)?; #[cfg(not(feature = "luau"))] init_internal_metatable::(state, None)?; #[cfg(feature = "async")] @@ -1195,56 +1199,105 @@ impl RawLua { } // Creates a Function out of a Callback and a continuation containing a 'static Fn. - #[cfg(feature = "luau")] - pub(crate) fn create_callback_with_luau_continuation( + // + // In Luau, uses pushcclosurek + // + // In Lua 5.2/5.3/5.4/JIT, makes a normal function that then yields to the continuation via yieldk + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + pub(crate) fn create_callback_with_continuation( &self, func: Callback, - cont: LuauContinuation, + cont: Continuation, ) -> Result { - unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { - let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => (func.0)(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) - } + #[cfg(feature = "luau")] + { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.0)(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }) + } - unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { - let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => (func.1)(rawlua, nargs, status), - None => Err(Error::CallbackDestructed), - } - }) - } + unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }) + } - let state = self.state(); - unsafe { - let _sg = StackGuard::new(state); - check_stack(state, 4)?; + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; - let func = Some((func, cont)); - let extra = XRc::clone(&self.extra); - let protect = !self.unlikely_memory_error(); - push_internal_userdata(state, LuauContinuationUpvalue { data: func, extra }, protect)?; - if protect { - protect_lua!(state, 1, 1, fn(state) { + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, ContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + })?; + } else { ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); - })?; - } else { - ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + } + + Ok(Function(self.pop_ref())) } + } - Ok(Function(self.pop_ref())) + #[cfg(not(feature = "luau"))] + { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some((ref func, _)) => match func(rawlua, nargs) { + Ok(r) => { + (*extra).yield_continuation = true; + Ok(r) + } + Err(e) => Err(e), + }, + None => Err(Error::CallbackDestructed), + } + }) + } + + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, ContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosure(state, call_callback, 1); + })?; + } else { + ffi::lua_pushcclosure(state, call_callback, 1); + } + + Ok(Function(self.pop_ref())) + } } } diff --git a/src/state/util.rs b/src/state/util.rs index ff39e70b..dd906d56 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -9,6 +9,9 @@ use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; use crate::util::{self, check_stack, get_internal_metatable, WrappedFailure}; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit"), not(feature = "luau")))] +use crate::{types::ContinuationUpvalue, util::get_userdata}; + struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); impl<'a> StateGuard<'a> { @@ -112,6 +115,8 @@ where Ok(Ok(r)) => { // Ensure yielded values are cleared take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + take(&mut extra.as_mut().unwrap_unchecked().yield_continuation); // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); @@ -188,6 +193,9 @@ where let raw = extra.as_ref().unwrap_unchecked().raw_lua(); let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + let yield_cont = take(&mut extra.as_mut().unwrap_unchecked().yield_continuation); + if let Some(values) = values { if raw.state() == state { // Edge case: main thread is being yielded @@ -217,6 +225,67 @@ where } ffi::lua_xmove(raw.state(), state, nargs); } + + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + { + // Yield to a continuation. Unlike luau, we need to do this manually and on the + // fly using a yieldk call + if yield_cont { + // On Lua 5.2, status and ctx are not present, so use 0 as status for + // compatibility + #[cfg(feature = "lua52")] + unsafe extern "C-unwind" fn cont_callback( + state: *mut ffi::lua_State, + ) -> c_int { + let upvalue = + get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available + // (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, 0), + None => Err(Error::CallbackDestructed), + } + }, + ) + } + + // Lua 5.3/5.4 case + #[cfg(not(feature = "lua52"))] + unsafe extern "C-unwind" fn cont_callback( + state: *mut ffi::lua_State, + status: c_int, + _ctx: ffi::lua_KContext, + ) -> c_int { + let upvalue = + get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available + // (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }, + ) + } + + return ffi::lua_yieldc(state, nargs, cont_callback); + } + } + return ffi::lua_yield(state, nargs); } Err(err) => { diff --git a/src/thread.rs b/src/thread.rs index 3eb9f9f7..ce096f06 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -35,6 +35,7 @@ pub enum ContinuationStatus { } impl ContinuationStatus { + #[allow(dead_code)] pub(crate) fn from_status(status: c_int) -> Self { match status { ffi::LUA_YIELD => Self::Yielded, diff --git a/src/types.rs b/src/types.rs index 3b6b0d8f..45806247 100644 --- a/src/types.rs +++ b/src/types.rs @@ -39,11 +39,11 @@ pub(crate) type Callback = Box Result + Send + #[cfg(not(feature = "send"))] pub(crate) type Callback = Box Result + 'static>; +#[cfg(all(feature = "send", not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type Continuation = Box Result + Send + 'static>; -#[cfg(all(feature = "send", feature = "luau"))] -pub(crate) type LuauContinuation = Box Result + Send + 'static>; -#[cfg(all(not(feature = "send"), feature = "luau"))] -pub(crate) type LuauContinuation = Box Result + 'static>; +#[cfg(all(not(feature = "send"), not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type Continuation = Box Result + 'static>; pub(crate) type ScopedCallback<'s> = Box Result + 's>; @@ -53,8 +53,8 @@ pub(crate) struct Upvalue { } pub(crate) type CallbackUpvalue = Upvalue>; -#[cfg(feature = "luau")] -pub(crate) type LuauContinuationUpvalue = Upvalue>; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type ContinuationUpvalue = Upvalue>; #[cfg(all(feature = "async", feature = "send"))] pub(crate) type AsyncCallback = diff --git a/src/util/types.rs b/src/util/types.rs index 16945009..8627042f 100644 --- a/src/util/types.rs +++ b/src/util/types.rs @@ -3,8 +3,8 @@ use std::os::raw::c_void; use crate::types::{Callback, CallbackUpvalue}; -#[cfg(feature = "luau")] -use crate::types::LuauContinuationUpvalue; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::ContinuationUpvalue; #[cfg(feature = "async")] use crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}; @@ -37,12 +37,12 @@ impl TypeKey for CallbackUpvalue { } } -#[cfg(feature = "luau")] -impl TypeKey for LuauContinuationUpvalue { +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +impl TypeKey for ContinuationUpvalue { #[inline(always)] fn type_key() -> *const c_void { - static LUAU_CONTINUATION_UPVALUE_TYPE_KEY: u8 = 0; - &LUAU_CONTINUATION_UPVALUE_TYPE_KEY as *const u8 as *const c_void + static CONTINUATION_UPVALUE_TYPE_KEY: u8 = 0; + &CONTINUATION_UPVALUE_TYPE_KEY as *const u8 as *const c_void } } diff --git a/tests/luau.rs b/tests/luau.rs index 7b62bae3..20fbee6a 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -449,6 +449,3 @@ fn test_typeof_error() -> Result<()> { #[path = "luau/require.rs"] mod require; - -#[path = "luau/cont.rs"] -mod cont; diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs deleted file mode 100644 index a22987a7..00000000 --- a/tests/luau/cont.rs +++ /dev/null @@ -1,282 +0,0 @@ -use mlua::IntoLuaMulti; -#[cfg(feature = "luau")] -use mlua::Lua; - -#[test] -fn test_luau_continuation() { - let lua = Lua::new(); - // No yielding continuation fflag test - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 39) - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res + 1 - ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 41); - - // empty yield args test - let cont_func = lua - .create_function_with_continuation( - |lua, _: ()| { - match lua.set_yield_args(()) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res - 1 - ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - assert!(v.is_empty()); - let v = th.resume::(v).expect("Failed to load continuation"); - assert_eq!(v, -1); - - // Yielding continuation fflag test - mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); - - let cont_func = lua - .create_function_with_continuation( - |_lua, a: u64| Ok(a + 1), - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 2) - }, - ) - .expect("Failed to create cont_func"); - - // Ensure normal calls work still - assert_eq!( - lua.load("local cont_func = ...\nreturn cont_func(1)") - .call::(cont_func) - .expect("Failed to call cont_func"), - 2 - ); - - // basic yield test before we go any further - let always_yield = lua - .create_function(|lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }) - .unwrap(); - - let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!( - thread.resume::<(i32, String, f32)>(()).unwrap(), - (42, String::from("69420"), 45.6) - ); - - // Trigger the continuation - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 39) - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res + 1 - ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 41); - - let always_yield = lua - .create_function_with_continuation( - |lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }, - |_lua, _, mv: mlua::MultiValue| { - println!("Reached second continuation"); - if mv.is_empty() { - return Ok(mv); - } - Err(mlua::Error::external(format!("a{}", mv.len()))) - }, - ) - .unwrap(); - - let thread = lua.create_thread(always_yield).unwrap(); - let mv = thread.resume::(()).unwrap(); - assert!(thread - .resume::(mv) - .unwrap_err() - .to_string() - .starts_with("a3")); - - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args((a + 1, 1)) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } - Ok(()) - }, - |lua, status, args: mlua::MultiValue| { - println!("Reached cont recursive: {:?}", args); - - if args.len() == 5 { - assert_eq!(status, mlua::ContinuationStatus::Ok); - return 6_i32.into_lua_multi(lua); - } - - lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) - (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res + 1 - ", - ) - .into_function() - .expect("Failed to create function"); - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - println!("v={:?}", v); - - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - - // (2, 1) followed by () - assert_eq!(v.len(), 2 + 3); - - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 7); - - // test panics - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, _a: u64| { - panic!("Reached continuation which should panic!"); - #[allow(unreachable_code)] - Ok(()) - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local ok, res = pcall(cont_func, 1) - assert(not ok) - return tostring(res) - ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - - let v = th.resume::(v).expect("Failed to load continuation"); - assert!(v.contains("Reached continuation which should panic!")); -} diff --git a/tests/thread.rs b/tests/thread.rs index e6825ff3..a51cf851 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -2,6 +2,9 @@ use std::panic::catch_unwind; use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; +#[cfg(feature = "luau")] +use mlua::IntoLuaMulti; + #[test] fn test_thread() -> Result<()> { let lua = Lua::new(); @@ -269,3 +272,286 @@ fn test_thread_yield_args() -> Result<()> { Ok(()) } + +#[test] +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +fn test_continuation() { + let lua = Lua::new(); + // No yielding continuation fflag test + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + // empty yield args test + let cont_func = lua + .create_function_with_continuation( + |lua, _: ()| { + match lua.set_yield_args(()) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res - 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + assert!(v.is_empty()); + let v = th.resume::(v).expect("Failed to load continuation"); + assert_eq!(v, -1); + + // Yielding continuation test (only supported on luau) + #[cfg(feature = "luau")] + { + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + + let cont_func = lua + .create_function_with_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); + + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func) + .expect("Failed to call cont_func"), + 2 + ); + + // basic yield test before we go any further + let always_yield = lua + .create_function(|lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + }) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); + + // Trigger the continuation + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + let always_yield = lua + .create_function_with_continuation( + |lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + }, + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args((a + 1, 1)) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + } + Ok(()) + }, + |lua, status, args: mlua::MultiValue| { + println!("Reached cont recursive: {:?}", args); + + if args.len() == 5 { + assert_eq!(status, mlua::ContinuationStatus::Ok); + return 6_i32.into_lua_multi(lua); + } + + lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) + (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); + + // test panics + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local ok, res = pcall(cont_func, 1) + assert(not ok) + return tostring(res) + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); + } +} From 8400386fa476aa25c92a89bc1f93b71b57effad1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 21:52:04 -0400 Subject: [PATCH 19/46] fix --- tests/thread.rs | 64 ++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/thread.rs b/tests/thread.rs index a51cf851..5e1a199d 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -512,46 +512,46 @@ fn test_continuation() { let v = th.resume::(v).expect("Failed to load continuation"); assert_eq!(v, 7); + } - // test panics - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, _a: u64| { - panic!("Reached continuation which should panic!"); - #[allow(unreachable_code)] - Ok(()) - }, - ) - .expect("Failed to create cont_func"); + // test panics + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); - let luau_func = lua - .load( - " + let luau_func = lua + .load( + " local cont_func = ... local ok, res = pcall(cont_func, 1) assert(not ok) return tostring(res) ", - ) - .into_function() - .expect("Failed to create function"); + ) + .into_function() + .expect("Failed to create function"); - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); - let v = th - .resume::(cont_func) - .expect("Failed to resume"); + let v = th + .resume::(cont_func) + .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); - assert!(v.contains("Reached continuation which should panic!")); - } + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); } From 5d86852ff99580dbc05a6253982243a0afe5bc53 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 21:57:23 -0400 Subject: [PATCH 20/46] fix --- src/state/raw.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index 298e88e7..a4eb8ccf 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,7 +12,7 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -#[cfg(not(feature = "luau"))] +#[cfg(any(not(feature = "luau"), feature = "luau-vector4"))] use crate::state::util::callback_error_ext; use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; From 9e77e5c556fea8675cfa3a8dbb40346f226b7a77 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 21:59:42 -0400 Subject: [PATCH 21/46] fix --- src/state/raw.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index a4eb8ccf..c49d7dfa 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,9 +12,7 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -#[cfg(any(not(feature = "luau"), feature = "luau-vector4"))] -use crate::state::util::callback_error_ext; -use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; +use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; From eb637554b9ccae59aed2593b94676d8254989361 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 22:13:45 -0400 Subject: [PATCH 22/46] fix yieldable continuation on non-luau --- src/state/extra.rs | 5 - src/state/raw.rs | 55 +++++---- src/state/util.rs | 10 +- tests/thread.rs | 285 +++++++++++++++++++++++---------------------- 4 files changed, 179 insertions(+), 176 deletions(-) diff --git a/src/state/extra.rs b/src/state/extra.rs index 1c49a064..fbb21c66 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -98,9 +98,6 @@ pub(crate) struct ExtraData { // Values currently being yielded from Lua.yield() pub(super) yielded_values: Option, - - #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] - pub(super) yield_continuation: bool, } impl Drop for ExtraData { @@ -203,8 +200,6 @@ impl ExtraData { #[cfg(feature = "luau")] running_gc: false, yielded_values: None, - #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] - yield_continuation: false, })); // Store it in the registry diff --git a/src/state/raw.rs b/src/state/raw.rs index c49d7dfa..bca9d1a1 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1164,15 +1164,21 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => func(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + false, + ) } let state = self.state(); @@ -1260,21 +1266,22 @@ impl RawLua { { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some((ref func, _)) => match func(rawlua, nargs) { - Ok(r) => { - (*extra).yield_continuation = true; - Ok(r) - } - Err(e) => Err(e), - }, - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some((ref func, _)) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) } let state = self.state(); diff --git a/src/state/util.rs b/src/state/util.rs index dd906d56..5c1bb4b6 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -115,8 +115,6 @@ where Ok(Ok(r)) => { // Ensure yielded values are cleared take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] - take(&mut extra.as_mut().unwrap_unchecked().yield_continuation); // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); @@ -170,6 +168,7 @@ pub(crate) unsafe fn callback_error_ext_yieldable( mut extra: *mut ExtraData, wrap_error: bool, f: F, + in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, @@ -193,9 +192,6 @@ where let raw = extra.as_ref().unwrap_unchecked().raw_lua(); let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] - let yield_cont = take(&mut extra.as_mut().unwrap_unchecked().yield_continuation); - if let Some(values) = values { if raw.state() == state { // Edge case: main thread is being yielded @@ -230,7 +226,7 @@ where { // Yield to a continuation. Unlike luau, we need to do this manually and on the // fly using a yieldk call - if yield_cont { + if in_callback_with_continuation { // On Lua 5.2, status and ctx are not present, so use 0 as status for // compatibility #[cfg(feature = "lua52")] @@ -253,6 +249,7 @@ where None => Err(Error::CallbackDestructed), } }, + true, ) } @@ -279,6 +276,7 @@ where None => Err(Error::CallbackDestructed), } }, + true, ) } diff --git a/tests/thread.rs b/tests/thread.rs index 5e1a199d..ff346aef 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -2,7 +2,6 @@ use std::panic::catch_unwind; use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; -#[cfg(feature = "luau")] use mlua::IntoLuaMulti; #[test] @@ -356,163 +355,167 @@ fn test_continuation() { #[cfg(feature = "luau")] { mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + } - let cont_func = lua - .create_function_with_continuation( - |_lua, a: u64| Ok(a + 1), - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 2) - }, - ) - .expect("Failed to create cont_func"); - - // Ensure normal calls work still - assert_eq!( - lua.load("local cont_func = ...\nreturn cont_func(1)") - .call::(cont_func) - .expect("Failed to call cont_func"), - 2 - ); + let cont_func = lua + .create_function_with_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); - // basic yield test before we go any further - let always_yield = lua - .create_function(|lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }) - .unwrap(); + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func) + .expect("Failed to call cont_func"), + 2 + ); - let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!( - thread.resume::<(i32, String, f32)>(()).unwrap(), - (42, String::from("69420"), 45.6) - ); + // basic yield test before we go any further + let always_yield = lua + .create_function(|lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + }) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); + + // Trigger the continuation + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); - // Trigger the continuation - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 39) - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " + let luau_func = lua + .load( + " local cont_func = ... local res = cont_func(1) return res + 1 ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 41); - - let always_yield = lua - .create_function_with_continuation( - |lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }, - |_lua, _, mv: mlua::MultiValue| { - println!("Reached second continuation"); - if mv.is_empty() { - return Ok(mv); - } - Err(mlua::Error::external(format!("a{}", mv.len()))) - }, - ) - .unwrap(); - - let thread = lua.create_thread(always_yield).unwrap(); - let mv = thread.resume::(()).unwrap(); - assert!(thread - .resume::(mv) - .unwrap_err() - .to_string() - .starts_with("a3")); - - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args((a + 1, 1)) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } - Ok(()) - }, - |lua, status, args: mlua::MultiValue| { - println!("Reached cont recursive: {:?}", args); + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); - if args.len() == 5 { + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + let always_yield = lua + .create_function_with_continuation( + |lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + }, + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args((a + 1, 1)) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + } + Ok(()) + }, + |lua, status, args: mlua::MultiValue| { + println!("Reached cont recursive: {:?}", args); + + if args.len() == 5 { + if cfg!(any(feature = "luau", feature = "lua52")) { assert_eq!(status, mlua::ContinuationStatus::Ok); - return 6_i32.into_lua_multi(lua); + } else { + assert_eq!(status, mlua::ContinuationStatus::Yielded); } + return 6_i32.into_lua_multi(lua); + } - lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) - (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored - }, - ) - .expect("Failed to create cont_func"); + lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) + (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored + }, + ) + .expect("Failed to create cont_func"); - let luau_func = lua - .load( - " + let luau_func = lua + .load( + " local cont_func = ... local res = cont_func(1) return res + 1 ", - ) - .into_function() - .expect("Failed to create function"); - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - println!("v={:?}", v); - - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - - // (2, 1) followed by () - assert_eq!(v.len(), 2 + 3); - - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 7); - } + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); // test panics let cont_func = lua From 9def001b4d227c187ede5adabf591d4b2495cdd9 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 22:21:34 -0400 Subject: [PATCH 23/46] fix luau compiler bug --- src/state/raw.rs | 50 ++++++++++++++++++++++++++++++----------------- src/state/util.rs | 3 ++- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index bca9d1a1..98f7e55f 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1217,28 +1217,42 @@ impl RawLua { { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => (func.0)(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.0)(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) } unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => (func.1)(rawlua, nargs, status), - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) } let state = self.state(); diff --git a/src/state/util.rs b/src/state/util.rs index 5c1bb4b6..d40c6978 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -168,7 +168,8 @@ pub(crate) unsafe fn callback_error_ext_yieldable( mut extra: *mut ExtraData, wrap_error: bool, f: F, - in_callback_with_continuation: bool, + #[cfg(feature = "luau")] _in_callback_with_continuation: bool, + #[cfg(not(feature = "luau"))] in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, From e985f8312c216718225807d2690b7949b1fd8019 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 22:35:20 -0400 Subject: [PATCH 24/46] add note on why pop --- src/state/util.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/state/util.rs b/src/state/util.rs index d40c6978..00688072 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -194,6 +194,16 @@ where let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); if let Some(values) = values { + // A note on Luau + // + // When using the yieldable continuations fflag (and in future when the fflag gets removed and + // yieldable continuations) becomes default, we must either pop the top of the + // stack on the state we are resuming or somehow store the number of + // args on top of stack pre-yield and then subtract in the resume in order to get predictable + // behaviour here. See https://github.com/luau-lang/luau/issues/1867 for more information + // + // In this case, popping is easier and leads to less bugs/more ergonomic API. + if raw.state() == state { // Edge case: main thread is being yielded // From 3dcb1315409f3937e5244e091d28669801cfa57a Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 23:49:09 -0400 Subject: [PATCH 25/46] fix docs which incorrectly state only luau supports yieldable continuation --- src/state.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/state.rs b/src/state.rs index df001893..8a856509 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1295,7 +1295,7 @@ impl Lua { /// Same as ``create_function`` but with an added continuation function. /// /// The values passed to the continuation will be the yielded arguments - /// from the function for the initial continuation call. On Luau, if yielding from a + /// from the function for the initial continuation call. If yielding from a /// continuation, the yielded results will be returned to the ``Thread::resume`` caller. The /// arguments passed in the next ``Thread::resume`` call will then be the arguments passed /// to the yielding continuation upon resumption. @@ -1303,8 +1303,6 @@ impl Lua { /// Returning a value from a continuation without setting yield /// arguments will then be returned as the final return value of the Lua function call. /// Values returned in a function in which there is also yielding will be ignored - /// - /// Note that yielding in continuations is only supported on Luau #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] pub fn create_function_with_continuation( &self, From 3a079fe825f2e6668557fca0f0cb22c7f8e27c93 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:20:35 -0400 Subject: [PATCH 26/46] make the api more ergonomic --- src/error.rs | 6 +++++ src/state.rs | 31 ++++++---------------- src/state/extra.rs | 5 ---- src/state/util.rs | 23 +++++++--------- tests/thread.rs | 65 +++++++++------------------------------------- 5 files changed, 35 insertions(+), 95 deletions(-) diff --git a/src/error.rs b/src/error.rs index 483171e5..cb95494b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,4 @@ +use crate::MultiValue; use std::error::Error as StdError; use std::fmt; use std::io::Error as IoError; @@ -205,6 +206,8 @@ pub enum Error { /// Underlying error. cause: Arc, }, + /// A special error variant that tells Rust to yield to Lua with the specified args + Yield(MultiValue), } /// A specialized `Result` type used by `mlua`'s API. @@ -322,6 +325,9 @@ impl fmt::Display for Error { writeln!(fmt, "{context}")?; write!(fmt, "{cause}") }, + Error::Yield(_) => { + write!(fmt, "attempt to yield within a context that does not support yielding") + } } } } diff --git a/src/state.rs b/src/state.rs index 8a856509..7a1a1fba 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2146,33 +2146,18 @@ impl Lua { &*self.raw.data_ptr() } - /// Sets the yields arguments. Note that ``Ok(())`` must be returned for the Rust function - /// to actually yield. Any values returned in a function in which there is also yielding may - /// be ignored. + /// Helper method to set the yield arguments, returning a Error::Yield. /// - /// This method is mostly useful with Luau continuations and Rust-Rust yields + /// This method is mostly useful with continuations and Rust-Rust yields /// due to the Rust/Lua boundary - /// - /// If this function cannot yield, it will raise a runtime error. - /// - /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield - /// or not until it reaches the Lua state. - /// - /// While this method *should be safe*, it is new and may have bugs lurking within. Use - /// with caution - pub fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { - let raw = self.lock(); - #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] - if !raw.is_yieldable() { - return Err(Error::runtime("cannot yield across Rust/Lua boundary.")); - } - unsafe { - raw.extra.get().as_mut().unwrap_unchecked().yielded_values = Some(args.into_lua_multi(self)?); - } - Ok(()) + pub fn yield_with(&self, args: impl IntoLuaMulti) -> Result<()> { + Err(Error::Yield(args.into_lua_multi(self)?)) } - /// Checks if Lua is currently allowed to yield. + /// Checks if Lua could be allowed to yield. + /// + /// Note that this method is not fool proof and is prone to false negatives + /// especially when continuations are involved #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] #[inline] pub fn is_yieldable(&self) -> bool { diff --git a/src/state/extra.rs b/src/state/extra.rs index fbb21c66..45f0e379 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -19,7 +19,6 @@ use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, Wrapp #[cfg(any(feature = "luau", doc))] use crate::chunk::Compiler; -use crate::MultiValue; #[cfg(feature = "async")] use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; @@ -95,9 +94,6 @@ pub(crate) struct ExtraData { pub(super) compiler: Option, #[cfg(feature = "luau-jit")] pub(super) enable_jit: bool, - - // Values currently being yielded from Lua.yield() - pub(super) yielded_values: Option, } impl Drop for ExtraData { @@ -199,7 +195,6 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, - yielded_values: None, })); // Store it in the registry diff --git a/src/state/util.rs b/src/state/util.rs index 00688072..f4cbde4f 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,5 +1,4 @@ use crate::IntoLuaMulti; -use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; @@ -113,9 +112,6 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - // Ensure yielded values are cleared - take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r @@ -163,13 +159,13 @@ where /// /// Unlike ``callback_error_ext``, this method requires a c_int return /// and not a generic R +#[allow(unused_variables)] pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, wrap_error: bool, f: F, - #[cfg(feature = "luau")] _in_callback_with_continuation: bool, - #[cfg(not(feature = "luau"))] in_callback_with_continuation: bool, + in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, @@ -190,10 +186,14 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - let raw = extra.as_ref().unwrap_unchecked().raw_lua(); - let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + // Return unused `WrappedFailure` to the pool + prealloc_failure.release(state, extra); + r + } + Ok(Err(err)) => { + if let Error::Yield(values) = err { + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); - if let Some(values) = values { // A note on Luau // // When using the yieldable continuations fflag (and in future when the fflag gets removed and @@ -311,11 +311,6 @@ where } } - // Return unused `WrappedFailure` to the pool - prealloc_failure.release(state, extra); - r - } - Ok(Err(err)) => { let wrapped_error = prealloc_failure.r#use(state, extra); if !wrap_error { diff --git a/tests/thread.rs b/tests/thread.rs index ff346aef..a5655af2 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -2,8 +2,6 @@ use std::panic::catch_unwind; use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; -use mlua::IntoLuaMulti; - #[test] fn test_thread() -> Result<()> { let lua = Lua::new(); @@ -258,10 +256,7 @@ fn test_thread_resume_error() -> Result<()> { #[test] fn test_thread_yield_args() -> Result<()> { let lua = Lua::new(); - let always_yield = lua.create_function(|lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - })?; + let always_yield = lua.create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)))?; let thread = lua.create_thread(always_yield)?; assert_eq!( @@ -279,13 +274,7 @@ fn test_continuation() { // No yielding continuation fflag test let cont_func = lua .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, + |lua, a: u64| lua.yield_with(a), |_lua, _status, a: u64| { println!("Reached cont"); Ok(a + 39) @@ -318,13 +307,7 @@ fn test_continuation() { // empty yield args test let cont_func = lua .create_function_with_continuation( - |lua, _: ()| { - match lua.set_yield_args(()) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, + |lua, _: ()| lua.yield_with(()), |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), ) .expect("Failed to create cont_func"); @@ -377,10 +360,7 @@ fn test_continuation() { // basic yield test before we go any further let always_yield = lua - .create_function(|lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }) + .create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6))) .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); @@ -392,13 +372,7 @@ fn test_continuation() { // Trigger the continuation let cont_func = lua .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, + |lua, a: u64| lua.yield_with(a), |_lua, _status, a: u64| { println!("Reached cont"); Ok(a + 39) @@ -430,10 +404,7 @@ fn test_continuation() { let always_yield = lua .create_function_with_continuation( - |lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }, + |lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)), |_lua, _, mv: mlua::MultiValue| { println!("Reached second continuation"); if mv.is_empty() { @@ -454,15 +425,9 @@ fn test_continuation() { let cont_func = lua .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args((a + 1, 1)) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } - Ok(()) - }, + |lua, a: u64| lua.yield_with((a + 1, 1)), |lua, status, args: mlua::MultiValue| { - println!("Reached cont recursive: {:?}", args); + println!("Reached cont recursive/multiple: {:?}", args); if args.len() == 5 { if cfg!(any(feature = "luau", feature = "lua52")) { @@ -470,11 +435,11 @@ fn test_continuation() { } else { assert_eq!(status, mlua::ContinuationStatus::Yielded); } - return 6_i32.into_lua_multi(lua); + return Ok(6_i32); } - lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) - (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored + lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 + unreachable!(); }, ) .expect("Failed to create cont_func"); @@ -520,13 +485,7 @@ fn test_continuation() { // test panics let cont_func = lua .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, + |lua, a: u64| lua.yield_with(a), |_lua, _status, _a: u64| { panic!("Reached continuation which should panic!"); #[allow(unreachable_code)] From 3c005d7a85e7ad4d6ed8d38c50046e7a8ee7e40d Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:27:30 -0400 Subject: [PATCH 27/46] rename set_yield_args to yield_with --- src/error.rs | 2 +- src/state.rs | 21 ++++++++++++++++++++- src/state/raw.rs | 4 +++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/error.rs b/src/error.rs index cb95494b..2f1030f8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -206,7 +206,7 @@ pub enum Error { /// Underlying error. cause: Arc, }, - /// A special error variant that tells Rust to yield to Lua with the specified args + /// A special error variant that tells Rust to yield to Lua with the specified arguments Yield(MultiValue), } diff --git a/src/state.rs b/src/state.rs index 7a1a1fba..1220fe0e 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2146,10 +2146,29 @@ impl Lua { &*self.raw.data_ptr() } - /// Helper method to set the yield arguments, returning a Error::Yield. + /// Helper method to set the yield arguments, returning a ``Error::Yield``. + /// + /// Internally, this method is equivalent to ``Err(Error::Yield(args.into_lua_multi(self)?))`` /// /// This method is mostly useful with continuations and Rust-Rust yields /// due to the Rust/Lua boundary + /// + /// Example: + /// + /// ```rust + /// fn test() -> mlua::Result<()> { + /// let lua = mlua::Lua::new(); + /// let always_yield = lua.create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)))?; + /// + /// let thread = lua.create_thread(always_yield)?; + /// assert_eq!( + /// thread.resume::<(i32, String, f32)>(())?, + /// (42, String::from("69420"), 45.6) + /// ); + /// + /// Ok(()) + /// } + /// ``` pub fn yield_with(&self, args: impl IntoLuaMulti) -> Result<()> { Err(Error::Yield(args.into_lua_multi(self)?)) } diff --git a/src/state/raw.rs b/src/state/raw.rs index 98f7e55f..26935a39 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,7 +12,9 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; +#[allow(unused_imports)] +use crate::state::util::callback_error_ext; +use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; From c7ddbe633c4096e3a73334352575c4b6cd0c9307 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:49:20 -0400 Subject: [PATCH 28/46] partially revert due to compiler error --- src/error.rs | 6 ------ src/state.rs | 10 ++++++---- src/state/extra.rs | 5 +++++ src/state/util.rs | 22 +++++++++++++--------- tests/thread.rs | 2 +- 5 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/error.rs b/src/error.rs index 2f1030f8..483171e5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,3 @@ -use crate::MultiValue; use std::error::Error as StdError; use std::fmt; use std::io::Error as IoError; @@ -206,8 +205,6 @@ pub enum Error { /// Underlying error. cause: Arc, }, - /// A special error variant that tells Rust to yield to Lua with the specified arguments - Yield(MultiValue), } /// A specialized `Result` type used by `mlua`'s API. @@ -325,9 +322,6 @@ impl fmt::Display for Error { writeln!(fmt, "{context}")?; write!(fmt, "{cause}") }, - Error::Yield(_) => { - write!(fmt, "attempt to yield within a context that does not support yielding") - } } } } diff --git a/src/state.rs b/src/state.rs index 1220fe0e..3685e399 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2146,9 +2146,7 @@ impl Lua { &*self.raw.data_ptr() } - /// Helper method to set the yield arguments, returning a ``Error::Yield``. - /// - /// Internally, this method is equivalent to ``Err(Error::Yield(args.into_lua_multi(self)?))`` + /// Set the yield arguments. Note that Lua will not yield until you return from the function /// /// This method is mostly useful with continuations and Rust-Rust yields /// due to the Rust/Lua boundary @@ -2170,7 +2168,11 @@ impl Lua { /// } /// ``` pub fn yield_with(&self, args: impl IntoLuaMulti) -> Result<()> { - Err(Error::Yield(args.into_lua_multi(self)?)) + let raw = self.lock(); + unsafe { + raw.extra.get().as_mut().unwrap_unchecked().yielded_values = Some(args.into_lua_multi(self)?); + } + Ok(()) } /// Checks if Lua could be allowed to yield. diff --git a/src/state/extra.rs b/src/state/extra.rs index 45f0e379..fbb21c66 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -19,6 +19,7 @@ use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, Wrapp #[cfg(any(feature = "luau", doc))] use crate::chunk::Compiler; +use crate::MultiValue; #[cfg(feature = "async")] use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; @@ -94,6 +95,9 @@ pub(crate) struct ExtraData { pub(super) compiler: Option, #[cfg(feature = "luau-jit")] pub(super) enable_jit: bool, + + // Values currently being yielded from Lua.yield() + pub(super) yielded_values: Option, } impl Drop for ExtraData { @@ -195,6 +199,7 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, + yielded_values: None, })); // Store it in the registry diff --git a/src/state/util.rs b/src/state/util.rs index f4cbde4f..67f9edaf 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,4 +1,5 @@ use crate::IntoLuaMulti; +use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; @@ -112,6 +113,8 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { + take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r @@ -159,13 +162,13 @@ where /// /// Unlike ``callback_error_ext``, this method requires a c_int return /// and not a generic R -#[allow(unused_variables)] pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, wrap_error: bool, f: F, - in_callback_with_continuation: bool, + #[cfg(feature = "luau")] _in_callback_with_continuation: bool, + #[cfg(not(feature = "luau"))] in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, @@ -186,14 +189,10 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - // Return unused `WrappedFailure` to the pool - prealloc_failure.release(state, extra); - r - } - Ok(Err(err)) => { - if let Error::Yield(values) = err { - let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + if let Some(values) = values { // A note on Luau // // When using the yieldable continuations fflag (and in future when the fflag gets removed and @@ -311,6 +310,11 @@ where } } + // Return unused `WrappedFailure` to the pool + prealloc_failure.release(state, extra); + r + } + Ok(Err(err)) => { let wrapped_error = prealloc_failure.r#use(state, extra); if !wrap_error { diff --git a/tests/thread.rs b/tests/thread.rs index a5655af2..9c8c1b65 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -439,7 +439,7 @@ fn test_continuation() { } lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 - unreachable!(); + Ok(1_i32) // this will be ignored }, ) .expect("Failed to create cont_func"); From b44075faff983a6022b8cba1037089b237c05b54 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:51:01 -0400 Subject: [PATCH 29/46] fix --- src/state/util.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index 67f9edaf..360b85d9 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -167,8 +167,7 @@ pub(crate) unsafe fn callback_error_ext_yieldable( mut extra: *mut ExtraData, wrap_error: bool, f: F, - #[cfg(feature = "luau")] _in_callback_with_continuation: bool, - #[cfg(not(feature = "luau"))] in_callback_with_continuation: bool, + #[allow(unused_variables)] in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, From a2e8694b5297453239a6a5f5d293a9c793fffc47 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:51:21 -0400 Subject: [PATCH 30/46] fix --- src/state/util.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index 360b85d9..d3fa1c30 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -113,8 +113,6 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r From 4a37dc5f0fabb37e4bbb719aa91518f46d1a0ee4 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 07:06:54 -0400 Subject: [PATCH 31/46] ensure wrappedfailure is returned to pool --- src/state.rs | 5 +---- src/state/util.rs | 34 ++++++++++++++++------------------ 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/state.rs b/src/state.rs index 3685e399..1df736e5 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2175,10 +2175,7 @@ impl Lua { Ok(()) } - /// Checks if Lua could be allowed to yield. - /// - /// Note that this method is not fool proof and is prone to false negatives - /// especially when continuations are involved + /// Checks if Lua is be allowed to yield. #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] #[inline] pub fn is_yieldable(&self) -> bool { diff --git a/src/state/util.rs b/src/state/util.rs index d3fa1c30..eade8870 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -186,6 +186,12 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { + // Return unused `WrappedFailure` to the pool + // + // In either case, we cannot use it in the yield case anyways due to the lua_pop call + // so drop it properly now while we can. + prealloc_failure.release(state, extra); + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); @@ -216,14 +222,11 @@ where // Even outside of luau, clearing the stack is probably desirable ffi::lua_pop(state, -1); if let Err(err) = check_stack(state, nargs) { - let wrapped_error = prealloc_failure.r#use(state, extra); - ptr::write( - wrapped_error, - WrappedFailure::Error(Error::external(err.to_string())), - ); - get_internal_metatable::(state); - ffi::lua_setmetatable(state, -2); - + // Unfortunately, we can't do a wrapped error here + // due to the lua_pop call, so just push a CString + // manually instead (todo: look into a better way here) + let cs = std::ffi::CString::new(err.to_string()).unwrap_unchecked(); + ffi::lua_pushstring(state, cs.as_ptr()); ffi::lua_error(state) } ffi::lua_xmove(raw.state(), state, nargs); @@ -294,21 +297,16 @@ where return ffi::lua_yield(state, nargs); } Err(err) => { - let wrapped_error = prealloc_failure.r#use(state, extra); - ptr::write( - wrapped_error, - WrappedFailure::Error(Error::external(err.to_string())), - ); - get_internal_metatable::(state); - ffi::lua_setmetatable(state, -2); - + // Unfortunately, we can't do a wrapped error here + // due to the above lua_pop call, so just push a CString + // manually instead (todo: look into a better way here) + let cs = std::ffi::CString::new(err.to_string()).unwrap_unchecked(); + ffi::lua_pushstring(state, cs.as_ptr()); ffi::lua_error(state) } } } - // Return unused `WrappedFailure` to the pool - prealloc_failure.release(state, extra); r } Ok(Err(err)) => { From a5c1ea1511b12a935160c9613c43e80d202fc0cd Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 07:35:31 -0400 Subject: [PATCH 32/46] fix wrap error case --- src/state/util.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index eade8870..69edf40f 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -222,12 +222,13 @@ where // Even outside of luau, clearing the stack is probably desirable ffi::lua_pop(state, -1); if let Err(err) = check_stack(state, nargs) { - // Unfortunately, we can't do a wrapped error here - // due to the lua_pop call, so just push a CString - // manually instead (todo: look into a better way here) - let cs = std::ffi::CString::new(err.to_string()).unwrap_unchecked(); - ffi::lua_pushstring(state, cs.as_ptr()); - ffi::lua_error(state) + // Make a *new* preallocated failure, and then do normal error + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state); } ffi::lua_xmove(raw.state(), state, nargs); } @@ -297,12 +298,13 @@ where return ffi::lua_yield(state, nargs); } Err(err) => { - // Unfortunately, we can't do a wrapped error here - // due to the above lua_pop call, so just push a CString - // manually instead (todo: look into a better way here) - let cs = std::ffi::CString::new(err.to_string()).unwrap_unchecked(); - ffi::lua_pushstring(state, cs.as_ptr()); - ffi::lua_error(state) + // Make a *new* preallocated failure, and then do normal wrap_error + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state); } } } From 541331ce32c352308f5712ec4807bd55633bae84 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 7 Jun 2025 09:05:59 -0400 Subject: [PATCH 33/46] begin adding auxthread code --- src/buffer.rs | 2 +- src/conversion.rs | 9 ++- src/function.rs | 13 +++- src/scope.rs | 6 +- src/state.rs | 21 +++++-- src/state/extra.rs | 77 ++++++++++++++--------- src/state/raw.rs | 136 ++++++++++++++++++++++++++++++----------- src/state/util.rs | 134 +++++++++++++++++++++++++++++++++------- src/string.rs | 2 +- src/table.rs | 17 +++--- src/thread.rs | 2 +- src/types/value_ref.rs | 20 +++++- src/userdata.rs | 26 ++++++-- src/value.rs | 2 +- tests/byte_string.rs | 7 +++ tests/tests.rs | 11 ++-- tests/thread.rs | 8 +++ 17 files changed, 368 insertions(+), 125 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 42b20088..10a943d3 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -62,7 +62,7 @@ impl Buffer { unsafe fn as_raw_parts(&self) -> (*mut u8, usize) { let lua = self.0.lua.lock(); let mut size = 0usize; - let buf = ffi::lua_tobuffer(lua.ref_thread(), self.0.index, &mut size); + let buf = ffi::lua_tobuffer(lua.ref_thread(self.0.aux_thread), self.0.index, &mut size); mlua_assert!(!buf.is_null(), "invalid Luau buffer"); (buf as *mut u8, size) } diff --git a/src/conversion.rs b/src/conversion.rs index b7dfa305..65f04982 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -12,6 +12,7 @@ use num_traits::cast; use crate::error::{Error, Result}; use crate::function::Function; +use crate::state::util::get_next_spot; use crate::state::{Lua, RawLua}; use crate::string::{BorrowedBytes, BorrowedStr, String}; use crate::table::Table; @@ -83,8 +84,12 @@ impl FromLua for String { let state = lua.state(); let type_id = ffi::lua_type(state, idx); if type_id == ffi::LUA_TSTRING { - ffi::lua_xpush(state, lua.ref_thread(), idx); - return Ok(String(lua.pop_ref_thread())); + let (aux_thread, idxs, replace) = get_next_spot(lua.extra()); + ffi::lua_xpush(state, lua.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(lua.ref_thread(aux_thread), idxs); + } + return Ok(String(lua.new_value_ref(aux_thread, idxs))); } // Fallback to default Self::from_lua(lua.stack_value(idx, Some(type_id)), lua.lua()) diff --git a/src/function.rs b/src/function.rs index 9a5b291b..4ec5c8fe 100644 --- a/src/function.rs +++ b/src/function.rs @@ -3,6 +3,7 @@ use std::os::raw::{c_int, c_void}; use std::{mem, ptr, slice}; use crate::error::{Error, Result}; +use crate::state::util::get_next_spot; use crate::state::Lua; use crate::table::Table; use crate::traits::{FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut}; @@ -494,14 +495,22 @@ impl Function { #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn deep_clone(&self) -> Self { let lua = self.0.lua.lock(); - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); unsafe { if ffi::lua_iscfunction(ref_thread, self.0.index) != 0 { return self.clone(); } ffi::lua_clonefunction(ref_thread, self.0.index); - Function(lua.pop_ref_thread()) + + // Get the real next spot + let (aux_thread, index, replace) = get_next_spot(lua.extra()); + ffi::lua_xpush(lua.ref_thread(self.0.aux_thread), lua.ref_thread(aux_thread), -1); + if replace { + ffi::lua_replace(lua.ref_thread(aux_thread), index); + } + + Function(lua.new_value_ref(aux_thread, index)) } } } diff --git a/src/scope.rs b/src/scope.rs index c56647a4..0ea61996 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -267,7 +267,7 @@ impl<'scope, 'env: 'scope> Scope<'scope, 'env> { let f = self.lua.create_callback(f)?; let destructor: DestructorCallback = Box::new(|rawlua, vref| { - let ref_thread = rawlua.ref_thread(); + let ref_thread = rawlua.ref_thread(vref.aux_thread); ffi::lua_getupvalue(ref_thread, vref.index, 1); let upvalue = get_userdata::(ref_thread, -1); let data = (*upvalue).data.take(); @@ -287,13 +287,13 @@ impl<'scope, 'env: 'scope> Scope<'scope, 'env> { Ok(Some(_)) => {} Ok(None) => { // Deregister metatable - let mt_ptr = get_metatable_ptr(rawlua.ref_thread(), vref.index); + let mt_ptr = get_metatable_ptr(rawlua.ref_thread(vref.aux_thread), vref.index); rawlua.deregister_userdata_metatable(mt_ptr); } Err(_) => return vec![], } - let data = take_userdata::>(rawlua.ref_thread(), vref.index); + let data = take_userdata::>(rawlua.ref_thread(vref.aux_thread), vref.index); vec![Box::new(move || drop(data))] }); self.destructors.0.borrow_mut().push((ud.0.clone(), destructor)); diff --git a/src/state.rs b/src/state.rs index 1df736e5..1619e73a 100644 --- a/src/state.rs +++ b/src/state.rs @@ -14,6 +14,7 @@ use crate::hook::Debug; use crate::memory::MemoryState; use crate::multi::MultiValue; use crate::scope::Scope; +use crate::state::util::get_next_spot; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; @@ -528,7 +529,7 @@ impl Lua { ffi::luaL_sandboxthread(state); } else { // Restore original `LUA_GLOBALSINDEX` - ffi::lua_xpush(lua.ref_thread(), state, ffi::LUA_GLOBALSINDEX); + ffi::lua_xpush(lua.ref_thread_internal(), state, ffi::LUA_GLOBALSINDEX); ffi::lua_replace(state, ffi::LUA_GLOBALSINDEX); ffi::luaL_sandbox(state, 0); } @@ -768,8 +769,12 @@ impl Lua { return; // Don't allow recursion } ffi::lua_pushthread(child); - ffi::lua_xmove(child, (*extra).ref_thread, 1); - let value = Thread((*extra).raw_lua().pop_ref_thread(), child); + let (aux_thread, index, replace) = get_next_spot(extra); + ffi::lua_xmove(child, (*extra).raw_lua().ref_thread(aux_thread), 1); + if replace { + ffi::lua_replace((*extra).raw_lua().ref_thread(aux_thread), index); + } + let value = Thread((*extra).raw_lua().new_value_ref(aux_thread, index), child); callback_error_ext(parent, extra, false, move |extra, _| { callback((*extra).lua(), value) }) @@ -1225,7 +1230,6 @@ impl Lua { ffi::lua_rawset(state, -3); } } - Ok(Table(lua.pop_ref())) } } @@ -1351,8 +1355,13 @@ impl Lua { /// This function is unsafe because provides a way to execute unsafe C function. pub unsafe fn create_c_function(&self, func: ffi::lua_CFunction) -> Result { let lua = self.lock(); - ffi::lua_pushcfunction(lua.ref_thread(), func); - Ok(Function(lua.pop_ref_thread())) + let (aux_thread, idx, replace) = get_next_spot(lua.extra()); + ffi::lua_pushcfunction(lua.ref_thread(aux_thread), func); + if replace { + ffi::lua_replace(lua.ref_thread(aux_thread), idx); + } + + Ok(Function(lua.new_value_ref(aux_thread, idx))) } /// Wraps a Rust async function or closure, creating a callable Lua function handle to it. diff --git a/src/state/extra.rs b/src/state/extra.rs index fbb21c66..150e7daf 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -32,6 +32,49 @@ static EXTRA_REGISTRY_KEY: u8 = 0; const WRAPPED_FAILURE_POOL_DEFAULT_CAPACITY: usize = 64; const REF_STACK_RESERVE: c_int = 2; +pub(crate) struct RefThread { + pub(super) ref_thread: *mut ffi::lua_State, + pub(super) stack_size: c_int, + pub(super) stack_top: c_int, + pub(super) free: Vec, +} + +impl RefThread { + #[inline(always)] + pub(crate) unsafe fn new(state: *mut ffi::lua_State) -> Self { + // Create ref stack thread and place it in the registry to prevent it + // from being garbage collected. + let ref_thread = mlua_expect!( + protect_lua!(state, 0, 0, |state| { + let thread = ffi::lua_newthread(state); + ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX); + thread + }), + "Error while creating ref thread", + ); + + // Store `error_traceback` function on the ref stack + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + { + ffi::lua_pushcfunction(ref_thread, crate::util::error_traceback); + assert_eq!(ffi::lua_gettop(ref_thread), ExtraData::ERROR_TRACEBACK_IDX); + } + + RefThread { + ref_thread, + // We need some reserved stack space to move values in and out of the ref stack. + stack_size: ffi::LUA_MINSTACK - REF_STACK_RESERVE, + stack_top: ffi::lua_gettop(ref_thread), + free: Vec::new(), + } + } + + #[inline(always)] + pub(crate) fn top(&self) -> c_int { + self.stack_top + } +} + /// Data associated with the Lua state. pub(crate) struct ExtraData { pub(super) lua: MaybeUninit, @@ -54,11 +97,10 @@ pub(crate) struct ExtraData { // Used in module mode pub(super) skip_memory_check: bool, - // Auxiliary thread to store references - pub(super) ref_thread: *mut ffi::lua_State, - pub(super) ref_stack_size: c_int, - pub(super) ref_stack_top: c_int, - pub(super) ref_free: Vec, + // Auxiliary threads to store references + pub(super) ref_thread: Vec, + // Special auxillary thread for mlua internal use + pub(super) ref_thread_internal: RefThread, // Pool of `WrappedFailure` enums in the ref thread (as userdata) pub(super) wrapped_failure_pool: Vec, @@ -128,17 +170,6 @@ impl ExtraData { pub(super) const ERROR_TRACEBACK_IDX: c_int = 1; pub(super) unsafe fn init(state: *mut ffi::lua_State, owned: bool) -> XRc> { - // Create ref stack thread and place it in the registry to prevent it - // from being garbage collected. - let ref_thread = mlua_expect!( - protect_lua!(state, 0, 0, |state| { - let thread = ffi::lua_newthread(state); - ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX); - thread - }), - "Error while creating ref thread", - ); - let wrapped_failure_mt_ptr = { get_internal_metatable::(state); let ptr = ffi::lua_topointer(state, -1); @@ -146,13 +177,6 @@ impl ExtraData { ptr }; - // Store `error_traceback` function on the ref stack - #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] - { - ffi::lua_pushcfunction(ref_thread, crate::util::error_traceback); - assert_eq!(ffi::lua_gettop(ref_thread), Self::ERROR_TRACEBACK_IDX); - } - #[allow(clippy::arc_with_non_send_sync)] let extra = XRc::new(UnsafeCell::new(ExtraData { lua: MaybeUninit::uninit(), @@ -167,11 +191,8 @@ impl ExtraData { safe: false, libs: StdLib::NONE, skip_memory_check: false, - ref_thread, - // We need some reserved stack space to move values in and out of the ref stack. - ref_stack_size: ffi::LUA_MINSTACK - REF_STACK_RESERVE, - ref_stack_top: ffi::lua_gettop(ref_thread), - ref_free: Vec::new(), + ref_thread: vec![RefThread::new(state)], + ref_thread_internal: RefThread::new(state), wrapped_failure_pool: Vec::with_capacity(WRAPPED_FAILURE_POOL_DEFAULT_CAPACITY), wrapped_failure_top: 0, #[cfg(feature = "async")] diff --git a/src/state/raw.rs b/src/state/raw.rs index 26935a39..35ad9812 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1,6 +1,7 @@ use std::any::TypeId; use std::cell::{Cell, UnsafeCell}; use std::ffi::CStr; +use std::io::Write; use std::mem; use std::os::raw::{c_char, c_int, c_void}; use std::panic::resume_unwind; @@ -14,7 +15,7 @@ use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; #[allow(unused_imports)] use crate::state::util::callback_error_ext; -use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; +use crate::state::util::{callback_error_ext_yieldable, get_next_spot}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; @@ -125,8 +126,24 @@ impl RawLua { } #[inline(always)] - pub(crate) fn ref_thread(&self) -> *mut ffi::lua_State { - unsafe { (*self.extra.get()).ref_thread } + pub(crate) fn ref_thread(&self, aux_thread: usize) -> *mut ffi::lua_State { + unsafe { + (*self.extra.get()) + .ref_thread + .get(aux_thread) + .unwrap_unchecked() + .ref_thread + } + } + + #[inline(always)] + pub(crate) fn ref_thread_internal(&self) -> *mut ffi::lua_State { + unsafe { (*self.extra.get()).ref_thread_internal.ref_thread } + } + + #[inline(always)] + pub(crate) fn extra(&self) -> *mut ExtraData { + self.extra.get() } pub(super) unsafe fn new(libs: StdLib, options: &LuaOptions) -> XRc> { @@ -599,7 +616,7 @@ impl RawLua { self.set_thread_hook(thread_state, HookKind::Global)?; let thread = Thread(self.pop_ref(), thread_state); - ffi::lua_xpush(self.ref_thread(), thread_state, func.0.index); + ffi::lua_xpush(self.ref_thread(thread.0.aux_thread), thread_state, func.0.index); Ok(thread) } @@ -607,8 +624,8 @@ impl RawLua { #[cfg(feature = "async")] pub(crate) unsafe fn create_recycled_thread(&self, func: &Function) -> Result { if let Some(index) = (*self.extra.get()).thread_pool.pop() { - let thread_state = ffi::lua_tothread(self.ref_thread(), index); - ffi::lua_xpush(self.ref_thread(), thread_state, func.0.index); + let thread_state = ffi::lua_tothread(self.ref_thread(func.0.aux_thread), index); + ffi::lua_xpush(self.ref_thread(func.0.aux_thread), thread_state, func.0.index); #[cfg(feature = "luau")] { @@ -689,6 +706,7 @@ impl RawLua { /// Uses 2 stack spaces, does not call checkstack. pub(crate) unsafe fn stack_value(&self, idx: c_int, type_hint: Option) -> Value { let state = self.state(); + println!("ABC"); match type_hint.unwrap_or_else(|| ffi::lua_type(state, idx)) { ffi::LUA_TNIL => Nil, @@ -727,18 +745,32 @@ impl RawLua { } ffi::LUA_TSTRING => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::String(String(self.pop_ref_thread())) + println!("STRING"); + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::String(String(self.new_value_ref(aux_thread, idxs))) } ffi::LUA_TTABLE => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::Table(Table(self.pop_ref_thread())) + println!("TABLE"); + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Table(Table(self.new_value_ref(aux_thread, idxs))) } ffi::LUA_TFUNCTION => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::Function(Function(self.pop_ref_thread())) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Function(Function(self.new_value_ref(aux_thread, idxs))) } ffi::LUA_TUSERDATA => { @@ -754,27 +786,44 @@ impl RawLua { Value::Nil } _ => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::UserData(AnyUserData(self.pop_ref_thread())) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + + Value::UserData(AnyUserData(self.new_value_ref(aux_thread, idxs))) } } } ffi::LUA_TTHREAD => { - ffi::lua_xpush(state, self.ref_thread(), idx); - let thread_state = ffi::lua_tothread(self.ref_thread(), -1); - Value::Thread(Thread(self.pop_ref_thread(), thread_state)) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + let thread_state = ffi::lua_tothread(self.ref_thread(aux_thread), -1); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Thread(Thread(self.new_value_ref(aux_thread, idxs), thread_state)) } #[cfg(feature = "luau")] ffi::LUA_TBUFFER => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::Buffer(crate::Buffer(self.pop_ref_thread())) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Buffer(crate::Buffer(self.new_value_ref(aux_thread, idxs))) } _ => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::Other(self.pop_ref_thread()) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Other(self.new_value_ref(aux_thread, idxs)) } } } @@ -786,7 +835,7 @@ impl RawLua { self.weak() == &vref.lua, "Lua instance passed Value created from a different main Lua state" ); - unsafe { ffi::lua_xpush(self.ref_thread(), self.state(), vref.index) }; + unsafe { ffi::lua_xpush(self.ref_thread(vref.aux_thread), self.state(), vref.index) }; } // Pops the topmost element of the stack and stores a reference to it. This pins the object, @@ -798,41 +847,54 @@ impl RawLua { // used stack. #[inline] pub(crate) unsafe fn pop_ref(&self) -> ValueRef { - ffi::lua_xmove(self.state(), self.ref_thread(), 1); - let index = ref_stack_pop(self.extra.get()); - ValueRef::new(self, index) + let (aux_thread, idx, replace) = get_next_spot(self.extra.get()); + ffi::lua_xmove(self.state(), self.ref_thread(aux_thread), 1); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idx); + } + + ValueRef::new(self, aux_thread, idx) } - // Same as `pop_ref` but assumes the value is already on the reference thread + // Given a known aux_thread and index, creates a ValueRef. #[inline] - pub(crate) unsafe fn pop_ref_thread(&self) -> ValueRef { - let index = ref_stack_pop(self.extra.get()); - ValueRef::new(self, index) + pub(crate) unsafe fn new_value_ref(&self, aux_thread: usize, index: c_int) -> ValueRef { + ValueRef::new(self, aux_thread, index) } #[inline] pub(crate) unsafe fn clone_ref(&self, vref: &ValueRef) -> ValueRef { - ffi::lua_pushvalue(self.ref_thread(), vref.index); - let index = ref_stack_pop(self.extra.get()); - ValueRef::new(self, index) + let (aux_thread, index, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush( + self.ref_thread(vref.aux_thread), + self.ref_thread(aux_thread), + vref.index, + ); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), index); + } + ValueRef::new(self, aux_thread, index) } pub(crate) unsafe fn drop_ref(&self, vref: &ValueRef) { - let ref_thread = self.ref_thread(); + let ref_thread = self.ref_thread(vref.aux_thread); mlua_debug_assert!( ffi::lua_gettop(ref_thread) >= vref.index, "GC finalizer is not allowed in ref_thread" ); + println!("Trying to dropref"); ffi::lua_pushnil(ref_thread); ffi::lua_replace(ref_thread, vref.index); - (*self.extra.get()).ref_free.push(vref.index); + (*self.extra.get()).ref_thread[vref.aux_thread] + .free + .push(vref.index); } #[inline] pub(crate) unsafe fn push_error_traceback(&self) { let state = self.state(); #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] - ffi::lua_xpush(self.ref_thread(), state, ExtraData::ERROR_TRACEBACK_IDX); + ffi::lua_xpush(self.ref_thread_internal(), state, ExtraData::ERROR_TRACEBACK_IDX); // Lua 5.2+ support light C functions that does not require extra allocations #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] ffi::lua_pushcfunction(state, crate::util::error_traceback); @@ -1104,7 +1166,7 @@ impl RawLua { // Returns `None` if the userdata is registered but non-static. #[inline(always)] pub(crate) fn get_userdata_ref_type_id(&self, vref: &ValueRef) -> Result> { - unsafe { self.get_userdata_type_id_inner(self.ref_thread(), vref.index) } + unsafe { self.get_userdata_type_id_inner(self.ref_thread(vref.aux_thread), vref.index) } } // Same as `get_userdata_ref_type_id` but assumes the userdata is already on the stack. @@ -1157,7 +1219,7 @@ impl RawLua { // Pushes a ValueRef (userdata) value onto the stack, returning their `TypeId`. // Uses 1 stack space, does not call checkstack. pub(crate) unsafe fn push_userdata_ref(&self, vref: &ValueRef) -> Result> { - let type_id = self.get_userdata_type_id_inner(self.ref_thread(), vref.index)?; + let type_id = self.get_userdata_type_id_inner(self.ref_thread(vref.aux_thread), vref.index)?; self.push_ref(vref); Ok(type_id) } diff --git a/src/state/util.rs b/src/state/util.rs index 69edf40f..d5984add 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -6,6 +6,7 @@ use std::ptr; use std::sync::Arc; use crate::error::{Error, Result}; +use crate::state::extra::RefThread; use crate::state::{ExtraData, RawLua}; use crate::util::{self, check_stack, get_internal_metatable, WrappedFailure}; @@ -51,7 +52,7 @@ impl PreallocatedFailure { #[cold] unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { - let ref_thread = (*extra).ref_thread; + let ref_thread = &(*extra).ref_thread_internal; match *self { PreallocatedFailure::New(ud) => { ffi::lua_settop(state, 1); @@ -62,22 +63,22 @@ impl PreallocatedFailure { ffi::lua_settop(state, 0); #[cfg(feature = "luau")] ffi::lua_rawcheckstack(state, 2); - ffi::lua_xpush(ref_thread, state, index); - ffi::lua_pushnil(ref_thread); - ffi::lua_replace(ref_thread, index); - (*extra).ref_free.push(index); + ffi::lua_xpush(ref_thread.ref_thread, state, index); + ffi::lua_pushnil(ref_thread.ref_thread); + ffi::lua_replace(ref_thread.ref_thread, index); + (*extra).ref_thread_internal.free.push(index); ffi::lua_touserdata(state, -1) as *mut WrappedFailure } } } unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { - let ref_thread = (*extra).ref_thread; + let ref_thread = &(*extra).ref_thread_internal; match self { PreallocatedFailure::New(_) => { ffi::lua_rotate(state, 1, -1); - ffi::lua_xmove(state, ref_thread, 1); - let index = ref_stack_pop(extra); + ffi::lua_xmove(state, ref_thread.ref_thread, 1); + let index = ref_stack_pop_internal(extra); (*extra).wrapped_failure_pool.push(index); (*extra).wrapped_failure_top += 1; } @@ -350,30 +351,121 @@ where } } -pub(super) unsafe fn ref_stack_pop(extra: *mut ExtraData) -> c_int { +pub(super) unsafe fn ref_stack_pop_internal(extra: *mut ExtraData) -> c_int { let extra = &mut *extra; - if let Some(free) = extra.ref_free.pop() { - ffi::lua_replace(extra.ref_thread, free); + let ref_th = &mut extra.ref_thread_internal; + + if let Some(free) = ref_th.free.pop() { + ffi::lua_replace(ref_th.ref_thread, free); return free; } // Try to grow max stack size - if extra.ref_stack_top >= extra.ref_stack_size { - let mut inc = extra.ref_stack_size; // Try to double stack size - while inc > 0 && ffi::lua_checkstack(extra.ref_thread, inc) == 0 { + if ref_th.stack_top >= ref_th.stack_size { + let mut inc = ref_th.stack_size; // Try to double stack size + while inc > 0 && ffi::lua_checkstack(ref_th.ref_thread, inc) == 0 { inc /= 2; } if inc == 0 { // Pop item on top of the stack to avoid stack leaking and successfully run destructors // during unwinding. - ffi::lua_pop(extra.ref_thread, 1); - let top = extra.ref_stack_top; + ffi::lua_pop(ref_th.ref_thread, 1); + let top = ref_th.stack_top; // It is a user error to create enough references to exhaust the Lua max stack size for - // the ref thread. - panic!("cannot create a Lua reference, out of auxiliary stack space (used {top} slots)"); + // the ref thread. This should never happen for the internal aux thread but still + panic!("internal error: cannot create a Lua reference, out of internal auxiliary stack space (used {top} slots)"); } - extra.ref_stack_size += inc; + ref_th.stack_size += inc; + } + ref_th.stack_top += 1; + return ref_th.stack_top; +} + +// Run a comparison function on two Lua references from different auxiliary threads. +pub(crate) unsafe fn compare_refs( + extra: *mut ExtraData, + aux_thread_a: usize, + aux_thread_a_index: c_int, + aux_thread_b: usize, + aux_thread_b_index: c_int, + f: impl FnOnce(*mut ffi::lua_State, c_int, c_int) -> R, +) -> R { + let extra = &mut *extra; + + if aux_thread_a == aux_thread_b { + // If both threads are the same, just return the value at the index + let th = &mut extra.ref_thread[aux_thread_a]; + return f(th.ref_thread, aux_thread_a_index, aux_thread_b_index); } - extra.ref_stack_top += 1; - extra.ref_stack_top + + let th_a = &extra.ref_thread[aux_thread_a]; + let th_b = &extra.ref_thread[aux_thread_b]; + let internal_thread = &mut extra.ref_thread_internal; + + // 4 spaces needed: 1st element on A, idx element on A, 1st element on B, idx element on B + check_stack(internal_thread.ref_thread, 4) + .expect("internal error: cannot merge references, out of internal auxiliary stack space"); + + panic!("Unsupported"); + + // Push the first element from thread A to ensure we have enough stack space on thread A + ffi::lua_xmove(th_a.ref_thread, internal_thread.ref_thread, 1); + // Push the first element from thread B to ensure we have enough stack space on thread B + ffi::lua_xmove(th_b.ref_thread, internal_thread.ref_thread, 1); + // Push the index element from thread A to top + ffi::lua_pushvalue(th_a.ref_thread, aux_thread_a_index); + ffi::lua_xmove(th_a.ref_thread, internal_thread.ref_thread, 1); + // Push the index element from thread B to top + ffi::lua_pushvalue(th_b.ref_thread, aux_thread_b_index); + ffi::lua_xmove(th_b.ref_thread, internal_thread.ref_thread, 1); + // Now we have the following stack: + // - 1st element from thread A (4) + // - 1st element from thread B (3) + // - index element from thread A (2) [copy from pushvalue] + // - index element from thread B (1) [copy from pushvalue] + // We want to compare the index elements from both threads, so use 3 and 4 as indices + let result = f(internal_thread.ref_thread, 2, 1); + + // Pop the top 2 elements to clean the copies + ffi::lua_pop(internal_thread.ref_thread, 2); + // Move the first element from thread B back to thread B + ffi::lua_xmove(internal_thread.ref_thread, th_b.ref_thread, 1); + // Move the first element from thread A back to thread A + ffi::lua_xmove(internal_thread.ref_thread, th_a.ref_thread, 1); + + result +} + +pub(crate) unsafe fn get_next_spot(extra: *mut ExtraData) -> (usize, c_int, bool) { + let extra = &mut *extra; + + // Find the first thread with a free slot + for (i, ref_th) in extra.ref_thread.iter_mut().enumerate() { + if let Some(free) = ref_th.free.pop() { + println!("{} {} {}", i, free, true); + return (i, free, true); + } + + // Try to grow max stack size + if ref_th.stack_top >= ref_th.stack_size { + let mut inc = ref_th.stack_size; // Try to double stack size + while inc > 0 && ffi::lua_checkstack(ref_th.ref_thread, inc + 1) == 0 { + inc /= 2; + } + if inc == 0 { + continue; // No stack space available, try next thread + } + ref_th.stack_size += inc; + } + + ref_th.stack_top += 1; + println!("{} {} {}", i, ref_th.stack_top, false); + return (i, ref_th.stack_top, false); + } + + // No free slots found, create a new one + println!("No free slots found, creating a new ref thread"); + let new_ref_thread = RefThread::new(extra.raw_lua().state()); + extra.ref_thread.push(new_ref_thread); + return get_next_spot(extra); } diff --git a/src/string.rs b/src/string.rs index 9c86102b..f5b5ae1b 100644 --- a/src/string.rs +++ b/src/string.rs @@ -119,7 +119,7 @@ impl String { let lua = self.0.lua.upgrade(); let slice = { let rawlua = lua.lock(); - let ref_thread = rawlua.ref_thread(); + let ref_thread = rawlua.ref_thread(self.0.aux_thread); mlua_debug_assert!( ffi::lua_type(ref_thread, self.0.index) == ffi::LUA_TSTRING, diff --git a/src/table.rs b/src/table.rs index 92228683..9cb8cd4e 100644 --- a/src/table.rs +++ b/src/table.rs @@ -256,6 +256,7 @@ impl Table { value.push_into_stack(&lua)?; if lua.unlikely_memory_error() { + println!("Typeoftable: {}", ffi::lua_type(state, -3)); ffi::lua_rawset(state, -3); ffi::lua_pop(state, 1); Ok(()) @@ -406,7 +407,7 @@ impl Table { #[cfg(feature = "luau")] { self.check_readonly_write(&lua)?; - ffi::lua_cleartable(lua.ref_thread(), self.0.index); + ffi::lua_cleartable(lua.ref_thread(self.0.aux_thread), self.0.index); } #[cfg(not(feature = "luau"))] @@ -461,7 +462,7 @@ impl Table { /// Returns the result of the Lua `#` operator, without invoking the `__len` metamethod. pub fn raw_len(&self) -> usize { let lua = self.0.lua.lock(); - unsafe { ffi::lua_rawlen(lua.ref_thread(), self.0.index) } + unsafe { ffi::lua_rawlen(lua.ref_thread(self.0.aux_thread), self.0.index) } } /// Returns `true` if the table is empty, without invoking metamethods. @@ -469,7 +470,7 @@ impl Table { /// It checks both the array part and the hash part. pub fn is_empty(&self) -> bool { let lua = self.0.lua.lock(); - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); unsafe { ffi::lua_pushnil(ref_thread); if ffi::lua_next(ref_thread, self.0.index) == 0 { @@ -533,7 +534,7 @@ impl Table { #[inline] pub fn has_metatable(&self) -> bool { let lua = self.0.lua.lock(); - unsafe { !get_metatable_ptr(lua.ref_thread(), self.0.index).is_null() } + unsafe { !get_metatable_ptr(lua.ref_thread(self.0.aux_thread), self.0.index).is_null() } } /// Sets `readonly` attribute on the table. @@ -541,7 +542,7 @@ impl Table { #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn set_readonly(&self, enabled: bool) { let lua = self.0.lua.lock(); - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); unsafe { ffi::lua_setreadonly(ref_thread, self.0.index, enabled as _); if !enabled { @@ -556,7 +557,7 @@ impl Table { #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn is_readonly(&self) -> bool { let lua = self.0.lua.lock(); - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); unsafe { ffi::lua_getreadonly(ref_thread, self.0.index) != 0 } } @@ -573,7 +574,7 @@ impl Table { #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn set_safeenv(&self, enabled: bool) { let lua = self.0.lua.lock(); - unsafe { ffi::lua_setsafeenv(lua.ref_thread(), self.0.index, enabled as _) }; + unsafe { ffi::lua_setsafeenv(lua.ref_thread(self.0.aux_thread), self.0.index, enabled as _) }; } /// Converts this table to a generic C pointer. @@ -755,7 +756,7 @@ impl Table { #[cfg(feature = "luau")] #[inline(always)] fn check_readonly_write(&self, lua: &RawLua) -> Result<()> { - if unsafe { ffi::lua_getreadonly(lua.ref_thread(), self.0.index) != 0 } { + if unsafe { ffi::lua_getreadonly(lua.ref_thread(self.0.aux_thread), self.0.index) != 0 } { return Err(Error::runtime("attempt to modify a readonly table")); } Ok(()) diff --git a/src/thread.rs b/src/thread.rs index ce096f06..b6db85f6 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -330,7 +330,7 @@ impl Thread { self.reset_inner(status)?; // Push function to the top of the thread stack - ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index); + ffi::lua_xpush(lua.ref_thread(func.0.aux_thread), thread_state, func.0.index); #[cfg(feature = "luau")] { diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs index 89bac543..fd9027e4 100644 --- a/src/types/value_ref.rs +++ b/src/types/value_ref.rs @@ -1,20 +1,23 @@ use std::fmt; use std::os::raw::{c_int, c_void}; +use crate::state::util::compare_refs; use crate::state::{RawLua, WeakLua}; /// A reference to a Lua (complex) value stored in the Lua auxiliary thread. pub struct ValueRef { pub(crate) lua: WeakLua, + pub(crate) aux_thread: usize, pub(crate) index: c_int, pub(crate) drop: bool, } impl ValueRef { #[inline] - pub(crate) fn new(lua: &RawLua, index: c_int) -> Self { + pub(crate) fn new(lua: &RawLua, aux_thread: usize, index: c_int) -> Self { ValueRef { lua: lua.weak().clone(), + aux_thread, index, drop: true, } @@ -23,7 +26,7 @@ impl ValueRef { #[inline] pub(crate) fn to_pointer(&self) -> *const c_void { let lua = self.lua.lock(); - unsafe { ffi::lua_topointer(lua.ref_thread(), self.index) } + unsafe { ffi::lua_topointer(lua.ref_thread(self.aux_thread), self.index) } } /// Returns a copy of the value, which is valid as long as the original value is held. @@ -31,6 +34,7 @@ impl ValueRef { pub(crate) fn copy(&self) -> Self { ValueRef { lua: self.lua.clone(), + aux_thread: self.aux_thread, index: self.index, drop: false, } @@ -66,6 +70,16 @@ impl PartialEq for ValueRef { "Lua instance passed Value created from a different main Lua state" ); let lua = self.lua.lock(); - unsafe { ffi::lua_rawequal(lua.ref_thread(), self.index, other.index) == 1 } + + unsafe { + compare_refs( + lua.extra(), + self.aux_thread, + self.index, + other.aux_thread, + other.index, + |state, a, b| ffi::lua_rawequal(state, a, b) == 1, + ) + } } } diff --git a/src/userdata.rs b/src/userdata.rs index 9bdfbd81..f102acb8 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -634,7 +634,7 @@ impl AnyUserData { #[inline] pub fn borrow(&self) -> Result> { let lua = self.0.lua.lock(); - unsafe { UserDataRef::borrow_from_stack(&lua, lua.ref_thread(), self.0.index) } + unsafe { UserDataRef::borrow_from_stack(&lua, lua.ref_thread(self.0.aux_thread), self.0.index) } } /// Borrow this userdata immutably if it is of type `T`, passing the borrowed value @@ -645,7 +645,15 @@ impl AnyUserData { let lua = self.0.lua.lock(); let type_id = lua.get_userdata_ref_type_id(&self.0)?; let type_hints = TypeIdHints::new::(); - unsafe { borrow_userdata_scoped(lua.ref_thread(), self.0.index, type_id, type_hints, f) } + unsafe { + borrow_userdata_scoped( + lua.ref_thread(self.0.aux_thread), + self.0.index, + type_id, + type_hints, + f, + ) + } } /// Borrow this userdata mutably if it is of type `T`. @@ -661,7 +669,7 @@ impl AnyUserData { #[inline] pub fn borrow_mut(&self) -> Result> { let lua = self.0.lua.lock(); - unsafe { UserDataRefMut::borrow_from_stack(&lua, lua.ref_thread(), self.0.index) } + unsafe { UserDataRefMut::borrow_from_stack(&lua, lua.ref_thread(self.0.aux_thread), self.0.index) } } /// Borrow this userdata mutably if it is of type `T`, passing the borrowed value @@ -672,7 +680,15 @@ impl AnyUserData { let lua = self.0.lua.lock(); let type_id = lua.get_userdata_ref_type_id(&self.0)?; let type_hints = TypeIdHints::new::(); - unsafe { borrow_userdata_scoped_mut(lua.ref_thread(), self.0.index, type_id, type_hints, f) } + unsafe { + borrow_userdata_scoped_mut( + lua.ref_thread(self.0.aux_thread), + self.0.index, + type_id, + type_hints, + f, + ) + } } /// Takes the value out of this userdata. @@ -685,7 +701,7 @@ impl AnyUserData { let lua = self.0.lua.lock(); match lua.get_userdata_ref_type_id(&self.0)? { Some(type_id) if type_id == TypeId::of::() => unsafe { - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); if (*get_userdata::>(ref_thread, self.0.index)).has_exclusive_access() { take_userdata::>(ref_thread, self.0.index).into_inner() } else { diff --git a/src/value.rs b/src/value.rs index 119f147e..db84c328 100644 --- a/src/value.rs +++ b/src/value.rs @@ -132,7 +132,7 @@ impl Value { // In Lua < 5.4 (excluding Luau), string pointers are NULL // Use alternative approach let lua = vref.lua.lock(); - unsafe { ffi::lua_tostring(lua.ref_thread(), vref.index) as *const c_void } + unsafe { ffi::lua_tostring(lua.ref_thread(vref.aux_thread), vref.index) as *const c_void } } Value::LightUserData(ud) => ud.0, Value::Table(Table(vref)) diff --git a/tests/byte_string.rs b/tests/byte_string.rs index 76e43e14..4768a475 100644 --- a/tests/byte_string.rs +++ b/tests/byte_string.rs @@ -2,6 +2,13 @@ use bstr::{BStr, BString}; use mlua::{Lua, Result}; #[test] +fn create_lua() { + let lua = Lua::new(); + let th = lua.create_table().unwrap(); + println!("{th:#?}"); +} + +//#[test] fn test_byte_string_round_trip() -> Result<()> { let lua = Lua::new(); diff --git a/tests/tests.rs b/tests/tests.rs index 7bf40af5..a6ba169c 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1016,16 +1016,14 @@ fn test_ref_stack_exhaustion() { match catch_unwind(AssertUnwindSafe(|| -> Result<()> { let lua = Lua::new(); let mut vals = Vec::new(); - for _ in 0..10000000 { + for _ in 0..200000 { + println!("Creating table {}", vals.len()); vals.push(lua.create_table()?); } Ok(()) })) { - Ok(_) => panic!("no panic was detected"), - Err(p) => assert!(p - .downcast::() - .unwrap() - .starts_with("cannot create a Lua reference, out of auxiliary stack space")), + Ok(_) => {} + Err(p) => panic!("got panic: {:?}", p), } } @@ -1490,6 +1488,7 @@ fn test_gc_drop_ref_thread() -> Result<()> { #[cfg(not(feature = "luau"))] #[test] fn test_get_or_init_from_ptr() -> Result<()> { + println!("ABC"); // This would not work with Luau, the state must be init by mlua internally let state = unsafe { ffi::luaL_newstate() }; diff --git a/tests/thread.rs b/tests/thread.rs index 9c8c1b65..69957542 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -516,4 +516,12 @@ fn test_continuation() { let v = th.resume::(v).expect("Failed to load continuation"); assert!(v.contains("Reached continuation which should panic!")); + + let mut ths = Vec::new(); + for i in 1..1000000 { + let th = lua + .create_thread(lua.create_function(|_, ()| Ok(())).unwrap()) + .expect("Failed to create thread"); + ths.push(th); + } } From cfc942d1ccab339f557ee433b5c0a967b61a7320 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 7 Jun 2025 10:14:54 -0400 Subject: [PATCH 34/46] fix more places --- src/serde/mod.rs | 11 +++++++++-- src/state/extra.rs | 2 +- src/state/raw.rs | 8 ++++---- src/state/util.rs | 2 -- src/userdata.rs | 4 ++-- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/serde/mod.rs b/src/serde/mod.rs index f4752145..d491953d 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -7,6 +7,7 @@ use serde::ser::Serialize; use crate::error::Result; use crate::private::Sealed; +use crate::state::util::get_next_spot; use crate::state::Lua; use crate::table::Table; use crate::util::check_stack; @@ -183,8 +184,14 @@ impl LuaSerdeExt for Lua { fn array_metatable(&self) -> Table { let lua = self.lock(); unsafe { - push_array_metatable(lua.ref_thread()); - Table(lua.pop_ref_thread()) + let (aux_thread, index, replace) = get_next_spot(lua.extra()); + push_array_metatable(lua.state()); + ffi::lua_xmove(lua.state(), lua.ref_thread(aux_thread), 1); + if replace { + ffi::lua_replace(lua.ref_thread(aux_thread), index); + } + + Table(lua.new_value_ref(aux_thread, index)) } } diff --git a/src/state/extra.rs b/src/state/extra.rs index 150e7daf..df3ebd85 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -107,7 +107,7 @@ pub(crate) struct ExtraData { pub(super) wrapped_failure_top: usize, // Pool of `Thread`s (coroutines) for async execution #[cfg(feature = "async")] - pub(super) thread_pool: Vec, + pub(super) thread_pool: Vec<(usize, c_int)>, // Address of `WrappedFailure` metatable pub(super) wrapped_failure_mt_ptr: *const c_void, diff --git a/src/state/raw.rs b/src/state/raw.rs index 35ad9812..67313979 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -623,8 +623,8 @@ impl RawLua { /// Wraps a Lua function into a new or recycled thread (coroutine). #[cfg(feature = "async")] pub(crate) unsafe fn create_recycled_thread(&self, func: &Function) -> Result { - if let Some(index) = (*self.extra.get()).thread_pool.pop() { - let thread_state = ffi::lua_tothread(self.ref_thread(func.0.aux_thread), index); + if let Some((aux_thread, index)) = (*self.extra.get()).thread_pool.pop() { + let thread_state = ffi::lua_tothread(self.ref_thread(aux_thread), index); ffi::lua_xpush(self.ref_thread(func.0.aux_thread), thread_state, func.0.index); #[cfg(feature = "luau")] @@ -634,7 +634,7 @@ impl RawLua { ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX); } - return Ok(Thread(ValueRef::new(self, index), thread_state)); + return Ok(Thread(ValueRef::new(self, aux_thread, index), thread_state)); } self.create_thread(func) @@ -645,7 +645,7 @@ impl RawLua { pub(crate) unsafe fn recycle_thread(&self, thread: &mut Thread) { let extra = &mut *self.extra.get(); if extra.thread_pool.len() < extra.thread_pool.capacity() { - extra.thread_pool.push(thread.0.index); + extra.thread_pool.push((thread.0.aux_thread, thread.0.index)); thread.0.drop = false; // Prevent thread from being garbage collected } } diff --git a/src/state/util.rs b/src/state/util.rs index d5984add..ee2376d4 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -406,8 +406,6 @@ pub(crate) unsafe fn compare_refs( check_stack(internal_thread.ref_thread, 4) .expect("internal error: cannot merge references, out of internal auxiliary stack space"); - panic!("Unsupported"); - // Push the first element from thread A to ensure we have enough stack space on thread A ffi::lua_xmove(th_a.ref_thread, internal_thread.ref_thread, 1); // Push the first element from thread B to ensure we have enough stack space on thread B diff --git a/src/userdata.rs b/src/userdata.rs index f102acb8..5aed0e57 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -979,7 +979,7 @@ impl AnyUserData { let is_serializable = || unsafe { // Userdata must be registered and not destructed let _ = lua.get_userdata_ref_type_id(&self.0)?; - let ud = &*get_userdata::>(lua.ref_thread(), self.0.index); + let ud = &*get_userdata::>(lua.ref_thread(self.0.aux_thread), self.0.index); Ok::<_, Error>((*ud).is_serializable()) }; is_serializable().unwrap_or(false) @@ -1068,7 +1068,7 @@ impl Serialize for AnyUserData { let _ = lua .get_userdata_ref_type_id(&self.0) .map_err(ser::Error::custom)?; - let ud = &*get_userdata::>(lua.ref_thread(), self.0.index); + let ud = &*get_userdata::>(lua.ref_thread(self.0.aux_thread), self.0.index); ud.serialize(serializer) } } From 4174bdd445cefe50ba2adb9e4ea41ae3e174eb10 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 7 Jun 2025 10:27:42 -0400 Subject: [PATCH 35/46] remove useless print calls --- src/state/raw.rs | 4 ---- src/state/util.rs | 3 --- src/table.rs | 1 - src/util/mod.rs | 1 - tests/tests.rs | 1 - 5 files changed, 10 deletions(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index 67313979..4917226e 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -706,7 +706,6 @@ impl RawLua { /// Uses 2 stack spaces, does not call checkstack. pub(crate) unsafe fn stack_value(&self, idx: c_int, type_hint: Option) -> Value { let state = self.state(); - println!("ABC"); match type_hint.unwrap_or_else(|| ffi::lua_type(state, idx)) { ffi::LUA_TNIL => Nil, @@ -745,7 +744,6 @@ impl RawLua { } ffi::LUA_TSTRING => { - println!("STRING"); let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); if replace { @@ -755,7 +753,6 @@ impl RawLua { } ffi::LUA_TTABLE => { - println!("TABLE"); let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); if replace { @@ -882,7 +879,6 @@ impl RawLua { ffi::lua_gettop(ref_thread) >= vref.index, "GC finalizer is not allowed in ref_thread" ); - println!("Trying to dropref"); ffi::lua_pushnil(ref_thread); ffi::lua_replace(ref_thread, vref.index); (*self.extra.get()).ref_thread[vref.aux_thread] diff --git a/src/state/util.rs b/src/state/util.rs index ee2376d4..1f75d4b7 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -440,7 +440,6 @@ pub(crate) unsafe fn get_next_spot(extra: *mut ExtraData) -> (usize, c_int, bool // Find the first thread with a free slot for (i, ref_th) in extra.ref_thread.iter_mut().enumerate() { if let Some(free) = ref_th.free.pop() { - println!("{} {} {}", i, free, true); return (i, free, true); } @@ -457,12 +456,10 @@ pub(crate) unsafe fn get_next_spot(extra: *mut ExtraData) -> (usize, c_int, bool } ref_th.stack_top += 1; - println!("{} {} {}", i, ref_th.stack_top, false); return (i, ref_th.stack_top, false); } // No free slots found, create a new one - println!("No free slots found, creating a new ref thread"); let new_ref_thread = RefThread::new(extra.raw_lua().state()); extra.ref_thread.push(new_ref_thread); return get_next_spot(extra); diff --git a/src/table.rs b/src/table.rs index 9cb8cd4e..3bf1062b 100644 --- a/src/table.rs +++ b/src/table.rs @@ -256,7 +256,6 @@ impl Table { value.push_into_stack(&lua)?; if lua.unlikely_memory_error() { - println!("Typeoftable: {}", ffi::lua_type(state, -3)); ffi::lua_rawset(state, -3); ffi::lua_pop(state, 1); Ok(()) diff --git a/src/util/mod.rs b/src/util/mod.rs index 5dd6b19c..f5fbae52 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -75,7 +75,6 @@ impl Drop for StackGuard { unsafe { let top = ffi::lua_gettop(self.state); if top < self.top { - println!("top={}, self.top={}", top, self.top); mlua_panic!("{} too many stack values popped", self.top - top) } if top > self.top { diff --git a/tests/tests.rs b/tests/tests.rs index a6ba169c..4d6cc805 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1488,7 +1488,6 @@ fn test_gc_drop_ref_thread() -> Result<()> { #[cfg(not(feature = "luau"))] #[test] fn test_get_or_init_from_ptr() -> Result<()> { - println!("ABC"); // This would not work with Luau, the state must be init by mlua internally let state = unsafe { ffi::luaL_newstate() }; From 42c8d160ed9036045f14140a949d6cd6ff07280f Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 7 Jun 2025 11:00:40 -0400 Subject: [PATCH 36/46] ensure compare_refs use right indices --- src/state/util.rs | 16 +++++++--------- tests/thread.rs | 18 +++++++++++++++++- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index 1f75d4b7..ec3ed427 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -411,18 +411,16 @@ pub(crate) unsafe fn compare_refs( // Push the first element from thread B to ensure we have enough stack space on thread B ffi::lua_xmove(th_b.ref_thread, internal_thread.ref_thread, 1); // Push the index element from thread A to top - ffi::lua_pushvalue(th_a.ref_thread, aux_thread_a_index); - ffi::lua_xmove(th_a.ref_thread, internal_thread.ref_thread, 1); + ffi::lua_xpush(th_a.ref_thread, internal_thread.ref_thread, aux_thread_a_index); // Push the index element from thread B to top - ffi::lua_pushvalue(th_b.ref_thread, aux_thread_b_index); - ffi::lua_xmove(th_b.ref_thread, internal_thread.ref_thread, 1); + ffi::lua_xpush(th_b.ref_thread, internal_thread.ref_thread, aux_thread_b_index); // Now we have the following stack: - // - 1st element from thread A (4) - // - 1st element from thread B (3) - // - index element from thread A (2) [copy from pushvalue] - // - index element from thread B (1) [copy from pushvalue] + // - index element from thread A (1) [copy from pushvalue] + // - index element from thread B (2) [copy from pushvalue] + // - 1st element from thread A (3) + // - 1st element from thread B (4) // We want to compare the index elements from both threads, so use 3 and 4 as indices - let result = f(internal_thread.ref_thread, 2, 1); + let result = f(internal_thread.ref_thread, 3, 4); // Pop the top 2 elements to clean the copies ffi::lua_pop(internal_thread.ref_thread, 2); diff --git a/tests/thread.rs b/tests/thread.rs index 69957542..baddc9a8 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -517,11 +517,27 @@ fn test_continuation() { let v = th.resume::(v).expect("Failed to load continuation"); assert!(v.contains("Reached continuation which should panic!")); + let th1 = lua + .create_thread(lua.create_function(|lua, _: ()| Ok(())).unwrap()) + .unwrap(); let mut ths = Vec::new(); - for i in 1..1000000 { + for i in 1..2000000 { let th = lua .create_thread(lua.create_function(|_, ()| Ok(())).unwrap()) .expect("Failed to create thread"); ths.push(th); } + let th2 = lua + .create_thread(lua.create_function(|lua, _: ()| Ok(())).unwrap()) + .unwrap(); + + for rth in ths { + assert!( + th1 != rth && th2 != rth, + "Thread {:?} is equal to th1 ({:?}) or th2 ({:?})", + rth, + th1, + th2 + ); + } } From a43b0db10b60db0317eb0e284b57e355b9df7421 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 7 Jun 2025 11:05:45 -0400 Subject: [PATCH 37/46] fix indices again --- src/state/util.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index ec3ed427..0088ea25 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -424,10 +424,10 @@ pub(crate) unsafe fn compare_refs( // Pop the top 2 elements to clean the copies ffi::lua_pop(internal_thread.ref_thread, 2); - // Move the first element from thread B back to thread B - ffi::lua_xmove(internal_thread.ref_thread, th_b.ref_thread, 1); // Move the first element from thread A back to thread A ffi::lua_xmove(internal_thread.ref_thread, th_a.ref_thread, 1); + // Move the first element from thread B back to thread B + ffi::lua_xmove(internal_thread.ref_thread, th_b.ref_thread, 1); result } From dcd87c30228a5dc83320f67ee61f2138c78af2ec Mon Sep 17 00:00:00 2001 From: Rootspring Date: Sat, 7 Jun 2025 20:59:02 +0530 Subject: [PATCH 38/46] Update util.rs --- src/state/util.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index 0088ea25..ec3ed427 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -424,10 +424,10 @@ pub(crate) unsafe fn compare_refs( // Pop the top 2 elements to clean the copies ffi::lua_pop(internal_thread.ref_thread, 2); - // Move the first element from thread A back to thread A - ffi::lua_xmove(internal_thread.ref_thread, th_a.ref_thread, 1); // Move the first element from thread B back to thread B ffi::lua_xmove(internal_thread.ref_thread, th_b.ref_thread, 1); + // Move the first element from thread A back to thread A + ffi::lua_xmove(internal_thread.ref_thread, th_a.ref_thread, 1); result } From 5d0a2d952a5d38062be4fb6833590ac4cc2a7a86 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 7 Jun 2025 23:08:14 -0400 Subject: [PATCH 39/46] fix --- src/state/util.rs | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index ec3ed427..a82b931e 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -402,14 +402,10 @@ pub(crate) unsafe fn compare_refs( let th_b = &extra.ref_thread[aux_thread_b]; let internal_thread = &mut extra.ref_thread_internal; - // 4 spaces needed: 1st element on A, idx element on A, 1st element on B, idx element on B - check_stack(internal_thread.ref_thread, 4) + // 4 spaces needed: idx element on A, idx element on B + check_stack(internal_thread.ref_thread, 2) .expect("internal error: cannot merge references, out of internal auxiliary stack space"); - // Push the first element from thread A to ensure we have enough stack space on thread A - ffi::lua_xmove(th_a.ref_thread, internal_thread.ref_thread, 1); - // Push the first element from thread B to ensure we have enough stack space on thread B - ffi::lua_xmove(th_b.ref_thread, internal_thread.ref_thread, 1); // Push the index element from thread A to top ffi::lua_xpush(th_a.ref_thread, internal_thread.ref_thread, aux_thread_a_index); // Push the index element from thread B to top @@ -417,17 +413,11 @@ pub(crate) unsafe fn compare_refs( // Now we have the following stack: // - index element from thread A (1) [copy from pushvalue] // - index element from thread B (2) [copy from pushvalue] - // - 1st element from thread A (3) - // - 1st element from thread B (4) // We want to compare the index elements from both threads, so use 3 and 4 as indices let result = f(internal_thread.ref_thread, 3, 4); // Pop the top 2 elements to clean the copies ffi::lua_pop(internal_thread.ref_thread, 2); - // Move the first element from thread B back to thread B - ffi::lua_xmove(internal_thread.ref_thread, th_b.ref_thread, 1); - // Move the first element from thread A back to thread A - ffi::lua_xmove(internal_thread.ref_thread, th_a.ref_thread, 1); result } From 3fb7abcfe5bb080faf57666a2270c48aace63088 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 7 Jun 2025 23:18:36 -0400 Subject: [PATCH 40/46] split out large thread create tests --- src/function.rs | 1 + src/state/extra.rs | 5 ----- src/state/raw.rs | 2 +- tests/thread.rs | 18 ++++++++++++++++++ 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/function.rs b/src/function.rs index 4ec5c8fe..bf4a6b63 100644 --- a/src/function.rs +++ b/src/function.rs @@ -3,6 +3,7 @@ use std::os::raw::{c_int, c_void}; use std::{mem, ptr, slice}; use crate::error::{Error, Result}; +#[cfg(feature = "luau")] use crate::state::util::get_next_spot; use crate::state::Lua; use crate::table::Table; diff --git a/src/state/extra.rs b/src/state/extra.rs index df3ebd85..e4676419 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -68,11 +68,6 @@ impl RefThread { free: Vec::new(), } } - - #[inline(always)] - pub(crate) fn top(&self) -> c_int { - self.stack_top - } } /// Data associated with the Lua state. diff --git a/src/state/raw.rs b/src/state/raw.rs index 4917226e..c7d43581 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1,7 +1,6 @@ use std::any::TypeId; use std::cell::{Cell, UnsafeCell}; use std::ffi::CStr; -use std::io::Write; use std::mem; use std::os::raw::{c_char, c_int, c_void}; use std::panic::resume_unwind; @@ -137,6 +136,7 @@ impl RawLua { } #[inline(always)] + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] pub(crate) fn ref_thread_internal(&self) -> *mut ffi::lua_State { unsafe { (*self.extra.get()).ref_thread_internal.ref_thread } } diff --git a/tests/thread.rs b/tests/thread.rs index baddc9a8..437b0936 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -516,10 +516,16 @@ fn test_continuation() { let v = th.resume::(v).expect("Failed to load continuation"); assert!(v.contains("Reached continuation which should panic!")); +} +#[test] +fn test_large_thread_creation() { + let lua = Lua::new(); + lua.set_memory_limit(100_000_000_000).unwrap(); let th1 = lua .create_thread(lua.create_function(|lua, _: ()| Ok(())).unwrap()) .unwrap(); + let mut ths = Vec::new(); for i in 1..2000000 { let th = lua @@ -532,6 +538,9 @@ fn test_continuation() { .unwrap(); for rth in ths { + let dbg_a = format!("{:?}", rth); + let th_a = format!("{:?}", th1); + let th_b = format!("{:?}", th2); assert!( th1 != rth && th2 != rth, "Thread {:?} is equal to th1 ({:?}) or th2 ({:?})", @@ -539,5 +548,14 @@ fn test_continuation() { th1, th2 ); + let dbg_b = format!("{:?}", rth); + let dbg_th1 = format!("{:?}", th1); + let dbg_th2 = format!("{:?}", th2); + + // Ensure that the PartialEq across auxillary threads does not affect the values on stack + // themselves. + assert_eq!(dbg_a, dbg_b, "Thread {:?} debug format changed", rth); + assert_eq!(th_a, dbg_th1, "Thread {:?} debug format changed for th1", rth); + assert_eq!(th_b, dbg_th2, "Thread {:?} debug format changed for th2", rth); } } From 4b8140b19eea343cb87c543aca65439192247bdc Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 8 Jun 2025 01:30:38 -0400 Subject: [PATCH 41/46] add retest of continuation/yield across continuation --- tests/thread.rs | 184 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) diff --git a/tests/thread.rs b/tests/thread.rs index 437b0936..5750f9e9 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -558,4 +558,188 @@ fn test_large_thread_creation() { assert_eq!(th_a, dbg_th1, "Thread {:?} debug format changed for th1", rth); assert_eq!(th_b, dbg_th2, "Thread {:?} debug format changed for th2", rth); } + + // Repeat yielded continuation test now with a new aux thread + // Yielding continuation test (only supported on luau) + #[cfg(feature = "luau")] + { + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + } + + let cont_func = lua + .create_function_with_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); + + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func) + .expect("Failed to call cont_func"), + 2 + ); + + // basic yield test before we go any further + let always_yield = lua + .create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6))) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); + + // Trigger the continuation + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + let always_yield = lua + .create_function_with_continuation( + |lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)), + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with((a + 1, 1)), + |lua, status, args: mlua::MultiValue| { + println!("Reached cont recursive/multiple: {:?}", args); + + if args.len() == 5 { + if cfg!(any(feature = "luau", feature = "lua52")) { + assert_eq!(status, mlua::ContinuationStatus::Ok); + } else { + assert_eq!(status, mlua::ContinuationStatus::Yielded); + } + return Ok(6_i32); + } + + lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 + Ok(1_i32) // this will be ignored + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); + + // test panics + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local ok, res = pcall(cont_func, 1) + assert(not ok) + return tostring(res) + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); } From c68b990e3846e1a87783ad5f5ad76f38147d99b8 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 8 Jun 2025 02:20:47 -0400 Subject: [PATCH 42/46] fix partialeq --- src/state/util.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/state/util.rs b/src/state/util.rs index a82b931e..881016ae 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -414,7 +414,7 @@ pub(crate) unsafe fn compare_refs( // - index element from thread A (1) [copy from pushvalue] // - index element from thread B (2) [copy from pushvalue] // We want to compare the index elements from both threads, so use 3 and 4 as indices - let result = f(internal_thread.ref_thread, 3, 4); + let result = f(internal_thread.ref_thread, -1, -2); // Pop the top 2 elements to clean the copies ffi::lua_pop(internal_thread.ref_thread, 2); From cbca60c506eaf5b15bca7ca3f5ef88dbf72f9ab8 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 8 Jun 2025 02:24:06 -0400 Subject: [PATCH 43/46] debug --- src/state/util.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/state/util.rs b/src/state/util.rs index 881016ae..8567b955 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -406,6 +406,8 @@ pub(crate) unsafe fn compare_refs( check_stack(internal_thread.ref_thread, 2) .expect("internal error: cannot merge references, out of internal auxiliary stack space"); + println!("Using cref comparison across threads"); + // Push the index element from thread A to top ffi::lua_xpush(th_a.ref_thread, internal_thread.ref_thread, aux_thread_a_index); // Push the index element from thread B to top From c385453c1d8e1ab57124c792d8c65d3dd39adfb8 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 8 Jun 2025 02:28:16 -0400 Subject: [PATCH 44/46] fix index in create_thread --- src/state/raw.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index c7d43581..0e11a7f0 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -616,7 +616,7 @@ impl RawLua { self.set_thread_hook(thread_state, HookKind::Global)?; let thread = Thread(self.pop_ref(), thread_state); - ffi::lua_xpush(self.ref_thread(thread.0.aux_thread), thread_state, func.0.index); + ffi::lua_xpush(self.ref_thread(func.0.aux_thread), thread_state, func.0.index); Ok(thread) } From aee199ffecde00ce663e8ad6ebd6146ec2c40361 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 8 Jun 2025 03:05:17 -0400 Subject: [PATCH 45/46] remove --- src/state/util.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index 8567b955..881016ae 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -406,8 +406,6 @@ pub(crate) unsafe fn compare_refs( check_stack(internal_thread.ref_thread, 2) .expect("internal error: cannot merge references, out of internal auxiliary stack space"); - println!("Using cref comparison across threads"); - // Push the index element from thread A to top ffi::lua_xpush(th_a.ref_thread, internal_thread.ref_thread, aux_thread_a_index); // Push the index element from thread B to top From 2e6855572c52bffc838a04e7b4476f737fb1404c Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 8 Jun 2025 05:19:07 -0400 Subject: [PATCH 46/46] fix thread --- tests/thread.rs | 349 ++++++++++++++++++++++++------------------------ 1 file changed, 176 insertions(+), 173 deletions(-) diff --git a/tests/thread.rs b/tests/thread.rs index 5750f9e9..0dd9c6e1 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -559,187 +559,190 @@ fn test_large_thread_creation() { assert_eq!(th_b, dbg_th2, "Thread {:?} debug format changed for th2", rth); } - // Repeat yielded continuation test now with a new aux thread - // Yielding continuation test (only supported on luau) - #[cfg(feature = "luau")] + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] { - mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); - } - - let cont_func = lua - .create_function_with_continuation( - |_lua, a: u64| Ok(a + 1), - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 2) - }, - ) - .expect("Failed to create cont_func"); - - // Ensure normal calls work still - assert_eq!( - lua.load("local cont_func = ...\nreturn cont_func(1)") - .call::(cont_func) - .expect("Failed to call cont_func"), - 2 - ); - - // basic yield test before we go any further - let always_yield = lua - .create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6))) - .unwrap(); - - let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!( - thread.resume::<(i32, String, f32)>(()).unwrap(), - (42, String::from("69420"), 45.6) - ); - - // Trigger the continuation - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| lua.yield_with(a), - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 39) - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res + 1 - ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); + // Repeat yielded continuation test now with a new aux thread + // Yielding continuation test (only supported on luau) + #[cfg(feature = "luau")] + { + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + } + + let cont_func = lua + .create_function_with_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); + + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func) + .expect("Failed to call cont_func"), + 2 + ); - assert_eq!(v, 41); + // basic yield test before we go any further + let always_yield = lua + .create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6))) + .unwrap(); - let always_yield = lua - .create_function_with_continuation( - |lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)), - |_lua, _, mv: mlua::MultiValue| { - println!("Reached second continuation"); - if mv.is_empty() { - return Ok(mv); - } - Err(mlua::Error::external(format!("a{}", mv.len()))) - }, - ) - .unwrap(); + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); - let thread = lua.create_thread(always_yield).unwrap(); - let mv = thread.resume::(()).unwrap(); - assert!(thread - .resume::(mv) - .unwrap_err() - .to_string() - .starts_with("a3")); + // Trigger the continuation + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| lua.yield_with((a + 1, 1)), - |lua, status, args: mlua::MultiValue| { - println!("Reached cont recursive/multiple: {:?}", args); - - if args.len() == 5 { - if cfg!(any(feature = "luau", feature = "lua52")) { - assert_eq!(status, mlua::ContinuationStatus::Ok); - } else { - assert_eq!(status, mlua::ContinuationStatus::Yielded); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + let always_yield = lua + .create_function_with_continuation( + |lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)), + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with((a + 1, 1)), + |lua, status, args: mlua::MultiValue| { + println!("Reached cont recursive/multiple: {:?}", args); + + if args.len() == 5 { + if cfg!(any(feature = "luau", feature = "lua52")) { + assert_eq!(status, mlua::ContinuationStatus::Ok); + } else { + assert_eq!(status, mlua::ContinuationStatus::Yielded); + } + return Ok(6_i32); } - return Ok(6_i32); - } - - lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 - Ok(1_i32) // this will be ignored - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res + 1 - ", - ) - .into_function() - .expect("Failed to create function"); - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - println!("v={:?}", v); - - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - - // (2, 1) followed by () - assert_eq!(v.len(), 2 + 3); - - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 7); - - // test panics - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| lua.yield_with(a), - |_lua, _status, _a: u64| { - panic!("Reached continuation which should panic!"); - #[allow(unreachable_code)] - Ok(()) - }, - ) - .expect("Failed to create cont_func"); - let luau_func = lua - .load( - " - local cont_func = ... - local ok, res = pcall(cont_func, 1) - assert(not ok) - return tostring(res) - ", - ) - .into_function() - .expect("Failed to create function"); + lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 + Ok(1_i32) // this will be ignored + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); + + // test panics + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local ok, res = pcall(cont_func, 1) + assert(not ok) + return tostring(res) + ", + ) + .into_function() + .expect("Failed to create function"); - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); - let v = th - .resume::(cont_func) - .expect("Failed to resume"); + let v = th + .resume::(cont_func) + .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); - assert!(v.contains("Reached continuation which should panic!")); + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); + } }