Skip to content

Commit fa217d3

Browse files
committed
Better Luau buffer type support.
- Add `Lua::create_buffer()` function - Support serializing buffer type as a byte slice - Support accessing copy of underlying bytes using `BString`
1 parent b62f2ee commit fa217d3

File tree

7 files changed

+162
-5
lines changed

7 files changed

+162
-5
lines changed

src/conversion.rs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -679,19 +679,49 @@ impl<'lua> IntoLua<'lua> for BString {
679679
}
680680

681681
impl<'lua> FromLua<'lua> for BString {
682-
#[inline]
683682
fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result<Self> {
684683
let ty = value.type_name();
685-
Ok(BString::from(
686-
lua.coerce_string(value)?
684+
match value {
685+
Value::String(s) => Ok(s.as_bytes().into()),
686+
#[cfg(feature = "luau")]
687+
Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
688+
let mut size = 0usize;
689+
let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size);
690+
mlua_assert!(!buf.is_null(), "invalid Luau buffer");
691+
Ok(slice::from_raw_parts(buf as *const u8, size).into())
692+
},
693+
_ => Ok(lua
694+
.coerce_string(value)?
687695
.ok_or_else(|| Error::FromLuaConversionError {
688696
from: ty,
689697
to: "BString",
690698
message: Some("expected string or number".to_string()),
691699
})?
692700
.as_bytes()
693-
.to_vec(),
694-
))
701+
.into()),
702+
}
703+
}
704+
705+
unsafe fn from_stack(idx: c_int, lua: &'lua Lua) -> Result<Self> {
706+
let state = lua.state();
707+
match ffi::lua_type(state, idx) {
708+
ffi::LUA_TSTRING => {
709+
let mut size = 0;
710+
let data = ffi::lua_tolstring(state, idx, &mut size);
711+
Ok(slice::from_raw_parts(data as *const u8, size).into())
712+
}
713+
#[cfg(feature = "luau")]
714+
ffi::LUA_TBUFFER => {
715+
let mut size = 0;
716+
let buf = ffi::lua_tobuffer(state, idx, &mut size);
717+
mlua_assert!(!buf.is_null(), "invalid Luau buffer");
718+
Ok(slice::from_raw_parts(buf as *const u8, size).into())
719+
}
720+
_ => {
721+
// Fallback to default
722+
Self::from_lua(lua.stack_value(idx), lua)
723+
}
724+
}
695725
}
696726
}
697727

src/lua.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,27 @@ impl Lua {
13731373
}
13741374
}
13751375

1376+
/// Create and return a Luau [buffer] object from a byte slice of data.
1377+
///
1378+
/// Requires `feature = "luau"`
1379+
///
1380+
/// [buffer]: https://luau-lang.org/library#buffer-library
1381+
#[cfg(feature = "luau")]
1382+
pub fn create_buffer(&self, buf: impl AsRef<[u8]>) -> Result<AnyUserData> {
1383+
let state = self.state();
1384+
unsafe {
1385+
if self.unlikely_memory_error() {
1386+
crate::util::push_buffer(self.ref_thread(), buf.as_ref(), false)?;
1387+
return Ok(AnyUserData(self.pop_ref_thread(), SubtypeId::Buffer));
1388+
}
1389+
1390+
let _sg = StackGuard::new(state);
1391+
check_stack(state, 4)?;
1392+
crate::util::push_buffer(state, buf.as_ref(), true)?;
1393+
Ok(AnyUserData(self.pop_ref(), SubtypeId::Buffer))
1394+
}
1395+
}
1396+
13761397
/// Creates and returns a new empty table.
13771398
pub fn create_table(&self) -> Result<Table> {
13781399
self.create_table_with_capacity(0, 0)

src/serde/de.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
148148
Value::UserData(ud) if ud.is_serializable() => {
149149
serde_userdata(ud, |value| value.deserialize_any(visitor))
150150
}
151+
#[cfg(feature = "luau")]
152+
Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
153+
let mut size = 0usize;
154+
let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size);
155+
mlua_assert!(!buf.is_null(), "invalid Luau buffer");
156+
let buf = std::slice::from_raw_parts(buf as *const u8, size);
157+
visitor.visit_bytes(buf)
158+
},
151159
Value::Function(_)
152160
| Value::Thread(_)
153161
| Value::UserData(_)

src/userdata.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,19 @@ impl<'lua> Serialize for AnyUserData<'lua> {
13401340
S: Serializer,
13411341
{
13421342
let lua = self.0.lua;
1343+
1344+
// Special case for Luau buffer type
1345+
#[cfg(feature = "luau")]
1346+
if self.1 == SubtypeId::Buffer {
1347+
let buf = unsafe {
1348+
let mut size = 0usize;
1349+
let buf = ffi::lua_tobuffer(lua.ref_thread(), self.0.index, &mut size);
1350+
mlua_assert!(!buf.is_null(), "invalid Luau buffer");
1351+
std::slice::from_raw_parts(buf as *const u8, size)
1352+
};
1353+
return serializer.serialize_bytes(buf);
1354+
}
1355+
13431356
let data = unsafe {
13441357
let _ = lua
13451358
.get_userdata_ref_type_id(&self.0)

src/util/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,20 @@ pub unsafe fn push_string(state: *mut ffi::lua_State, s: &[u8], protect: bool) -
253253
}
254254
}
255255

256+
// Uses 3 stack spaces (when protect), does not call checkstack.
257+
#[cfg(feature = "luau")]
258+
#[inline(always)]
259+
pub unsafe fn push_buffer(state: *mut ffi::lua_State, b: &[u8], protect: bool) -> Result<()> {
260+
let data = if protect {
261+
protect_lua!(state, 0, 1, |state| ffi::lua_newbuffer(state, b.len()))?
262+
} else {
263+
ffi::lua_newbuffer(state, b.len())
264+
};
265+
let buf = slice::from_raw_parts_mut(data as *mut u8, b.len());
266+
buf.copy_from_slice(b);
267+
Ok(())
268+
}
269+
256270
// Uses 3 stack spaces, does not call checkstack.
257271
#[inline]
258272
pub unsafe fn push_table(

tests/conversion.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::Cow;
22
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
33
use std::ffi::{CStr, CString};
44

5+
use bstr::BString;
56
use maplit::{btreemap, btreeset, hashmap, hashset};
67
use mlua::{
78
AnyUserData, Error, Function, IntoLua, Lua, RegistryKey, Result, Table, Thread, UserDataRef,
@@ -409,3 +410,46 @@ fn test_conv_array() -> Result<()> {
409410

410411
Ok(())
411412
}
413+
414+
#[test]
415+
fn test_bstring_from_lua() -> Result<()> {
416+
let lua = Lua::new();
417+
418+
let s = lua.create_string("hello, world")?;
419+
let bstr = lua.unpack::<BString>(Value::String(s))?;
420+
assert_eq!(bstr, "hello, world");
421+
422+
let bstr = lua.unpack::<BString>(Value::Integer(123))?;
423+
assert_eq!(bstr, "123");
424+
425+
let bstr = lua.unpack::<BString>(Value::Number(-123.55))?;
426+
assert_eq!(bstr, "-123.55");
427+
428+
// Test from stack
429+
let f = lua.create_function(|_, bstr: BString| Ok(bstr))?;
430+
let bstr = f.call::<_, BString>("hello, world")?;
431+
assert_eq!(bstr, "hello, world");
432+
433+
let bstr = f.call::<_, BString>(-43.22)?;
434+
assert_eq!(bstr, "-43.22");
435+
436+
Ok(())
437+
}
438+
439+
#[cfg(feature = "luau")]
440+
#[test]
441+
fn test_bstring_from_lua_buffer() -> Result<()> {
442+
let lua = Lua::new();
443+
444+
let b = lua.create_buffer("hello, world")?;
445+
let bstr = lua.unpack::<BString>(Value::UserData(b))?;
446+
assert_eq!(bstr, "hello, world");
447+
448+
// Test from stack
449+
let f = lua.create_function(|_, bstr: BString| Ok(bstr))?;
450+
let buf = lua.create_buffer("hello, world")?;
451+
let bstr = f.call::<_, BString>(buf)?;
452+
assert_eq!(bstr, "hello, world");
453+
454+
Ok(())
455+
}

tests/serde.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,3 +728,30 @@ fn test_arbitrary_precision() {
728728
"{\n [\"$serde_json::private::Number\"] = \"124.4\",\n}"
729729
);
730730
}
731+
732+
#[cfg(feature = "luau")]
733+
#[test]
734+
fn test_buffer_serialize() {
735+
let lua = Lua::new();
736+
737+
let buf = lua.create_buffer(&[1, 2, 3, 4]).unwrap();
738+
let val = serde_value::to_value(&buf).unwrap();
739+
assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4]));
740+
741+
// Try empty buffer
742+
let buf = lua.create_buffer(&[]).unwrap();
743+
let val = serde_value::to_value(&buf).unwrap();
744+
assert_eq!(val, serde_value::Value::Bytes(vec![]));
745+
}
746+
747+
#[cfg(feature = "luau")]
748+
#[test]
749+
fn test_buffer_from_value() {
750+
let lua = Lua::new();
751+
752+
let buf = lua.create_buffer(&[1, 2, 3, 4]).unwrap();
753+
let val = lua
754+
.from_value::<serde_value::Value>(Value::UserData(buf))
755+
.unwrap();
756+
assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4]));
757+
}

0 commit comments

Comments
 (0)