diff --git a/mlua-sys/src/lua52/lua.rs b/mlua-sys/src/lua52/lua.rs index e5239cee..8de7cdee 100644 --- a/mlua-sys/src/lua52/lua.rs +++ b/mlua-sys/src/lua52/lua.rs @@ -272,6 +272,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_CFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Garbage-collection function and options // diff --git a/mlua-sys/src/lua53/lua.rs b/mlua-sys/src/lua53/lua.rs index 2729fdcd..c3d82e63 100644 --- a/mlua-sys/src/lua53/lua.rs +++ b/mlua-sys/src/lua53/lua.rs @@ -286,6 +286,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_KFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Garbage-collection function and options // diff --git a/mlua-sys/src/lua54/lua.rs b/mlua-sys/src/lua54/lua.rs index 15a30444..c74e1576 100644 --- a/mlua-sys/src/lua54/lua.rs +++ b/mlua-sys/src/lua54/lua.rs @@ -299,6 +299,11 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { lua_yieldk(L, n, 0, None) } +#[inline(always)] +pub unsafe fn lua_yieldc(L: *mut lua_State, n: c_int, k: lua_KFunction) -> c_int { + lua_yieldk(L, n, 0, Some(k)) +} + // // Warning-related functions // diff --git a/mlua-sys/src/luau/lua.rs b/mlua-sys/src/luau/lua.rs index 8a55eef1..d898534c 100644 --- a/mlua-sys/src/luau/lua.rs +++ b/mlua-sys/src/luau/lua.rs @@ -426,6 +426,11 @@ pub unsafe fn lua_pushcclosure(L: *mut lua_State, f: lua_CFunction, nup: c_int) lua_pushcclosurek(L, f, ptr::null(), nup, None) } +#[inline(always)] +pub unsafe fn lua_pushcclosurec(L: *mut lua_State, f: lua_CFunction, cont: lua_Continuation, nup: c_int) { + lua_pushcclosurek(L, f, ptr::null(), nup, Some(cont)) +} + #[inline(always)] pub unsafe fn lua_pushcclosured(L: *mut lua_State, f: lua_CFunction, debugname: *const c_char, nup: c_int) { lua_pushcclosurek(L, f, debugname, nup, None) diff --git a/src/buffer.rs b/src/buffer.rs index 181e3696..4989acef 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -61,7 +61,7 @@ impl Buffer { unsafe fn as_raw_parts(&self) -> (*mut u8, usize) { let lua = self.0.lua.lock(); let mut size = 0usize; - let buf = ffi::lua_tobuffer(lua.ref_thread(), self.0.index, &mut size); + let buf = ffi::lua_tobuffer(lua.ref_thread(self.0.aux_thread), self.0.index, &mut size); mlua_assert!(!buf.is_null(), "invalid Luau buffer"); (buf as *mut u8, size) } diff --git a/src/conversion.rs b/src/conversion.rs index b7dfa305..65f04982 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -12,6 +12,7 @@ use num_traits::cast; use crate::error::{Error, Result}; use crate::function::Function; +use crate::state::util::get_next_spot; use crate::state::{Lua, RawLua}; use crate::string::{BorrowedBytes, BorrowedStr, String}; use crate::table::Table; @@ -83,8 +84,12 @@ impl FromLua for String { let state = lua.state(); let type_id = ffi::lua_type(state, idx); if type_id == ffi::LUA_TSTRING { - ffi::lua_xpush(state, lua.ref_thread(), idx); - return Ok(String(lua.pop_ref_thread())); + let (aux_thread, idxs, replace) = get_next_spot(lua.extra()); + ffi::lua_xpush(state, lua.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(lua.ref_thread(aux_thread), idxs); + } + return Ok(String(lua.new_value_ref(aux_thread, idxs))); } // Fallback to default Self::from_lua(lua.stack_value(idx, Some(type_id)), lua.lua()) diff --git a/src/error.rs b/src/error.rs index c20a39ef..2d369109 100644 --- a/src/error.rs +++ b/src/error.rs @@ -321,7 +321,7 @@ impl fmt::Display for Error { Error::WithContext { context, cause } => { writeln!(fmt, "{context}")?; write!(fmt, "{cause}") - } + }, } } } diff --git a/src/function.rs b/src/function.rs index 9a5b291b..bf4a6b63 100644 --- a/src/function.rs +++ b/src/function.rs @@ -3,6 +3,8 @@ use std::os::raw::{c_int, c_void}; use std::{mem, ptr, slice}; use crate::error::{Error, Result}; +#[cfg(feature = "luau")] +use crate::state::util::get_next_spot; use crate::state::Lua; use crate::table::Table; use crate::traits::{FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut}; @@ -494,14 +496,22 @@ impl Function { #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn deep_clone(&self) -> Self { let lua = self.0.lua.lock(); - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); unsafe { if ffi::lua_iscfunction(ref_thread, self.0.index) != 0 { return self.clone(); } ffi::lua_clonefunction(ref_thread, self.0.index); - Function(lua.pop_ref_thread()) + + // Get the real next spot + let (aux_thread, index, replace) = get_next_spot(lua.extra()); + ffi::lua_xpush(lua.ref_thread(self.0.aux_thread), lua.ref_thread(aux_thread), -1); + if replace { + ffi::lua_replace(lua.ref_thread(aux_thread), index); + } + + Function(lua.new_value_ref(aux_thread, index)) } } } diff --git a/src/lib.rs b/src/lib.rs index e1589ce6..8c7668ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,7 +110,7 @@ pub use crate::state::{GCMode, Lua, LuaOptions, WeakLua}; pub use crate::stdlib::StdLib; pub use crate::string::{BorrowedBytes, BorrowedStr, String}; pub use crate::table::{Table, TablePairs, TableSequence}; -pub use crate::thread::{Thread, ThreadStatus}; +pub use crate::thread::{ContinuationStatus, Thread, ThreadStatus}; pub use crate::traits::{ FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut, ObjectLike, }; diff --git a/src/prelude.rs b/src/prelude.rs index a3a03201..c3b5d443 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -3,15 +3,15 @@ #[doc(no_inline)] pub use crate::{ AnyUserData as LuaAnyUserData, BorrowedBytes as LuaBorrowedBytes, BorrowedStr as LuaBorrowedStr, - Chunk as LuaChunk, Either as LuaEither, Error as LuaError, ErrorContext as LuaErrorContext, - ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, FromLua, FromLuaMulti, - Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, Integer as LuaInteger, - IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, LuaNativeFnMut, LuaOptions, - MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, - ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, - String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, - Thread as LuaThread, ThreadStatus as LuaThreadStatus, UserData as LuaUserData, - UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, + Chunk as LuaChunk, ContinuationStatus as LuaContinuationStatus, Either as LuaEither, Error as LuaError, + ErrorContext as LuaErrorContext, ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, + FromLua, FromLuaMulti, Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, + Integer as LuaInteger, IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, + LuaNativeFnMut, LuaOptions, MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, + Number as LuaNumber, ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, + StdLib as LuaStdLib, String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, + TableSequence as LuaTableSequence, Thread as LuaThread, ThreadStatus as LuaThreadStatus, + UserData as LuaUserData, UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, Variadic as LuaVariadic, VmState as LuaVmState, WeakLua, diff --git a/src/scope.rs b/src/scope.rs index c56647a4..0ea61996 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -267,7 +267,7 @@ impl<'scope, 'env: 'scope> Scope<'scope, 'env> { let f = self.lua.create_callback(f)?; let destructor: DestructorCallback = Box::new(|rawlua, vref| { - let ref_thread = rawlua.ref_thread(); + let ref_thread = rawlua.ref_thread(vref.aux_thread); ffi::lua_getupvalue(ref_thread, vref.index, 1); let upvalue = get_userdata::(ref_thread, -1); let data = (*upvalue).data.take(); @@ -287,13 +287,13 @@ impl<'scope, 'env: 'scope> Scope<'scope, 'env> { Ok(Some(_)) => {} Ok(None) => { // Deregister metatable - let mt_ptr = get_metatable_ptr(rawlua.ref_thread(), vref.index); + let mt_ptr = get_metatable_ptr(rawlua.ref_thread(vref.aux_thread), vref.index); rawlua.deregister_userdata_metatable(mt_ptr); } Err(_) => return vec![], } - let data = take_userdata::>(rawlua.ref_thread(), vref.index); + let data = take_userdata::>(rawlua.ref_thread(vref.aux_thread), vref.index); vec![Box::new(move || drop(data))] }); self.destructors.0.borrow_mut().push((ud.0.clone(), destructor)); diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 1b85a763..87563ae8 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -7,6 +7,7 @@ use serde::ser::Serialize; use crate::error::Result; use crate::private::Sealed; +use crate::state::util::get_next_spot; use crate::state::Lua; use crate::table::Table; use crate::util::check_stack; @@ -183,8 +184,14 @@ impl LuaSerdeExt for Lua { fn array_metatable(&self) -> Table { let lua = self.lock(); unsafe { - push_array_metatable(lua.ref_thread()); - Table(lua.pop_ref_thread()) + let (aux_thread, index, replace) = get_next_spot(lua.extra()); + push_array_metatable(lua.state()); + ffi::lua_xmove(lua.state(), lua.ref_thread(aux_thread), 1); + if replace { + ffi::lua_replace(lua.ref_thread(aux_thread), index); + } + + Table(lua.new_value_ref(aux_thread, index)) } } diff --git a/src/state.rs b/src/state.rs index dbe4c2b2..c344dc8e 100644 --- a/src/state.rs +++ b/src/state.rs @@ -14,10 +14,15 @@ use crate::hook::Debug; use crate::memory::MemoryState; use crate::multi::MultiValue; use crate::scope::Scope; +use crate::state::util::get_next_spot; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; use crate::thread::Thread; + +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::thread::ContinuationStatus; + use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex, @@ -522,7 +527,7 @@ impl Lua { ffi::luaL_sandboxthread(state); } else { // Restore original `LUA_GLOBALSINDEX` - ffi::lua_xpush(lua.ref_thread(), state, ffi::LUA_GLOBALSINDEX); + ffi::lua_xpush(lua.ref_thread_internal(), state, ffi::LUA_GLOBALSINDEX); ffi::lua_replace(state, ffi::LUA_GLOBALSINDEX); ffi::luaL_sandbox(state, 0); } @@ -762,8 +767,12 @@ impl Lua { return; // Don't allow recursion } ffi::lua_pushthread(child); - ffi::lua_xmove(child, (*extra).ref_thread, 1); - let value = Thread((*extra).raw_lua().pop_ref_thread(), child); + let (aux_thread, index, replace) = get_next_spot(extra); + ffi::lua_xmove(child, (*extra).raw_lua().ref_thread(aux_thread), 1); + if replace { + ffi::lua_replace((*extra).raw_lua().ref_thread(aux_thread), index); + } + let value = Thread((*extra).raw_lua().new_value_ref(aux_thread, index), child); callback_error_ext(parent, extra, false, move |extra, _| { callback((*extra).lua(), value) }) @@ -1265,6 +1274,44 @@ impl Lua { })) } + /// Same as ``create_function`` but with an added continuation function. + /// + /// The values passed to the continuation will be the yielded arguments + /// from the function for the initial continuation call. If yielding from a + /// continuation, the yielded results will be returned to the ``Thread::resume`` caller. The + /// arguments passed in the next ``Thread::resume`` call will then be the arguments passed + /// to the yielding continuation upon resumption. + /// + /// Returning a value from a continuation without setting yield + /// arguments will then be returned as the final return value of the Lua function call. + /// Values returned in a function in which there is also yielding will be ignored + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + pub fn create_function_with_continuation( + &self, + func: F, + cont: FC, + ) -> Result + where + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + FC: Fn(&Lua, ContinuationStatus, AC) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + AC: FromLuaMulti, + R: IntoLuaMulti, + RC: IntoLuaMulti, + { + (self.lock()).create_callback_with_continuation( + Box::new(move |rawlua, nargs| unsafe { + let args = A::from_stack_args(nargs, 1, None, rawlua)?; + func(rawlua.lua(), args)?.push_into_stack_multi(rawlua) + }), + Box::new(move |rawlua, nargs, status| unsafe { + let args = AC::from_stack_args(nargs, 1, None, rawlua)?; + let status = ContinuationStatus::from_status(status); + cont(rawlua.lua(), status, args)?.push_into_stack_multi(rawlua) + }), + ) + } + /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. /// /// This is a version of [`Lua::create_function`] that accepts a `FnMut` argument. @@ -1286,8 +1333,13 @@ impl Lua { /// This function is unsafe because provides a way to execute unsafe C function. pub unsafe fn create_c_function(&self, func: ffi::lua_CFunction) -> Result { let lua = self.lock(); - ffi::lua_pushcfunction(lua.ref_thread(), func); - Ok(Function(lua.pop_ref_thread())) + let (aux_thread, idx, replace) = get_next_spot(lua.extra()); + ffi::lua_pushcfunction(lua.ref_thread(aux_thread), func); + if replace { + ffi::lua_replace(lua.ref_thread(aux_thread), idx); + } + + Ok(Function(lua.new_value_ref(aux_thread, idx))) } /// Wraps a Rust async function or closure, creating a callable Lua function handle to it. @@ -2080,6 +2132,42 @@ impl Lua { pub(crate) unsafe fn raw_lua(&self) -> &RawLua { &*self.raw.data_ptr() } + + /// Set the yield arguments. Note that Lua will not yield until you return from the function + /// + /// This method is mostly useful with continuations and Rust-Rust yields + /// due to the Rust/Lua boundary + /// + /// Example: + /// + /// ```rust + /// fn test() -> mlua::Result<()> { + /// let lua = mlua::Lua::new(); + /// let always_yield = lua.create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)))?; + /// + /// let thread = lua.create_thread(always_yield)?; + /// assert_eq!( + /// thread.resume::<(i32, String, f32)>(())?, + /// (42, String::from("69420"), 45.6) + /// ); + /// + /// Ok(()) + /// } + /// ``` + pub fn yield_with(&self, args: impl IntoLuaMulti) -> Result<()> { + let raw = self.lock(); + unsafe { + raw.extra.get().as_mut().unwrap_unchecked().yielded_values = Some(args.into_lua_multi(self)?); + } + Ok(()) + } + + /// Checks if Lua is be allowed to yield. + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] + #[inline] + pub fn is_yieldable(&self) -> bool { + self.lock().is_yieldable() + } } impl WeakLua { diff --git a/src/state/extra.rs b/src/state/extra.rs index 5ff74a33..fa6d2d3a 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -13,11 +13,13 @@ use crate::error::Result; use crate::state::RawLua; use crate::stdlib::StdLib; use crate::types::{AppData, ReentrantMutex, XRc}; + use crate::userdata::RawUserDataRegistry; use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, WrappedFailure}; #[cfg(any(feature = "luau", doc))] use crate::chunk::Compiler; +use crate::MultiValue; #[cfg(feature = "async")] use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; @@ -30,6 +32,44 @@ static EXTRA_REGISTRY_KEY: u8 = 0; const WRAPPED_FAILURE_POOL_DEFAULT_CAPACITY: usize = 64; const REF_STACK_RESERVE: c_int = 2; +pub(crate) struct RefThread { + pub(super) ref_thread: *mut ffi::lua_State, + pub(super) stack_size: c_int, + pub(super) stack_top: c_int, + pub(super) free: Vec, +} + +impl RefThread { + #[inline(always)] + pub(crate) unsafe fn new(state: *mut ffi::lua_State) -> Self { + // Create ref stack thread and place it in the registry to prevent it + // from being garbage collected. + let ref_thread = mlua_expect!( + protect_lua!(state, 0, 0, |state| { + let thread = ffi::lua_newthread(state); + ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX); + thread + }), + "Error while creating ref thread", + ); + + // Store `error_traceback` function on the ref stack + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + { + ffi::lua_pushcfunction(ref_thread, crate::util::error_traceback); + assert_eq!(ffi::lua_gettop(ref_thread), ExtraData::ERROR_TRACEBACK_IDX); + } + + RefThread { + ref_thread, + // We need some reserved stack space to move values in and out of the ref stack. + stack_size: ffi::LUA_MINSTACK - REF_STACK_RESERVE, + stack_top: ffi::lua_gettop(ref_thread), + free: Vec::new(), + } + } +} + /// Data associated with the Lua state. pub(crate) struct ExtraData { pub(super) lua: MaybeUninit, @@ -53,18 +93,17 @@ pub(crate) struct ExtraData { // Used in module mode pub(super) skip_memory_check: bool, - // Auxiliary thread to store references - pub(super) ref_thread: *mut ffi::lua_State, - pub(super) ref_stack_size: c_int, - pub(super) ref_stack_top: c_int, - pub(super) ref_free: Vec, + // Auxiliary threads to store references + pub(super) ref_thread: Vec, + // Special auxillary thread for mlua internal use + pub(super) ref_thread_internal: RefThread, // Pool of `WrappedFailure` enums in the ref thread (as userdata) pub(super) wrapped_failure_pool: Vec, pub(super) wrapped_failure_top: usize, // Pool of `Thread`s (coroutines) for async execution #[cfg(feature = "async")] - pub(super) thread_pool: Vec, + pub(super) thread_pool: Vec<(usize, c_int)>, // Address of `WrappedFailure` metatable pub(super) wrapped_failure_mt_ptr: *const c_void, @@ -94,6 +133,9 @@ pub(crate) struct ExtraData { pub(super) compiler: Option, #[cfg(feature = "luau-jit")] pub(super) enable_jit: bool, + + // Values currently being yielded from Lua.yield() + pub(super) yielded_values: Option, } impl Drop for ExtraData { @@ -124,17 +166,6 @@ impl ExtraData { pub(super) const ERROR_TRACEBACK_IDX: c_int = 1; pub(super) unsafe fn init(state: *mut ffi::lua_State, owned: bool) -> XRc> { - // Create ref stack thread and place it in the registry to prevent it - // from being garbage collected. - let ref_thread = mlua_expect!( - protect_lua!(state, 0, 0, |state| { - let thread = ffi::lua_newthread(state); - ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX); - thread - }), - "Error while creating ref thread", - ); - let wrapped_failure_mt_ptr = { get_internal_metatable::(state); let ptr = ffi::lua_topointer(state, -1); @@ -142,13 +173,6 @@ impl ExtraData { ptr }; - // Store `error_traceback` function on the ref stack - #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] - { - ffi::lua_pushcfunction(ref_thread, crate::util::error_traceback); - assert_eq!(ffi::lua_gettop(ref_thread), Self::ERROR_TRACEBACK_IDX); - } - #[allow(clippy::arc_with_non_send_sync)] let extra = XRc::new(UnsafeCell::new(ExtraData { lua: MaybeUninit::uninit(), @@ -164,11 +188,8 @@ impl ExtraData { safe: false, libs: StdLib::NONE, skip_memory_check: false, - ref_thread, - // We need some reserved stack space to move values in and out of the ref stack. - ref_stack_size: ffi::LUA_MINSTACK - REF_STACK_RESERVE, - ref_stack_top: ffi::lua_gettop(ref_thread), - ref_free: Vec::new(), + ref_thread: vec![RefThread::new(state)], + ref_thread_internal: RefThread::new(state), wrapped_failure_pool: Vec::with_capacity(WRAPPED_FAILURE_POOL_DEFAULT_CAPACITY), wrapped_failure_top: 0, #[cfg(feature = "async")] @@ -196,6 +217,7 @@ impl ExtraData { enable_jit: true, #[cfg(feature = "luau")] running_gc: false, + yielded_values: None, })); // Store it in the registry diff --git a/src/state/raw.rs b/src/state/raw.rs index b7de97f2..eca8c9d5 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -11,7 +11,9 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, ref_stack_pop}; +#[allow(unused_imports)] +use crate::state::util::callback_error_ext; +use crate::state::util::{callback_error_ext_yieldable, get_next_spot}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; @@ -21,6 +23,12 @@ use crate::types::{ AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; + +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::Continuation; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::ContinuationUpvalue; + use crate::userdata::{ init_userdata_metatable, AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, UserDataStorage, @@ -116,8 +124,25 @@ impl RawLua { } #[inline(always)] - pub(crate) fn ref_thread(&self) -> *mut ffi::lua_State { - unsafe { (*self.extra.get()).ref_thread } + pub(crate) fn ref_thread(&self, aux_thread: usize) -> *mut ffi::lua_State { + unsafe { + (*self.extra.get()) + .ref_thread + .get(aux_thread) + .unwrap_unchecked() + .ref_thread + } + } + + #[inline(always)] + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + pub(crate) fn ref_thread_internal(&self) -> *mut ffi::lua_State { + unsafe { (*self.extra.get()).ref_thread_internal.ref_thread } + } + + #[inline(always)] + pub(crate) fn extra(&self) -> *mut ExtraData { + self.extra.get() } pub(super) unsafe fn new(libs: StdLib, options: &LuaOptions) -> XRc> { @@ -197,6 +222,8 @@ impl RawLua { init_internal_metatable::>>(state, None)?; init_internal_metatable::(state, None)?; init_internal_metatable::(state, None)?; + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + init_internal_metatable::(state, None)?; #[cfg(not(feature = "luau"))] init_internal_metatable::(state, None)?; #[cfg(feature = "async")] @@ -616,16 +643,16 @@ impl RawLua { self.set_thread_hook(thread_state, HookKind::Global)?; let thread = Thread(self.pop_ref(), thread_state); - ffi::lua_xpush(self.ref_thread(), thread_state, func.0.index); + ffi::lua_xpush(self.ref_thread(func.0.aux_thread), thread_state, func.0.index); Ok(thread) } /// Wraps a Lua function into a new or recycled thread (coroutine). #[cfg(feature = "async")] pub(crate) unsafe fn create_recycled_thread(&self, func: &Function) -> Result { - if let Some(index) = (*self.extra.get()).thread_pool.pop() { - let thread_state = ffi::lua_tothread(self.ref_thread(), index); - ffi::lua_xpush(self.ref_thread(), thread_state, func.0.index); + if let Some((aux_thread, index)) = (*self.extra.get()).thread_pool.pop() { + let thread_state = ffi::lua_tothread(self.ref_thread(aux_thread), index); + ffi::lua_xpush(self.ref_thread(func.0.aux_thread), thread_state, func.0.index); #[cfg(feature = "luau")] { @@ -634,7 +661,7 @@ impl RawLua { ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX); } - return Ok(Thread(ValueRef::new(self, index), thread_state)); + return Ok(Thread(ValueRef::new(self, aux_thread, index), thread_state)); } self.create_thread(func) @@ -645,7 +672,7 @@ impl RawLua { pub(crate) unsafe fn recycle_thread(&self, thread: &mut Thread) { let extra = &mut *self.extra.get(); if extra.thread_pool.len() < extra.thread_pool.capacity() { - extra.thread_pool.push(thread.0.index); + extra.thread_pool.push((thread.0.aux_thread, thread.0.index)); thread.0.drop = false; // Prevent thread from being garbage collected } } @@ -744,18 +771,30 @@ impl RawLua { } ffi::LUA_TSTRING => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::String(String(self.pop_ref_thread())) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::String(String(self.new_value_ref(aux_thread, idxs))) } ffi::LUA_TTABLE => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::Table(Table(self.pop_ref_thread())) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Table(Table(self.new_value_ref(aux_thread, idxs))) } ffi::LUA_TFUNCTION => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::Function(Function(self.pop_ref_thread())) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Function(Function(self.new_value_ref(aux_thread, idxs))) } ffi::LUA_TUSERDATA => { @@ -771,27 +810,44 @@ impl RawLua { Value::Nil } _ => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::UserData(AnyUserData(self.pop_ref_thread())) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + + Value::UserData(AnyUserData(self.new_value_ref(aux_thread, idxs))) } } } ffi::LUA_TTHREAD => { - ffi::lua_xpush(state, self.ref_thread(), idx); - let thread_state = ffi::lua_tothread(self.ref_thread(), -1); - Value::Thread(Thread(self.pop_ref_thread(), thread_state)) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + let thread_state = ffi::lua_tothread(self.ref_thread(aux_thread), -1); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Thread(Thread(self.new_value_ref(aux_thread, idxs), thread_state)) } #[cfg(feature = "luau")] ffi::LUA_TBUFFER => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::Buffer(crate::Buffer(self.pop_ref_thread())) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Buffer(crate::Buffer(self.new_value_ref(aux_thread, idxs))) } _ => { - ffi::lua_xpush(state, self.ref_thread(), idx); - Value::Other(self.pop_ref_thread()) + let (aux_thread, idxs, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush(state, self.ref_thread(aux_thread), idx); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idxs); + } + Value::Other(self.new_value_ref(aux_thread, idxs)) } } } @@ -803,7 +859,7 @@ impl RawLua { self.weak() == &vref.lua, "Lua instance passed Value created from a different main Lua state" ); - unsafe { ffi::lua_xpush(self.ref_thread(), self.state(), vref.index) }; + unsafe { ffi::lua_xpush(self.ref_thread(vref.aux_thread), self.state(), vref.index) }; } // Pops the topmost element of the stack and stores a reference to it. This pins the object, @@ -815,41 +871,53 @@ impl RawLua { // used stack. #[inline] pub(crate) unsafe fn pop_ref(&self) -> ValueRef { - ffi::lua_xmove(self.state(), self.ref_thread(), 1); - let index = ref_stack_pop(self.extra.get()); - ValueRef::new(self, index) + let (aux_thread, idx, replace) = get_next_spot(self.extra.get()); + ffi::lua_xmove(self.state(), self.ref_thread(aux_thread), 1); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), idx); + } + + ValueRef::new(self, aux_thread, idx) } - // Same as `pop_ref` but assumes the value is already on the reference thread + // Given a known aux_thread and index, creates a ValueRef. #[inline] - pub(crate) unsafe fn pop_ref_thread(&self) -> ValueRef { - let index = ref_stack_pop(self.extra.get()); - ValueRef::new(self, index) + pub(crate) unsafe fn new_value_ref(&self, aux_thread: usize, index: c_int) -> ValueRef { + ValueRef::new(self, aux_thread, index) } #[inline] pub(crate) unsafe fn clone_ref(&self, vref: &ValueRef) -> ValueRef { - ffi::lua_pushvalue(self.ref_thread(), vref.index); - let index = ref_stack_pop(self.extra.get()); - ValueRef::new(self, index) + let (aux_thread, index, replace) = get_next_spot(self.extra.get()); + ffi::lua_xpush( + self.ref_thread(vref.aux_thread), + self.ref_thread(aux_thread), + vref.index, + ); + if replace { + ffi::lua_replace(self.ref_thread(aux_thread), index); + } + ValueRef::new(self, aux_thread, index) } pub(crate) unsafe fn drop_ref(&self, vref: &ValueRef) { - let ref_thread = self.ref_thread(); + let ref_thread = self.ref_thread(vref.aux_thread); mlua_debug_assert!( ffi::lua_gettop(ref_thread) >= vref.index, "GC finalizer is not allowed in ref_thread" ); ffi::lua_pushnil(ref_thread); ffi::lua_replace(ref_thread, vref.index); - (*self.extra.get()).ref_free.push(vref.index); + (*self.extra.get()).ref_thread[vref.aux_thread] + .free + .push(vref.index); } #[inline] pub(crate) unsafe fn push_error_traceback(&self) { let state = self.state(); #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] - ffi::lua_xpush(self.ref_thread(), state, ExtraData::ERROR_TRACEBACK_IDX); + ffi::lua_xpush(self.ref_thread_internal(), state, ExtraData::ERROR_TRACEBACK_IDX); // Lua 5.2+ support light C functions that does not require extra allocations #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] ffi::lua_pushcfunction(state, crate::util::error_traceback); @@ -1121,7 +1189,7 @@ impl RawLua { // Returns `None` if the userdata is registered but non-static. #[inline(always)] pub(crate) fn get_userdata_ref_type_id(&self, vref: &ValueRef) -> Result> { - unsafe { self.get_userdata_type_id_inner(self.ref_thread(), vref.index) } + unsafe { self.get_userdata_type_id_inner(self.ref_thread(vref.aux_thread), vref.index) } } // Same as `get_userdata_ref_type_id` but assumes the userdata is already on the stack. @@ -1174,7 +1242,7 @@ impl RawLua { // Pushes a ValueRef (userdata) value onto the stack, returning their `TypeId`. // Uses 1 stack space, does not call checkstack. pub(crate) unsafe fn push_userdata_ref(&self, vref: &ValueRef) -> Result> { - let type_id = self.get_userdata_type_id_inner(self.ref_thread(), vref.index)?; + let type_id = self.get_userdata_type_id_inner(self.ref_thread(vref.aux_thread), vref.index)?; self.push_ref(vref); Ok(type_id) } @@ -1183,15 +1251,21 @@ impl RawLua { pub(crate) fn create_callback(&self, func: Callback) -> Result { unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { - // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) - // The lock must be already held as the callback is executed - let rawlua = (*extra).raw_lua(); - match (*upvalue).data { - Some(ref func) => func(rawlua, nargs), - None => Err(Error::CallbackDestructed), - } - }) + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + false, + ) } let state = self.state(); @@ -1215,6 +1289,124 @@ impl RawLua { } } + // Creates a Function out of a Callback and a continuation containing a 'static Fn. + // + // In Luau, uses pushcclosurek + // + // In Lua 5.2/5.3/5.4/JIT, makes a normal function that then yields to the continuation via yieldk + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + pub(crate) fn create_callback_with_continuation( + &self, + func: Callback, + cont: Continuation, + ) -> Result { + #[cfg(feature = "luau")] + { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.0)(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + unsafe extern "C-unwind" fn cont_callback(state: *mut ffi::lua_State, status: c_int) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, ContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + })?; + } else { + ffi::lua_pushcclosurec(state, call_callback, cont_callback, 1); + } + + Ok(Function(self.pop_ref())) + } + } + + #[cfg(not(feature = "luau"))] + { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing + // arguments) The lock must be already held as the callback is + // executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some((ref func, _)) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let func = Some((func, cont)); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, ContinuationUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosure(state, call_callback, 1); + })?; + } else { + ffi::lua_pushcclosure(state, call_callback, 1); + } + + Ok(Function(self.pop_ref())) + } + } + } + #[cfg(feature = "async")] pub(crate) fn create_async_callback(&self, func: AsyncCallback) -> Result { // Ensure that the coroutine library is loaded @@ -1375,6 +1567,12 @@ impl RawLua { pub(crate) unsafe fn set_waker(&self, waker: NonNull) -> NonNull { mem::replace(&mut (*self.extra.get()).waker, waker) } + + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] + #[inline] + pub(crate) fn is_yieldable(&self) -> bool { + unsafe { ffi::lua_isyieldable(self.state()) != 0 } + } } // Uses 3 stack spaces diff --git a/src/state/util.rs b/src/state/util.rs index c3c79302..881016ae 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,11 +1,17 @@ +use crate::IntoLuaMulti; +use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; use std::sync::Arc; use crate::error::{Error, Result}; +use crate::state::extra::RefThread; use crate::state::{ExtraData, RawLua}; -use crate::util::{self, get_internal_metatable, WrappedFailure}; +use crate::util::{self, check_stack, get_internal_metatable, WrappedFailure}; + +#[cfg(all(not(feature = "lua51"), not(feature = "luajit"), not(feature = "luau")))] +use crate::{types::ContinuationUpvalue, util::get_userdata}; struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); @@ -22,6 +28,65 @@ impl Drop for StateGuard<'_> { } } +pub(crate) enum PreallocatedFailure { + New(*mut WrappedFailure), + Reserved, +} + +impl PreallocatedFailure { + unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { + if (*extra).wrapped_failure_top > 0 { + (*extra).wrapped_failure_top -= 1; + return PreallocatedFailure::Reserved; + } + + // We need to check stack for Luau in case when callback is called from interrupt + // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + // Place it to the beginning of the stack + let ud = WrappedFailure::new_userdata(state); + ffi::lua_insert(state, 1); + PreallocatedFailure::New(ud) + } + + #[cold] + unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { + let ref_thread = &(*extra).ref_thread_internal; + match *self { + PreallocatedFailure::New(ud) => { + ffi::lua_settop(state, 1); + ud + } + PreallocatedFailure::Reserved => { + let index = (*extra).wrapped_failure_pool.pop().unwrap(); + ffi::lua_settop(state, 0); + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + ffi::lua_xpush(ref_thread.ref_thread, state, index); + ffi::lua_pushnil(ref_thread.ref_thread); + ffi::lua_replace(ref_thread.ref_thread, index); + (*extra).ref_thread_internal.free.push(index); + ffi::lua_touserdata(state, -1) as *mut WrappedFailure + } + } + } + + unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { + let ref_thread = &(*extra).ref_thread_internal; + match self { + PreallocatedFailure::New(_) => { + ffi::lua_rotate(state, 1, -1); + ffi::lua_xmove(state, ref_thread.ref_thread, 1); + let index = ref_stack_pop_internal(extra); + (*extra).wrapped_failure_pool.push(index); + (*extra).wrapped_failure_top += 1; + } + PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, + } + } +} + // An optimized version of `callback_error` that does not allocate `WrappedFailure` userdata // and instead reuses unused values from previous calls (or allocates new). pub(crate) unsafe fn callback_error_ext( @@ -39,64 +104,78 @@ where let nargs = ffi::lua_gettop(state); - enum PreallocatedFailure { - New(*mut WrappedFailure), - Reserved, - } - - impl PreallocatedFailure { - unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { - if (*extra).wrapped_failure_top > 0 { - (*extra).wrapped_failure_top -= 1; - return PreallocatedFailure::Reserved; - } + // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory + // to store a wrapped failure (error or panic) *before* we proceed. + let prealloc_failure = PreallocatedFailure::reserve(state, extra); - // We need to check stack for Luau in case when callback is called from interrupt - // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - // Place it to the beginning of the stack - let ud = WrappedFailure::new_userdata(state); - ffi::lua_insert(state, 1); - PreallocatedFailure::New(ud) + match catch_unwind(AssertUnwindSafe(|| { + let rawlua = (*extra).raw_lua(); + let _guard = StateGuard::new(rawlua, state); + f(extra, nargs) + })) { + Ok(Ok(r)) => { + // Return unused `WrappedFailure` to the pool + prealloc_failure.release(state, extra); + r } + Ok(Err(err)) => { + let wrapped_error = prealloc_failure.r#use(state, extra); - #[cold] - unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { - let ref_thread = (*extra).ref_thread; - match *self { - PreallocatedFailure::New(ud) => { - ffi::lua_settop(state, 1); - ud - } - PreallocatedFailure::Reserved => { - let index = (*extra).wrapped_failure_pool.pop().unwrap(); - ffi::lua_settop(state, 0); - #[cfg(feature = "luau")] - ffi::lua_rawcheckstack(state, 2); - ffi::lua_xpush(ref_thread, state, index); - ffi::lua_pushnil(ref_thread); - ffi::lua_replace(ref_thread, index); - (*extra).ref_free.push(index); - ffi::lua_touserdata(state, -1) as *mut WrappedFailure - } + if !wrap_error { + ptr::write(wrapped_error, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) } - } - unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { - let ref_thread = (*extra).ref_thread; - match self { - PreallocatedFailure::New(_) => { - ffi::lua_rotate(state, 1, -1); - ffi::lua_xmove(state, ref_thread, 1); - let index = ref_stack_pop(extra); - (*extra).wrapped_failure_pool.push(index); - (*extra).wrapped_failure_top += 1; - } - PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, - } + // Build `CallbackError` with traceback + let traceback = if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { + ffi::luaL_traceback(state, state, ptr::null(), 0); + let traceback = util::to_string(state, -1); + ffi::lua_pop(state, 1); + traceback + } else { + "".to_string() + }; + let cause = Arc::new(err); + ptr::write( + wrapped_error, + WrappedFailure::Error(Error::CallbackError { traceback, cause }), + ); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + + ffi::lua_error(state) + } + Err(p) => { + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Panic(Some(p))); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) } } +} + +/// An yieldable version of `callback_error_ext` +/// +/// Unlike ``callback_error_ext``, this method requires a c_int return +/// and not a generic R +pub(crate) unsafe fn callback_error_ext_yieldable( + state: *mut ffi::lua_State, + mut extra: *mut ExtraData, + wrap_error: bool, + f: F, + #[allow(unused_variables)] in_callback_with_continuation: bool, +) -> c_int +where + F: FnOnce(*mut ExtraData, c_int) -> Result, +{ + if extra.is_null() { + extra = ExtraData::get(state); + } + + let nargs = ffi::lua_gettop(state); // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory // to store a wrapped failure (error or panic) *before* we proceed. @@ -109,7 +188,128 @@ where })) { Ok(Ok(r)) => { // Return unused `WrappedFailure` to the pool + // + // In either case, we cannot use it in the yield case anyways due to the lua_pop call + // so drop it properly now while we can. prealloc_failure.release(state, extra); + + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + + if let Some(values) = values { + // A note on Luau + // + // When using the yieldable continuations fflag (and in future when the fflag gets removed and + // yieldable continuations) becomes default, we must either pop the top of the + // stack on the state we are resuming or somehow store the number of + // args on top of stack pre-yield and then subtract in the resume in order to get predictable + // behaviour here. See https://github.com/luau-lang/luau/issues/1867 for more information + // + // In this case, popping is easier and leads to less bugs/more ergonomic API. + + if raw.state() == state { + // Edge case: main thread is being yielded + // + // We need to pop/clear stack early, then push args + ffi::lua_pop(state, -1); + } + + match values.push_into_stack_multi(raw) { + Ok(nargs) => { + // If not main thread, then clear and xmove to target thread + if raw.state() != state { + // luau preserves the stack making yieldable continuations ugly and leaky + // + // Even outside of luau, clearing the stack is probably desirable + ffi::lua_pop(state, -1); + if let Err(err) = check_stack(state, nargs) { + // Make a *new* preallocated failure, and then do normal error + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state); + } + ffi::lua_xmove(raw.state(), state, nargs); + } + + #[cfg(all(not(feature = "luau"), not(feature = "lua51"), not(feature = "luajit")))] + { + // Yield to a continuation. Unlike luau, we need to do this manually and on the + // fly using a yieldk call + if in_callback_with_continuation { + // On Lua 5.2, status and ctx are not present, so use 0 as status for + // compatibility + #[cfg(feature = "lua52")] + unsafe extern "C-unwind" fn cont_callback( + state: *mut ffi::lua_State, + ) -> c_int { + let upvalue = + get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available + // (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, 0), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + // Lua 5.3/5.4 case + #[cfg(not(feature = "lua52"))] + unsafe extern "C-unwind" fn cont_callback( + state: *mut ffi::lua_State, + status: c_int, + _ctx: ffi::lua_KContext, + ) -> c_int { + let upvalue = + get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext_yieldable( + state, + (*upvalue).extra.get(), + true, + |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available + // (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => (func.1)(rawlua, nargs, status), + None => Err(Error::CallbackDestructed), + } + }, + true, + ) + } + + return ffi::lua_yieldc(state, nargs, cont_callback); + } + } + + return ffi::lua_yield(state, nargs); + } + Err(err) => { + // Make a *new* preallocated failure, and then do normal wrap_error + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state); + } + } + } + r } Ok(Err(err)) => { @@ -151,30 +351,104 @@ where } } -pub(super) unsafe fn ref_stack_pop(extra: *mut ExtraData) -> c_int { +pub(super) unsafe fn ref_stack_pop_internal(extra: *mut ExtraData) -> c_int { let extra = &mut *extra; - if let Some(free) = extra.ref_free.pop() { - ffi::lua_replace(extra.ref_thread, free); + let ref_th = &mut extra.ref_thread_internal; + + if let Some(free) = ref_th.free.pop() { + ffi::lua_replace(ref_th.ref_thread, free); return free; } // Try to grow max stack size - if extra.ref_stack_top >= extra.ref_stack_size { - let mut inc = extra.ref_stack_size; // Try to double stack size - while inc > 0 && ffi::lua_checkstack(extra.ref_thread, inc) == 0 { + if ref_th.stack_top >= ref_th.stack_size { + let mut inc = ref_th.stack_size; // Try to double stack size + while inc > 0 && ffi::lua_checkstack(ref_th.ref_thread, inc) == 0 { inc /= 2; } if inc == 0 { // Pop item on top of the stack to avoid stack leaking and successfully run destructors // during unwinding. - ffi::lua_pop(extra.ref_thread, 1); - let top = extra.ref_stack_top; + ffi::lua_pop(ref_th.ref_thread, 1); + let top = ref_th.stack_top; // It is a user error to create enough references to exhaust the Lua max stack size for - // the ref thread. - panic!("cannot create a Lua reference, out of auxiliary stack space (used {top} slots)"); + // the ref thread. This should never happen for the internal aux thread but still + panic!("internal error: cannot create a Lua reference, out of internal auxiliary stack space (used {top} slots)"); } - extra.ref_stack_size += inc; + ref_th.stack_size += inc; } - extra.ref_stack_top += 1; - extra.ref_stack_top + ref_th.stack_top += 1; + return ref_th.stack_top; +} + +// Run a comparison function on two Lua references from different auxiliary threads. +pub(crate) unsafe fn compare_refs( + extra: *mut ExtraData, + aux_thread_a: usize, + aux_thread_a_index: c_int, + aux_thread_b: usize, + aux_thread_b_index: c_int, + f: impl FnOnce(*mut ffi::lua_State, c_int, c_int) -> R, +) -> R { + let extra = &mut *extra; + + if aux_thread_a == aux_thread_b { + // If both threads are the same, just return the value at the index + let th = &mut extra.ref_thread[aux_thread_a]; + return f(th.ref_thread, aux_thread_a_index, aux_thread_b_index); + } + + let th_a = &extra.ref_thread[aux_thread_a]; + let th_b = &extra.ref_thread[aux_thread_b]; + let internal_thread = &mut extra.ref_thread_internal; + + // 4 spaces needed: idx element on A, idx element on B + check_stack(internal_thread.ref_thread, 2) + .expect("internal error: cannot merge references, out of internal auxiliary stack space"); + + // Push the index element from thread A to top + ffi::lua_xpush(th_a.ref_thread, internal_thread.ref_thread, aux_thread_a_index); + // Push the index element from thread B to top + ffi::lua_xpush(th_b.ref_thread, internal_thread.ref_thread, aux_thread_b_index); + // Now we have the following stack: + // - index element from thread A (1) [copy from pushvalue] + // - index element from thread B (2) [copy from pushvalue] + // We want to compare the index elements from both threads, so use 3 and 4 as indices + let result = f(internal_thread.ref_thread, -1, -2); + + // Pop the top 2 elements to clean the copies + ffi::lua_pop(internal_thread.ref_thread, 2); + + result +} + +pub(crate) unsafe fn get_next_spot(extra: *mut ExtraData) -> (usize, c_int, bool) { + let extra = &mut *extra; + + // Find the first thread with a free slot + for (i, ref_th) in extra.ref_thread.iter_mut().enumerate() { + if let Some(free) = ref_th.free.pop() { + return (i, free, true); + } + + // Try to grow max stack size + if ref_th.stack_top >= ref_th.stack_size { + let mut inc = ref_th.stack_size; // Try to double stack size + while inc > 0 && ffi::lua_checkstack(ref_th.ref_thread, inc + 1) == 0 { + inc /= 2; + } + if inc == 0 { + continue; // No stack space available, try next thread + } + ref_th.stack_size += inc; + } + + ref_th.stack_top += 1; + return (i, ref_th.stack_top, false); + } + + // No free slots found, create a new one + let new_ref_thread = RefThread::new(extra.raw_lua().state()); + extra.ref_thread.push(new_ref_thread); + return get_next_spot(extra); } diff --git a/src/string.rs b/src/string.rs index 6304d484..d2567ca0 100644 --- a/src/string.rs +++ b/src/string.rs @@ -119,7 +119,7 @@ impl String { let lua = self.0.lua.upgrade(); let slice = { let rawlua = lua.lock(); - let ref_thread = rawlua.ref_thread(); + let ref_thread = rawlua.ref_thread(self.0.aux_thread); mlua_debug_assert!( ffi::lua_type(ref_thread, self.0.index) == ffi::LUA_TSTRING, diff --git a/src/table.rs b/src/table.rs index 4b62705d..8957c436 100644 --- a/src/table.rs +++ b/src/table.rs @@ -406,7 +406,7 @@ impl Table { #[cfg(feature = "luau")] { self.check_readonly_write(&lua)?; - ffi::lua_cleartable(lua.ref_thread(), self.0.index); + ffi::lua_cleartable(lua.ref_thread(self.0.aux_thread), self.0.index); } #[cfg(not(feature = "luau"))] @@ -461,7 +461,7 @@ impl Table { /// Returns the result of the Lua `#` operator, without invoking the `__len` metamethod. pub fn raw_len(&self) -> usize { let lua = self.0.lua.lock(); - unsafe { ffi::lua_rawlen(lua.ref_thread(), self.0.index) } + unsafe { ffi::lua_rawlen(lua.ref_thread(self.0.aux_thread), self.0.index) } } /// Returns `true` if the table is empty, without invoking metamethods. @@ -469,7 +469,7 @@ impl Table { /// It checks both the array part and the hash part. pub fn is_empty(&self) -> bool { let lua = self.0.lua.lock(); - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); unsafe { ffi::lua_pushnil(ref_thread); if ffi::lua_next(ref_thread, self.0.index) == 0 { @@ -533,7 +533,7 @@ impl Table { #[inline] pub fn has_metatable(&self) -> bool { let lua = self.0.lua.lock(); - unsafe { !get_metatable_ptr(lua.ref_thread(), self.0.index).is_null() } + unsafe { !get_metatable_ptr(lua.ref_thread(self.0.aux_thread), self.0.index).is_null() } } /// Sets `readonly` attribute on the table. @@ -541,7 +541,7 @@ impl Table { #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn set_readonly(&self, enabled: bool) { let lua = self.0.lua.lock(); - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); unsafe { ffi::lua_setreadonly(ref_thread, self.0.index, enabled as _); if !enabled { @@ -556,7 +556,7 @@ impl Table { #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn is_readonly(&self) -> bool { let lua = self.0.lua.lock(); - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); unsafe { ffi::lua_getreadonly(ref_thread, self.0.index) != 0 } } @@ -573,7 +573,7 @@ impl Table { #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn set_safeenv(&self, enabled: bool) { let lua = self.0.lua.lock(); - unsafe { ffi::lua_setsafeenv(lua.ref_thread(), self.0.index, enabled as _) }; + unsafe { ffi::lua_setsafeenv(lua.ref_thread(self.0.aux_thread), self.0.index, enabled as _) }; } /// Converts this table to a generic C pointer. @@ -755,7 +755,7 @@ impl Table { #[cfg(feature = "luau")] #[inline(always)] fn check_readonly_write(&self, lua: &RawLua) -> Result<()> { - if unsafe { ffi::lua_getreadonly(lua.ref_thread(), self.0.index) != 0 } { + if unsafe { ffi::lua_getreadonly(lua.ref_thread(self.0.aux_thread), self.0.index) != 0 } { return Err(Error::runtime("attempt to modify a readonly table")); } Ok(()) diff --git a/src/thread.rs b/src/thread.rs index 13d95532..0c923380 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -26,6 +26,25 @@ use { }, }; +/// Continuation thread status. Can either be Ok, Yielded (rare, but can happen) or Error +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum ContinuationStatus { + Ok, + Yielded, + Error, +} + +impl ContinuationStatus { + #[allow(dead_code)] + pub(crate) fn from_status(status: c_int) -> Self { + match status { + ffi::LUA_YIELD => Self::Yielded, + ffi::LUA_OK => Self::Ok, + _ => Self::Error, + } + } +} + /// Status of a Lua thread (coroutine). #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum ThreadStatus { @@ -215,6 +234,7 @@ impl Thread { let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); #[cfg(feature = "luau")] let ret = ffi::lua_resumex(thread_state, state, nargs, &mut nresults as *mut c_int); + match ret { ffi::LUA_OK => Ok((ThreadStatusInner::Finished, nresults)), ffi::LUA_YIELD => Ok((ThreadStatusInner::Yielded(0), nresults)), @@ -312,7 +332,7 @@ impl Thread { self.reset_inner(status)?; // Push function to the top of the thread stack - ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index); + ffi::lua_xpush(lua.ref_thread(func.0.aux_thread), thread_state, func.0.index); #[cfg(feature = "luau")] { diff --git a/src/types.rs b/src/types.rs index 2589ea6e..45806247 100644 --- a/src/types.rs +++ b/src/types.rs @@ -39,6 +39,11 @@ pub(crate) type Callback = Box Result + Send + #[cfg(not(feature = "send"))] pub(crate) type Callback = Box Result + 'static>; +#[cfg(all(feature = "send", not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type Continuation = Box Result + Send + 'static>; + +#[cfg(all(not(feature = "send"), not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type Continuation = Box Result + 'static>; pub(crate) type ScopedCallback<'s> = Box Result + 's>; @@ -48,6 +53,8 @@ pub(crate) struct Upvalue { } pub(crate) type CallbackUpvalue = Upvalue>; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +pub(crate) type ContinuationUpvalue = Upvalue>; #[cfg(all(feature = "async", feature = "send"))] pub(crate) type AsyncCallback = diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs index 89bac543..fd9027e4 100644 --- a/src/types/value_ref.rs +++ b/src/types/value_ref.rs @@ -1,20 +1,23 @@ use std::fmt; use std::os::raw::{c_int, c_void}; +use crate::state::util::compare_refs; use crate::state::{RawLua, WeakLua}; /// A reference to a Lua (complex) value stored in the Lua auxiliary thread. pub struct ValueRef { pub(crate) lua: WeakLua, + pub(crate) aux_thread: usize, pub(crate) index: c_int, pub(crate) drop: bool, } impl ValueRef { #[inline] - pub(crate) fn new(lua: &RawLua, index: c_int) -> Self { + pub(crate) fn new(lua: &RawLua, aux_thread: usize, index: c_int) -> Self { ValueRef { lua: lua.weak().clone(), + aux_thread, index, drop: true, } @@ -23,7 +26,7 @@ impl ValueRef { #[inline] pub(crate) fn to_pointer(&self) -> *const c_void { let lua = self.lua.lock(); - unsafe { ffi::lua_topointer(lua.ref_thread(), self.index) } + unsafe { ffi::lua_topointer(lua.ref_thread(self.aux_thread), self.index) } } /// Returns a copy of the value, which is valid as long as the original value is held. @@ -31,6 +34,7 @@ impl ValueRef { pub(crate) fn copy(&self) -> Self { ValueRef { lua: self.lua.clone(), + aux_thread: self.aux_thread, index: self.index, drop: false, } @@ -66,6 +70,16 @@ impl PartialEq for ValueRef { "Lua instance passed Value created from a different main Lua state" ); let lua = self.lua.lock(); - unsafe { ffi::lua_rawequal(lua.ref_thread(), self.index, other.index) == 1 } + + unsafe { + compare_refs( + lua.extra(), + self.aux_thread, + self.index, + other.aux_thread, + other.index, + |state, a, b| ffi::lua_rawequal(state, a, b) == 1, + ) + } } } diff --git a/src/userdata.rs b/src/userdata.rs index a568485d..cddc931a 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -634,7 +634,7 @@ impl AnyUserData { #[inline] pub fn borrow(&self) -> Result> { let lua = self.0.lua.lock(); - unsafe { UserDataRef::borrow_from_stack(&lua, lua.ref_thread(), self.0.index) } + unsafe { UserDataRef::borrow_from_stack(&lua, lua.ref_thread(self.0.aux_thread), self.0.index) } } /// Borrow this userdata immutably if it is of type `T`, passing the borrowed value @@ -645,7 +645,15 @@ impl AnyUserData { let lua = self.0.lua.lock(); let type_id = lua.get_userdata_ref_type_id(&self.0)?; let type_hints = TypeIdHints::new::(); - unsafe { borrow_userdata_scoped(lua.ref_thread(), self.0.index, type_id, type_hints, f) } + unsafe { + borrow_userdata_scoped( + lua.ref_thread(self.0.aux_thread), + self.0.index, + type_id, + type_hints, + f, + ) + } } /// Borrow this userdata mutably if it is of type `T`. @@ -661,7 +669,7 @@ impl AnyUserData { #[inline] pub fn borrow_mut(&self) -> Result> { let lua = self.0.lua.lock(); - unsafe { UserDataRefMut::borrow_from_stack(&lua, lua.ref_thread(), self.0.index) } + unsafe { UserDataRefMut::borrow_from_stack(&lua, lua.ref_thread(self.0.aux_thread), self.0.index) } } /// Borrow this userdata mutably if it is of type `T`, passing the borrowed value @@ -672,7 +680,15 @@ impl AnyUserData { let lua = self.0.lua.lock(); let type_id = lua.get_userdata_ref_type_id(&self.0)?; let type_hints = TypeIdHints::new::(); - unsafe { borrow_userdata_scoped_mut(lua.ref_thread(), self.0.index, type_id, type_hints, f) } + unsafe { + borrow_userdata_scoped_mut( + lua.ref_thread(self.0.aux_thread), + self.0.index, + type_id, + type_hints, + f, + ) + } } /// Takes the value out of this userdata. @@ -685,7 +701,7 @@ impl AnyUserData { let lua = self.0.lua.lock(); match lua.get_userdata_ref_type_id(&self.0)? { Some(type_id) if type_id == TypeId::of::() => unsafe { - let ref_thread = lua.ref_thread(); + let ref_thread = lua.ref_thread(self.0.aux_thread); if (*get_userdata::>(ref_thread, self.0.index)).has_exclusive_access() { take_userdata::>(ref_thread, self.0.index).into_inner() } else { @@ -963,7 +979,7 @@ impl AnyUserData { let is_serializable = || unsafe { // Userdata must be registered and not destructed let _ = lua.get_userdata_ref_type_id(&self.0)?; - let ud = &*get_userdata::>(lua.ref_thread(), self.0.index); + let ud = &*get_userdata::>(lua.ref_thread(self.0.aux_thread), self.0.index); Ok::<_, Error>((*ud).is_serializable()) }; is_serializable().unwrap_or(false) @@ -1052,7 +1068,7 @@ impl Serialize for AnyUserData { let _ = lua .get_userdata_ref_type_id(&self.0) .map_err(ser::Error::custom)?; - let ud = &*get_userdata::>(lua.ref_thread(), self.0.index); + let ud = &*get_userdata::>(lua.ref_thread(self.0.aux_thread), self.0.index); ud.serialize(serializer) } } diff --git a/src/util/types.rs b/src/util/types.rs index 8bc9d8b2..8627042f 100644 --- a/src/util/types.rs +++ b/src/util/types.rs @@ -3,6 +3,9 @@ use std::os::raw::c_void; use crate::types::{Callback, CallbackUpvalue}; +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +use crate::types::ContinuationUpvalue; + #[cfg(feature = "async")] use crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}; @@ -34,6 +37,15 @@ impl TypeKey for CallbackUpvalue { } } +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +impl TypeKey for ContinuationUpvalue { + #[inline(always)] + fn type_key() -> *const c_void { + static CONTINUATION_UPVALUE_TYPE_KEY: u8 = 0; + &CONTINUATION_UPVALUE_TYPE_KEY as *const u8 as *const c_void + } +} + #[cfg(not(feature = "luau"))] impl TypeKey for crate::types::HookCallback { #[inline(always)] diff --git a/src/value.rs b/src/value.rs index cfc13251..bc27bcfa 100644 --- a/src/value.rs +++ b/src/value.rs @@ -132,7 +132,7 @@ impl Value { // In Lua < 5.4 (excluding Luau), string pointers are NULL // Use alternative approach let lua = vref.lua.lock(); - unsafe { ffi::lua_tostring(lua.ref_thread(), vref.index) as *const c_void } + unsafe { ffi::lua_tostring(lua.ref_thread(vref.aux_thread), vref.index) as *const c_void } } Value::LightUserData(ud) => ud.0, Value::Table(Table(vref)) diff --git a/tests/byte_string.rs b/tests/byte_string.rs index 76e43e14..4768a475 100644 --- a/tests/byte_string.rs +++ b/tests/byte_string.rs @@ -2,6 +2,13 @@ use bstr::{BStr, BString}; use mlua::{Lua, Result}; #[test] +fn create_lua() { + let lua = Lua::new(); + let th = lua.create_table().unwrap(); + println!("{th:#?}"); +} + +//#[test] fn test_byte_string_round_trip() -> Result<()> { let lua = Lua::new(); diff --git a/tests/tests.rs b/tests/tests.rs index 7bf40af5..4d6cc805 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1016,16 +1016,14 @@ fn test_ref_stack_exhaustion() { match catch_unwind(AssertUnwindSafe(|| -> Result<()> { let lua = Lua::new(); let mut vals = Vec::new(); - for _ in 0..10000000 { + for _ in 0..200000 { + println!("Creating table {}", vals.len()); vals.push(lua.create_table()?); } Ok(()) })) { - Ok(_) => panic!("no panic was detected"), - Err(p) => assert!(p - .downcast::() - .unwrap() - .starts_with("cannot create a Lua reference, out of auxiliary stack space")), + Ok(_) => {} + Err(p) => panic!("got panic: {:?}", p), } } diff --git a/tests/thread.rs b/tests/thread.rs index 4cb6ab10..0dd9c6e1 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -252,3 +252,497 @@ fn test_thread_resume_error() -> Result<()> { Ok(()) } + +#[test] +fn test_thread_yield_args() -> Result<()> { + let lua = Lua::new(); + let always_yield = lua.create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)))?; + + let thread = lua.create_thread(always_yield)?; + assert_eq!( + thread.resume::<(i32, String, f32)>(())?, + (42, String::from("69420"), 45.6) + ); + + Ok(()) +} + +#[test] +#[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] +fn test_continuation() { + let lua = Lua::new(); + // No yielding continuation fflag test + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + // empty yield args test + let cont_func = lua + .create_function_with_continuation( + |lua, _: ()| lua.yield_with(()), + |_lua, _status, mv: mlua::MultiValue| Ok(mv.len()), + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res - 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + assert!(v.is_empty()); + let v = th.resume::(v).expect("Failed to load continuation"); + assert_eq!(v, -1); + + // Yielding continuation test (only supported on luau) + #[cfg(feature = "luau")] + { + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + } + + let cont_func = lua + .create_function_with_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); + + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func) + .expect("Failed to call cont_func"), + 2 + ); + + // basic yield test before we go any further + let always_yield = lua + .create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6))) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); + + // Trigger the continuation + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + let always_yield = lua + .create_function_with_continuation( + |lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)), + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with((a + 1, 1)), + |lua, status, args: mlua::MultiValue| { + println!("Reached cont recursive/multiple: {:?}", args); + + if args.len() == 5 { + if cfg!(any(feature = "luau", feature = "lua52")) { + assert_eq!(status, mlua::ContinuationStatus::Ok); + } else { + assert_eq!(status, mlua::ContinuationStatus::Yielded); + } + return Ok(6_i32); + } + + lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 + Ok(1_i32) // this will be ignored + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); + + // test panics + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local ok, res = pcall(cont_func, 1) + assert(not ok) + return tostring(res) + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); +} + +#[test] +fn test_large_thread_creation() { + let lua = Lua::new(); + lua.set_memory_limit(100_000_000_000).unwrap(); + let th1 = lua + .create_thread(lua.create_function(|lua, _: ()| Ok(())).unwrap()) + .unwrap(); + + let mut ths = Vec::new(); + for i in 1..2000000 { + let th = lua + .create_thread(lua.create_function(|_, ()| Ok(())).unwrap()) + .expect("Failed to create thread"); + ths.push(th); + } + let th2 = lua + .create_thread(lua.create_function(|lua, _: ()| Ok(())).unwrap()) + .unwrap(); + + for rth in ths { + let dbg_a = format!("{:?}", rth); + let th_a = format!("{:?}", th1); + let th_b = format!("{:?}", th2); + assert!( + th1 != rth && th2 != rth, + "Thread {:?} is equal to th1 ({:?}) or th2 ({:?})", + rth, + th1, + th2 + ); + let dbg_b = format!("{:?}", rth); + let dbg_th1 = format!("{:?}", th1); + let dbg_th2 = format!("{:?}", th2); + + // Ensure that the PartialEq across auxillary threads does not affect the values on stack + // themselves. + assert_eq!(dbg_a, dbg_b, "Thread {:?} debug format changed", rth); + assert_eq!(th_a, dbg_th1, "Thread {:?} debug format changed for th1", rth); + assert_eq!(th_b, dbg_th2, "Thread {:?} debug format changed for th2", rth); + } + + #[cfg(all(not(feature = "lua51"), not(feature = "luajit")))] + { + // Repeat yielded continuation test now with a new aux thread + // Yielding continuation test (only supported on luau) + #[cfg(feature = "luau")] + { + mlua::Lua::set_fflag("LuauYieldableContinuations", true).unwrap(); + } + + let cont_func = lua + .create_function_with_continuation( + |_lua, a: u64| Ok(a + 1), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 2) + }, + ) + .expect("Failed to create cont_func"); + + // Ensure normal calls work still + assert_eq!( + lua.load("local cont_func = ...\nreturn cont_func(1)") + .call::(cont_func) + .expect("Failed to call cont_func"), + 2 + ); + + // basic yield test before we go any further + let always_yield = lua + .create_function(|lua, ()| lua.yield_with((42, "69420".to_string(), 45.6))) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!( + thread.resume::<(i32, String, f32)>(()).unwrap(), + (42, String::from("69420"), 45.6) + ); + + // Trigger the continuation + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, a: u64| { + println!("Reached cont"); + Ok(a + 39) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 41); + + let always_yield = lua + .create_function_with_continuation( + |lua, ()| lua.yield_with((42, "69420".to_string(), 45.6)), + |_lua, _, mv: mlua::MultiValue| { + println!("Reached second continuation"); + if mv.is_empty() { + return Ok(mv); + } + Err(mlua::Error::external(format!("a{}", mv.len()))) + }, + ) + .unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + let mv = thread.resume::(()).unwrap(); + assert!(thread + .resume::(mv) + .unwrap_err() + .to_string() + .starts_with("a3")); + + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with((a + 1, 1)), + |lua, status, args: mlua::MultiValue| { + println!("Reached cont recursive/multiple: {:?}", args); + + if args.len() == 5 { + if cfg!(any(feature = "luau", feature = "lua52")) { + assert_eq!(status, mlua::ContinuationStatus::Ok); + } else { + assert_eq!(status, mlua::ContinuationStatus::Yielded); + } + return Ok(6_i32); + } + + lua.yield_with((args.len() + 1, args))?; // thread state becomes LEN, LEN-1... 1 + Ok(1_i32) // this will be ignored + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local res = cont_func(1) + return res + 1 + ", + ) + .into_function() + .expect("Failed to create function"); + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + println!("v={:?}", v); + + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + println!("v={:?}", v); + let v = th + .resume::(v) + .expect("Failed to load continuation"); + + // (2, 1) followed by () + assert_eq!(v.len(), 2 + 3); + + let v = th.resume::(v).expect("Failed to load continuation"); + + assert_eq!(v, 7); + + // test panics + let cont_func = lua + .create_function_with_continuation( + |lua, a: u64| lua.yield_with(a), + |_lua, _status, _a: u64| { + panic!("Reached continuation which should panic!"); + #[allow(unreachable_code)] + Ok(()) + }, + ) + .expect("Failed to create cont_func"); + + let luau_func = lua + .load( + " + local cont_func = ... + local ok, res = pcall(cont_func, 1) + assert(not ok) + return tostring(res) + ", + ) + .into_function() + .expect("Failed to create function"); + + let th = lua + .create_thread(luau_func) + .expect("Failed to create luau thread"); + + let v = th + .resume::(cont_func) + .expect("Failed to resume"); + + let v = th.resume::(v).expect("Failed to load continuation"); + assert!(v.contains("Reached continuation which should panic!")); + } +}