Skip to content

Commit fc1570d

Browse files
committed
Support yielding from hooks for Lua 5.3+
1 parent fce8538 commit fc1570d

File tree

7 files changed

+83
-31
lines changed

7 files changed

+83
-31
lines changed

src/lib.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ pub use crate::string::{BorrowedBytes, BorrowedStr, String};
116116
pub use crate::table::{Table, TablePairs, TableSequence};
117117
pub use crate::thread::{Thread, ThreadStatus};
118118
pub use crate::traits::ObjectLike;
119-
pub use crate::types::{AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey};
119+
pub use crate::types::{
120+
AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState,
121+
};
120122
pub use crate::userdata::{
121123
AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods, UserDataRef,
122124
UserDataRefMut, UserDataRegistry,
@@ -128,11 +130,7 @@ pub use crate::hook::HookTriggers;
128130

129131
#[cfg(any(feature = "luau", doc))]
130132
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
131-
pub use crate::{
132-
chunk::Compiler,
133-
function::CoverageInfo,
134-
types::{Vector, VmState},
135-
};
133+
pub use crate::{chunk::Compiler, function::CoverageInfo, types::Vector};
136134

137135
#[cfg(feature = "async")]
138136
pub use crate::thread::AsyncThread;

src/prelude.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub use crate::{
1212
ThreadStatus as LuaThreadStatus, UserData as LuaUserData, UserDataFields as LuaUserDataFields,
1313
UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods,
1414
UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut,
15-
UserDataRegistry as LuaUserDataRegistry, Value as LuaValue,
15+
UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, VmState as LuaVmState,
1616
};
1717

1818
#[cfg(not(feature = "luau"))]
@@ -21,7 +21,7 @@ pub use crate::HookTriggers as LuaHookTriggers;
2121

2222
#[cfg(feature = "luau")]
2323
#[doc(no_inline)]
24-
pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector, VmState as LuaVmState};
24+
pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector};
2525

2626
#[cfg(feature = "async")]
2727
#[doc(no_inline)]

src/state.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::table::Table;
1919
use crate::thread::Thread;
2020
use crate::types::{
2121
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, MaybeSend, Number, ReentrantMutex,
22-
ReentrantMutexGuard, RegistryKey, XRc, XWeak,
22+
ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak,
2323
};
2424
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage};
2525
use crate::util::{
@@ -31,7 +31,7 @@ use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, MultiValue, Nil
3131
use crate::hook::HookTriggers;
3232

3333
#[cfg(any(feature = "luau", doc))]
34-
use crate::{chunk::Compiler, types::VmState};
34+
use crate::chunk::Compiler;
3535

3636
#[cfg(feature = "async")]
3737
use {
@@ -499,12 +499,12 @@ impl Lua {
499499
/// Shows each line number of code being executed by the Lua interpreter.
500500
///
501501
/// ```
502-
/// # use mlua::{Lua, HookTriggers, Result};
502+
/// # use mlua::{Lua, HookTriggers, Result, VmState};
503503
/// # fn main() -> Result<()> {
504504
/// let lua = Lua::new();
505505
/// lua.set_hook(HookTriggers::EVERY_LINE, |_lua, debug| {
506506
/// println!("line {}", debug.curr_line());
507-
/// Ok(())
507+
/// Ok(VmState::Continue)
508508
/// });
509509
///
510510
/// lua.load(r#"
@@ -521,7 +521,7 @@ impl Lua {
521521
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
522522
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
523523
where
524-
F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
524+
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
525525
{
526526
let lua = self.lock();
527527
unsafe { lua.set_thread_hook(lua.state(), triggers, callback) };

src/state/raw.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::table::Table;
1818
use crate::thread::Thread;
1919
use crate::types::{
2020
AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData,
21-
MaybeSend, ReentrantMutex, RegistryKey, SubtypeId, ValueRef, XRc,
21+
MaybeSend, ReentrantMutex, RegistryKey, SubtypeId, ValueRef, VmState, XRc,
2222
};
2323
use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataRegistry, UserDataStorage};
2424
use crate::util::{
@@ -356,7 +356,7 @@ impl RawLua {
356356
triggers: HookTriggers,
357357
callback: F,
358358
) where
359-
F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
359+
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
360360
{
361361
unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
362362
let extra = ExtraData::get(state);
@@ -365,17 +365,34 @@ impl RawLua {
365365
ffi::lua_sethook(state, None, 0, 0);
366366
return;
367367
}
368-
callback_error_ext(state, extra, move |extra, _| {
368+
let result = callback_error_ext(state, extra, move |extra, _| {
369369
let hook_cb = (*extra).hook_callback.clone();
370370
let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc");
371371
if std::rc::Rc::strong_count(&hook_cb) > 2 {
372-
return Ok(()); // Don't allow recursion
372+
return Ok(VmState::Continue); // Don't allow recursion
373373
}
374374
let rawlua = (*extra).raw_lua();
375375
let _guard = StateGuard::new(rawlua, state);
376376
let debug = Debug::new(rawlua, ar);
377377
hook_cb((*extra).lua(), debug)
378-
})
378+
});
379+
match result {
380+
VmState::Continue => {}
381+
VmState::Yield => {
382+
// Only count and line events can yield
383+
if (*ar).event == ffi::LUA_HOOKCOUNT || (*ar).event == ffi::LUA_HOOKLINE {
384+
#[cfg(any(feature = "lua54", feature = "lua53"))]
385+
if ffi::lua_isyieldable(state) != 0 {
386+
ffi::lua_yield(state, 0);
387+
}
388+
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
389+
{
390+
ffi::lua_pushliteral(state, "attempt to yield from a hook");
391+
ffi::lua_error(state);
392+
}
393+
}
394+
}
395+
}
379396
}
380397

381398
(*self.extra.get()).hook_callback = Some(std::rc::Rc::new(callback));

src/thread.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::error::{Error, Result};
44
#[allow(unused)]
55
use crate::state::Lua;
66
use crate::state::RawLua;
7-
use crate::types::ValueRef;
7+
use crate::types::{ValueRef, VmState};
88
use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
99
use crate::value::{FromLuaMulti, IntoLuaMulti};
1010

@@ -194,7 +194,7 @@ impl Thread {
194194
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
195195
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
196196
where
197-
F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
197+
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
198198
{
199199
let lua = self.0.lua.lock();
200200
unsafe {

src/types.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,19 @@ pub(crate) type AsyncCallbackUpvalue = Upvalue<AsyncCallback>;
7676
pub(crate) type AsyncPollUpvalue = Upvalue<BoxFuture<'static, Result<c_int>>>;
7777

7878
/// Type to set next Luau VM action after executing interrupt function.
79-
#[cfg(any(feature = "luau", doc))]
80-
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
8179
pub enum VmState {
8280
Continue,
81+
/// Yield the current thread.
82+
///
83+
/// Supported by Lua 5.3+ and Luau.
8384
Yield,
8485
}
8586

8687
#[cfg(all(feature = "send", not(feature = "luau")))]
87-
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<()> + Send>;
88+
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;
8889

8990
#[cfg(all(not(feature = "send"), not(feature = "luau")))]
90-
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<()>>;
91+
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState>>;
9192

9293
#[cfg(all(feature = "send", feature = "luau"))]
9394
pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;

tests/hooks.rs

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::ops::Deref;
44
use std::sync::atomic::{AtomicI64, Ordering};
55
use std::sync::{Arc, Mutex};
66

7-
use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, Value};
7+
use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, ThreadStatus, Value, VmState};
88

99
#[test]
1010
fn test_hook_triggers() {
@@ -26,7 +26,7 @@ fn test_line_counts() -> Result<()> {
2626
lua.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| {
2727
assert_eq!(debug.event(), DebugEvent::Line);
2828
hook_output.lock().unwrap().push(debug.curr_line());
29-
Ok(())
29+
Ok(VmState::Continue)
3030
});
3131
lua.load(
3232
r#"
@@ -61,7 +61,7 @@ fn test_function_calls() -> Result<()> {
6161
let source = debug.source();
6262
let name = names.name.map(|s| s.into_owned());
6363
hook_output.lock().unwrap().push((name, source.what));
64-
Ok(())
64+
Ok(VmState::Continue)
6565
});
6666

6767
lua.load(
@@ -120,7 +120,7 @@ fn test_limit_execution_instructions() -> Result<()> {
120120
if max_instructions.fetch_sub(30, Ordering::Relaxed) <= 30 {
121121
Err(Error::runtime("time's up"))
122122
} else {
123-
Ok(())
123+
Ok(VmState::Continue)
124124
}
125125
},
126126
);
@@ -191,10 +191,10 @@ fn test_hook_swap_within_hook() -> Result<()> {
191191
TL_LUA.with(|tl| {
192192
tl.borrow().as_ref().unwrap().remove_hook();
193193
});
194-
Ok(())
194+
Ok(VmState::Continue)
195195
})
196196
});
197-
Ok(())
197+
Ok(VmState::Continue)
198198
})
199199
});
200200

@@ -234,7 +234,7 @@ fn test_hook_threads() -> Result<()> {
234234
co.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| {
235235
assert_eq!(debug.event(), DebugEvent::Line);
236236
hook_output.lock().unwrap().push(debug.curr_line());
237-
Ok(())
237+
Ok(VmState::Continue)
238238
});
239239

240240
co.resume::<()>(())?;
@@ -249,3 +249,39 @@ fn test_hook_threads() -> Result<()> {
249249

250250
Ok(())
251251
}
252+
253+
#[test]
254+
fn test_hook_yield() -> Result<()> {
255+
let lua = Lua::new();
256+
257+
let func = lua
258+
.load(
259+
r#"
260+
local x = 2 + 3
261+
local y = x * 63
262+
local z = string.len(x..", "..y)
263+
"#,
264+
)
265+
.into_function()?;
266+
let co = lua.create_thread(func)?;
267+
268+
co.set_hook(HookTriggers::EVERY_LINE, move |_lua, _debug| Ok(VmState::Yield));
269+
270+
#[cfg(any(feature = "lua54", feature = "lua53"))]
271+
{
272+
assert!(co.resume::<()>(()).is_ok());
273+
assert!(co.resume::<()>(()).is_ok());
274+
assert!(co.resume::<()>(()).is_ok());
275+
assert!(co.resume::<()>(()).is_ok());
276+
assert!(co.status() == ThreadStatus::Finished);
277+
}
278+
#[cfg(any(feature = "lua51", feature = "lua52", feature = "luajit"))]
279+
{
280+
assert!(
281+
matches!(co.resume::<()>(()), Err(Error::RuntimeError(err)) if err.contains("attempt to yield from a hook"))
282+
);
283+
assert!(co.status() == ThreadStatus::Error);
284+
}
285+
286+
Ok(())
287+
}

0 commit comments

Comments
 (0)