From 6cfa2267c42cf24f59570ddbad692842a9ff8852 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 04:22:09 -0400 Subject: [PATCH 01/32] mlua continuation yield --- mlua-sys/src/luau/lua.rs | 5 ++ src/error.rs | 2 +- src/state.rs | 56 +++++++++++++- src/state/extra.rs | 5 ++ src/state/raw.rs | 67 +++++++++++++++- src/state/util.rs | 161 +++++++++++++++++++++++++++++++++++++++ src/thread.rs | 19 +++++ src/types.rs | 6 ++ src/util/mod.rs | 1 + src/util/types.rs | 12 +++ tests/luau.rs | 3 + tests/luau/cont.rs | 59 ++++++++++++++ 12 files changed, 390 insertions(+), 6 deletions(-) create mode 100644 tests/luau/cont.rs diff --git a/mlua-sys/src/luau/lua.rs b/mlua-sys/src/luau/lua.rs index 8a55eef1..d898534c 100644 --- a/mlua-sys/src/luau/lua.rs +++ b/mlua-sys/src/luau/lua.rs @@ -426,6 +426,11 @@ pub unsafe fn lua_pushcclosure(L: *mut lua_State, f: lua_CFunction, nup: c_int) lua_pushcclosurek(L, f, ptr::null(), nup, None) } +#[inline(always)] +pub unsafe fn lua_pushcclosurec(L: *mut lua_State, f: lua_CFunction, cont: lua_Continuation, nup: c_int) { + lua_pushcclosurek(L, f, ptr::null(), nup, Some(cont)) +} + #[inline(always)] pub unsafe fn lua_pushcclosured(L: *mut lua_State, f: lua_CFunction, debugname: *const c_char, nup: c_int) { lua_pushcclosurek(L, f, debugname, nup, None) diff --git a/src/error.rs b/src/error.rs index 1f243967..483171e5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -321,7 +321,7 @@ impl fmt::Display for Error { Error::WithContext { context, cause } => { writeln!(fmt, "{context}")?; write!(fmt, "{cause}") - } + }, } } } diff --git a/src/state.rs b/src/state.rs index 4be9ba4c..a01f5332 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,5 +1,6 @@ use std::any::TypeId; use std::cell::{BorrowError, BorrowMutError, RefCell}; +use std::convert::Infallible; use std::marker::PhantomData; use std::ops::Deref; use std::os::raw::{c_char, c_int}; @@ -17,7 +18,7 @@ use crate::scope::Scope; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; -use crate::thread::Thread; +use crate::thread::{Thread, LuauContinuationStatus}; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex, @@ -1288,6 +1289,32 @@ impl Lua { })) } + /// Same as ``create_function`` but with support for Luau continuations + /// + /// Note that yieldable luau continuations are not currently supported at this time + #[cfg(feature = "luau")] + pub fn create_function_with_luau_continuation(&self, func: F, cont: FC) -> Result + where + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + FC: Fn(&Lua, LuauContinuationStatus, AC) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + AC: FromLuaMulti, + R: IntoLuaMulti, + RC: IntoLuaMulti, + { + (self.lock()).create_callback_with_luau_continuation( + Box::new(move |rawlua, nargs| unsafe { + let args = A::from_stack_args(nargs, 1, None, rawlua)?; + func(rawlua.lua(), args)?.push_into_stack_multi(rawlua) + }), + Box::new(move |rawlua, nargs, status| unsafe { + let args = AC::from_stack_args(nargs, 1, None, rawlua)?; + let status = LuauContinuationStatus::from_status(status); + cont(rawlua.lua(), status, args)?.push_into_stack_multi(rawlua) + }), + ) + } + /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. /// /// This is a version of [`Lua::create_function`] that accepts a `FnMut` argument. @@ -2103,6 +2130,33 @@ impl Lua { pub(crate) unsafe fn raw_lua(&self) -> &RawLua { &*self.raw.data_ptr() } + + /// Yields arguments + /// + /// If this function cannot yield, it will raise a runtime error. + /// + /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield + /// or not until it reaches the Lua state. + /// + /// Unsafe and should only be used in a function with a luau continuation for now + pub unsafe fn yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { + let raw = self.lock(); + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] + if !raw.is_yieldable() { + return Err(Error::runtime("cannot yield across Rust/Lua boundary.")) + } + unsafe { + raw.extra.get().as_mut().unwrap_unchecked().yielded_values = args.into_lua_multi(self)?; + } + Ok(()) + } + + /// Checks if Lua is currently allowed to yield. + #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[inline] + pub(crate) fn is_yieldable(&self) -> bool { + self.lock().is_yieldable() + } } impl WeakLua { diff --git a/src/state/extra.rs b/src/state/extra.rs index 12133639..445c232c 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -18,6 +18,7 @@ use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, Wrapp #[cfg(any(feature = "luau", doc))] use crate::chunk::Compiler; +use crate::MultiValue; #[cfg(feature = "async")] use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; @@ -93,6 +94,9 @@ pub(crate) struct ExtraData { pub(super) compiler: Option, #[cfg(feature = "luau-jit")] pub(super) enable_jit: bool, + + // Values currently being yielded from Lua.yield() + pub(super) yielded_values: MultiValue, } impl Drop for ExtraData { @@ -194,6 +198,7 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, + yielded_values: MultiValue::with_capacity(0), })); // Store it in the registry diff --git a/src/state/raw.rs b/src/state/raw.rs index cc4ce1cf..341520ec 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,15 +12,15 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, ref_stack_pop}; +use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; use crate::traits::IntoLua; use crate::types::{ - AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, - MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, + AppDataRef, AppDataRefMut, Callback, LuauContinuation, CallbackUpvalue, LuauContinuationUpvalue, DestructedUserdata, + Integer, LightUserData, MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; use crate::userdata::{ init_userdata_metatable, AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, @@ -198,6 +198,8 @@ impl RawLua { init_internal_metatable::>>(state, None)?; init_internal_metatable::(state, None)?; init_internal_metatable::(state, None)?; + #[cfg(feature = "luau")] + init_internal_metatable::(state, None)?; #[cfg(not(feature = "luau"))] init_internal_metatable::(state, None)?; #[cfg(feature = "async")] @@ -1156,10 +1158,11 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); + #[cfg(feature = "luau")] match (*upvalue).data { Some(ref func) => func(rawlua, nargs), None => Err(Error::CallbackDestructed), @@ -1188,6 +1191,56 @@ impl RawLua { } } + // Creates a Function out of a Callback and a continuation containing a 'static Fn. + #[cfg(feature = "luau")] + pub(crate) fn create_callback_with_luau_continuation(&self, func: Callback, cont: LuauContinuation) -> Result { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.0)(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }) + } + + unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }) + } + + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, LuauContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + })?; + } else { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + } + + Ok(Function(self.pop_ref())) + } + } + #[cfg(feature = "async")] pub(crate) fn create_async_callback(&self, func: AsyncCallback) -> Result { // Ensure that the coroutine library is loaded @@ -1348,6 +1401,12 @@ impl RawLua { pub(crate) unsafe fn set_waker(&self, waker: NonNull) -> NonNull { mem::replace(&mut (*self.extra.get()).waker, waker) } + + #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[inline] + pub(crate) fn is_yieldable(&self) -> bool { + unsafe { ffi::lua_isyieldable(self.state()) != 0 } + } } // Uses 3 stack spaces diff --git a/src/state/util.rs b/src/state/util.rs index c3c79302..97cb2f60 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,7 +1,9 @@ +use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; use std::sync::Arc; +use crate::IntoLuaMulti; use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; @@ -38,7 +40,140 @@ where } let nargs = ffi::lua_gettop(state); + + enum PreallocatedFailure { + New(*mut WrappedFailure), + Reserved, + } + + impl PreallocatedFailure { + unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { + if (*extra).wrapped_failure_top > 0 { + (*extra).wrapped_failure_top -= 1; + return PreallocatedFailure::Reserved; + } + + // We need to check stack for Luau in case when callback is called from interrupt + // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + // Place it to the beginning of the stack + let ud = WrappedFailure::new_userdata(state); + ffi::lua_insert(state, 1); + PreallocatedFailure::New(ud) + } + + #[cold] + unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { + let ref_thread = (*extra).ref_thread; + match *self { + PreallocatedFailure::New(ud) => { + ffi::lua_settop(state, 1); + ud + } + PreallocatedFailure::Reserved => { + let index = (*extra).wrapped_failure_pool.pop().unwrap(); + ffi::lua_settop(state, 0); + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + ffi::lua_xpush(ref_thread, state, index); + ffi::lua_pushnil(ref_thread); + ffi::lua_replace(ref_thread, index); + (*extra).ref_free.push(index); + ffi::lua_touserdata(state, -1) as *mut WrappedFailure + } + } + } + unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { + let ref_thread = (*extra).ref_thread; + match self { + PreallocatedFailure::New(_) => { + ffi::lua_rotate(state, 1, -1); + ffi::lua_xmove(state, ref_thread, 1); + let index = ref_stack_pop(extra); + (*extra).wrapped_failure_pool.push(index); + (*extra).wrapped_failure_top += 1; + } + PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, + } + } + } + + // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory + // to store a wrapped failure (error or panic) *before* we proceed. + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + + match catch_unwind(AssertUnwindSafe(|| { + let rawlua = (*extra).raw_lua(); + let _guard = StateGuard::new(rawlua, state); + f(extra, nargs) + })) { + Ok(Ok(r)) => { + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + + // Return unused `WrappedFailure` to the pool + prealloc_failure.release(state, extra); + r + } + Ok(Err(err)) => { + let wrapped_error = prealloc_failure.r#use(state, extra); + + if !wrap_error { + ptr::write(wrapped_error, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) + } + + // Build `CallbackError` with traceback + let traceback = if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { + ffi::luaL_traceback(state, state, ptr::null(), 0); + let traceback = util::to_string(state, -1); + ffi::lua_pop(state, 1); + traceback + } else { + "".to_string() + }; + let cause = Arc::new(err); + ptr::write( + wrapped_error, + WrappedFailure::Error(Error::CallbackError { traceback, cause }), + ); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + + ffi::lua_error(state) + } + Err(p) => { + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Panic(Some(p))); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) + } + } +} + +// An yieldable version of `callback_error_ext` +pub(crate) unsafe fn callback_error_ext_yieldable( + state: *mut ffi::lua_State, + mut extra: *mut ExtraData, + wrap_error: bool, + f: F, +) -> c_int +where + F: FnOnce(*mut ExtraData, c_int) -> Result, +{ + println!("callback_error_ext_yieldable"); + + if extra.is_null() { + extra = ExtraData::get(state); + } + + let nargs = ffi::lua_gettop(state); + enum PreallocatedFailure { New(*mut WrappedFailure), Reserved, @@ -108,6 +243,32 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + + if !values.is_empty() { + println!("YIELD {:?}", values); + match values.push_into_stack_multi(raw) { + Ok(nargs) => { + println!("YIELDARGS {}", nargs); + return ffi::lua_yield(state, nargs); + }, + Err(err) => { + let wrapped_error = prealloc_failure.r#use(state, extra); + ptr::write( + wrapped_error, + WrappedFailure::Error(Error::external(err.to_string())), + ); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + + ffi::lua_error(state) + + } + } + } + + // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r diff --git a/src/thread.rs b/src/thread.rs index da7962b0..0657b3e7 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -26,6 +26,25 @@ use { }, }; +/// Luau continuation final status +#[cfg(feature = "luau")] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum LuauContinuationStatus { + Ok, + Yielded, + Error, +} + +impl LuauContinuationStatus { + pub(crate) fn from_status(status: c_int) -> Self { + match status { + ffi::LUA_YIELD => Self::Yielded, + ffi::LUA_OK => Self::Ok, + _ => Self::Error, + } + } +} + /// Status of a Lua thread (coroutine). #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum ThreadStatus { diff --git a/src/types.rs b/src/types.rs index 2589ea6e..c1efed0e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -40,6 +40,11 @@ pub(crate) type Callback = Box Result + Send + #[cfg(not(feature = "send"))] pub(crate) type Callback = Box Result + 'static>; +#[cfg(all(feature = "send", feature = "luau"))] +pub(crate) type LuauContinuation = Box Result + Send + 'static>; +#[cfg(all(not(feature = "send"), feature = "luau"))] +pub(crate) type LuauContinuation = Box Result + 'static>; + pub(crate) type ScopedCallback<'s> = Box Result + 's>; pub(crate) struct Upvalue { @@ -48,6 +53,7 @@ pub(crate) struct Upvalue { } pub(crate) type CallbackUpvalue = Upvalue>; +pub(crate) type LuauContinuationUpvalue = Upvalue>; #[cfg(all(feature = "async", feature = "send"))] pub(crate) type AsyncCallback = diff --git a/src/util/mod.rs b/src/util/mod.rs index f5fbae52..5dd6b19c 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -75,6 +75,7 @@ impl Drop for StackGuard { unsafe { let top = ffi::lua_gettop(self.state); if top < self.top { + println!("top={}, self.top={}", top, self.top); mlua_panic!("{} too many stack values popped", self.top - top) } if top > self.top { diff --git a/src/util/types.rs b/src/util/types.rs index 8bc9d8b2..16945009 100644 --- a/src/util/types.rs +++ b/src/util/types.rs @@ -3,6 +3,9 @@ use std::os::raw::c_void; use crate::types::{Callback, CallbackUpvalue}; +#[cfg(feature = "luau")] +use crate::types::LuauContinuationUpvalue; + #[cfg(feature = "async")] use crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}; @@ -34,6 +37,15 @@ impl TypeKey for CallbackUpvalue { } } +#[cfg(feature = "luau")] +impl TypeKey for LuauContinuationUpvalue { + #[inline(always)] + fn type_key() -> *const c_void { + static LUAU_CONTINUATION_UPVALUE_TYPE_KEY: u8 = 0; + &LUAU_CONTINUATION_UPVALUE_TYPE_KEY as *const u8 as *const c_void + } +} + #[cfg(not(feature = "luau"))] impl TypeKey for crate::types::HookCallback { #[inline(always)] diff --git a/tests/luau.rs b/tests/luau.rs index 20fbee6a..7b62bae3 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -449,3 +449,6 @@ fn test_typeof_error() -> Result<()> { #[path = "luau/require.rs"] mod require; + +#[path = "luau/cont.rs"] +mod cont; diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs new file mode 100644 index 00000000..9cb7e89f --- /dev/null +++ b/tests/luau/cont.rs @@ -0,0 +1,59 @@ +#[cfg(feature = "luau")] +use mlua::Lua; + +#[test] +fn test_luau_continuation() { + let lua = Lua::new(); + + let cont_func = lua.create_function_with_luau_continuation( + |lua, a: u64| Ok(a + 1), + |lua, status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + } + ).expect("Failed to create cont_func"); + + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func).expect("Failed to call cont_func"), + 2 + ); + + // does not work yet + /*let always_yield = lua.create_function(|lua, ()| { + unsafe { lua.yield_args((42, "69420")) } + }).unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!(thread.resume::<(i32, String)>(()).unwrap(), (42, String::from("69420")));*/ + + // Trigger the continuation + let cont_func = lua.create_function_with_luau_continuation( + |lua, a: u64| { + unsafe { + match lua.yield_args(a) { + Ok(()) => println!("yield_args called"), + Err(e) => println!("{:?}", e) + } + } + Ok(()) + }, + |lua, status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + } + ).expect("Failed to create cont_func"); + + let luau_func = lua.load(" + local cont_func = ... + local res = cont_func(1) + return res + ").into_function().expect("Failed to create function"); + let th = lua.create_thread(luau_func).expect("Failed to create luau thread"); + + let v = th.resume::(cont_func).expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 40); +} \ No newline at end of file From e01e45fe21ae8d7a3d80f53ef080b963036e73c1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 06:15:25 -0400 Subject: [PATCH 02/32] add missing xmove --- src/state/util.rs | 1 + src/thread.rs | 1 + tests/luau/cont.rs | 15 +++++++++------ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index 97cb2f60..957c069f 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -251,6 +251,7 @@ where match values.push_into_stack_multi(raw) { Ok(nargs) => { println!("YIELDARGS {}", nargs); + ffi::lua_xmove(raw.state(), state, nargs); return ffi::lua_yield(state, nargs); }, Err(err) => { diff --git a/src/thread.rs b/src/thread.rs index 0657b3e7..4ad040af 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -234,6 +234,7 @@ impl Thread { let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); #[cfg(feature = "luau")] let ret = ffi::lua_resumex(thread_state, state, nargs, &mut nresults as *mut c_int); + match ret { ffi::LUA_OK => Ok((ThreadStatusInner::Finished, nresults)), ffi::LUA_YIELD => Ok((ThreadStatusInner::Yielded(0), nresults)), diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 9cb7e89f..87c1ca78 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -21,12 +21,15 @@ fn test_luau_continuation() { ); // does not work yet - /*let always_yield = lua.create_function(|lua, ()| { - unsafe { lua.yield_args((42, "69420")) } - }).unwrap(); + let always_yield = lua.create_function( + |lua, ()| { + unsafe { lua.yield_args((42, "69420".to_string()))? } + Ok(()) + }) + .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!(thread.resume::<(i32, String)>(()).unwrap(), (42, String::from("69420")));*/ + assert_eq!(thread.resume::<(i32, String)>(()).unwrap(), (42, String::from("69420"))); // Trigger the continuation let cont_func = lua.create_function_with_luau_continuation( @@ -48,12 +51,12 @@ fn test_luau_continuation() { let luau_func = lua.load(" local cont_func = ... local res = cont_func(1) - return res + return res + 1 ").into_function().expect("Failed to create function"); let th = lua.create_thread(luau_func).expect("Failed to create luau thread"); let v = th.resume::(cont_func).expect("Failed to resume"); let v = th.resume::(v).expect("Failed to load continuation"); - assert_eq!(v, 40); + assert_eq!(v, 41); } \ No newline at end of file From bd6b65b20da638c7a52270b3d75be186f44b2210 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 06:18:11 -0400 Subject: [PATCH 03/32] remove comment --- tests/luau/cont.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 87c1ca78..04122035 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -20,7 +20,6 @@ fn test_luau_continuation() { 2 ); - // does not work yet let always_yield = lua.create_function( |lua, ()| { unsafe { lua.yield_args((42, "69420".to_string()))? } From c6f1ff250a99f0754c2cf0bdf1801f8066f80729 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 06:19:58 -0400 Subject: [PATCH 04/32] add a f32 to test stack --- tests/luau/cont.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 04122035..7e6fc6ee 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -22,13 +22,13 @@ fn test_luau_continuation() { let always_yield = lua.create_function( |lua, ()| { - unsafe { lua.yield_args((42, "69420".to_string()))? } + unsafe { lua.yield_args((42, "69420".to_string(), 45.6))? } Ok(()) }) .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!(thread.resume::<(i32, String)>(()).unwrap(), (42, String::from("69420"))); + assert_eq!(thread.resume::<(i32, String, f32)>(()).unwrap(), (42, String::from("69420"), 45.6)); // Trigger the continuation let cont_func = lua.create_function_with_luau_continuation( From 875d6e8ca6d4310bbb208d95dc161c0038f1ffb2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 06:57:48 -0400 Subject: [PATCH 05/32] cleanup --- src/state.rs | 19 +++++++++++++------ src/state/raw.rs | 25 +++++++++++++++++++++++-- src/state/util.rs | 13 +++++++------ src/thread.rs | 1 + src/types.rs | 1 + tests/luau/cont.rs | 25 ++++++++++++++++++++++--- 6 files changed, 67 insertions(+), 17 deletions(-) diff --git a/src/state.rs b/src/state.rs index a01f5332..8113dd92 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,5 @@ use std::any::TypeId; use std::cell::{BorrowError, BorrowMutError, RefCell}; -use std::convert::Infallible; use std::marker::PhantomData; use std::ops::Deref; use std::os::raw::{c_char, c_int}; @@ -18,7 +17,13 @@ use crate::scope::Scope; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; -use crate::thread::{Thread, LuauContinuationStatus}; +use crate::thread::Thread; + +#[cfg(feature = "luau")] +use crate::thread::LuauContinuationStatus; +#[cfg(feature = "luau")] +use std::convert::Infallible; + use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex, @@ -2131,15 +2136,17 @@ impl Lua { &*self.raw.data_ptr() } - /// Yields arguments + /// Sets the yields arguments. Note that Ok(()) must be returned for the Rust function + /// to actually yield. This method is mostly useful with Luau continuations /// /// If this function cannot yield, it will raise a runtime error. /// /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield /// or not until it reaches the Lua state. /// - /// Unsafe and should only be used in a function with a luau continuation for now - pub unsafe fn yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { + /// Potentially unsafe at this time. Use with caution + #[cfg(feature = "luau")] // todo: support non-luau set_yield_args, the groundwork is here + pub unsafe fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { let raw = self.lock(); #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] if !raw.is_yieldable() { @@ -2154,7 +2161,7 @@ impl Lua { /// Checks if Lua is currently allowed to yield. #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] #[inline] - pub(crate) fn is_yieldable(&self) -> bool { + pub fn is_yieldable(&self) -> bool { self.lock().is_yieldable() } } diff --git a/src/state/raw.rs b/src/state/raw.rs index 341520ec..5d2ba91c 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,16 +12,22 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; +use crate::state::util::{callback_error_ext, ref_stack_pop}; +#[cfg(feature = "luau")] +use crate::state::util::callback_error_ext_yieldable; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; use crate::traits::IntoLua; use crate::types::{ - AppDataRef, AppDataRefMut, Callback, LuauContinuation, CallbackUpvalue, LuauContinuationUpvalue, DestructedUserdata, + AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; + +#[cfg(feature = "luau")] +use crate::types::{LuauContinuation, LuauContinuationUpvalue}; + use crate::userdata::{ init_userdata_metatable, AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, UserDataStorage, @@ -1156,6 +1162,21 @@ impl RawLua { // Creates a Function out of a Callback containing a 'static Fn. pub(crate) fn create_callback(&self, func: Callback) -> Result { + #[cfg(not(feature = "luau"))] + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }) + } + + #[cfg(feature = "luau")] unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { diff --git a/src/state/util.rs b/src/state/util.rs index 957c069f..230a26f9 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -3,6 +3,8 @@ use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; use std::sync::Arc; + +#[cfg(feature = "luau")] use crate::IntoLuaMulti; use crate::error::{Error, Result}; @@ -110,7 +112,7 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + // Ensure yielded values are cleared take(&mut extra.as_mut().unwrap_unchecked().yielded_values); // Return unused `WrappedFailure` to the pool @@ -156,7 +158,10 @@ where } } -// An yieldable version of `callback_error_ext` +/// An yieldable version of `callback_error_ext` +/// +/// Outside of Luau, this does the same thing as ``callback_error_ext`` right now +#[cfg(feature = "luau")] pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, @@ -166,8 +171,6 @@ pub(crate) unsafe fn callback_error_ext_yieldable( where F: FnOnce(*mut ExtraData, c_int) -> Result, { - println!("callback_error_ext_yieldable"); - if extra.is_null() { extra = ExtraData::get(state); } @@ -247,10 +250,8 @@ where let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); if !values.is_empty() { - println!("YIELD {:?}", values); match values.push_into_stack_multi(raw) { Ok(nargs) => { - println!("YIELDARGS {}", nargs); ffi::lua_xmove(raw.state(), state, nargs); return ffi::lua_yield(state, nargs); }, diff --git a/src/thread.rs b/src/thread.rs index 4ad040af..6b834081 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -35,6 +35,7 @@ pub enum LuauContinuationStatus { Error, } +#[cfg(feature = "luau")] impl LuauContinuationStatus { pub(crate) fn from_status(status: c_int) -> Self { match status { diff --git a/src/types.rs b/src/types.rs index c1efed0e..3b6b0d8f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -53,6 +53,7 @@ pub(crate) struct Upvalue { } pub(crate) type CallbackUpvalue = Upvalue>; +#[cfg(feature = "luau")] pub(crate) type LuauContinuationUpvalue = Upvalue>; #[cfg(all(feature = "async", feature = "send"))] diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 7e6fc6ee..80154d03 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -22,7 +22,7 @@ fn test_luau_continuation() { let always_yield = lua.create_function( |lua, ()| { - unsafe { lua.yield_args((42, "69420".to_string(), 45.6))? } + unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } Ok(()) }) .unwrap(); @@ -34,8 +34,8 @@ fn test_luau_continuation() { let cont_func = lua.create_function_with_luau_continuation( |lua, a: u64| { unsafe { - match lua.yield_args(a) { - Ok(()) => println!("yield_args called"), + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), Err(e) => println!("{:?}", e) } } @@ -58,4 +58,23 @@ fn test_luau_continuation() { let v = th.resume::(v).expect("Failed to load continuation"); assert_eq!(v, 41); + + let always_yield = lua.create_function_with_luau_continuation( + |lua, ()| { + unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + Ok(()) + }, + |lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + } + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread.resume::(mv).unwrap_err().to_string().starts_with("a3")); } \ No newline at end of file From 2a67b268f8b4983c7e2664ba6aacc059a3509a90 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 07:04:18 -0400 Subject: [PATCH 06/32] deal with warnings --- src/state.rs | 2 -- tests/luau/cont.rs | 9 +++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/state.rs b/src/state.rs index 8113dd92..cca8a56c 100644 --- a/src/state.rs +++ b/src/state.rs @@ -21,8 +21,6 @@ use crate::thread::Thread; #[cfg(feature = "luau")] use crate::thread::LuauContinuationStatus; -#[cfg(feature = "luau")] -use std::convert::Infallible; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 80154d03..d7d021cc 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -6,8 +6,8 @@ fn test_luau_continuation() { let lua = Lua::new(); let cont_func = lua.create_function_with_luau_continuation( - |lua, a: u64| Ok(a + 1), - |lua, status, a: u64| { + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { println!("Reached cont"); Ok(a + 2) } @@ -20,6 +20,7 @@ fn test_luau_continuation() { 2 ); + // basic yield test before we go any further let always_yield = lua.create_function( |lua, ()| { unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } @@ -41,7 +42,7 @@ fn test_luau_continuation() { } Ok(()) }, - |lua, status, a: u64| { + |_lua, _status, a: u64| { println!("Reached cont"); Ok(a + 39) } @@ -64,7 +65,7 @@ fn test_luau_continuation() { unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } Ok(()) }, - |lua, _, mv: mlua::MultiValue| { + |_lua, _, mv: mlua::MultiValue| { println!("Reached second continuation"); if mv.is_empty() { return Ok(mv); From d23f058549da3f0681b60aa0c78798eee08f7ffa Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 07:10:24 -0400 Subject: [PATCH 07/32] update warning --- src/state.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/state.rs b/src/state.rs index cca8a56c..e6f7f8d6 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2142,7 +2142,10 @@ impl Lua { /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield /// or not until it reaches the Lua state. /// - /// Potentially unsafe at this time. Use with caution + /// Potentially unsafe at this time. Use with caution. + /// + /// This method only supports Luau for now as proper Rust yielding in other Lua variants is + /// more complicated. This limitation may be lifted in the future. #[cfg(feature = "luau")] // todo: support non-luau set_yield_args, the groundwork is here pub unsafe fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { let raw = self.lock(); From 62330e78e09d3d9dbc63042e0f57e5a5f8ea5e5e Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 10:34:08 -0400 Subject: [PATCH 08/32] allow yield on non-luau as well --- src/state.rs | 9 +++------ src/state/raw.rs | 24 ++++-------------------- src/state/util.rs | 5 ----- tests/thread.rs | 15 +++++++++++++++ 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/src/state.rs b/src/state.rs index e6f7f8d6..bfa73907 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2134,8 +2134,9 @@ impl Lua { &*self.raw.data_ptr() } - /// Sets the yields arguments. Note that Ok(()) must be returned for the Rust function - /// to actually yield. This method is mostly useful with Luau continuations + /// Sets the yields arguments. Note that ``Ok(())`` must be returned for the Rust function + /// to actually yield. This method is mostly useful with Luau continuations and Rust-Rust + /// yields /// /// If this function cannot yield, it will raise a runtime error. /// @@ -2143,10 +2144,6 @@ impl Lua { /// or not until it reaches the Lua state. /// /// Potentially unsafe at this time. Use with caution. - /// - /// This method only supports Luau for now as proper Rust yielding in other Lua variants is - /// more complicated. This limitation may be lifted in the future. - #[cfg(feature = "luau")] // todo: support non-luau set_yield_args, the groundwork is here pub unsafe fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { let raw = self.lock(); #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] diff --git a/src/state/raw.rs b/src/state/raw.rs index 5d2ba91c..a9216a95 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,9 +12,9 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, ref_stack_pop}; -#[cfg(feature = "luau")] -use crate::state::util::callback_error_ext_yieldable; +use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; +#[cfg(not(feature = "luau"))] +use crate::state::util::callback_error_ext; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; @@ -1162,28 +1162,12 @@ impl RawLua { // Creates a Function out of a Callback containing a 'static Fn. pub(crate) fn create_callback(&self, func: Callback) -> Result { - #[cfg(not(feature = "luau"))] - unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { - let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => func(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) - } - - #[cfg(feature = "luau")] unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); - #[cfg(feature = "luau")] match (*upvalue).data { Some(ref func) => func(rawlua, nargs), None => Err(Error::CallbackDestructed), @@ -1230,7 +1214,7 @@ impl RawLua { unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); diff --git a/src/state/util.rs b/src/state/util.rs index 230a26f9..702ee6f4 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -3,8 +3,6 @@ use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; use std::sync::Arc; - -#[cfg(feature = "luau")] use crate::IntoLuaMulti; use crate::error::{Error, Result}; @@ -159,9 +157,6 @@ where } /// An yieldable version of `callback_error_ext` -/// -/// Outside of Luau, this does the same thing as ``callback_error_ext`` right now -#[cfg(feature = "luau")] pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, diff --git a/tests/thread.rs b/tests/thread.rs index 4cb6ab10..de5aca2b 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -252,3 +252,18 @@ fn test_thread_resume_error() -> Result<()> { Ok(()) } + +#[test] +fn test_thread_yield_args() -> Result<()> { + let lua = Lua::new(); + let always_yield = lua.create_function( + |lua, ()| { + unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + Ok(()) + })?; + + let thread = lua.create_thread(always_yield)?; + assert_eq!(thread.resume::<(i32, String, f32)>(())?, (42, String::from("69420"), 45.6)); + + Ok(()) +} \ No newline at end of file From 1c7743c9142ce627ba0a53e9c2679ad719c8c50c Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 23:03:50 -0400 Subject: [PATCH 09/32] fix --- src/state.rs | 26 ++++--- src/state/raw.rs | 20 +++-- src/state/util.rs | 15 ++-- tests/luau/cont.rs | 180 +++++++++++++++++++++++++++++++++------------ 4 files changed, 169 insertions(+), 72 deletions(-) diff --git a/src/state.rs b/src/state.rs index bfa73907..7d0c6885 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1292,11 +1292,19 @@ impl Lua { })) } - /// Same as ``create_function`` but with support for Luau continuations + /// Same as ``create_function`` but with support for continuations. /// - /// Note that yieldable luau continuations are not currently supported at this time + /// Currently only luau-style continuations are supported at this time. + /// + /// The values passed to the continuation will either be the yielded arguments + /// from the function or the arguments from resuming the thread. Yielded values + /// from a continuation become the resume args. #[cfg(feature = "luau")] - pub fn create_function_with_luau_continuation(&self, func: F, cont: FC) -> Result + pub fn create_function_with_luau_continuation( + &self, + func: F, + cont: FC, + ) -> Result where F: Fn(&Lua, A) -> Result + MaybeSend + 'static, FC: Fn(&Lua, LuauContinuationStatus, AC) -> Result + MaybeSend + 'static, @@ -1316,7 +1324,7 @@ impl Lua { cont(rawlua.lua(), status, args)?.push_into_stack_multi(rawlua) }), ) - } + } /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. /// @@ -2137,10 +2145,10 @@ impl Lua { /// Sets the yields arguments. Note that ``Ok(())`` must be returned for the Rust function /// to actually yield. This method is mostly useful with Luau continuations and Rust-Rust /// yields - /// + /// /// If this function cannot yield, it will raise a runtime error. - /// - /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield + /// + /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield /// or not until it reaches the Lua state. /// /// Potentially unsafe at this time. Use with caution. @@ -2148,7 +2156,7 @@ impl Lua { let raw = self.lock(); #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] if !raw.is_yieldable() { - return Err(Error::runtime("cannot yield across Rust/Lua boundary.")) + return Err(Error::runtime("cannot yield across Rust/Lua boundary.")); } unsafe { raw.extra.get().as_mut().unwrap_unchecked().yielded_values = args.into_lua_multi(self)?; @@ -2157,7 +2165,7 @@ impl Lua { } /// Checks if Lua is currently allowed to yield. - #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] #[inline] pub fn is_yieldable(&self) -> bool { self.lock().is_yieldable() diff --git a/src/state/raw.rs b/src/state/raw.rs index a9216a95..a151c231 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,17 +12,17 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; #[cfg(not(feature = "luau"))] use crate::state::util::callback_error_ext; +use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; use crate::traits::IntoLua; use crate::types::{ - AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, - Integer, LightUserData, MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, + AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, + MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; #[cfg(feature = "luau")] @@ -1164,7 +1164,7 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1198,10 +1198,14 @@ impl RawLua { // Creates a Function out of a Callback and a continuation containing a 'static Fn. #[cfg(feature = "luau")] - pub(crate) fn create_callback_with_luau_continuation(&self, func: Callback, cont: LuauContinuation) -> Result { + pub(crate) fn create_callback_with_luau_continuation( + &self, + func: Callback, + cont: LuauContinuation, + ) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1214,7 +1218,7 @@ impl RawLua { unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, state, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1407,7 +1411,7 @@ impl RawLua { mem::replace(&mut (*self.extra.get()).waker, waker) } - #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] #[inline] pub(crate) fn is_yieldable(&self) -> bool { unsafe { ffi::lua_isyieldable(self.state()) != 0 } diff --git a/src/state/util.rs b/src/state/util.rs index 702ee6f4..922d36f3 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,9 +1,9 @@ +use crate::IntoLuaMulti; use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; use std::sync::Arc; -use crate::IntoLuaMulti; use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; @@ -40,7 +40,7 @@ where } let nargs = ffi::lua_gettop(state); - + enum PreallocatedFailure { New(*mut WrappedFailure), Reserved, @@ -164,14 +164,14 @@ pub(crate) unsafe fn callback_error_ext_yieldable( f: F, ) -> c_int where - F: FnOnce(*mut ExtraData, c_int) -> Result, + F: FnOnce(*mut ExtraData, *mut ffi::lua_State, c_int) -> Result, { if extra.is_null() { extra = ExtraData::get(state); } let nargs = ffi::lua_gettop(state); - + enum PreallocatedFailure { New(*mut WrappedFailure), Reserved, @@ -238,7 +238,7 @@ where match catch_unwind(AssertUnwindSafe(|| { let rawlua = (*extra).raw_lua(); let _guard = StateGuard::new(rawlua, state); - f(extra, nargs) + f(extra, state, nargs) })) { Ok(Ok(r)) => { let raw = extra.as_ref().unwrap_unchecked().raw_lua(); @@ -247,9 +247,10 @@ where if !values.is_empty() { match values.push_into_stack_multi(raw) { Ok(nargs) => { + ffi::lua_pop(state, -1); ffi::lua_xmove(raw.state(), state, nargs); return ffi::lua_yield(state, nargs); - }, + } Err(err) => { let wrapped_error = prealloc_failure.r#use(state, extra); ptr::write( @@ -260,12 +261,10 @@ where ffi::lua_setmetatable(state, -2); ffi::lua_error(state) - } } } - // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index d7d021cc..0d90b8f4 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -1,81 +1,167 @@ +use mlua::IntoLuaMulti; #[cfg(feature = "luau")] use mlua::Lua; #[test] fn test_luau_continuation() { + // Yielding continuation + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + let lua = Lua::new(); - let cont_func = lua.create_function_with_luau_continuation( - |_lua, a: u64| Ok(a + 1), - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 2) - } - ).expect("Failed to create cont_func"); + let cont_func = lua + .create_function_with_luau_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); // Ensure normal calls work still assert_eq!( lua.load("local cont_func = ...\nreturn cont_func(1)") - .call::(cont_func).expect("Failed to call cont_func"), + .call::(cont_func) + .expect("Failed to call cont_func"), 2 ); // basic yield test before we go any further - let always_yield = lua.create_function( - |lua, ()| { + let always_yield = lua + .create_function(|lua, ()| { unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } Ok(()) }) .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!(thread.resume::<(i32, String, f32)>(()).unwrap(), (42, String::from("69420"), 45.6)); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); // Trigger the continuation - let cont_func = lua.create_function_with_luau_continuation( - |lua, a: u64| { - unsafe { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e) - } - } - Ok(()) - }, - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 39) - } - ).expect("Failed to create cont_func"); - - let luau_func = lua.load(" + let cont_func = lua + .create_function_with_luau_continuation( + |lua, a: u64| { + unsafe { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + } + } + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " local cont_func = ... local res = cont_func(1) return res + 1 - ").into_function().expect("Failed to create function"); - let th = lua.create_thread(luau_func).expect("Failed to create luau thread"); + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); - let v = th.resume::(cont_func).expect("Failed to resume"); + let v = th + .resume::(cont_func) + .expect("Failed to resume"); let v = th.resume::(v).expect("Failed to load continuation"); assert_eq!(v, 41); - let always_yield = lua.create_function_with_luau_continuation( - |lua, ()| { - unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } - Ok(()) - }, - |_lua, _, mv: mlua::MultiValue| { - println!("Reached second continuation"); - if mv.is_empty() { - return Ok(mv); - } - Err(mlua::Error::external(format!("a{}", mv.len()))) - } - ) - .unwrap(); + let always_yield = lua + .create_function_with_luau_continuation( + |lua, ()| { + unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + Ok(()) + }, + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); let mv = thread.resume::(()).unwrap(); - assert!(thread.resume::(mv).unwrap_err().to_string().starts_with("a3")); -} \ No newline at end of file + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_luau_continuation( + |lua, a: u64| { + unsafe { + match lua.set_yield_args((a + 1, 1)) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + } + } + Ok(()) + }, + |lua, _status, args: mlua::MultiValue| { + println!("Reached cont recursive: {:?}", args); + + if args.len() == 5 { + return 6_i32.into_lua_multi(lua); + } + + unsafe { lua.set_yield_args((args.len() + 1, args))? } // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) + (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); +} From 85416fdd3064dc119a719261079329c5f1087bc1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Jun 2025 23:46:21 -0400 Subject: [PATCH 10/32] Document continuations a bit more --- src/state.rs | 30 +++++++++++++++++++++--------- src/state/raw.rs | 2 +- src/state/util.rs | 3 +++ tests/luau/cont.rs | 25 +++++++++++-------------- tests/thread.rs | 16 +++++++++------- 5 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/state.rs b/src/state.rs index 7d0c6885..341adf8d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1292,13 +1292,21 @@ impl Lua { })) } - /// Same as ``create_function`` but with support for continuations. + /// Same as ``create_function`` but with support for luau-style continuations. /// - /// Currently only luau-style continuations are supported at this time. + /// Other Lua versions have completely different continuation semantics and are both + /// not supported at this time, much more complicated to support within mlua and cannot + /// be supported with this API in any case. /// - /// The values passed to the continuation will either be the yielded arguments - /// from the function or the arguments from resuming the thread. Yielded values - /// from a continuation become the resume args. + /// The values passed to the continuation will be the yielded arguments + /// from the function for the initial continuation call. If yielding from a continuation, + /// the yielded results will be returned to the ``Thread::resume`` caller. The arguments + /// passed in the next ``Thread::resume`` call will then be the arguments passed to the yielding + /// continuation upon resumption. + /// + /// Returning a value from a continuation without setting yield + /// arguments will then be returned as the final return value of the Luau function call. + /// Values returned in a function in which there is also yielding will be ignored #[cfg(feature = "luau")] pub fn create_function_with_luau_continuation( &self, @@ -2143,16 +2151,20 @@ impl Lua { } /// Sets the yields arguments. Note that ``Ok(())`` must be returned for the Rust function - /// to actually yield. This method is mostly useful with Luau continuations and Rust-Rust - /// yields + /// to actually yield. Any values returned in a function in which there is also yielding may + /// be ignored. + /// + /// This method is mostly useful with Luau continuations and Rust-Rust yields + /// due to the Rust/Lua boundary /// /// If this function cannot yield, it will raise a runtime error. /// /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield /// or not until it reaches the Lua state. /// - /// Potentially unsafe at this time. Use with caution. - pub unsafe fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { + /// While this method *should be safe*, it is new and may have bugs lurking within. Use + /// with caution + pub fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { let raw = self.lock(); #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] if !raw.is_yieldable() { diff --git a/src/state/raw.rs b/src/state/raw.rs index a151c231..147e6cbe 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1218,7 +1218,7 @@ impl RawLua { unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, state, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); diff --git a/src/state/util.rs b/src/state/util.rs index 922d36f3..5c4d1409 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -157,6 +157,9 @@ where } /// An yieldable version of `callback_error_ext` +/// +/// Unlike ``callback_error_ext``, this method requires a c_int return +/// and not a generic R pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 0d90b8f4..ce8f58f8 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -30,7 +30,7 @@ fn test_luau_continuation() { // basic yield test before we go any further let always_yield = lua .create_function(|lua, ()| { - unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + lua.set_yield_args((42, "69420".to_string(), 45.6))?; Ok(()) }) .unwrap(); @@ -45,12 +45,10 @@ fn test_luau_continuation() { let cont_func = lua .create_function_with_luau_continuation( |lua, a: u64| { - unsafe { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } - } + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; Ok(()) }, |_lua, _status, a: u64| { @@ -70,6 +68,7 @@ fn test_luau_continuation() { ) .into_function() .expect("Failed to create function"); + let th = lua .create_thread(luau_func) .expect("Failed to create luau thread"); @@ -84,7 +83,7 @@ fn test_luau_continuation() { let always_yield = lua .create_function_with_luau_continuation( |lua, ()| { - unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } + lua.set_yield_args((42, "69420".to_string(), 45.6))?; Ok(()) }, |_lua, _, mv: mlua::MultiValue| { @@ -108,11 +107,9 @@ fn test_luau_continuation() { let cont_func = lua .create_function_with_luau_continuation( |lua, a: u64| { - unsafe { - match lua.set_yield_args((a + 1, 1)) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } + match lua.set_yield_args((a + 1, 1)) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), } Ok(()) }, @@ -123,7 +120,7 @@ fn test_luau_continuation() { return 6_i32.into_lua_multi(lua); } - unsafe { lua.set_yield_args((args.len() + 1, args))? } // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) + lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored }, ) diff --git a/tests/thread.rs b/tests/thread.rs index de5aca2b..e6825ff3 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -256,14 +256,16 @@ fn test_thread_resume_error() -> Result<()> { #[test] fn test_thread_yield_args() -> Result<()> { let lua = Lua::new(); - let always_yield = lua.create_function( - |lua, ()| { - unsafe { lua.set_yield_args((42, "69420".to_string(), 45.6))? } - Ok(()) - })?; + let always_yield = lua.create_function(|lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + })?; let thread = lua.create_thread(always_yield)?; - assert_eq!(thread.resume::<(i32, String, f32)>(())?, (42, String::from("69420"), 45.6)); + assert_eq!( + thread.resume::<(i32, String, f32)>(())?, + (42, String::from("69420"), 45.6) + ); Ok(()) -} \ No newline at end of file +} From b3a54ac15a4b54f8424cdd842728bff5e5074e91 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 00:18:47 -0400 Subject: [PATCH 11/32] reuse same preallocatedfailure in both callback_error_ext and callback_error_ext_yieldable --- src/state/raw.rs | 6 +- src/state/util.rs | 181 +++++++++++++++------------------------------ tests/luau/cont.rs | 2 + 3 files changed, 66 insertions(+), 123 deletions(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index 147e6cbe..74939b4c 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1164,7 +1164,7 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1205,7 +1205,7 @@ impl RawLua { ) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); @@ -1218,7 +1218,7 @@ impl RawLua { unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, _state, nargs| { + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); diff --git a/src/state/util.rs b/src/state/util.rs index 5c4d1409..cbdee8b4 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -24,6 +24,65 @@ impl Drop for StateGuard<'_> { } } +pub(crate) enum PreallocatedFailure { + New(*mut WrappedFailure), + Reserved, +} + +impl PreallocatedFailure { + unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { + if (*extra).wrapped_failure_top > 0 { + (*extra).wrapped_failure_top -= 1; + return PreallocatedFailure::Reserved; + } + + // We need to check stack for Luau in case when callback is called from interrupt + // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + // Place it to the beginning of the stack + let ud = WrappedFailure::new_userdata(state); + ffi::lua_insert(state, 1); + PreallocatedFailure::New(ud) + } + + #[cold] + unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { + let ref_thread = (*extra).ref_thread; + match *self { + PreallocatedFailure::New(ud) => { + ffi::lua_settop(state, 1); + ud + } + PreallocatedFailure::Reserved => { + let index = (*extra).wrapped_failure_pool.pop().unwrap(); + ffi::lua_settop(state, 0); + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + ffi::lua_xpush(ref_thread, state, index); + ffi::lua_pushnil(ref_thread); + ffi::lua_replace(ref_thread, index); + (*extra).ref_free.push(index); + ffi::lua_touserdata(state, -1) as *mut WrappedFailure + } + } + } + + unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { + let ref_thread = (*extra).ref_thread; + match self { + PreallocatedFailure::New(_) => { + ffi::lua_rotate(state, 1, -1); + ffi::lua_xmove(state, ref_thread, 1); + let index = ref_stack_pop(extra); + (*extra).wrapped_failure_pool.push(index); + (*extra).wrapped_failure_top += 1; + } + PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, + } + } +} + // An optimized version of `callback_error` that does not allocate `WrappedFailure` userdata // and instead reuses unused values from previous calls (or allocates new). pub(crate) unsafe fn callback_error_ext( @@ -41,65 +100,6 @@ where let nargs = ffi::lua_gettop(state); - enum PreallocatedFailure { - New(*mut WrappedFailure), - Reserved, - } - - impl PreallocatedFailure { - unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { - if (*extra).wrapped_failure_top > 0 { - (*extra).wrapped_failure_top -= 1; - return PreallocatedFailure::Reserved; - } - - // We need to check stack for Luau in case when callback is called from interrupt - // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - // Place it to the beginning of the stack - let ud = WrappedFailure::new_userdata(state); - ffi::lua_insert(state, 1); - PreallocatedFailure::New(ud) - } - - #[cold] - unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { - let ref_thread = (*extra).ref_thread; - match *self { - PreallocatedFailure::New(ud) => { - ffi::lua_settop(state, 1); - ud - } - PreallocatedFailure::Reserved => { - let index = (*extra).wrapped_failure_pool.pop().unwrap(); - ffi::lua_settop(state, 0); - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - ffi::lua_xpush(ref_thread, state, index); - ffi::lua_pushnil(ref_thread); - ffi::lua_replace(ref_thread, index); - (*extra).ref_free.push(index); - ffi::lua_touserdata(state, -1) as *mut WrappedFailure - } - } - } - - unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { - let ref_thread = (*extra).ref_thread; - match self { - PreallocatedFailure::New(_) => { - ffi::lua_rotate(state, 1, -1); - ffi::lua_xmove(state, ref_thread, 1); - let index = ref_stack_pop(extra); - (*extra).wrapped_failure_pool.push(index); - (*extra).wrapped_failure_top += 1; - } - PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, - } - } - } - // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory // to store a wrapped failure (error or panic) *before* we proceed. let prealloc_failure = PreallocatedFailure::reserve(state, extra); @@ -167,7 +167,7 @@ pub(crate) unsafe fn callback_error_ext_yieldable( f: F, ) -> c_int where - F: FnOnce(*mut ExtraData, *mut ffi::lua_State, c_int) -> Result, + F: FnOnce(*mut ExtraData, c_int) -> Result, { if extra.is_null() { extra = ExtraData::get(state); @@ -175,65 +175,6 @@ where let nargs = ffi::lua_gettop(state); - enum PreallocatedFailure { - New(*mut WrappedFailure), - Reserved, - } - - impl PreallocatedFailure { - unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { - if (*extra).wrapped_failure_top > 0 { - (*extra).wrapped_failure_top -= 1; - return PreallocatedFailure::Reserved; - } - - // We need to check stack for Luau in case when callback is called from interrupt - // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - // Place it to the beginning of the stack - let ud = WrappedFailure::new_userdata(state); - ffi::lua_insert(state, 1); - PreallocatedFailure::New(ud) - } - - #[cold] - unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { - let ref_thread = (*extra).ref_thread; - match *self { - PreallocatedFailure::New(ud) => { - ffi::lua_settop(state, 1); - ud - } - PreallocatedFailure::Reserved => { - let index = (*extra).wrapped_failure_pool.pop().unwrap(); - ffi::lua_settop(state, 0); - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - ffi::lua_xpush(ref_thread, state, index); - ffi::lua_pushnil(ref_thread); - ffi::lua_replace(ref_thread, index); - (*extra).ref_free.push(index); - ffi::lua_touserdata(state, -1) as *mut WrappedFailure - } - } - } - - unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { - let ref_thread = (*extra).ref_thread; - match self { - PreallocatedFailure::New(_) => { - ffi::lua_rotate(state, 1, -1); - ffi::lua_xmove(state, ref_thread, 1); - let index = ref_stack_pop(extra); - (*extra).wrapped_failure_pool.push(index); - (*extra).wrapped_failure_top += 1; - } - PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, - } - } - } - // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory // to store a wrapped failure (error or panic) *before* we proceed. let prealloc_failure = PreallocatedFailure::reserve(state, extra); @@ -241,7 +182,7 @@ where match catch_unwind(AssertUnwindSafe(|| { let rawlua = (*extra).raw_lua(); let _guard = StateGuard::new(rawlua, state); - f(extra, state, nargs) + f(extra, nargs) })) { Ok(Ok(r)) => { let raw = extra.as_ref().unwrap_unchecked().raw_lua(); diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index ce8f58f8..675d90d8 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -148,9 +148,11 @@ fn test_luau_continuation() { let v = th .resume::(v) .expect("Failed to load continuation"); + println!("v={:?}", v); let v = th .resume::(v) .expect("Failed to load continuation"); + println!("v={:?}", v); let v = th .resume::(v) .expect("Failed to load continuation"); From 085c62a68e36c95402b2a5c70d4acd013d3a9fb5 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 04:56:06 -0400 Subject: [PATCH 12/32] handle mainthread edgecase --- src/state/util.rs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index cbdee8b4..517ae487 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -189,10 +189,23 @@ where let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); if !values.is_empty() { + if raw.state() == state { + // Edge case: main thread is being yielded + // + // We need to pop/clear stack early, then push args + ffi::lua_pop(state, -1); + } + match values.push_into_stack_multi(raw) { Ok(nargs) => { - ffi::lua_pop(state, -1); - ffi::lua_xmove(raw.state(), state, nargs); + // If not main thread, then clear and xmove to target thread + if raw.state() != state { + // luau preserves the stack making yieldable continuations ugly and leaky + // + // Even outside of luau, clearing the stack is probably desirable + ffi::lua_pop(state, -1); + ffi::lua_xmove(raw.state(), state, nargs); + } return ffi::lua_yield(state, nargs); } Err(err) => { From 198b8575801d10557834b4c631f513ee3af3b3f7 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 07:06:12 -0400 Subject: [PATCH 13/32] fix import --- src/state/raw.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index 74939b4c..897e693a 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,9 +12,7 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -#[cfg(not(feature = "luau"))] -use crate::state::util::callback_error_ext; -use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; +use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; From 2cd34db321417b0d28a5be3f6a38284caa8b4cc6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 12:04:44 -0400 Subject: [PATCH 14/32] support empty yield args --- src/state.rs | 2 +- src/state/extra.rs | 4 +-- src/state/util.rs | 2 +- tests/luau/cont.rs | 80 ++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 81 insertions(+), 7 deletions(-) diff --git a/src/state.rs b/src/state.rs index 341adf8d..dacf62b9 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2171,7 +2171,7 @@ impl Lua { return Err(Error::runtime("cannot yield across Rust/Lua boundary.")); } unsafe { - raw.extra.get().as_mut().unwrap_unchecked().yielded_values = args.into_lua_multi(self)?; + raw.extra.get().as_mut().unwrap_unchecked().yielded_values = Some(args.into_lua_multi(self)?); } Ok(()) } diff --git a/src/state/extra.rs b/src/state/extra.rs index 445c232c..7c18e9b8 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -96,7 +96,7 @@ pub(crate) struct ExtraData { pub(super) enable_jit: bool, // Values currently being yielded from Lua.yield() - pub(super) yielded_values: MultiValue, + pub(super) yielded_values: Option, } impl Drop for ExtraData { @@ -198,7 +198,7 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, - yielded_values: MultiValue::with_capacity(0), + yielded_values: None, })); // Store it in the registry diff --git a/src/state/util.rs b/src/state/util.rs index 517ae487..6b8d702d 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -188,7 +188,7 @@ where let raw = extra.as_ref().unwrap_unchecked().raw_lua(); let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - if !values.is_empty() { + if let Some(values) = values { if raw.state() == state { // Edge case: main thread is being yielded // diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 675d90d8..9cb2b9a8 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -4,10 +4,84 @@ use mlua::Lua; #[test] fn test_luau_continuation() { - // Yielding continuation - mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); - let lua = Lua::new(); + // No yielding continuation fflag test + let cont_func = lua + .create_function_with_luau_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + // empty yield args test + let cont_func = lua + .create_function_with_luau_continuation( + |lua, _: ()| { + match lua.set_yield_args(()) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res - 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + assert!(v.is_empty()); + let v = th.resume::(v).expect("Failed to load continuation"); + assert_eq!(v, -1); + + // Yielding continuation fflag test + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); let cont_func = lua .create_function_with_luau_continuation( From 23b77056229b316f5b0eea407924f8945fa03abc Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 12:29:52 -0400 Subject: [PATCH 15/32] add more tests for cont --- src/lib.rs | 3 +++ src/prelude.rs | 3 +++ tests/luau/cont.rs | 44 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index de321205..3b27345d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,6 +126,9 @@ pub use crate::value::{Nil, Value}; #[cfg(not(feature = "luau"))] pub use crate::hook::HookTriggers; +#[cfg(feature = "luau")] +pub use crate::thread::LuauContinuationStatus; + #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub use crate::{ diff --git a/src/prelude.rs b/src/prelude.rs index eeeaea26..5becef07 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -28,6 +28,9 @@ pub use crate::{ NavigateError as LuaNavigateError, Require as LuaRequire, Vector as LuaVector, }; +#[cfg(feature = "luau")] +pub use crate::LuauContinuationStatus; + #[cfg(feature = "async")] #[doc(no_inline)] pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn}; diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 9cb2b9a8..58ffdd2c 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -187,10 +187,11 @@ fn test_luau_continuation() { } Ok(()) }, - |lua, _status, args: mlua::MultiValue| { + |lua, status, args: mlua::MultiValue| { println!("Reached cont recursive: {:?}", args); if args.len() == 5 { + assert_eq!(status, mlua::LuauContinuationStatus::Ok); return 6_i32.into_lua_multi(lua); } @@ -237,4 +238,45 @@ fn test_luau_continuation() { let v = th.resume::(v).expect("Failed to load continuation"); assert_eq!(v, 7); + + // test panics + let cont_func = lua + .create_function_with_luau_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local ok, res = pcall(cont_func, 1) + assert(not ok) + return tostring(res) + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); } From dd3ea057e79762eea1379c6a4abdde4353c1dce6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 12:43:43 -0400 Subject: [PATCH 16/32] ensure check_stack of state --- src/state/util.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/state/util.rs b/src/state/util.rs index 6b8d702d..ff39e70b 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; -use crate::util::{self, get_internal_metatable, WrappedFailure}; +use crate::util::{self, check_stack, get_internal_metatable, WrappedFailure}; struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); @@ -204,6 +204,17 @@ where // // Even outside of luau, clearing the stack is probably desirable ffi::lua_pop(state, -1); + if let Err(err) = check_stack(state, nargs) { + let wrapped_error = prealloc_failure.r#use(state, extra); + ptr::write( + wrapped_error, + WrappedFailure::Error(Error::external(err.to_string())), + ); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + + ffi::lua_error(state) + } ffi::lua_xmove(raw.state(), state, nargs); } return ffi::lua_yield(state, nargs); From 52176a8271fb7709ad24636a4d64436c06aa6d87 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 20:39:21 -0400 Subject: [PATCH 17/32] rename to luau continuation status --- src/lib.rs | 5 +---- src/prelude.rs | 21 +++++++++------------ src/state.rs | 16 ++++++++++------ src/thread.rs | 8 +++----- tests/luau/cont.rs | 16 ++++++++-------- 5 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3b27345d..20bf3237 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,7 +110,7 @@ pub use crate::state::{GCMode, Lua, LuaOptions, WeakLua}; pub use crate::stdlib::StdLib; pub use crate::string::{BorrowedBytes, BorrowedStr, String}; pub use crate::table::{Table, TablePairs, TableSequence}; -pub use crate::thread::{Thread, ThreadStatus}; +pub use crate::thread::{ContinuationStatus, Thread, ThreadStatus}; pub use crate::traits::{ FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut, ObjectLike, }; @@ -126,9 +126,6 @@ pub use crate::value::{Nil, Value}; #[cfg(not(feature = "luau"))] pub use crate::hook::HookTriggers; -#[cfg(feature = "luau")] -pub use crate::thread::LuauContinuationStatus; - #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub use crate::{ diff --git a/src/prelude.rs b/src/prelude.rs index 5becef07..d2324c22 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -3,15 +3,15 @@ #[doc(no_inline)] pub use crate::{ AnyUserData as LuaAnyUserData, BorrowedBytes as LuaBorrowedBytes, BorrowedStr as LuaBorrowedStr, - Chunk as LuaChunk, Either as LuaEither, Error as LuaError, ErrorContext as LuaErrorContext, - ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, FromLua, FromLuaMulti, - Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, Integer as LuaInteger, - IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, LuaNativeFnMut, LuaOptions, - MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, - ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, - String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, - Thread as LuaThread, ThreadStatus as LuaThreadStatus, UserData as LuaUserData, - UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, + Chunk as LuaChunk, ContinuationStatus as LuaContinuationStatus, Either as LuaEither, Error as LuaError, + ErrorContext as LuaErrorContext, ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, + FromLua, FromLuaMulti, Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, + Integer as LuaInteger, IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, + LuaNativeFnMut, LuaOptions, MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, + Number as LuaNumber, ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, + StdLib as LuaStdLib, String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, + TableSequence as LuaTableSequence, Thread as LuaThread, ThreadStatus as LuaThreadStatus, + UserData as LuaUserData, UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, Variadic as LuaVariadic, VmState as LuaVmState, WeakLua, @@ -28,9 +28,6 @@ pub use crate::{ NavigateError as LuaNavigateError, Require as LuaRequire, Vector as LuaVector, }; -#[cfg(feature = "luau")] -pub use crate::LuauContinuationStatus; - #[cfg(feature = "async")] #[doc(no_inline)] pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn}; diff --git a/src/state.rs b/src/state.rs index dacf62b9..bb3644e0 100644 --- a/src/state.rs +++ b/src/state.rs @@ -20,7 +20,7 @@ use crate::table::Table; use crate::thread::Thread; #[cfg(feature = "luau")] -use crate::thread::LuauContinuationStatus; +use crate::thread::ContinuationStatus; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ @@ -1295,8 +1295,7 @@ impl Lua { /// Same as ``create_function`` but with support for luau-style continuations. /// /// Other Lua versions have completely different continuation semantics and are both - /// not supported at this time, much more complicated to support within mlua and cannot - /// be supported with this API in any case. + /// not supported at this time. /// /// The values passed to the continuation will be the yielded arguments /// from the function for the initial continuation call. If yielding from a continuation, @@ -1308,19 +1307,24 @@ impl Lua { /// arguments will then be returned as the final return value of the Luau function call. /// Values returned in a function in which there is also yielding will be ignored #[cfg(feature = "luau")] - pub fn create_function_with_luau_continuation( + pub fn create_function_with_continuation( &self, func: F, cont: FC, ) -> Result where F: Fn(&Lua, A) -> Result + MaybeSend + 'static, - FC: Fn(&Lua, LuauContinuationStatus, AC) -> Result + MaybeSend + 'static, + FC: Fn(&Lua, ContinuationStatus, AC) -> Result + MaybeSend + 'static, A: FromLuaMulti, AC: FromLuaMulti, R: IntoLuaMulti, RC: IntoLuaMulti, { + // On luau, use a callback with luau continuation + // + // For other lua versions (in future), this will instead + // make a wrapper function that at the end calls lua_yieldk + #[cfg(feature = "luau")] (self.lock()).create_callback_with_luau_continuation( Box::new(move |rawlua, nargs| unsafe { let args = A::from_stack_args(nargs, 1, None, rawlua)?; @@ -1328,7 +1332,7 @@ impl Lua { }), Box::new(move |rawlua, nargs, status| unsafe { let args = AC::from_stack_args(nargs, 1, None, rawlua)?; - let status = LuauContinuationStatus::from_status(status); + let status = ContinuationStatus::from_status(status); cont(rawlua.lua(), status, args)?.push_into_stack_multi(rawlua) }), ) diff --git a/src/thread.rs b/src/thread.rs index 6b834081..3eb9f9f7 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -26,17 +26,15 @@ use { }, }; -/// Luau continuation final status -#[cfg(feature = "luau")] +/// Continuation thread status. Can either be Ok, Yielded (rare, but can happen) or Error #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum LuauContinuationStatus { +pub enum ContinuationStatus { Ok, Yielded, Error, } -#[cfg(feature = "luau")] -impl LuauContinuationStatus { +impl ContinuationStatus { pub(crate) fn from_status(status: c_int) -> Self { match status { ffi::LUA_YIELD => Self::Yielded, diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs index 58ffdd2c..a22987a7 100644 --- a/tests/luau/cont.rs +++ b/tests/luau/cont.rs @@ -7,7 +7,7 @@ fn test_luau_continuation() { let lua = Lua::new(); // No yielding continuation fflag test let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, a: u64| { match lua.set_yield_args(a) { Ok(()) => println!("set_yield_args called"), @@ -46,7 +46,7 @@ fn test_luau_continuation() { // empty yield args test let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, _: ()| { match lua.set_yield_args(()) { Ok(()) => println!("set_yield_args called"), @@ -84,7 +84,7 @@ fn test_luau_continuation() { mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |_lua, a: u64| Ok(a + 1), |_lua, _status, a: u64| { println!("Reached cont"); @@ -117,7 +117,7 @@ fn test_luau_continuation() { // Trigger the continuation let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, a: u64| { match lua.set_yield_args(a) { Ok(()) => println!("set_yield_args called"), @@ -155,7 +155,7 @@ fn test_luau_continuation() { assert_eq!(v, 41); let always_yield = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, ()| { lua.set_yield_args((42, "69420".to_string(), 45.6))?; Ok(()) @@ -179,7 +179,7 @@ fn test_luau_continuation() { .starts_with("a3")); let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, a: u64| { match lua.set_yield_args((a + 1, 1)) { Ok(()) => println!("set_yield_args called"), @@ -191,7 +191,7 @@ fn test_luau_continuation() { println!("Reached cont recursive: {:?}", args); if args.len() == 5 { - assert_eq!(status, mlua::LuauContinuationStatus::Ok); + assert_eq!(status, mlua::ContinuationStatus::Ok); return 6_i32.into_lua_multi(lua); } @@ -241,7 +241,7 @@ fn test_luau_continuation() { // test panics let cont_func = lua - .create_function_with_luau_continuation( + .create_function_with_continuation( |lua, a: u64| { match lua.set_yield_args(a) { Ok(()) => println!("set_yield_args called"), From 10d2df40f523663978b6b4b67f74469226650755 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 21:46:58 -0400 Subject: [PATCH 18/32] add support for continuations outside luau --- mlua-sys/src/lua52/lua.rs | 5 + mlua-sys/src/lua53/lua.rs | 5 + mlua-sys/src/lua54/lua.rs | 5 + src/state.rs | 28 ++-- src/state/extra.rs | 6 + src/state/raw.rs | 145 +++++++++++++------ src/state/util.rs | 69 +++++++++ src/thread.rs | 1 + src/types.rs | 12 +- src/util/types.rs | 12 +- tests/luau.rs | 3 - tests/luau/cont.rs | 282 ------------------------------------- tests/thread.rs | 286 ++++++++++++++++++++++++++++++++++++++ 13 files changed, 499 insertions(+), 360 deletions(-) delete mode 100644 tests/luau/cont.rs diff --git a/mlua-sys/src/lua52/lua.rs b/mlua-sys/src/lua52/lua.rs index e5239cee..8de7cdee 100644 --- a/mlua-sys/src/lua52/lua.rs +++ b/mlua-sys/src/lua52/lua.rs @@ -272,6 +272,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_CFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Garbage-collection function and options // diff --git a/mlua-sys/src/lua53/lua.rs b/mlua-sys/src/lua53/lua.rs index 2729fdcd..c3d82e63 100644 --- a/mlua-sys/src/lua53/lua.rs +++ b/mlua-sys/src/lua53/lua.rs @@ -286,6 +286,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_KFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Garbage-collection function and options // diff --git a/mlua-sys/src/lua54/lua.rs b/mlua-sys/src/lua54/lua.rs index 15a30444..c74e1576 100644 --- a/mlua-sys/src/lua54/lua.rs +++ b/mlua-sys/src/lua54/lua.rs @@ -299,6 +299,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_KFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Warning-related functions // diff --git a/src/state.rs b/src/state.rs index bb3644e0..df001893 100644 --- a/src/state.rs +++ b/src/state.rs @@ -19,7 +19,7 @@ use crate::string::String; use crate::table::Table; use crate::thread::Thread; -#[cfg(feature = "luau")] +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] use crate::thread::ContinuationStatus; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; @@ -1292,21 +1292,20 @@ impl Lua { })) } - /// Same as ``create_function`` but with support for luau-style continuations. - /// - /// Other Lua versions have completely different continuation semantics and are both - /// not supported at this time. + /// Same as ``create_function`` but with an added continuation function. /// /// The values passed to the continuation will be the yielded arguments - /// from the function for the initial continuation call. If yielding from a continuation, - /// the yielded results will be returned to the ``Thread::resume`` caller. The arguments - /// passed in the next ``Thread::resume`` call will then be the arguments passed to the yielding - /// continuation upon resumption. + /// from the function for the initial continuation call. On Luau, if yielding from a + /// continuation, the yielded results will be returned to the ``Thread::resume`` caller. The + /// arguments passed in the next ``Thread::resume`` call will then be the arguments passed + /// to the yielding continuation upon resumption. /// /// Returning a value from a continuation without setting yield - /// arguments will then be returned as the final return value of the Luau function call. + /// arguments will then be returned as the final return value of the Lua function call. /// Values returned in a function in which there is also yielding will be ignored - #[cfg(feature = "luau")] + /// + /// Note that yielding in continuations is only supported on Luau + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] pub fn create_function_with_continuation( &self, func: F, @@ -1320,12 +1319,7 @@ impl Lua { R: IntoLuaMulti, RC: IntoLuaMulti, { - // On luau, use a callback with luau continuation - // - // For other lua versions (in future), this will instead - // make a wrapper function that at the end calls lua_yieldk - #[cfg(feature = "luau")] - (self.lock()).create_callback_with_luau_continuation( + (self.lock()).create_callback_with_continuation( Box::new(move |rawlua, nargs| unsafe { let args = A::from_stack_args(nargs, 1, None, rawlua)?; func(rawlua.lua(), args)?.push_into_stack_multi(rawlua) diff --git a/src/state/extra.rs b/src/state/extra.rs index 7c18e9b8..1c49a064 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -13,6 +13,7 @@ use crate::error::Result; use crate::state::RawLua; use crate::stdlib::StdLib; use crate::types::{AppData, ReentrantMutex, XRc}; + use crate::userdata::RawUserDataRegistry; use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, WrappedFailure}; @@ -97,6 +98,9 @@ pub(crate) struct ExtraData { // Values currently being yielded from Lua.yield() pub(super) yielded_values: Option, + + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + pub(super) yield_continuation: bool, } impl Drop for ExtraData { @@ -199,6 +203,8 @@ impl ExtraData { #[cfg(feature = "luau")] running_gc: false, yielded_values: None, + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + yield_continuation: false, })); // Store it in the registry diff --git a/src/state/raw.rs b/src/state/raw.rs index 897e693a..298e88e7 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,7 +12,9 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; +#[cfg(not(feature = "luau"))] +use crate::state::util::callback_error_ext; +use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; @@ -23,8 +25,10 @@ use crate::types::{ MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; -#[cfg(feature = "luau")] -use crate::types::{LuauContinuation, LuauContinuationUpvalue}; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::Continuation; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::ContinuationUpvalue; use crate::userdata::{ init_userdata_metatable, AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, @@ -202,8 +206,8 @@ impl RawLua { init_internal_metatable::>>(state, None)?; init_internal_metatable::(state, None)?; init_internal_metatable::(state, None)?; - #[cfg(feature = "luau")] - init_internal_metatable::(state, None)?; + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + init_internal_metatable::(state, None)?; #[cfg(not(feature = "luau"))] init_internal_metatable::(state, None)?; #[cfg(feature = "async")] @@ -1195,56 +1199,105 @@ impl RawLua { } // Creates a Function out of a Callback and a continuation containing a 'static Fn. - #[cfg(feature = "luau")] - pub(crate) fn create_callback_with_luau_continuation( + // + // In Luau, uses pushcclosurek + // + // In Lua 5.2/5.3/5.4/JIT, makes a normal function that then yields to the continuation via yieldk + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + pub(crate) fn create_callback_with_continuation( &self, func: Callback, - cont: LuauContinuation, + cont: Continuation, ) -> Result { - unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { - let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => (func.0)(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) - } + #[cfg(feature = "luau")] + { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.0)(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }) + } - unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { - let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => (func.1)(rawlua, nargs, status), - None => Err(Error::CallbackDestructed), - } - }) - } + unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }) + } - let state = self.state(); - unsafe { - let _sg = StackGuard::new(state); - check_stack(state, 4)?; + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; - let func = Some((func, cont)); - let extra = XRc::clone(&self.extra); - let protect = !self.unlikely_memory_error(); - push_internal_userdata(state, LuauContinuationUpvalue { data: func, extra }, protect)?; - if protect { - protect_lua!(state, 1, 1, fn(state) { + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, ContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + })?; + } else { ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); - })?; - } else { - ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + } + + Ok(Function(self.pop_ref())) } + } - Ok(Function(self.pop_ref())) + #[cfg(not(feature = "luau"))] + { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some((ref func, _)) => match func(rawlua, nargs) { + Ok(r) => { + (*extra).yield_continuation = true; + Ok(r) + } + Err(e) => Err(e), + }, + None => Err(Error::CallbackDestructed), + } + }) + } + + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, ContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosure(state, call_callback, 1); + })?; + } else { + ffi::lua_pushcclosure(state, call_callback, 1); + } + + Ok(Function(self.pop_ref())) + } } } diff --git a/src/state/util.rs b/src/state/util.rs index ff39e70b..dd906d56 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -9,6 +9,9 @@ use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; use crate::util::{self, check_stack, get_internal_metatable, WrappedFailure}; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit"), not(feature = "luau")))] +use crate::{types::ContinuationUpvalue, util::get_userdata}; + struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); impl<'a> StateGuard<'a> { @@ -112,6 +115,8 @@ where Ok(Ok(r)) => { // Ensure yielded values are cleared take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + take(&mut extra.as_mut().unwrap_unchecked().yield_continuation); // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); @@ -188,6 +193,9 @@ where let raw = extra.as_ref().unwrap_unchecked().raw_lua(); let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + let yield_cont = take(&mut extra.as_mut().unwrap_unchecked().yield_continuation); + if let Some(values) = values { if raw.state() == state { // Edge case: main thread is being yielded @@ -217,6 +225,67 @@ where } ffi::lua_xmove(raw.state(), state, nargs); } + + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + { + // Yield to a continuation. Unlike luau, we need to do this manually and on the + // fly using a yieldk call + if yield_cont { + // On Lua 5.2, status and ctx are not present, so use 0 as status for + // compatibility + #[cfg(feature = "lua52")] + unsafe extern "C-unwind" fn cont_callback( + state: *mut ffi::lua_State, + ) -> c_int { + let upvalue = + get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available + // (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, 0), + None => Err(Error::CallbackDestructed), + } + }, + ) + } + + // Lua 5.3/5.4 case + #[cfg(not(feature = "lua52"))] + unsafe extern "C-unwind" fn cont_callback( + state: *mut ffi::lua_State, + status: c_int, + _ctx: ffi::lua_KContext, + ) -> c_int { + let upvalue = + get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available + // (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }, + ) + } + + return ffi::lua_yieldc(state, nargs, cont_callback); + } + } + return ffi::lua_yield(state, nargs); } Err(err) => { diff --git a/src/thread.rs b/src/thread.rs index 3eb9f9f7..ce096f06 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -35,6 +35,7 @@ pub enum ContinuationStatus { } impl ContinuationStatus { + #[allow(dead_code)] pub(crate) fn from_status(status: c_int) -> Self { match status { ffi::LUA_YIELD => Self::Yielded, diff --git a/src/types.rs b/src/types.rs index 3b6b0d8f..45806247 100644 --- a/src/types.rs +++ b/src/types.rs @@ -39,11 +39,11 @@ pub(crate) type Callback = Box Result + Send + #[cfg(not(feature = "send"))] pub(crate) type Callback = Box Result + 'static>; +#[cfg(all(feature = "send", not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type Continuation = Box Result + Send + 'static>; -#[cfg(all(feature = "send", feature = "luau"))] -pub(crate) type LuauContinuation = Box Result + Send + 'static>; -#[cfg(all(not(feature = "send"), feature = "luau"))] -pub(crate) type LuauContinuation = Box Result + 'static>; +#[cfg(all(not(feature = "send"), not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type Continuation = Box Result + 'static>; pub(crate) type ScopedCallback<'s> = Box Result + 's>; @@ -53,8 +53,8 @@ pub(crate) struct Upvalue { } pub(crate) type CallbackUpvalue = Upvalue>; -#[cfg(feature = "luau")] -pub(crate) type LuauContinuationUpvalue = Upvalue>; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type ContinuationUpvalue = Upvalue>; #[cfg(all(feature = "async", feature = "send"))] pub(crate) type AsyncCallback = diff --git a/src/util/types.rs b/src/util/types.rs index 16945009..8627042f 100644 --- a/src/util/types.rs +++ b/src/util/types.rs @@ -3,8 +3,8 @@ use std::os::raw::c_void; use crate::types::{Callback, CallbackUpvalue}; -#[cfg(feature = "luau")] -use crate::types::LuauContinuationUpvalue; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::ContinuationUpvalue; #[cfg(feature = "async")] use crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}; @@ -37,12 +37,12 @@ impl TypeKey for CallbackUpvalue { } } -#[cfg(feature = "luau")] -impl TypeKey for LuauContinuationUpvalue { +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +impl TypeKey for ContinuationUpvalue { #[inline(always)] fn type_key() -> *const c_void { - static LUAU_CONTINUATION_UPVALUE_TYPE_KEY: u8 = 0; - &LUAU_CONTINUATION_UPVALUE_TYPE_KEY as *const u8 as *const c_void + static CONTINUATION_UPVALUE_TYPE_KEY: u8 = 0; + &CONTINUATION_UPVALUE_TYPE_KEY as *const u8 as *const c_void } } diff --git a/tests/luau.rs b/tests/luau.rs index 7b62bae3..20fbee6a 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -449,6 +449,3 @@ fn test_typeof_error() -> Result<()> { #[path = "luau/require.rs"] mod require; - -#[path = "luau/cont.rs"] -mod cont; diff --git a/tests/luau/cont.rs b/tests/luau/cont.rs deleted file mode 100644 index a22987a7..00000000 --- a/tests/luau/cont.rs +++ /dev/null @@ -1,282 +0,0 @@ -use mlua::IntoLuaMulti; -#[cfg(feature = "luau")] -use mlua::Lua; - -#[test] -fn test_luau_continuation() { - let lua = Lua::new(); - // No yielding continuation fflag test - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 39) - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res + 1 - ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 41); - - // empty yield args test - let cont_func = lua - .create_function_with_continuation( - |lua, _: ()| { - match lua.set_yield_args(()) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res - 1 - ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - assert!(v.is_empty()); - let v = th.resume::(v).expect("Failed to load continuation"); - assert_eq!(v, -1); - - // Yielding continuation fflag test - mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); - - let cont_func = lua - .create_function_with_continuation( - |_lua, a: u64| Ok(a + 1), - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 2) - }, - ) - .expect("Failed to create cont_func"); - - // Ensure normal calls work still - assert_eq!( - lua.load("local cont_func = ...\nreturn cont_func(1)") - .call::(cont_func) - .expect("Failed to call cont_func"), - 2 - ); - - // basic yield test before we go any further - let always_yield = lua - .create_function(|lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }) - .unwrap(); - - let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!( - thread.resume::<(i32, String, f32)>(()).unwrap(), - (42, String::from("69420"), 45.6) - ); - - // Trigger the continuation - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 39) - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res + 1 - ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 41); - - let always_yield = lua - .create_function_with_continuation( - |lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }, - |_lua, _, mv: mlua::MultiValue| { - println!("Reached second continuation"); - if mv.is_empty() { - return Ok(mv); - } - Err(mlua::Error::external(format!("a{}", mv.len()))) - }, - ) - .unwrap(); - - let thread = lua.create_thread(always_yield).unwrap(); - let mv = thread.resume::(()).unwrap(); - assert!(thread - .resume::(mv) - .unwrap_err() - .to_string() - .starts_with("a3")); - - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args((a + 1, 1)) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } - Ok(()) - }, - |lua, status, args: mlua::MultiValue| { - println!("Reached cont recursive: {:?}", args); - - if args.len() == 5 { - assert_eq!(status, mlua::ContinuationStatus::Ok); - return 6_i32.into_lua_multi(lua); - } - - lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) - (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local res = cont_func(1) - return res + 1 - ", - ) - .into_function() - .expect("Failed to create function"); - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - println!("v={:?}", v); - - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - - // (2, 1) followed by () - assert_eq!(v.len(), 2 + 3); - - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 7); - - // test panics - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, _a: u64| { - panic!("Reached continuation which should panic!"); - #[allow(unreachable_code)] - Ok(()) - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " - local cont_func = ... - local ok, res = pcall(cont_func, 1) - assert(not ok) - return tostring(res) - ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - - let v = th.resume::(v).expect("Failed to load continuation"); - assert!(v.contains("Reached continuation which should panic!")); -} diff --git a/tests/thread.rs b/tests/thread.rs index e6825ff3..a51cf851 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -2,6 +2,9 @@ use std::panic::catch_unwind; use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; +#[cfg(feature = "luau")] +use mlua::IntoLuaMulti; + #[test] fn test_thread() -> Result<()> { let lua = Lua::new(); @@ -269,3 +272,286 @@ fn test_thread_yield_args() -> Result<()> { Ok(()) } + +#[test] +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +fn test_continuation() { + let lua = Lua::new(); + // No yielding continuation fflag test + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + // empty yield args test + let cont_func = lua + .create_function_with_continuation( + |lua, _: ()| { + match lua.set_yield_args(()) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res - 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + assert!(v.is_empty()); + let v = th.resume::(v).expect("Failed to load continuation"); + assert_eq!(v, -1); + + // Yielding continuation test (only supported on luau) + #[cfg(feature = "luau")] + { + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + + let cont_func = lua + .create_function_with_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); + + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func) + .expect("Failed to call cont_func"), + 2 + ); + + // basic yield test before we go any further + let always_yield = lua + .create_function(|lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + }) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); + + // Trigger the continuation + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + let always_yield = lua + .create_function_with_continuation( + |lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + }, + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args((a + 1, 1)) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + } + Ok(()) + }, + |lua, status, args: mlua::MultiValue| { + println!("Reached cont recursive: {:?}", args); + + if args.len() == 5 { + assert_eq!(status, mlua::ContinuationStatus::Ok); + return 6_i32.into_lua_multi(lua); + } + + lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) + (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); + + // test panics + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local ok, res = pcall(cont_func, 1) + assert(not ok) + return tostring(res) + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); + } +} From 8400386fa476aa25c92a89bc1f93b71b57effad1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 21:52:04 -0400 Subject: [PATCH 19/32] fix --- tests/thread.rs | 64 ++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/thread.rs b/tests/thread.rs index a51cf851..5e1a199d 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -512,46 +512,46 @@ fn test_continuation() { let v = th.resume::(v).expect("Failed to load continuation"); assert_eq!(v, 7); + } - // test panics - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, _a: u64| { - panic!("Reached continuation which should panic!"); - #[allow(unreachable_code)] - Ok(()) - }, - ) - .expect("Failed to create cont_func"); + // test panics + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); - let luau_func = lua - .load( - " + let luau_func = lua + .load( + " local cont_func = ... local ok, res = pcall(cont_func, 1) assert(not ok) return tostring(res) ", - ) - .into_function() - .expect("Failed to create function"); + ) + .into_function() + .expect("Failed to create function"); - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); - let v = th - .resume::(cont_func) - .expect("Failed to resume"); + let v = th + .resume::(cont_func) + .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); - assert!(v.contains("Reached continuation which should panic!")); - } + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); } From 5d86852ff99580dbc05a6253982243a0afe5bc53 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 21:57:23 -0400 Subject: [PATCH 20/32] fix --- src/state/raw.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index 298e88e7..a4eb8ccf 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,7 +12,7 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -#[cfg(not(feature = "luau"))] +#[cfg(any(not(feature = "luau"), feature = "luau-vector4"))] use crate::state::util::callback_error_ext; use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; From 9e77e5c556fea8675cfa3a8dbb40346f226b7a77 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 21:59:42 -0400 Subject: [PATCH 21/32] fix --- src/state/raw.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index a4eb8ccf..c49d7dfa 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,9 +12,7 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -#[cfg(any(not(feature = "luau"), feature = "luau-vector4"))] -use crate::state::util::callback_error_ext; -use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; +use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; From eb637554b9ccae59aed2593b94676d8254989361 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 22:13:45 -0400 Subject: [PATCH 22/32] fix yieldable continuation on non-luau --- src/state/extra.rs | 5 - src/state/raw.rs | 55 +++++---- src/state/util.rs | 10 +- tests/thread.rs | 285 +++++++++++++++++++++++---------------------- 4 files changed, 179 insertions(+), 176 deletions(-) diff --git a/src/state/extra.rs b/src/state/extra.rs index 1c49a064..fbb21c66 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -98,9 +98,6 @@ pub(crate) struct ExtraData { // Values currently being yielded from Lua.yield() pub(super) yielded_values: Option, - - #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] - pub(super) yield_continuation: bool, } impl Drop for ExtraData { @@ -203,8 +200,6 @@ impl ExtraData { #[cfg(feature = "luau")] running_gc: false, yielded_values: None, - #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] - yield_continuation: false, })); // Store it in the registry diff --git a/src/state/raw.rs b/src/state/raw.rs index c49d7dfa..bca9d1a1 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1164,15 +1164,21 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => func(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + false, + ) } let state = self.state(); @@ -1260,21 +1266,22 @@ impl RawLua { { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some((ref func, _)) => match func(rawlua, nargs) { - Ok(r) => { - (*extra).yield_continuation = true; - Ok(r) - } - Err(e) => Err(e), - }, - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some((ref func, _)) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) } let state = self.state(); diff --git a/src/state/util.rs b/src/state/util.rs index dd906d56..5c1bb4b6 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -115,8 +115,6 @@ where Ok(Ok(r)) => { // Ensure yielded values are cleared take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] - take(&mut extra.as_mut().unwrap_unchecked().yield_continuation); // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); @@ -170,6 +168,7 @@ pub(crate) unsafe fn callback_error_ext_yieldable( mut extra: *mut ExtraData, wrap_error: bool, f: F, + in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, @@ -193,9 +192,6 @@ where let raw = extra.as_ref().unwrap_unchecked().raw_lua(); let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] - let yield_cont = take(&mut extra.as_mut().unwrap_unchecked().yield_continuation); - if let Some(values) = values { if raw.state() == state { // Edge case: main thread is being yielded @@ -230,7 +226,7 @@ where { // Yield to a continuation. Unlike luau, we need to do this manually and on the // fly using a yieldk call - if yield_cont { + if in_callback_with_continuation { // On Lua 5.2, status and ctx are not present, so use 0 as status for // compatibility #[cfg(feature = "lua52")] @@ -253,6 +249,7 @@ where None => Err(Error::CallbackDestructed), } }, + true, ) } @@ -279,6 +276,7 @@ where None => Err(Error::CallbackDestructed), } }, + true, ) } diff --git a/tests/thread.rs b/tests/thread.rs index 5e1a199d..ff346aef 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -2,7 +2,6 @@ use std::panic::catch_unwind; use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; -#[cfg(feature = "luau")] use mlua::IntoLuaMulti; #[test] @@ -356,163 +355,167 @@ fn test_continuation() { #[cfg(feature = "luau")] { mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + } - let cont_func = lua - .create_function_with_continuation( - |_lua, a: u64| Ok(a + 1), - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 2) - }, - ) - .expect("Failed to create cont_func"); - - // Ensure normal calls work still - assert_eq!( - lua.load("local cont_func = ...\nreturn cont_func(1)") - .call::(cont_func) - .expect("Failed to call cont_func"), - 2 - ); + let cont_func = lua + .create_function_with_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); - // basic yield test before we go any further - let always_yield = lua - .create_function(|lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }) - .unwrap(); + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func) + .expect("Failed to call cont_func"), + 2 + ); - let thread = lua.create_thread(always_yield).unwrap(); - assert_eq!( - thread.resume::<(i32, String, f32)>(()).unwrap(), - (42, String::from("69420"), 45.6) - ); + // basic yield test before we go any further + let always_yield = lua + .create_function(|lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + }) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); + + // Trigger the continuation + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args(a) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + }; + Ok(()) + }, + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); - // Trigger the continuation - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, - |_lua, _status, a: u64| { - println!("Reached cont"); - Ok(a + 39) - }, - ) - .expect("Failed to create cont_func"); - - let luau_func = lua - .load( - " + let luau_func = lua + .load( + " local cont_func = ... local res = cont_func(1) return res + 1 ", - ) - .into_function() - .expect("Failed to create function"); - - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 41); - - let always_yield = lua - .create_function_with_continuation( - |lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }, - |_lua, _, mv: mlua::MultiValue| { - println!("Reached second continuation"); - if mv.is_empty() { - return Ok(mv); - } - Err(mlua::Error::external(format!("a{}", mv.len()))) - }, - ) - .unwrap(); - - let thread = lua.create_thread(always_yield).unwrap(); - let mv = thread.resume::(()).unwrap(); - assert!(thread - .resume::(mv) - .unwrap_err() - .to_string() - .starts_with("a3")); - - let cont_func = lua - .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args((a + 1, 1)) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } - Ok(()) - }, - |lua, status, args: mlua::MultiValue| { - println!("Reached cont recursive: {:?}", args); + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); - if args.len() == 5 { + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + let always_yield = lua + .create_function_with_continuation( + |lua, ()| { + lua.set_yield_args((42, "69420".to_string(), 45.6))?; + Ok(()) + }, + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| { + match lua.set_yield_args((a + 1, 1)) { + Ok(()) => println!("set_yield_args called"), + Err(e) => println!("{:?}", e), + } + Ok(()) + }, + |lua, status, args: mlua::MultiValue| { + println!("Reached cont recursive: {:?}", args); + + if args.len() == 5 { + if cfg!(any(feature = "luau", feature = "lua52")) { assert_eq!(status, mlua::ContinuationStatus::Ok); - return 6_i32.into_lua_multi(lua); + } else { + assert_eq!(status, mlua::ContinuationStatus::Yielded); } + return 6_i32.into_lua_multi(lua); + } - lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) - (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored - }, - ) - .expect("Failed to create cont_func"); + lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) + (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored + }, + ) + .expect("Failed to create cont_func"); - let luau_func = lua - .load( - " + let luau_func = lua + .load( + " local cont_func = ... local res = cont_func(1) return res + 1 ", - ) - .into_function() - .expect("Failed to create function"); - let th = lua - .create_thread(luau_func) - .expect("Failed to create luau thread"); - - let v = th - .resume::(cont_func) - .expect("Failed to resume"); - println!("v={:?}", v); - - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - println!("v={:?}", v); - let v = th - .resume::(v) - .expect("Failed to load continuation"); - - // (2, 1) followed by () - assert_eq!(v.len(), 2 + 3); - - let v = th.resume::(v).expect("Failed to load continuation"); - - assert_eq!(v, 7); - } + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); // test panics let cont_func = lua From 9def001b4d227c187ede5adabf591d4b2495cdd9 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 22:21:34 -0400 Subject: [PATCH 23/32] fix luau compiler bug --- src/state/raw.rs | 50 ++++++++++++++++++++++++++++++----------------- src/state/util.rs | 3 ++- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/state/raw.rs b/src/state/raw.rs index bca9d1a1..98f7e55f 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -1217,28 +1217,42 @@ impl RawLua { { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => (func.0)(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.0)(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) } unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext_yieldable(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => (func.1)(rawlua, nargs, status), - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) } let state = self.state(); diff --git a/src/state/util.rs b/src/state/util.rs index 5c1bb4b6..d40c6978 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -168,7 +168,8 @@ pub(crate) unsafe fn callback_error_ext_yieldable( mut extra: *mut ExtraData, wrap_error: bool, f: F, - in_callback_with_continuation: bool, + #[cfg(feature = "luau")] _in_callback_with_continuation: bool, + #[cfg(not(feature = "luau"))] in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, From e985f8312c216718225807d2690b7949b1fd8019 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 22:35:20 -0400 Subject: [PATCH 24/32] add note on why pop --- src/state/util.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/state/util.rs b/src/state/util.rs index d40c6978..00688072 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -194,6 +194,16 @@ where let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); if let Some(values) = values { + // A note on Luau + // + // When using the yieldable continuations fflag (and in future when the fflag gets removed and + // yieldable continuations) becomes default, we must either pop the top of the + // stack on the state we are resuming or somehow store the number of + // args on top of stack pre-yield and then subtract in the resume in order to get predictable + // behaviour here. See https://github.com/luau-lang/luau/issues/1867 for more information + // + // In this case, popping is easier and leads to less bugs/more ergonomic API. + if raw.state() == state { // Edge case: main thread is being yielded // From 3dcb1315409f3937e5244e091d28669801cfa57a Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 5 Jun 2025 23:49:09 -0400 Subject: [PATCH 25/32] fix docs which incorrectly state only luau supports yieldable continuation --- src/state.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/state.rs b/src/state.rs index df001893..8a856509 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1295,7 +1295,7 @@ impl Lua { /// Same as ``create_function`` but with an added continuation function. /// /// The values passed to the continuation will be the yielded arguments - /// from the function for the initial continuation call. On Luau, if yielding from a + /// from the function for the initial continuation call. If yielding from a /// continuation, the yielded results will be returned to the ``Thread::resume`` caller. The /// arguments passed in the next ``Thread::resume`` call will then be the arguments passed /// to the yielding continuation upon resumption. @@ -1303,8 +1303,6 @@ impl Lua { /// Returning a value from a continuation without setting yield /// arguments will then be returned as the final return value of the Lua function call. /// Values returned in a function in which there is also yielding will be ignored - /// - /// Note that yielding in continuations is only supported on Luau #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] pub fn create_function_with_continuation( &self, From 3a079fe825f2e6668557fca0f0cb22c7f8e27c93 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:20:35 -0400 Subject: [PATCH 26/32] make the api more ergonomic --- src/error.rs | 6 +++++ src/state.rs | 31 ++++++---------------- src/state/extra.rs | 5 ---- src/state/util.rs | 23 +++++++--------- tests/thread.rs | 65 +++++++++------------------------------------- 5 files changed, 35 insertions(+), 95 deletions(-) diff --git a/src/error.rs b/src/error.rs index 483171e5..cb95494b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,4 @@ +use crate::MultiValue; use std::error::Error as StdError; use std::fmt; use std::io::Error as IoError; @@ -205,6 +206,8 @@ pub enum Error { /// Underlying error. cause: Arc, }, + /// A special error variant that tells Rust to yield to Lua with the specified args + Yield(MultiValue), } /// A specialized `Result` type used by `mlua`'s API. @@ -322,6 +325,9 @@ impl fmt::Display for Error { writeln!(fmt, "{context}")?; write!(fmt, "{cause}") }, + Error::Yield(_) => { + write!(fmt, "attempt to yield within a context that does not support yielding") + } } } } diff --git a/src/state.rs b/src/state.rs index 8a856509..7a1a1fba 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2146,33 +2146,18 @@ impl Lua { &*self.raw.data_ptr() } - /// Sets the yields arguments. Note that ``Ok(())`` must be returned for the Rust function - /// to actually yield. Any values returned in a function in which there is also yielding may - /// be ignored. + /// Helper method to set the yield arguments, returning a Error::Yield. /// - /// This method is mostly useful with Luau continuations and Rust-Rust yields + /// This method is mostly useful with continuations and Rust-Rust yields /// due to the Rust/Lua boundary - /// - /// If this function cannot yield, it will raise a runtime error. - /// - /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield - /// or not until it reaches the Lua state. - /// - /// While this method *should be safe*, it is new and may have bugs lurking within. Use - /// with caution - pub fn set_yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { - let raw = self.lock(); - #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] - if !raw.is_yieldable() { - return Err(Error::runtime("cannot yield across Rust/Lua boundary.")); - } - unsafe { - raw.extra.get().as_mut().unwrap_unchecked().yielded_values = Some(args.into_lua_multi(self)?); - } - Ok(()) + pub fn yield_with(&self, args: impl IntoLuaMulti) -> Result<()> { + Err(Error::Yield(args.into_lua_multi(self)?)) } - /// Checks if Lua is currently allowed to yield. + /// Checks if Lua could be allowed to yield. + /// + /// Note that this method is not fool proof and is prone to false negatives + /// especially when continuations are involved #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] #[inline] pub fn is_yieldable(&self) -> bool { diff --git a/src/state/extra.rs b/src/state/extra.rs index fbb21c66..45f0e379 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -19,7 +19,6 @@ use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, Wrapp #[cfg(any(feature = "luau", doc))] use crate::chunk::Compiler; -use crate::MultiValue; #[cfg(feature = "async")] use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; @@ -95,9 +94,6 @@ pub(crate) struct ExtraData { pub(super) compiler: Option, #[cfg(feature = "luau-jit")] pub(super) enable_jit: bool, - - // Values currently being yielded from Lua.yield() - pub(super) yielded_values: Option, } impl Drop for ExtraData { @@ -199,7 +195,6 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, - yielded_values: None, })); // Store it in the registry diff --git a/src/state/util.rs b/src/state/util.rs index 00688072..f4cbde4f 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,5 +1,4 @@ use crate::IntoLuaMulti; -use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; @@ -113,9 +112,6 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - // Ensure yielded values are cleared - take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r @@ -163,13 +159,13 @@ where /// /// Unlike ``callback_error_ext``, this method requires a c_int return /// and not a generic R +#[allow(unused_variables)] pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, wrap_error: bool, f: F, - #[cfg(feature = "luau")] _in_callback_with_continuation: bool, - #[cfg(not(feature = "luau"))] in_callback_with_continuation: bool, + in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, @@ -190,10 +186,14 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - let raw = extra.as_ref().unwrap_unchecked().raw_lua(); - let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + // Return unused `WrappedFailure` to the pool + prealloc_failure.release(state, extra); + r + } + Ok(Err(err)) => { + if let Error::Yield(values) = err { + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); - if let Some(values) = values { // A note on Luau // // When using the yieldable continuations fflag (and in future when the fflag gets removed and @@ -311,11 +311,6 @@ where } } - // Return unused `WrappedFailure` to the pool - prealloc_failure.release(state, extra); - r - } - Ok(Err(err)) => { let wrapped_error = prealloc_failure.r#use(state, extra); if !wrap_error { diff --git a/tests/thread.rs b/tests/thread.rs index ff346aef..a5655af2 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -2,8 +2,6 @@ use std::panic::catch_unwind; use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; -use mlua::IntoLuaMulti; - #[test] fn test_thread() -> Result<()> { let lua = Lua::new(); @@ -258,10 +256,7 @@ fn test_thread_resume_error() -> Result<()> { #[test] fn test_thread_yield_args() -> Result<()> { let lua = Lua::new(); - let always_yield = lua.create_function(|lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - })?; + let always_yield = lua.create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)))?; let thread = lua.create_thread(always_yield)?; assert_eq!( @@ -279,13 +274,7 @@ fn test_continuation() { // No yielding continuation fflag test let cont_func = lua .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, + |lua, a: u64| lua.yield_with(a), |_lua, _status, a: u64| { println!("Reached cont"); Ok(a + 39) @@ -318,13 +307,7 @@ fn test_continuation() { // empty yield args test let cont_func = lua .create_function_with_continuation( - |lua, _: ()| { - match lua.set_yield_args(()) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, + |lua, _: ()| lua.yield_with(()), |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), ) .expect("Failed to create cont_func"); @@ -377,10 +360,7 @@ fn test_continuation() { // basic yield test before we go any further let always_yield = lua - .create_function(|lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }) + .create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6))) .unwrap(); let thread = lua.create_thread(always_yield).unwrap(); @@ -392,13 +372,7 @@ fn test_continuation() { // Trigger the continuation let cont_func = lua .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, + |lua, a: u64| lua.yield_with(a), |_lua, _status, a: u64| { println!("Reached cont"); Ok(a + 39) @@ -430,10 +404,7 @@ fn test_continuation() { let always_yield = lua .create_function_with_continuation( - |lua, ()| { - lua.set_yield_args((42, "69420".to_string(), 45.6))?; - Ok(()) - }, + |lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)), |_lua, _, mv: mlua::MultiValue| { println!("Reached second continuation"); if mv.is_empty() { @@ -454,15 +425,9 @@ fn test_continuation() { let cont_func = lua .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args((a + 1, 1)) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - } - Ok(()) - }, + |lua, a: u64| lua.yield_with((a + 1, 1)), |lua, status, args: mlua::MultiValue| { - println!("Reached cont recursive: {:?}", args); + println!("Reached cont recursive/multiple: {:?}", args); if args.len() == 5 { if cfg!(any(feature = "luau", feature = "lua52")) { @@ -470,11 +435,11 @@ fn test_continuation() { } else { assert_eq!(status, mlua::ContinuationStatus::Yielded); } - return 6_i32.into_lua_multi(lua); + return Ok(6_i32); } - lua.set_yield_args((args.len() + 1, args))?; // thread state becomes Integer(2), Integer(1), Integer(8), Integer(9), Integer(10) - (1, 2, 3, 4, 5).into_lua_multi(lua) // this value is ignored + lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 + unreachable!(); }, ) .expect("Failed to create cont_func"); @@ -520,13 +485,7 @@ fn test_continuation() { // test panics let cont_func = lua .create_function_with_continuation( - |lua, a: u64| { - match lua.set_yield_args(a) { - Ok(()) => println!("set_yield_args called"), - Err(e) => println!("{:?}", e), - }; - Ok(()) - }, + |lua, a: u64| lua.yield_with(a), |_lua, _status, _a: u64| { panic!("Reached continuation which should panic!"); #[allow(unreachable_code)] From 3c005d7a85e7ad4d6ed8d38c50046e7a8ee7e40d Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:27:30 -0400 Subject: [PATCH 27/32] rename set_yield_args to yield_with --- src/error.rs | 2 +- src/state.rs | 21 ++++++++++++++++++++- src/state/raw.rs | 4 +++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/error.rs b/src/error.rs index cb95494b..2f1030f8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -206,7 +206,7 @@ pub enum Error { /// Underlying error. cause: Arc, }, - /// A special error variant that tells Rust to yield to Lua with the specified args + /// A special error variant that tells Rust to yield to Lua with the specified arguments Yield(MultiValue), } diff --git a/src/state.rs b/src/state.rs index 7a1a1fba..1220fe0e 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2146,10 +2146,29 @@ impl Lua { &*self.raw.data_ptr() } - /// Helper method to set the yield arguments, returning a Error::Yield. + /// Helper method to set the yield arguments, returning a ``Error::Yield``. + /// + /// Internally, this method is equivalent to ``Err(Error::Yield(args.into_lua_multi(self)?))`` /// /// This method is mostly useful with continuations and Rust-Rust yields /// due to the Rust/Lua boundary + /// + /// Example: + /// + /// ```rust + /// fn test() -> mlua::Result<()> { + /// let lua = mlua::Lua::new(); + /// let always_yield = lua.create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)))?; + /// + /// let thread = lua.create_thread(always_yield)?; + /// assert_eq!( + /// thread.resume::<(i32, String, f32)>(())?, + /// (42, String::from("69420"), 45.6) + /// ); + /// + /// Ok(()) + /// } + /// ``` pub fn yield_with(&self, args: impl IntoLuaMulti) -> Result<()> { Err(Error::Yield(args.into_lua_multi(self)?)) } diff --git a/src/state/raw.rs b/src/state/raw.rs index 98f7e55f..26935a39 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,7 +12,9 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, callback_error_ext_yieldable, ref_stack_pop}; +#[allow(unused_imports)] +use crate::state::util::callback_error_ext; +use crate::state::util::{callback_error_ext_yieldable, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; From c7ddbe633c4096e3a73334352575c4b6cd0c9307 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:49:20 -0400 Subject: [PATCH 28/32] partially revert due to compiler error --- src/error.rs | 6 ------ src/state.rs | 10 ++++++---- src/state/extra.rs | 5 +++++ src/state/util.rs | 22 +++++++++++++--------- tests/thread.rs | 2 +- 5 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/error.rs b/src/error.rs index 2f1030f8..483171e5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,3 @@ -use crate::MultiValue; use std::error::Error as StdError; use std::fmt; use std::io::Error as IoError; @@ -206,8 +205,6 @@ pub enum Error { /// Underlying error. cause: Arc, }, - /// A special error variant that tells Rust to yield to Lua with the specified arguments - Yield(MultiValue), } /// A specialized `Result` type used by `mlua`'s API. @@ -325,9 +322,6 @@ impl fmt::Display for Error { writeln!(fmt, "{context}")?; write!(fmt, "{cause}") }, - Error::Yield(_) => { - write!(fmt, "attempt to yield within a context that does not support yielding") - } } } } diff --git a/src/state.rs b/src/state.rs index 1220fe0e..3685e399 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2146,9 +2146,7 @@ impl Lua { &*self.raw.data_ptr() } - /// Helper method to set the yield arguments, returning a ``Error::Yield``. - /// - /// Internally, this method is equivalent to ``Err(Error::Yield(args.into_lua_multi(self)?))`` + /// Set the yield arguments. Note that Lua will not yield until you return from the function /// /// This method is mostly useful with continuations and Rust-Rust yields /// due to the Rust/Lua boundary @@ -2170,7 +2168,11 @@ impl Lua { /// } /// ``` pub fn yield_with(&self, args: impl IntoLuaMulti) -> Result<()> { - Err(Error::Yield(args.into_lua_multi(self)?)) + let raw = self.lock(); + unsafe { + raw.extra.get().as_mut().unwrap_unchecked().yielded_values = Some(args.into_lua_multi(self)?); + } + Ok(()) } /// Checks if Lua could be allowed to yield. diff --git a/src/state/extra.rs b/src/state/extra.rs index 45f0e379..fbb21c66 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -19,6 +19,7 @@ use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, Wrapp #[cfg(any(feature = "luau", doc))] use crate::chunk::Compiler; +use crate::MultiValue; #[cfg(feature = "async")] use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; @@ -94,6 +95,9 @@ pub(crate) struct ExtraData { pub(super) compiler: Option, #[cfg(feature = "luau-jit")] pub(super) enable_jit: bool, + + // Values currently being yielded from Lua.yield() + pub(super) yielded_values: Option, } impl Drop for ExtraData { @@ -195,6 +199,7 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, + yielded_values: None, })); // Store it in the registry diff --git a/src/state/util.rs b/src/state/util.rs index f4cbde4f..67f9edaf 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,4 +1,5 @@ use crate::IntoLuaMulti; +use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; @@ -112,6 +113,8 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { + take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r @@ -159,13 +162,13 @@ where /// /// Unlike ``callback_error_ext``, this method requires a c_int return /// and not a generic R -#[allow(unused_variables)] pub(crate) unsafe fn callback_error_ext_yieldable( state: *mut ffi::lua_State, mut extra: *mut ExtraData, wrap_error: bool, f: F, - in_callback_with_continuation: bool, + #[cfg(feature = "luau")] _in_callback_with_continuation: bool, + #[cfg(not(feature = "luau"))] in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, @@ -186,14 +189,10 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - // Return unused `WrappedFailure` to the pool - prealloc_failure.release(state, extra); - r - } - Ok(Err(err)) => { - if let Error::Yield(values) = err { - let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + if let Some(values) = values { // A note on Luau // // When using the yieldable continuations fflag (and in future when the fflag gets removed and @@ -311,6 +310,11 @@ where } } + // Return unused `WrappedFailure` to the pool + prealloc_failure.release(state, extra); + r + } + Ok(Err(err)) => { let wrapped_error = prealloc_failure.r#use(state, extra); if !wrap_error { diff --git a/tests/thread.rs b/tests/thread.rs index a5655af2..9c8c1b65 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -439,7 +439,7 @@ fn test_continuation() { } lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 - unreachable!(); + Ok(1_i32) // this will be ignored }, ) .expect("Failed to create cont_func"); From b44075faff983a6022b8cba1037089b237c05b54 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:51:01 -0400 Subject: [PATCH 29/32] fix --- src/state/util.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index 67f9edaf..360b85d9 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -167,8 +167,7 @@ pub(crate) unsafe fn callback_error_ext_yieldable( mut extra: *mut ExtraData, wrap_error: bool, f: F, - #[cfg(feature = "luau")] _in_callback_with_continuation: bool, - #[cfg(not(feature = "luau"))] in_callback_with_continuation: bool, + #[allow(unused_variables)] in_callback_with_continuation: bool, ) -> c_int where F: FnOnce(*mut ExtraData, c_int) -> Result, From a2e8694b5297453239a6a5f5d293a9c793fffc47 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 03:51:21 -0400 Subject: [PATCH 30/32] fix --- src/state/util.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index 360b85d9..d3fa1c30 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -113,8 +113,6 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { - take(&mut extra.as_mut().unwrap_unchecked().yielded_values); - // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); r From 4a37dc5f0fabb37e4bbb719aa91518f46d1a0ee4 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 07:06:54 -0400 Subject: [PATCH 31/32] ensure wrappedfailure is returned to pool --- src/state.rs | 5 +---- src/state/util.rs | 34 ++++++++++++++++------------------ 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/state.rs b/src/state.rs index 3685e399..1df736e5 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2175,10 +2175,7 @@ impl Lua { Ok(()) } - /// Checks if Lua could be allowed to yield. - /// - /// Note that this method is not fool proof and is prone to false negatives - /// especially when continuations are involved + /// Checks if Lua is be allowed to yield. #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] #[inline] pub fn is_yieldable(&self) -> bool { diff --git a/src/state/util.rs b/src/state/util.rs index d3fa1c30..eade8870 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -186,6 +186,12 @@ where f(extra, nargs) })) { Ok(Ok(r)) => { + // Return unused `WrappedFailure` to the pool + // + // In either case, we cannot use it in the yield case anyways due to the lua_pop call + // so drop it properly now while we can. + prealloc_failure.release(state, extra); + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); @@ -216,14 +222,11 @@ where // Even outside of luau, clearing the stack is probably desirable ffi::lua_pop(state, -1); if let Err(err) = check_stack(state, nargs) { - let wrapped_error = prealloc_failure.r#use(state, extra); - ptr::write( - wrapped_error, - WrappedFailure::Error(Error::external(err.to_string())), - ); - get_internal_metatable::(state); - ffi::lua_setmetatable(state, -2); - + // Unfortunately, we can't do a wrapped error here + // due to the lua_pop call, so just push a CString + // manually instead (todo: look into a better way here) + let cs = std::ffi::CString::new(err.to_string()).unwrap_unchecked(); + ffi::lua_pushstring(state, cs.as_ptr()); ffi::lua_error(state) } ffi::lua_xmove(raw.state(), state, nargs); @@ -294,21 +297,16 @@ where return ffi::lua_yield(state, nargs); } Err(err) => { - let wrapped_error = prealloc_failure.r#use(state, extra); - ptr::write( - wrapped_error, - WrappedFailure::Error(Error::external(err.to_string())), - ); - get_internal_metatable::(state); - ffi::lua_setmetatable(state, -2); - + // Unfortunately, we can't do a wrapped error here + // due to the above lua_pop call, so just push a CString + // manually instead (todo: look into a better way here) + let cs = std::ffi::CString::new(err.to_string()).unwrap_unchecked(); + ffi::lua_pushstring(state, cs.as_ptr()); ffi::lua_error(state) } } } - // Return unused `WrappedFailure` to the pool - prealloc_failure.release(state, extra); r } Ok(Err(err)) => { From a5c1ea1511b12a935160c9613c43e80d202fc0cd Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 6 Jun 2025 07:35:31 -0400 Subject: [PATCH 32/32] fix wrap error case --- src/state/util.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/state/util.rs b/src/state/util.rs index eade8870..69edf40f 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -222,12 +222,13 @@ where // Even outside of luau, clearing the stack is probably desirable ffi::lua_pop(state, -1); if let Err(err) = check_stack(state, nargs) { - // Unfortunately, we can't do a wrapped error here - // due to the lua_pop call, so just push a CString - // manually instead (todo: look into a better way here) - let cs = std::ffi::CString::new(err.to_string()).unwrap_unchecked(); - ffi::lua_pushstring(state, cs.as_ptr()); - ffi::lua_error(state) + // Make a *new* preallocated failure, and then do normal error + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state); } ffi::lua_xmove(raw.state(), state, nargs); } @@ -297,12 +298,13 @@ where return ffi::lua_yield(state, nargs); } Err(err) => { - // Unfortunately, we can't do a wrapped error here - // due to the above lua_pop call, so just push a CString - // manually instead (todo: look into a better way here) - let cs = std::ffi::CString::new(err.to_string()).unwrap_unchecked(); - ffi::lua_pushstring(state, cs.as_ptr()); - ffi::lua_error(state) + // Make a *new* preallocated failure, and then do normal wrap_error + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state); } } }