diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a453361f..80e5f353 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -28,7 +28,7 @@ variables: KANIKO_REGISTRY_MIRROR: docker-proxy.binary.picodata.io CACHE_PATHS: target CARGO_INCREMENTAL: 0 - RUST_VERSION: "1.71" # Note: without the quotes yaml thinks this is a float and removes the trailing 0 + RUST_VERSION: "1.82" # Note: without the quotes yaml thinks this is a float and removes the trailing 0 CARGO_HOME: /shared-storage/tarantool-module/.cargo BASE_IMAGE_VANILLA: docker-public.binary.picodata.io/tarantool-module-build-base-vanilla BASE_IMAGE_FORK: docker-public.binary.picodata.io/tarantool-module-build-base-fork diff --git a/CHANGELOG.md b/CHANGELOG.md index 87eb9d69..6dd572c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,15 +1,146 @@ -# Change Log + + +# Change Log + +# [8.0.1] Unreleased + +### Changed + +- Unpinned the version of `linkme` crate (`0.3.29` -> `0.3`). + +- Replaced dependency from unmaintained `dlopen` crate to `libloading` of version `0.3`. + +- Replaced dependency from unmaintained `proc-macro-error` crate to `proc-macro-error2` of version `2`. + +### Fixed +- Unpinned the version of `time` crate from a two-year old `0.3.17` +- `space::Space::bsize` now returns `box.space.:bsize()` instead of `box.space..index[0]:bsize()` + + +# [8.0.0] Jun 24 2025 + +### Changed + +- Unpinned the version of `tokio` crate optional dependency + +### Fixed + +- Fix `tlua::Push`, `tlua::PushInto` and `tlua::LuaRead` derive macros to support default values for type parameters, + and `tlua::LuaRead` to support default values for default values for const parameters. + +### Deprecated + +### Breaking + +- Split a single `index::Part` type into `Part`, `Part` and `Part`, + depending on the accepted/returned types by the tarantool API + +# [7.0.0] May 26 2025 + +### Added + +- `AuthMethod::DEFAULT` represents the default value of `AuthMethod`, but is available in constant contexts. +- Support MsgPack ExtType encoding and decoding via `ExtStruct` +- Support Tuple msgpack decoding +- Support Decimal, UUID, Datetime msgpack encoding and decoding + +### Changed + +- 1.82 is now MSRV. + +### Fixed + +- `network::protocol::codec::Header::encode` will not truncate to `u8` an + `network::protocol::codec::IProtoType` integer from an outside stream at decoding. + +### Deprecated + +### Breaking changes + +- Use `extern "C-unwind"` instead of `extern "C"` for lua ffi functions, which + seems to help with lua_error!() in release builds on recent versions of rust. + +### Added (picodata) + +- `ffi::sql::PortC` either allows to append tuples and msgpacks to the port or + to iterate over the port data. It is also possible to manipulate with `PortC` + from the `FunctionCtx` of the called function. +- Introduce `sql::sql_execute_into_port` and `sql::Statement::execute_into_port` + methods. Now it is possible to store SQL results directly into the port. + +### Changed (picodata) + +### Fixed (picodata) + +### Breaking changes (picodata) + +- `AuthMethod::default()` is now `Md5`, not `ChapSha1`. +- `tarantool::tuple::Tuple::{as_named_buffer, names, name_count}` did not find their use, so they are now logging an + error message and panic. These functions are scheduled for removal in next major release and labeled as deprecated. + +### Deprecated (picodata) + +# [6.1.0] Dec 10 2024 + +### Added + - `network::client::tcp::TcpStream` not supports async connection, provided with `connect_async` and `connect_timeout_async` methods +- `impl Default for log::TarantoolLogger` + +### Changed +- `error::Result` type alias now has another generic parameter E, which defaults + to `tarantool::error::Error`, but allows this type alias to be used as a drop + in replacement from `std::result::Result`. + +### Fixed + +- `network::client::tcp::TcpStream` does not close underlying fd anymore. Now fd will be closed only when the last copy of tcp stream is dropped. + +### Deprecated +- `network::client::tcp::UnsafeSendSyncTcpStream` is now deprected. `network::client::tcp::TcpStream` should be used instead. # [6.0.0] Nov 20 2024 ### Added + - `tlua::Push` trait implementations for `OsString`, `OsStr`, `Path`, `PathBuf` - `tlua::LuaRead` trait implementations for `OsString`, `PathBuf` - tlua::LuaTable::metatable which is a better alternative to the existing `tlua::LuaTable::get_or_create_metatable` @@ -26,6 +157,7 @@ restricting time connection establishment. - `tlua::Nil` now supports (de)serialization via serde ### Changed + - `network::protocol::codec::IProtoType` uses C language representation - `cbus::sync::std::ThreadWaker` now uses internal thread FIFO queue when blocking threads on send. - `#[tarantool::proc]` attribute doesn't add procs to a global array unless @@ -33,6 +165,7 @@ restricting time connection establishment. - `proc::all_procs` will now panic if `stored_procs_slice` feature is disabled. ### Fixed + - `tlua::{Push, PushInto, LuaRead}` now work for HashSet & HashMap with custom hashers. - Use after free in `fiber::Builder::start_non_joinable` when the fiber exits without yielding. - Incorrect, off-spec MP Ext type: caused runtime errors on some platforms. @@ -40,19 +173,20 @@ restricting time connection establishment. - Impossible to use procedural macros(like `tarantool::proc`, `tarantool::test`) through reexporting tarantool. ### Deprecated + - tlua::LuaTable::get_or_create_metatable is deprecated now in favor of tlua::LuaTable::metatable. ### Breaking changes + - Use `extern "C-unwind"` instead of `extern "C"` for all trampolines which take `*mut ffi::lua_State` (checked with `rg 'extern "C".*lua_State'`). `tlua::error!` throws an exception to unwind the stack, hence we need to use a proper ABI to fix UB in picodata. ### Breaking changes (picodata) + - Add session ID to the argument list of the `sql_prepare_ext`. - Replace `sql_unprepare` with `sql_unprepare_ext` (contains an additional session ID argument). - - # [5.0.0] Aug 06 2024 ### Added @@ -216,12 +350,14 @@ restricting time connection establishment. **experimental** rust allocated implementation of tuple virtual table. ### Changed (picodata) + - `Tuple::decode` & `ToTupleBuffer` implementation for `Tuple` is now a bit more efficient because one redundant tuple data copy is removed. ### Fixed (picodata) ### Breaking changes (picodata) + - SQL module was totally refactored: all its public structures functions and FFIs have been changed. @@ -294,7 +430,7 @@ restricting time connection establishment. - `define_str_enum` will no longer produce warning "`&` without an explicit lifetime name cannot be used here". For more information, see - https://github.com/rust-lang/rust/issues/115010. + . - `#[tarantool::test]` declares a special static variable which is usually invisible to the users, but previously it would have a not so unique name which would sometimes lead to name conflicts with user-defined items. diff --git a/Cargo.toml b/Cargo.toml index e80f4dcc..efedc8b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "1" members = [ "tarantool", "tarantool-proc", diff --git a/README.md b/README.md index f86665a4..a90ef668 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ rustflags = [ Add the following lines to your project's Cargo.toml: ```toml [dependencies] -tarantool = "6.0" +tarantool = "7.0" [lib] crate-type = ["cdylib"] @@ -66,8 +66,10 @@ See https://github.com/picodata/brod for example usage. ### Features -- `net_box` - Enables protocol implementation (enabled by default) -- `schema` - Enables schema manipulation utils (WIP as of now) +- `net_box` - Enables IPROTO fiber aware client (enabled by default) +- `network_client` - Enables a second implementation of IPROTO fiber aware client but with Rust async support (enabled by default) +- `picodata` - Enables support for custom features added in picodata's fork of tarantool +- `tokio_components` - Enables support for interfacing with tokio based logic through cbus ### Stored procedures @@ -103,7 +105,7 @@ edition = "2018" # author, license, etc [dependencies] -tarantool = "6.0" +tarantool = "7.0" serde = "1.0" [lib] diff --git a/examples/async-h1-client/src/lib.rs b/examples/async-h1-client/src/lib.rs index a27ae5a5..0132d540 100644 --- a/examples/async-h1-client/src/lib.rs +++ b/examples/async-h1-client/src/lib.rs @@ -2,7 +2,6 @@ use http_types::{Method, Request, Url}; use tarantool::error::Error; use tarantool::fiber; use tarantool::network::client::tcp::TcpStream; -use tarantool::network::client::tcp::UnsafeSendSyncTcpStream; use tarantool::proc; #[proc] @@ -17,14 +16,12 @@ fn get(url: &str) -> Result<(), Error> { let mut res = match url.scheme() { "http" => { let stream = TcpStream::connect(host, 80).map_err(Error::other)?; - let stream = UnsafeSendSyncTcpStream(stream); println!("Sending request over http..."); async_h1::connect(stream, req).await.map_err(Error::other)? } #[cfg(feature = "tls")] "https" => { let stream = TcpStream::connect(host, 443).map_err(Error::other)?; - let stream = UnsafeSendSyncTcpStream(stream); let stream = async_native_tls::connect(host, stream) .await .map_err(Error::other)?; diff --git a/examples/tokio-hyper/Cargo.toml b/examples/tokio-hyper/Cargo.toml index e9e8144e..b6643c1d 100644 --- a/examples/tokio-hyper/Cargo.toml +++ b/examples/tokio-hyper/Cargo.toml @@ -7,7 +7,7 @@ license = "BSD-2-Clause" [dependencies] tarantool = { path = "../../tarantool" } hyper = { version = "0.14", features = ["full"] } -tokio = { version = "=1.29.1", features = ["full"] } +tokio = { version = "1", features = ["full"] } futures-util = "*" http-body-util = "0.1.0-rc.2" env_logger = "0.9.0" diff --git a/tarantool-proc/Cargo.toml b/tarantool-proc/Cargo.toml index ae0bd48a..7de63a58 100644 --- a/tarantool-proc/Cargo.toml +++ b/tarantool-proc/Cargo.toml @@ -4,12 +4,12 @@ authors = [ ] name = "tarantool-proc" description = "Tarantool proc macros" -version = "3.2.0" +version = "4.0.0" edition = "2021" license = "BSD-2-Clause" documentation = "https://docs.rs/tarantool-proc/" repository = "https://github.com/picodata/tarantool-module" -rust-version = "1.61" +rust-version = "1.82" [lib] proc-macro = true @@ -19,7 +19,7 @@ syn = { version = "^1.0", features = [ "full", "extra-traits" ] } quote = "^1.0" proc-macro2 = "^1.0" darling = "0.14.2" -proc-macro-error = "1" +proc-macro-error2 = "2" [features] stored_procs_slice = [] diff --git a/tarantool-proc/src/lib.rs b/tarantool-proc/src/lib.rs index c8c469fe..425e0e06 100644 --- a/tarantool-proc/src/lib.rs +++ b/tarantool-proc/src/lib.rs @@ -1,6 +1,6 @@ use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; -use proc_macro_error::{proc_macro_error, SpanRange}; +use proc_macro_error2::{proc_macro_error, SpanRange}; use quote::{quote, ToTokens}; use syn::{ parse_macro_input, parse_quote, punctuated::Punctuated, Attribute, AttributeArgs, DeriveInput, @@ -34,7 +34,7 @@ pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream { mod msgpack { use darling::FromDeriveInput; use proc_macro2::TokenStream; - use proc_macro_error::{abort, SpanRange}; + use proc_macro_error2::{abort, SpanRange}; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::{ parse_quote, spanned::Spanned, Data, Field, Fields, FieldsNamed, FieldsUnnamed, diff --git a/tarantool/Cargo.toml b/tarantool/Cargo.toml index 2acfb529..3fcbef47 100644 --- a/tarantool/Cargo.toml +++ b/tarantool/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tarantool" description = "Tarantool rust bindings" -version = "6.0.0" +version = "8.0.1" authors = [ "Dmitriy Koltsov ", "Georgy Moshkin ", @@ -15,18 +15,18 @@ documentation = "https://docs.rs/tarantool/" repository = "https://github.com/picodata/tarantool-module" keywords = ["ffi", "database", "tarantool"] categories = ["database"] -rust-version = "1.71" +rust-version = "1.82" [dependencies] base64 = "0.13" bitflags = "1.2" dec = { version = "0.4.8", optional = true } -dlopen = "0.1.8" +libloading = "0.8" thiserror = "1.0.30" libc = { version = "0.2", features = ["extra_traits"] } log = "0.4" once_cell = "1.4.0" -tlua = { path = "../tlua", version = "4.0.0" } +tlua = { path = "../tlua", version = "6.0.1" } refpool = { version = "0.4.3", optional = true } rmp = "0.8.11" rmp-serde = "1.1" @@ -36,13 +36,13 @@ serde_json = "1.0" serde_bytes = "^0" sha-1 = "0.9" md-5 = "0.10" -tarantool-proc = { path = "../tarantool-proc", version = "3.2.0" } +tarantool-proc = { path = "../tarantool-proc", version = "4.0.0" } uuid = "0.8.2" futures = "0.3.25" -linkme = "0.3.29" +linkme = "0.3" async-trait = "0.1.64" tester = { version = "0.7.0", optional = true } -time = ">=0.3.0, <0.3.18" +time = "0.3.37" crossbeam-queue = { version = "0.3.8", optional = true } async-std = { version = "1.12.0", optional = true, default-features = false, features = [ "std", @@ -50,11 +50,7 @@ async-std = { version = "1.12.0", optional = true, default-features = false, fea pretty_assertions = { version = "1.4", optional = true } tempfile = { version = "3.9", optional = true } va_list = ">=0.1.4" -tokio = { version = "=1.29.1", features = [ - "sync", - "rt", - "time", -], optional = true } +tokio = { version = "1", features = ["sync", "rt", "time"], optional = true } anyhow = { version = "1", optional = true } [features] @@ -81,5 +77,5 @@ standalone_decimal = ["dec"] stored_procs_slice = ["tarantool-proc/stored_procs_slice"] [dev-dependencies] -time-macros = "=0.2.6" +time-macros = "0.2.6" pretty_assertions = "1.4" diff --git a/tarantool/src/auth.rs b/tarantool/src/auth.rs index 21853672..b2edb9ae 100644 --- a/tarantool/src/auth.rs +++ b/tarantool/src/auth.rs @@ -9,17 +9,27 @@ crate::define_str_enum! { } } +#[cfg(not(feature = "picodata"))] +impl AuthMethod { + pub const DEFAULT: Self = Self::ChapSha1; +} + #[cfg(feature = "picodata")] crate::define_str_enum! { #[derive(Default)] pub enum AuthMethod { - #[default] ChapSha1 = "chap-sha1", + #[default] Md5 = "md5", Ldap = "ldap", } } +#[cfg(feature = "picodata")] +impl AuthMethod { + pub const DEFAULT: Self = Self::Md5; +} + #[cfg(feature = "picodata")] mod picodata { use super::AuthMethod; diff --git a/tarantool/src/cbus/mod.rs b/tarantool/src/cbus/mod.rs index ce690090..ea018ec6 100644 --- a/tarantool/src/cbus/mod.rs +++ b/tarantool/src/cbus/mod.rs @@ -239,6 +239,7 @@ mod tests { use crate::cbus::Message; use crate::fiber; use crate::fiber::Cond; + use crate::static_ref; use std::thread; use std::thread::ThreadId; @@ -280,9 +281,12 @@ mod tests { cond.wait(); unsafe { - assert!(SENDER_THREAD_ID.is_some()); - assert!(TX_THREAD_ID.is_some()); - assert_ne!(SENDER_THREAD_ID, TX_THREAD_ID); + assert!(static_ref!(const SENDER_THREAD_ID).is_some()); + assert!(static_ref!(const TX_THREAD_ID).is_some()); + assert_ne!( + static_ref!(const SENDER_THREAD_ID), + static_ref!(const TX_THREAD_ID) + ); } thread.join().unwrap(); diff --git a/tarantool/src/cbus/oneshot.rs b/tarantool/src/cbus/oneshot.rs index dff6383a..8c0eaa62 100644 --- a/tarantool/src/cbus/oneshot.rs +++ b/tarantool/src/cbus/oneshot.rs @@ -63,7 +63,7 @@ pub struct EndpointReceiver { /// # Arguments /// /// * `cbus_endpoint`: cbus endpoint name. Note that the tx thread (or any other cord) -/// must have a fiber occupied by the endpoint cbus_loop. +/// must have a fiber occupied by the endpoint cbus_loop. /// /// # Examples /// diff --git a/tarantool/src/cbus/sync/mod.rs b/tarantool/src/cbus/sync/mod.rs index 1c3a3794..4220e105 100644 --- a/tarantool/src/cbus/sync/mod.rs +++ b/tarantool/src/cbus/sync/mod.rs @@ -1,12 +1,12 @@ #![cfg(any(feature = "picodata", doc))] -/// A synchronous channels for popular runtimes. -/// Synchronous channel - means that channel has internal buffer with user-defined capacity. -/// Synchronous channel differs against of unbounded channel in the semantics of the sender: if -/// channel buffer is full then all sends called from producer will block a runtime, until channel -/// buffer is freed. -/// -/// It is important to use a channel that suits the runtime in which the producer works. +//! A synchronous channels for popular runtimes. +//! Synchronous channel - means that channel has internal buffer with user-defined capacity. +//! Synchronous channel differs against of unbounded channel in the semantics of the sender: if +//! channel buffer is full then all sends called from producer will block a runtime, until channel +//! buffer is freed. +//! +//! It is important to use a channel that suits the runtime in which the producer works. /// A channels for messaging between an OS thread (producer) and tarantool cord (consumer). pub mod std; diff --git a/tarantool/src/cbus/sync/std.rs b/tarantool/src/cbus/sync/std.rs index 17361c8d..e5aab459 100644 --- a/tarantool/src/cbus/sync/std.rs +++ b/tarantool/src/cbus/sync/std.rs @@ -113,7 +113,7 @@ impl Channel { /// # Arguments /// /// * `cbus_endpoint`: cbus endpoint name. Note that the tx thread (or any other cord) -/// must have a fiber occupied by the endpoint cbus_loop. +/// must have a fiber occupied by the endpoint cbus_loop. /// * `cap`: specifies the buffer size. /// /// # Examples diff --git a/tarantool/src/cbus/sync/tokio.rs b/tarantool/src/cbus/sync/tokio.rs index 5da1b681..9588af20 100644 --- a/tarantool/src/cbus/sync/tokio.rs +++ b/tarantool/src/cbus/sync/tokio.rs @@ -90,7 +90,7 @@ impl Channel { /// # Arguments /// /// * `cbus_endpoint`: cbus endpoint name. Note that the tx thread (or any other cord) -/// must have a fiber occupied by the endpoint cbus_loop. +/// must have a fiber occupied by the endpoint cbus_loop. /// * `cap`: specifies the buffer size. /// /// # Examples @@ -495,7 +495,7 @@ mod tests { }); for _ in 0..MESSAGES_PER_PRODUCER * 3 { - assert!(matches!(rx.receive(), Ok(_))); + assert!(rx.receive().is_ok()); } assert!(matches!(rx.receive(), Err(RecvError::Disconnected))); @@ -542,7 +542,7 @@ mod tests { // assert that all threads produce 10 messages and sleep after assert_eq!(SEND_COUNTER.load(Ordering::SeqCst), (i + 1) * 10); for _ in 0..10 { - assert!(matches!(rx.receive(), Ok(_))); + assert!(rx.receive().is_ok()); } fiber::sleep(Duration::from_millis(100)); } diff --git a/tarantool/src/cbus/unbounded.rs b/tarantool/src/cbus/unbounded.rs index 0af1fd61..9d92f800 100644 --- a/tarantool/src/cbus/unbounded.rs +++ b/tarantool/src/cbus/unbounded.rs @@ -98,7 +98,7 @@ impl Channel { /// # Arguments /// /// * `cbus_endpoint`: cbus endpoint name. Note that the tx thread (or any other cord) -/// must have a fiber occupied by the endpoint cbus_loop. +/// must have a fiber occupied by the endpoint cbus_loop. /// /// # Examples /// diff --git a/tarantool/src/datetime.rs b/tarantool/src/datetime.rs index 2cc837f3..214851dd 100644 --- a/tarantool/src/datetime.rs +++ b/tarantool/src/datetime.rs @@ -1,7 +1,11 @@ use crate::ffi::datetime as ffi; +use crate::msgpack; +use crate::msgpack::{Context, Decode, DecodeError, Encode, EncodeError}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; +use std::convert::{TryFrom, TryInto}; use std::fmt::Display; +use std::io::Write; use time::{Duration, UtcOffset}; type Inner = time::OffsetDateTime; @@ -99,6 +103,39 @@ impl Datetime { tzindex: 0, } } + + fn msgpack_bytes(&self) -> ([u8; 16], usize) { + let data = self.as_bytes_tt(); + let mut len = data.len(); + if data[8..] == [0, 0, 0, 0, 0, 0, 0, 0] { + len = 8; + } + (data, len) + } + + fn from_ext_structure(tag: i8, bytes: &[u8]) -> Result { + if tag != ffi::MP_DATETIME { + return Err(format!("Expected Datetime, found msgpack ext #{}", tag)); + } + + if bytes.len() != 8 && bytes.len() != 16 { + return Err(format!( + "Unexpected number of bytes for Datetime: expected 8 or 16, got {}", + bytes.len() + )); + } + + Self::from_bytes_tt(bytes).map_err(|_| "Error decoding msgpack bytes".to_string()) + } +} + +impl<'a> TryFrom> for Datetime { + type Error = String; + + #[inline(always)] + fn try_from(value: msgpack::ExtStruct<'a>) -> Result { + Self::from_ext_structure(value.tag, value.data) + } } impl From for Datetime { @@ -122,9 +159,23 @@ impl Display for Datetime { } //////////////////////////////////////////////////////////////////////////////// -/// Tuple +// Tuple //////////////////////////////////////////////////////////////////////////////// +impl Encode for Datetime { + fn encode(&self, w: &mut impl Write, context: &Context) -> Result<(), EncodeError> { + let (data, len) = self.msgpack_bytes(); + msgpack::ExtStruct::new(ffi::MP_DATETIME, &data[..len]).encode(w, context) + } +} +impl<'de> Decode<'de> for Datetime { + fn decode(r: &mut &'de [u8], context: &Context) -> Result { + msgpack::ExtStruct::decode(r, context)? + .try_into() + .map_err(DecodeError::new::) + } +} + impl serde::Serialize for Datetime { fn serialize(&self, serializer: S) -> Result where @@ -133,12 +184,9 @@ impl serde::Serialize for Datetime { #[derive(Serialize)] struct _ExtStruct<'a>((i8, &'a serde_bytes::Bytes)); - let data = self.as_bytes_tt(); - let mut data = data.as_slice(); - if data[8..] == [0, 0, 0, 0, 0, 0, 0, 0] { - data = &data[..8]; - } - _ExtStruct((ffi::MP_DATETIME, serde_bytes::Bytes::new(data))).serialize(serializer) + let (data, len) = self.msgpack_bytes(); + + _ExtStruct((ffi::MP_DATETIME, serde_bytes::Bytes::new(&data[..len]))).serialize(serializer) } } @@ -152,28 +200,12 @@ impl<'de> serde::Deserialize<'de> for Datetime { let _ExtStruct((kind, bytes)) = serde::Deserialize::deserialize(deserializer)?; - if kind != ffi::MP_DATETIME { - return Err(serde::de::Error::custom(format!( - "Expected Datetime, found msgpack ext #{}", - kind - ))); - } - - let data = bytes.as_slice(); - if data.len() != 8 && data.len() != 16 { - return Err(serde::de::Error::custom(format!( - "Unexpected number of bytes for Datetime: expected 8 or 16, got {}", - data.len() - ))); - } - - Self::from_bytes_tt(data) - .map_err(|_| serde::de::Error::custom("Error decoding msgpack bytes")) + Self::from_ext_structure(kind, bytes.as_slice()).map_err(serde::de::Error::custom) } } //////////////////////////////////////////////////////////////////////////////// -/// Lua +// Lua //////////////////////////////////////////////////////////////////////////////// static CTID_DATETIME: Lazy = Lazy::new(|| { @@ -239,6 +271,7 @@ impl tlua::PushOneInto for Datetime {} #[cfg(test)] mod tests { use super::*; + use crate::msgpack; use time_macros::datetime; #[test] @@ -266,6 +299,44 @@ mod tests { let expected: Datetime = datetime!(2023-11-11 0:00:0.0000 -0).into(); assert_eq!(only_date, expected); } + + #[test] + fn encode() { + let datetime: Datetime = datetime!(2023-11-11 2:03:19.35421 -3).into(); + let data = msgpack::encode(&datetime); + let expected = b"\xd8\x04\x17\x0b\x4f\x65\x00\x00\x00\x00\xd0\xd0\x1c\x15\x4c\xff\x00\x00"; + assert_eq!(data, expected); + + let only_date: Datetime = datetime!(1993-05-19 0:00:0.0000 +0).into(); + let data = msgpack::encode(&only_date); + let expected = b"\xd7\x04\x80\x78\xf9\x2b\x00\x00\x00\x00"; + assert_eq!(data, expected); + } + + #[test] + fn decode() { + let data = b"\xd8\x04\x46\x9f\r\x66\x00\x00\x00\x00\x50\x41\x7e\x3b\x4c\xff\x00\x00"; + let datetime: Datetime = msgpack::decode(data).unwrap(); + let expected: Datetime = datetime!(2024-04-03 15:26:14.99813 -3).into(); + assert_eq!(datetime, expected); + + let data = b"\xd7\x04\x00\xc4\x4e\x65\x00\x00\x00\x00"; + let only_date: Datetime = msgpack::decode(data).unwrap(); + let expected: Datetime = datetime!(2023-11-11 0:00:0.0000 -0).into(); + assert_eq!(only_date, expected); + } + + #[test] + fn decode_inside_structure() { + let expected = ( + 123, + datetime!(2023-11-11 0:00:0.0000 -0).into(), + "foobar".into(), + ); + let data = b"\x93{\xd7\x04\x00\xc4Ne\x00\x00\x00\x00\xa6foobar"; + let actual: (i32, Datetime, String) = msgpack::decode(data).unwrap(); + assert_eq!(expected, actual); + } } #[cfg(feature = "internal_test")] diff --git a/tarantool/src/decimal.rs b/tarantool/src/decimal.rs index 4f1502b3..ee23e538 100644 --- a/tarantool/src/decimal.rs +++ b/tarantool/src/decimal.rs @@ -7,7 +7,11 @@ use crate::ffi::decimal as ffi; #[cfg(all(not(feature = "standalone_decimal"), feature = "picodata"))] mod tarantool_decimal { use super::ffi; + use crate::msgpack; + use crate::msgpack::{Decode, DecodeError, Encode, EncodeError}; use serde::{Deserialize, Serialize}; + use std::convert::TryInto; + use std::io::Write; /// A Decimal number implemented using the builtin tarantool api. /// @@ -200,8 +204,38 @@ mod tarantool_decimal { Some(self) } } + + fn msgpack_bytes(&self) -> ([u8; MAX_MSGPACK_BYTES], usize) { + unsafe { + let mut data = [0u8; MAX_MSGPACK_BYTES]; + let len = ffi::decimal_len(&self.inner) as usize; + ffi::decimal_pack(data.as_mut_ptr() as _, &self.inner); + (data, len) + } + } + + fn from_ext_structure(tag: i8, bytes: &[u8]) -> Result { + if tag != ffi::MP_DECIMAL { + return Err(format!("Expected Decimal, found msgpack ext #{}", tag)); + } + + let data_p = &mut bytes.as_ptr().cast(); + let mut dec = std::mem::MaybeUninit::uninit(); + let res = unsafe { ffi::decimal_unpack(data_p, bytes.len() as _, dec.as_mut_ptr()) }; + if res.is_null() { + Err("Decimal out of range or corrupt".to_string()) + } else { + unsafe { Ok(Self::from_raw(dec.assume_init())) } + } + } } + // from tarantool + // sizeof_scale + ceil((digits + 1) / 2) + // sizeof_scale (max 9) + // digits (max 38) + pub const MAX_MSGPACK_BYTES: usize = 29; + #[allow(clippy::non_canonical_partial_ord_impl)] impl std::cmp::PartialOrd for Decimal { #[inline(always)] @@ -302,6 +336,15 @@ mod tarantool_decimal { } } + impl<'a> std::convert::TryFrom> for Decimal { + type Error = String; + + #[inline(always)] + fn try_from(value: msgpack::ExtStruct<'a>) -> Result { + Self::from_ext_structure(value.tag, value.data) + } + } + impl std::convert::TryFrom<&str> for Decimal { type Error = ::Err; @@ -431,6 +474,24 @@ mod tarantool_decimal { impl_try_into_int! {i64 isize => ffi::decimal_to_int64} impl_try_into_int! {u64 usize => ffi::decimal_to_uint64} + impl Encode for Decimal { + fn encode( + &self, + w: &mut impl Write, + context: &msgpack::Context, + ) -> Result<(), EncodeError> { + let (data, size) = self.msgpack_bytes(); + msgpack::ExtStruct::new(ffi::MP_DECIMAL, &data[..size]).encode(w, context) + } + } + + impl<'de> Decode<'de> for Decimal { + fn decode(r: &mut &'de [u8], context: &msgpack::Context) -> Result { + msgpack::ExtStruct::decode(r, context)? + .try_into() + .map_err(DecodeError::new::) + } + } impl serde::Serialize for Decimal { fn serialize(&self, serializer: S) -> Result @@ -438,16 +499,11 @@ mod tarantool_decimal { S: serde::Serializer, { #[derive(Serialize)] - struct _ExtStruct((i8, serde_bytes::ByteBuf)); + struct _ExtStruct<'a>((i8, &'a serde_bytes::Bytes)); - let data = unsafe { - let len = ffi::decimal_len(&self.inner) as usize; - let mut data = Vec::::with_capacity(len); - ffi::decimal_pack(data.as_mut_ptr() as _, &self.inner); - data.set_len(len); - data - }; - _ExtStruct((ffi::MP_DECIMAL, serde_bytes::ByteBuf::from(data))).serialize(serializer) + let (data, len) = self.msgpack_bytes(); + _ExtStruct((ffi::MP_DECIMAL, serde_bytes::Bytes::new(&data[..len]))) + .serialize(serializer) } } @@ -459,24 +515,8 @@ mod tarantool_decimal { #[derive(Deserialize)] struct _ExtStruct((i8, serde_bytes::ByteBuf)); - match serde::Deserialize::deserialize(deserializer)? { - _ExtStruct((ffi::MP_DECIMAL, bytes)) => { - let data = bytes.into_vec(); - let data_p = &mut data.as_ptr().cast(); - let mut dec = std::mem::MaybeUninit::uninit(); - let res = - unsafe { ffi::decimal_unpack(data_p, data.len() as _, dec.as_mut_ptr()) }; - if res.is_null() { - Err(serde::de::Error::custom("Decimal out of range or corrupt")) - } else { - unsafe { Ok(Self::from_raw(dec.assume_init())) } - } - } - _ExtStruct((kind, _)) => Err(serde::de::Error::custom(format!( - "Expected Decimal, found msgpack ext #{}", - kind - ))), - } + let _ExtStruct((tag, bytes)) = serde::Deserialize::deserialize(deserializer)?; + Self::from_ext_structure(tag, bytes.as_slice()).map_err(serde::de::Error::custom) } } @@ -504,13 +544,14 @@ mod tarantool_decimal { #[cfg(feature = "standalone_decimal")] mod standalone_decimal { + use super::ffi; + use crate::msgpack; + use crate::msgpack::{Decode, DecodeError, Encode, EncodeError}; + use once_cell::sync::Lazy; use std::convert::TryInto; + use std::io::{Cursor, Write}; use std::{convert::TryFrom, mem::size_of}; - use once_cell::sync::Lazy; - - use super::ffi; - /// A Decimal number implemented using the builtin tarantool api. /// /// ## Availability @@ -759,7 +800,34 @@ mod standalone_decimal { with_context(|ctx| ctx.pow(&mut self.inner, &pow.inner))?; Self::try_from(self.inner).ok() } + + fn msgpack_bytes(&self) -> ([u8; MAX_MSGPACK_BYTES], usize) { + let mut data = [0u8; MAX_MSGPACK_BYTES]; + let mut cursor = Cursor::new(&mut data[..]); + let (bcd, scale) = self.inner.clone().to_packed_bcd().unwrap(); + rmp::encode::write_sint(&mut cursor, scale as i64).unwrap(); + cursor.write_all(&bcd).unwrap(); + let size = cursor.position() as usize; + (data, size) + } + + fn from_ext_structure(tag: i8, bytes: &[u8]) -> Result { + if tag != ffi::MP_DECIMAL { + return Err(format!("Expected Decimal, found msgpack ext #{}", tag)); + } + + let mut data = bytes; + let scale = rmp::decode::read_int(&mut data).unwrap(); + let bcd = data; + + DecimalImpl::from_packed_bcd(bcd, scale) + .map_err(|e| format!("Failed to unpack decimal: {e}"))? + .try_into() + .map_err(|e| format!("Failed to unpack decimal: {e}")) + } } + // (DECIMAL_MAX_DIGITS / 2) + 1 + 1 (msgpack header) + 8 (i64 representation) + const MAX_MSGPACK_BYTES: usize = 29; type DecimalImpl = dec::Decimal<{ ffi::DECNUMUNITS as _ }>; @@ -771,6 +839,15 @@ mod standalone_decimal { Nan, } + impl<'a> TryFrom> for Decimal { + type Error = String; + + #[inline(always)] + fn try_from(value: msgpack::ExtStruct<'a>) -> Result { + Self::from_ext_structure(value.tag, value.data) + } + } + impl TryFrom for Decimal { type Error = ToDecimalError; @@ -1135,6 +1212,24 @@ mod standalone_decimal { u64 => try_into_u64 usize => try_into_usize } + impl Encode for Decimal { + fn encode( + &self, + w: &mut impl Write, + context: &msgpack::Context, + ) -> Result<(), EncodeError> { + let (data, len) = self.msgpack_bytes(); + msgpack::ExtStruct::new(ffi::MP_DECIMAL, &data[..len]).encode(w, context) + } + } + + impl<'de> Decode<'de> for Decimal { + fn decode(r: &mut &'de [u8], context: &msgpack::Context) -> Result { + msgpack::ExtStruct::decode(r, context)? + .try_into() + .map_err(DecodeError::new::) + } + } impl serde::Serialize for Decimal { fn serialize(&self, serializer: S) -> Result @@ -1142,16 +1237,11 @@ mod standalone_decimal { S: serde::Serializer, { #[derive(serde::Serialize)] - struct _ExtStruct((i8, serde_bytes::ByteBuf)); + struct _ExtStruct<'a>((i8, &'a serde_bytes::Bytes)); - let data = { - let mut data = vec![]; - let (bcd, scale) = self.inner.clone().to_packed_bcd().unwrap(); - rmp::encode::write_sint(&mut data, scale as i64).unwrap(); - data.extend(bcd); - data - }; - _ExtStruct((ffi::MP_DECIMAL, serde_bytes::ByteBuf::from(data))).serialize(serializer) + let (data, len) = self.msgpack_bytes(); + _ExtStruct((ffi::MP_DECIMAL, serde_bytes::Bytes::new(&data[..len]))) + .serialize(serializer) } } @@ -1164,21 +1254,8 @@ mod standalone_decimal { #[derive(serde::Deserialize)] struct _ExtStruct((i8, serde_bytes::ByteBuf)); - match serde::Deserialize::deserialize(deserializer)? { - _ExtStruct((ffi::MP_DECIMAL, bytes)) => { - let mut data = bytes.as_slice(); - let scale = rmp::decode::read_int(&mut data).unwrap(); - let bcd = data; - DecimalImpl::from_packed_bcd(bcd, scale) - .map_err(|e| Error::custom(format!("Failed to unpack decimal: {e}")))? - .try_into() - .map_err(|e| Error::custom(format!("Failed to unpack decimal: {e}"))) - } - _ExtStruct((kind, _)) => Err(serde::de::Error::custom(format!( - "Expected Decimal, found msgpack ext #{}", - kind - ))), - } + let _ExtStruct((tag, bytes)) = serde::Deserialize::deserialize(deserializer)?; + Self::from_ext_structure(tag, bytes.as_slice()).map_err(Error::custom) } } @@ -1373,7 +1450,7 @@ impl std::error::Error for DecimalToIntError { mod tests { use std::convert::TryFrom; - use crate::{decimal, decimal::Decimal, tuple::Tuple}; + use crate::{decimal, decimal::Decimal, msgpack, tuple::Tuple}; #[crate::test(tarantool = "crate")] pub fn from_lua() { @@ -1909,6 +1986,18 @@ mod tests { let new_value: Decimal = rmp_serde::decode::from_slice(serialized).expect("cant deserialize decimal"); assert_eq!(&new_value, value); + + // new encode and decode + + // serialized value has the same representation as expected + let new_serialized = msgpack::encode(value); + assert_eq!( + serialized, &new_serialized, + "{value} was not encoded correctly" + ); + // we can deserialize from expected bytes to the same value + let new_value: Decimal = msgpack::decode(serialized).expect("cant deserialize decimal"); + assert_eq!(&new_value, value); } // separately test that we can decode from the shape we used to encode decimals with standalonoe_decimal @@ -1916,5 +2005,9 @@ mod tests { let value: Decimal = rmp_serde::decode::from_slice(&[199, 7, 1, 210, 0, 0, 0, 2, 3, 60]) .expect("cant deserialize decimal"); assert_eq!(value, decimal!(0.33)); + + let value: Decimal = msgpack::decode(&[199, 7, 1, 210, 0, 0, 0, 2, 3, 60]) + .expect("cant deserialize decimal"); + assert_eq!(value, decimal!(0.33)); } } diff --git a/tarantool/src/error.rs b/tarantool/src/error.rs index 3f1e2eaa..310e38aa 100644 --- a/tarantool/src/error.rs +++ b/tarantool/src/error.rs @@ -33,7 +33,7 @@ use crate::transaction::TransactionError; use crate::util::to_cstring_lossy; /// A specialized [`Result`] type for the crate -pub type Result = std::result::Result; +pub type Result = std::result::Result; pub type TimeoutError = crate::fiber::r#async::timeout::Error; @@ -441,7 +441,7 @@ unsafe fn error_get_file_line(ptr: *const ffi::BoxError) -> Option<(String, u32) struct Failure; static mut FIELD_OFFSETS: Option> = None; - if FIELD_OFFSETS.is_none() { + if (*std::ptr::addr_of!(FIELD_OFFSETS)).is_none() { let lua = crate::lua_state(); let res = lua.eval::<(u32, u32)>( "ffi = require 'ffi' diff --git a/tarantool/src/ffi/helper.rs b/tarantool/src/ffi/helper.rs index 366b24d2..e37fcf66 100644 --- a/tarantool/src/ffi/helper.rs +++ b/tarantool/src/ffi/helper.rs @@ -1,9 +1,11 @@ -use dlopen::symbor::Library; +use crate::error::{BoxError, TarantoolErrorCode}; use std::ffi::CStr; use std::os::raw::c_char; use std::ptr::NonNull; +use libloading::os::unix::Library; + //////////////////////////////////////////////////////////////////////////////// // c_str! //////////////////////////////////////////////////////////////////////////////// @@ -167,8 +169,9 @@ pub unsafe fn tnt_internal_symbol(name: &CStr) -> Option { static mut RELOC_FN: Option = Some(init); unsafe fn init(name: *const c_char) -> Option> { - let lib = Library::open_self().ok()?; - match lib.symbol_cstr(c_str!("tnt_internal_symbol")) { + let current_library = Library::this(); + let internal_symbol = c_str!("tnt_internal_symbol").to_bytes(); + match current_library.get(internal_symbol) { Ok(sym) => { RELOC_FN = Some(*sym); (RELOC_FN.unwrap())(name) @@ -189,20 +192,30 @@ pub unsafe fn has_dyn_symbol(name: &CStr) -> bool { /// Find a sybmol in the current executable using dlsym. #[inline] -pub unsafe fn get_dyn_symbol(name: &CStr) -> Result { - let lib = Library::open_self()?; - let sym = lib.symbol_cstr(name)?; - Ok(*sym) +pub unsafe fn get_dyn_symbol(name: &CStr) -> Result { + let current_library = Library::this(); + let symbol_name = name.to_bytes_with_nul(); + let symbol_pointer = current_library.get::(symbol_name).map_err(|e| { + let code = TarantoolErrorCode::NoSuchProc; + let message = format!("symbol '{name:?}' not found: {e}"); + BoxError::new(code, message) + })?; + Ok(*symbol_pointer) } /// Find a symbol either using the `tnt_internal_symbol` api or using dlsym as a /// fallback. #[inline] -pub unsafe fn get_any_symbol(name: &CStr) -> Result { +pub unsafe fn get_any_symbol(name: &CStr) -> Result { if let Some(sym) = tnt_internal_symbol(name) { return Ok(sym); } - let lib = Library::open_self()?; - let sym = lib.symbol_cstr(name)?; - Ok(*sym) + let current_library = Library::this(); + let symbol_name = name.to_bytes_with_nul(); + let symbol_pointer = current_library.get::(symbol_name).map_err(|e| { + let code = TarantoolErrorCode::NoSuchProc; + let message = format!("symbol '{name:?}' not found: {e}"); + BoxError::new(code, message) + })?; + Ok(*symbol_pointer) } diff --git a/tarantool/src/ffi/mod.rs b/tarantool/src/ffi/mod.rs index df3cf348..86ac95e1 100644 --- a/tarantool/src/ffi/mod.rs +++ b/tarantool/src/ffi/mod.rs @@ -78,7 +78,7 @@ pub fn has_datetime() -> bool { #[inline] pub unsafe fn has_fiber_set_ctx() -> bool { static mut RESULT: Option = None; - if RESULT.is_none() { + if (*std::ptr::addr_of!(RESULT)).is_none() { RESULT = Some(helper::has_dyn_symbol(crate::c_str!("fiber_set_ctx"))); } RESULT.unwrap() @@ -114,7 +114,7 @@ pub fn has_fully_temporary_spaces() -> bool { #[inline] pub unsafe fn has_fiber_id() -> bool { static mut RESULT: Option = None; - if RESULT.is_none() { + if (*std::ptr::addr_of!(RESULT)).is_none() { RESULT = Some(helper::has_dyn_symbol(crate::c_str!("fiber_id"))); } RESULT.unwrap() diff --git a/tarantool/src/ffi/sql.rs b/tarantool/src/ffi/sql.rs index 53fc0994..9f7b54e1 100644 --- a/tarantool/src/ffi/sql.rs +++ b/tarantool/src/ffi/sql.rs @@ -4,7 +4,15 @@ use libc::{iovec, size_t}; use std::cmp; use std::io::Read; use std::mem::MaybeUninit; +use std::ops::Range; use std::os::raw::{c_char, c_int, c_void}; +use std::ptr::{null, NonNull}; +use tlua::ffi::lua_State; + +use crate::error::{TarantoolError, TarantoolErrorCode}; +use crate::tuple::Tuple; + +use super::tarantool::BoxTuple; pub const IPROTO_DATA: u8 = 0x30; @@ -12,10 +20,22 @@ pub const IPROTO_DATA: u8 = 0x30; // even if they're only used in this file. This is because the `define_dlsym_reloc` // macro doesn't support private function declarations because rust's macro syntax is trash. crate::define_dlsym_reloc! { + pub fn port_destroy(port: *mut Port); + pub(crate) fn port_c_add_tuple(port: *mut Port, tuple: *mut BoxTuple); + pub(crate) fn port_c_add_mp(port: *mut Port, mp: *const c_char, mp_end: *const c_char); + pub(crate) fn port_c_create(port: *mut Port); + pub fn port_c_destroy(port: *mut Port); + pub(crate) fn cord_slab_cache() -> *const SlabCache; pub(crate) fn obuf_create(obuf: *mut Obuf, slab_cache: *const SlabCache, start_cap: size_t); pub(crate) fn obuf_destroy(obuf: *mut Obuf); + pub(crate) fn obuf_reset(obuf: *mut Obuf); + pub(crate) fn obuf_dup( + obuf: *mut Obuf, + data: *const c_void, + size: size_t, + ) -> size_t; /// Free memory allocated by this buffer pub fn ibuf_reinit(ibuf: *mut Ibuf); @@ -41,6 +61,19 @@ crate::define_dlsym_reloc! { vdbe_max_steps: u64, obuf: *mut Obuf, ) -> c_int; + pub(crate) fn stmt_execute_into_port( + stmt_id: u32, + mp_params: *const u8, + vdbe_max_steps: u64, + port: *mut Port + ) -> c_int; + pub(crate) fn sql_execute_into_port( + sql: *const u8, + len: c_int, + mp_params: *const u8, + vdbe_max_steps: u64, + port: *mut Port, + ) -> c_int; } #[repr(C)] @@ -61,7 +94,19 @@ pub struct Ibuf { start_capacity: usize, } -pub(crate) struct ObufWrapper { +pub unsafe fn obuf_append(obuf: *mut Obuf, mp: &[u8]) -> crate::Result<()> { + let size = obuf_dup(obuf, mp.as_ptr() as *const c_void, mp.len() as size_t); + if size != mp.len() as size_t { + return Err(TarantoolError::new( + TarantoolErrorCode::MemoryIssue, + format!("Failed to allocate {} bytes in obuf for data", mp.len()), + ) + .into()); + } + Ok(()) +} + +pub struct ObufWrapper { pub inner: Obuf, read_pos: usize, read_iov_n: usize, @@ -69,7 +114,10 @@ pub(crate) struct ObufWrapper { } impl ObufWrapper { + /// Create a new `ObufWrapper` with the given initial capacity. + /// The capacity must be greater than 0. pub fn new(start_capacity: usize) -> Self { + assert!(start_capacity > 0); let inner_buf = unsafe { let slab_c = cord_slab_cache(); @@ -86,7 +134,19 @@ impl ObufWrapper { } } - pub(crate) fn obuf(&mut self) -> *mut Obuf { + pub unsafe fn append_mp(&mut self, mp: &[u8]) -> crate::Result<()> { + obuf_append(self.obuf(), mp)?; + Ok(()) + } + + pub fn reset(&mut self) { + unsafe { obuf_reset(self.obuf()) } + self.read_pos = 0; + self.read_iov_n = 0; + self.read_iov_pos = 0; + } + + pub fn obuf(&mut self) -> *mut Obuf { &mut self.inner as *mut Obuf } } @@ -131,7 +191,7 @@ impl Read for ObufWrapper { // TODO: ASan-enabled build has a different layout (obuf_asan.h). #[repr(C)] -pub(crate) struct Obuf { +pub struct Obuf { _slab_cache: *const c_void, pub pos: i32, pub n_iov: i32, @@ -151,3 +211,312 @@ impl Drop for Obuf { unsafe { obuf_destroy(self as *mut Obuf) } } } + +#[repr(C)] +pub struct SqlValue { + _private: [u8; 0], +} + +#[repr(C)] +pub struct PortVTable { + pub dump_msgpack: unsafe extern "C" fn(port: *mut Port, out: *mut Obuf), + pub dump_msgpack_16: unsafe extern "C" fn(port: *mut Port, out: *mut Obuf), + pub dump_lua: unsafe extern "C" fn(port: *mut Port, l: *mut lua_State, is_flat: bool), + pub dump_plain: unsafe extern "C" fn(port: *mut Port, size: *mut u32) -> *const c_char, + pub get_msgpack: unsafe extern "C" fn(port: *mut Port, size: *mut u32) -> *const c_char, + pub get_vdbemem: unsafe extern "C" fn(port: *mut Port, size: *mut u32) -> *mut SqlValue, + pub destroy: unsafe extern "C" fn(port: *mut Port), +} + +impl PortVTable { + pub const fn new( + dump_msgpack: unsafe extern "C" fn(port: *mut Port, out: *mut Obuf), + dump_lua: unsafe extern "C" fn(port: *mut Port, l: *mut lua_State, is_flat: bool), + ) -> Self { + Self { + dump_msgpack, + dump_msgpack_16, + dump_lua, + dump_plain, + get_msgpack, + get_vdbemem, + destroy, + } + } +} + +#[no_mangle] +unsafe extern "C" fn dump_msgpack_16(_port: *mut Port, _out: *mut Obuf) { + unimplemented!(); +} + +#[no_mangle] +unsafe extern "C" fn dump_plain(_port: *mut Port, _size: *mut u32) -> *const c_char { + unimplemented!(); +} + +#[no_mangle] +unsafe extern "C" fn get_msgpack(_port: *mut Port, _size: *mut u32) -> *const c_char { + unimplemented!(); +} + +#[no_mangle] +unsafe extern "C" fn get_vdbemem(_port: *mut Port, _size: *mut u32) -> *mut SqlValue { + unimplemented!(); +} + +#[no_mangle] +unsafe extern "C" fn destroy(port: *mut Port) { + port_c_destroy(port); +} + +#[repr(C)] +#[derive(Debug)] +pub struct Port { + pub vtab: *const PortVTable, + _pad: [u8; 68], +} + +impl Drop for Port { + fn drop(&mut self) { + unsafe { port_destroy(self as *mut Port) } + } +} + +impl Port { + /// Interpret `Port` as a mutable raw pointer to `PortC`. + /// + /// # Safety + /// + /// The caller must be sure that the port was initialized with `new_port_c`. + pub unsafe fn as_mut_port_c(&mut self) -> &mut PortC { + unsafe { NonNull::new_unchecked(self as *mut Port as *mut PortC).as_mut() } + } + + pub fn as_ptr(&self) -> *const Port { + self as *const Port + } + + pub fn as_mut(&mut self) -> *mut Port { + self as *mut Port + } +} + +impl Port { + pub unsafe fn zeroed() -> Self { + unsafe { + Self { + vtab: null(), + _pad: std::mem::zeroed(), + } + } + } + + pub fn new_port_c() -> Self { + unsafe { + let mut port = Self::zeroed(); + port_c_create(&mut port as *mut Port); + port + } + } +} + +#[repr(C)] +union U { + tuple: NonNull, + mp: *const u8, +} + +#[repr(C)] +struct PortCEntry { + next: *const PortCEntry, + data: U, + mp_sz: u32, + tuple_format: *const c_void, +} + +impl PortCEntry { + unsafe fn data(&self) -> &[u8] { + if self.mp_sz == 0 { + let tuple_data = self.data.tuple.as_ref().data(); + return std::slice::from_raw_parts(tuple_data.as_ptr(), tuple_data.len()); + } + std::slice::from_raw_parts(self.data.mp, self.mp_sz as usize) + } +} + +#[repr(C)] +pub struct PortC { + pub vtab: *const PortVTable, + first: *const PortCEntry, + last: *const PortCEntry, + first_entry: PortCEntry, + size: i32, +} + +impl PortC { + pub fn size(&self) -> i32 { + self.size + } + + pub fn add_tuple(&mut self, tuple: &Tuple) { + unsafe { + port_c_add_tuple( + self as *mut PortC as *mut Port, + tuple.as_ptr() as *mut BoxTuple, + ); + } + } + + /// Add a msgpack-encoded data to the C port. + /// + /// # Safety + /// + /// The caller must ensure that the `mp` slice is valid msgpack data. + pub unsafe fn add_mp(&mut self, mp: &[u8]) { + let Range { start, end } = mp.as_ptr_range(); + unsafe { + port_c_add_mp( + self as *mut PortC as *mut Port, + start as *const c_char, + end as *const c_char, + ); + } + } + + pub fn iter(&self) -> PortCIterator { + PortCIterator::new(self) + } + + /// Interpret `PortC` as a mutable raw pointer to `Port`. + /// + /// # Safety + /// + /// The caller must ensure that `PortC`: + /// - occupies the same amount of memory as the `Port` does; + /// - is properly initialized with `port_c_create`. + pub unsafe fn as_mut_ptr(&mut self) -> *mut Port { + self as *mut PortC as *mut Port + } + + pub fn first_mp(&self) -> Option<&[u8]> { + if self.first.is_null() { + return None; + } + let entry = unsafe { &*(self.first as *const PortCEntry) }; + Some(unsafe { entry.data() }) + } + + pub fn last_mp(&self) -> Option<&[u8]> { + if self.last.is_null() { + return None; + } + let entry = unsafe { &*(self.last as *const PortCEntry) }; + Some(unsafe { entry.data() }) + } +} + +#[allow(dead_code)] +pub struct PortCIterator<'port> { + port: &'port PortC, + entry: *const PortCEntry, +} + +impl<'port> PortCIterator<'port> { + fn new(port: &'port PortC) -> Self { + Self { + port, + entry: port.first, + } + } +} + +impl<'port> Iterator for PortCIterator<'port> { + type Item = &'port [u8]; + + fn next(&mut self) -> Option { + if self.entry.is_null() { + return None; + } + + // The code was inspired by `port_c_dump_msgpack` function from `box/port.c`. + let entry = unsafe { &*self.entry }; + self.entry = entry.next; + Some(unsafe { entry.data() }) + } +} + +#[cfg(feature = "picodata")] +#[cfg(feature = "internal_test")] +mod tests { + use super::*; + use crate::offset_of; + + #[crate::test(tarantool = "crate")] + pub fn test_port_definition() { + let lua = crate::lua_state(); + let [size_of_port, offset_of_vtab, offset_of_pad]: [usize; 3] = lua + .eval( + "local ffi = require('ffi') + return { + ffi.sizeof('struct port'), + ffi.offsetof('struct port', 'vtab'), + ffi.offsetof('struct port', 'pad') + }", + ) + .unwrap(); + + assert_eq!(size_of_port, std::mem::size_of::()); + assert_eq!(offset_of_vtab, offset_of!(Port, vtab)); + assert_eq!(offset_of_pad, offset_of!(Port, _pad)); + } + + #[crate::test(tarantool = "crate")] + pub fn test_port_c_definition() { + let lua = crate::lua_state(); + let [size_of_port_c, offset_of_vtab, + offset_of_first, offset_of_last, + offset_of_first_entry, offset_of_size]: [usize; 6] = lua + .eval( + "local ffi = require('ffi') + return { + ffi.sizeof('struct port_c'), + ffi.offsetof('struct port_c', 'vtab'), + ffi.offsetof('struct port_c', 'first'), + ffi.offsetof('struct port_c', 'last'), + ffi.offsetof('struct port_c', 'first_entry'), + ffi.offsetof('struct port_c', 'size') + }", + ) + .unwrap(); + + assert_eq!(size_of_port_c, std::mem::size_of::()); + assert_eq!(offset_of_vtab, offset_of!(PortC, vtab)); + assert_eq!(offset_of_first, offset_of!(PortC, first)); + assert_eq!(offset_of_last, offset_of!(PortC, last)); + assert_eq!(offset_of_first_entry, offset_of!(PortC, first_entry)); + assert_eq!(offset_of_size, offset_of!(PortC, size)); + } + + #[crate::test(tarantool = "crate")] + pub fn test_obuf() { + let mut obuf = ObufWrapper::new(1024); + + //Check appending data. + let mp = b"\x92\x01\x02"; + unsafe { + obuf.append_mp(mp).unwrap(); + } + assert_eq!(obuf.read_pos, 0); + let mut buf = [0u8; 3]; + let read = obuf.read(&mut buf).unwrap(); + assert_eq!(read, 3); + assert_eq!(&buf, mp); + // Check that the read position is updated. + assert_eq!(obuf.read_pos, 3); + + // Check reset. + obuf.reset(); + assert_eq!(obuf.read_pos, 0); + } +} diff --git a/tarantool/src/ffi/tarantool.rs b/tarantool/src/ffi/tarantool.rs index ed6ab549..b5177220 100644 --- a/tarantool/src/ffi/tarantool.rs +++ b/tarantool/src/ffi/tarantool.rs @@ -37,19 +37,15 @@ extern "C" { /// Wait until **READ** or **WRITE** event on socket (`fd`). Yields. /// - `fd` - non-blocking socket file description /// - `events` - requested events to wait. - /// Combination of `TNT_IO_READ` | `TNT_IO_WRITE` bit flags. + /// Combination of `TNT_IO_READ` | `TNT_IO_WRITE` bit flags. /// - `timeout` - timeout in seconds. /// /// Returns: /// - `0` - timeout - /// - `>0` - returned events. Combination of `TNT_IO_READ` | `TNT_IO_WRITE` - /// bit flags. + /// - `>0` - returned events. Combination of `TNT_IO_READ` | `TNT_IO_WRITE` bit flags. pub fn coio_wait(fd: c_int, event: c_int, timeout: f64) -> c_int; - /** - * Close the fd and wake any fiber blocked in - * coio_wait() call on this fd. - */ + /// Close the fd and wake any fiber blocked in coio_wait() call on this fd. pub fn coio_close(fd: c_int) -> c_int; /// Fiber-friendly version of getaddrinfo(3). @@ -63,8 +59,8 @@ extern "C" { /// Returns: /// - `0` on success, please free @a res using freeaddrinfo(3). /// - `-1` on error, check diag. - /// Please note that the return value is not compatible with - /// getaddrinfo(3). + /// + /// > Please note that the return value is not compatible with getaddrinfo(3). pub fn coio_getaddrinfo( host: *const c_char, port: *const c_char, @@ -799,6 +795,27 @@ extern "C" { pub fn box_truncate(space_id: u32) -> c_int; } +extern "C" { + /// Try to look up a space by space number in the space cache. + /// FFI-friendly no-exception-thrown space lookup function. + /// + /// Return NULL if space not found, otherwise space object. + /// + /// # Safety + /// The caller must make sure not to hold on to the pointer for too + /// long as the space object may get deleted at some point after which + /// the derefencing the pointer will be **undefined behavior**. + pub(crate) fn space_by_id(id: u32) -> *mut space; + + /// Returns number of bytes used in memory by tuples in the space. + pub(crate) fn space_bsize(space: *mut space) -> usize; +} + +#[repr(C)] +pub(crate) struct space { + unused: [u8; 0], +} + //////////////////////////////////////////////////////////////////////////////// // ... //////////////////////////////////////////////////////////////////////////////// @@ -875,7 +892,7 @@ pub struct BoxTuple { #[cfg(feature = "picodata")] #[repr(C, packed)] pub struct BoxTuple { - pub(crate) refs: u8, + pub refs: u8, pub(crate) flags: u8, pub(crate) format_id: u16, pub(crate) data_offset: u16, @@ -909,6 +926,17 @@ impl BoxTuple { pub fn bsize(&self) -> usize { unsafe { box_tuple_bsize(self) } } + + /// # Safety + /// This is how tuple data is stored in tarantool. + #[inline] + pub unsafe fn data(&self) -> &[u8] { + unsafe { + let data_offset = self.data_offset() as isize; + let data = (self as *const BoxTuple).cast::().offset(data_offset); + std::slice::from_raw_parts(data, self.bsize()) + } + } } #[cfg(feature = "picodata")] @@ -1251,7 +1279,10 @@ extern "C" { #[repr(C)] pub struct BoxFunctionCtx { + #[cfg(not(feature = "picodata"))] _unused: [u8; 0], + #[cfg(feature = "picodata")] + pub(crate) port: *mut crate::ffi::sql::Port, } extern "C" { @@ -1290,6 +1321,7 @@ extern "C" { } use crate::ffi::lua::lua_State; + extern "C" { pub fn luaT_state() -> *mut lua_State; pub fn luaT_call(l: *mut lua_State, nargs: c_int, nreturns: c_int) -> isize; diff --git a/tarantool/src/fiber.rs b/tarantool/src/fiber.rs index bf4bcdee..eaeaa747 100644 --- a/tarantool/src/fiber.rs +++ b/tarantool/src/fiber.rs @@ -3,7 +3,7 @@ //! With the fiber module, you can: //! - create, run and manage [fibers](Builder), //! - use a synchronization mechanism for fibers, similar to “condition variables” and similar to operating-system -//! functions such as `pthread_cond_wait()` plus `pthread_cond_signal()`, +//! functions such as `pthread_cond_wait()` plus `pthread_cond_signal()`, //! - spawn a fiber based [async runtime](async). //! //! See also: @@ -110,14 +110,14 @@ pub struct Fiber<'a, T: 'a> { } #[allow(deprecated)] -impl<'a, T> ::std::fmt::Debug for Fiber<'a, T> { +impl ::std::fmt::Debug for Fiber<'_, T> { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { f.debug_struct("Fiber").finish_non_exhaustive() } } #[allow(deprecated)] -impl<'a, T> Fiber<'a, T> { +impl Fiber<'_, T> { /// Create a new fiber. /// /// Takes a fiber from fiber cache, if it's not empty. Can fail only if there is not enough memory for @@ -862,10 +862,9 @@ where lua::lua_getfield(l, -1, c_ptr!("new")); impl_details::push_userdata(l, f); lua::lua_pushcclosure(l, Self::trampoline_for_lua, 1); - impl_details::guarded_pcall(l, 1, 1).map_err(|e| { + impl_details::guarded_pcall(l, 1, 1).inspect_err(|_| { // Pop the fiber module from the stack lua::lua_pop(l, 1); - e })?; // stack[top] = fiber.new(c_closure) lua::lua_getfield(l, -1, c_ptr!("set_joinable")); @@ -970,10 +969,9 @@ mod impl_details { lua::lua_getfield(l, -1, c_ptr!("join")); lua::lua_pushinteger(l, f_id as _); - guarded_pcall(l, 1, 2).map_err(|e| { + guarded_pcall(l, 1, 2).inspect_err(|_| { // Pop the fiber module from the stack lua::lua_pop(l, 1); - e })?; // stack[top] = fiber.join(f_id) // 3 values on the stack that need to be dropped: @@ -1049,7 +1047,7 @@ pub struct JoinHandle<'f, T> { marker: PhantomData<&'f ()>, } -impl<'f, T> std::fmt::Debug for JoinHandle<'f, T> { +impl std::fmt::Debug for JoinHandle<'_, T> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.debug_struct("JoinHandle").finish_non_exhaustive() } @@ -1082,7 +1080,7 @@ enum JoinHandleImpl { type FiberResultCell = Box>>; -impl<'f, T> JoinHandle<'f, T> { +impl JoinHandle<'_, T> { #[inline(always)] fn ffi(fiber: NonNull, result_cell: Option>) -> Self { Self { @@ -1253,7 +1251,7 @@ impl<'f, T> JoinHandle<'f, T> { } } -impl<'f, T> Drop for JoinHandle<'f, T> { +impl Drop for JoinHandle<'_, T> { fn drop(&mut self) { if let Some(mut inner) = self.inner.take() { if let JoinHandleImpl::Ffi { result_cell, .. } = &mut inner { diff --git a/tarantool/src/fiber/async/mutex.rs b/tarantool/src/fiber/async/mutex.rs index a3daffd8..e1f8ac3f 100644 --- a/tarantool/src/fiber/async/mutex.rs +++ b/tarantool/src/fiber/async/mutex.rs @@ -191,7 +191,7 @@ impl From for Mutex { } } -impl Default for Mutex { +impl Default for Mutex { /// Creates a `Mutex`, with the `Default` value for T. fn default() -> Mutex { Mutex::new(Default::default()) diff --git a/tarantool/src/fiber/async/watch.rs b/tarantool/src/fiber/async/watch.rs index 1168ad7d..d1a75de1 100644 --- a/tarantool/src/fiber/async/watch.rs +++ b/tarantool/src/fiber/async/watch.rs @@ -204,7 +204,7 @@ impl Drop for Sender { /// to keep the borrow as short lived as possible. pub struct ValueRef<'a, T>(Ref<'a, Value>); -impl<'a, T> Deref for ValueRef<'a, T> { +impl Deref for ValueRef<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { @@ -312,7 +312,7 @@ impl Clone for Receiver { } } -impl<'a, T> Future for Notification<'a, T> { +impl Future for Notification<'_, T> { type Output = Result<(), RecvError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/tarantool/src/fiber/mutex.rs b/tarantool/src/fiber/mutex.rs index 3ca71b73..29e51b0c 100644 --- a/tarantool/src/fiber/mutex.rs +++ b/tarantool/src/fiber/mutex.rs @@ -208,7 +208,7 @@ impl From for Mutex { } } -impl Default for Mutex { +impl Default for Mutex { /// Creates a `Mutex`, with the `Default` value for T. fn default() -> Mutex { Mutex::new(Default::default()) diff --git a/tarantool/src/index.rs b/tarantool/src/index.rs index 6b3ee91b..a6f53f60 100644 --- a/tarantool/src/index.rs +++ b/tarantool/src/index.rs @@ -378,12 +378,18 @@ crate::define_str_enum! { #[deprecated = "Use `index::Part` instead"] pub type IndexPart = Part; -/// Index part. +/// Index part +/// +/// The `T` generic decides what is used to identify the field. It can be either a [`u32`] for +/// field index or [`String`] for field name, or [`NumOrStr`] for either. +/// +/// Field names are used in picodata metadata, indices are used in tarantool's metadata, while +/// tarantool's index creation API can accept either. #[derive( Clone, Default, Debug, Serialize, Deserialize, tlua::Push, tlua::LuaRead, PartialEq, Eq, )] -pub struct Part { - pub field: NumOrStr, +pub struct Part { + pub field: T, #[serde(default)] pub r#type: Option, #[serde(default)] @@ -406,9 +412,9 @@ macro_rules! define_setters { } } -impl Part { +impl Part { #[inline(always)] - pub fn field(field: impl Into) -> Self { + pub fn field(field: impl Into) -> Self { Self { field: field.into(), r#type: None, @@ -426,36 +432,50 @@ impl Part { } #[inline(always)] - pub fn new(fi: impl Into, ft: FieldType) -> Self { + pub fn new(fi: impl Into, ft: FieldType) -> Self { Self::field(fi).field_type(ft) } } -impl From<&str> for Part { +impl From<&str> for Part { #[inline(always)] fn from(f: &str) -> Self { Self::field(f.to_string()) } } -impl From for Part { +impl From for Part { #[inline(always)] fn from(f: String) -> Self { Self::field(f) } } -impl From for Part { +impl From<(String, FieldType)> for Part { #[inline(always)] - fn from(f: u32) -> Self { - Self::field(f) + fn from((f, t): (String, FieldType)) -> Self { + Self::field(f).field_type(t) } } -impl From<(u32, FieldType)> for Part { +impl From<(&str, FieldType)> for Part { #[inline(always)] - fn from((f, t): (u32, FieldType)) -> Self { - Self::field(f).field_type(t) + fn from((f, t): (&str, FieldType)) -> Self { + Self::field(f.to_string()).field_type(t) + } +} + +impl From<&str> for Part { + #[inline(always)] + fn from(f: &str) -> Self { + Self::field(f.to_string()) + } +} + +impl From for Part { + #[inline(always)] + fn from(f: String) -> Self { + Self::field(f) } } @@ -473,6 +493,46 @@ impl From<(&str, FieldType)> for Part { } } +impl From for Part { + #[inline(always)] + fn from(value: u32) -> Self { + Self::field(value) + } +} + +impl From<(u32, FieldType)> for Part { + #[inline(always)] + fn from((f, t): (u32, FieldType)) -> Self { + Self::field(f).field_type(t) + } +} + +impl From> for Part { + #[inline(always)] + fn from(value: Part) -> Self { + Part { + field: value.field.into(), + r#type: value.r#type, + collation: value.collation, + is_nullable: value.is_nullable, + path: value.path, + } + } +} + +impl From> for Part { + #[inline(always)] + fn from(value: Part) -> Self { + Part { + field: value.field.into(), + r#type: value.r#type, + collation: value.collation, + is_nullable: value.is_nullable, + path: value.path, + } + } +} + //////////////////////////////////////////////////////////////////////////////// // ... //////////////////////////////////////////////////////////////////////////////// @@ -936,41 +996,21 @@ pub struct Metadata<'a> { pub name: Cow<'a, str>, pub r#type: IndexType, pub opts: BTreeMap, Value<'a>>, - pub parts: Vec, + pub parts: Vec>, } impl Encode for Metadata<'_> {} -#[derive(thiserror::Error, Debug)] -#[error("field number expected, got string '{0}'")] -pub struct FieldMustBeNumber(pub String); - impl Metadata<'_> { /// Construct a [`KeyDef`] instance from index parts. - /// - /// # Panicking - /// Will panic if any of the parts have field name instead of field number. - /// Normally this doesn't happen, because `Metadata` returned from - /// `_index` always has field number, but if you got this metadata from - /// somewhere else, use [`Self::try_to_key_def`] instead, to check for this - /// error. #[inline(always)] pub fn to_key_def(&self) -> KeyDef { // TODO: we could optimize by caching these things and only recreating // then once box_schema_version changes. - self.try_to_key_def().unwrap() - } - - /// Construct a [`KeyDef`] instance from index parts. Returns error in case - /// any of the parts had field name instead of field number. - #[inline] - pub fn try_to_key_def(&self) -> Result { let mut kd_parts = Vec::with_capacity(self.parts.len()); for p in &self.parts { - let kd_p = KeyDefPart::try_from_index_part(p) - .ok_or_else(|| FieldMustBeNumber(p.field.clone().into()))?; - kd_parts.push(kd_p); + kd_parts.push(KeyDefPart::from_index_part(p)); } - Ok(KeyDef::new(&kd_parts).unwrap()) + KeyDef::new(&kd_parts).unwrap() } /// Construct a [`KeyDef`] instance from index parts for comparing keys only. @@ -1064,7 +1104,7 @@ mod tests { r#type: IndexType::Hash, opts: BTreeMap::from([("unique".into(), Value::from(true)),]), parts: vec![Part { - field: 0.into(), + field: 0, r#type: Some(FieldType::Unsigned), ..Default::default() }], @@ -1096,19 +1136,19 @@ mod tests { opts: BTreeMap::from([("unique".into(), Value::from(false)),]), parts: vec![ Part { - field: 1.into(), + field: 1, r#type: Some(FieldType::String), ..Default::default() }, Part { - field: 2.into(), + field: 2, r#type: Some(FieldType::Unsigned), is_nullable: Some(true), path: Some(".key".into()), ..Default::default() }, Part { - field: 2.into(), + field: 2, r#type: Some(FieldType::String), path: Some(".value[1]".into()), ..Default::default() diff --git a/tarantool/src/lib.rs b/tarantool/src/lib.rs index 1c08a8b5..f131f746 100644 --- a/tarantool/src/lib.rs +++ b/tarantool/src/lib.rs @@ -8,6 +8,7 @@ #![allow(clippy::needless_late_init)] #![allow(clippy::bool_assert_comparison)] #![allow(clippy::field_reassign_with_default)] +#![allow(clippy::manual_unwrap_or)] #![allow(rustdoc::redundant_explicit_links)] //! Tarantool C API bindings for Rust. //! This library contains the following Tarantool API's: @@ -63,14 +64,12 @@ pub mod clock; pub mod coio; pub mod datetime; pub mod decimal; -#[doc(hidden)] pub mod define_str_enum; pub mod error; pub mod ffi; pub mod fiber; pub mod index; pub mod log; -#[doc(hidden)] pub mod msgpack; pub mod net_box; pub mod network; @@ -231,7 +230,7 @@ pub mod vclock; /// [`Display`]), the return values read as follows: /// - `Ok(v)`: the stored procedure will return `v` /// - `Err(e)`: the stored procedure will fail and `e` will be set as the last -/// Tarantool error (see also [`TarantoolError::last`]) +/// Tarantool error (see also [`TarantoolError::last`]) /// ```no_run /// use tarantool::{error::Error, index::IteratorType::Eq, space::Space}; /// diff --git a/tarantool/src/log.rs b/tarantool/src/log.rs index 2f86cda1..45e71455 100644 --- a/tarantool/src/log.rs +++ b/tarantool/src/log.rs @@ -50,6 +50,13 @@ impl TarantoolLogger { } } +impl Default for TarantoolLogger { + #[inline(always)] + fn default() -> Self { + Self::new() + } +} + impl Log for TarantoolLogger { #[inline(always)] fn enabled(&self, metadata: &Metadata) -> bool { diff --git a/tarantool/src/msgpack.rs b/tarantool/src/msgpack.rs index a8cee9cf..cc2c3137 100644 --- a/tarantool/src/msgpack.rs +++ b/tarantool/src/msgpack.rs @@ -906,6 +906,7 @@ mod test { } #[cfg(feature = "internal_test")] +#[allow(clippy::disallowed_names)] mod tests { use super::*; use pretty_assertions::assert_eq; diff --git a/tarantool/src/msgpack/encode.rs b/tarantool/src/msgpack/encode.rs index 237ebcfd..18302d45 100644 --- a/tarantool/src/msgpack/encode.rs +++ b/tarantool/src/msgpack/encode.rs @@ -115,6 +115,24 @@ pub enum StructStyle { // TODO AllowDecodeAny - to allow decoding both arrays & maps } +//////////////////////////////////////////////////////////////////////////////// +// ExtStruct +//////////////////////////////////////////////////////////////////////////////// + +/// ExtStruct for serialization and deserialization MessagePack extension type +#[derive(Clone)] +pub struct ExtStruct<'a> { + pub tag: i8, + pub data: &'a [u8], +} + +impl<'a> ExtStruct<'a> { + #[inline(always)] + pub fn new(tag: i8, data: &'a [u8]) -> Self { + Self { tag, data } + } +} + //////////////////////////////////////////////////////////////////////////////// // Decode //////////////////////////////////////////////////////////////////////////////// @@ -398,9 +416,9 @@ where } } -impl<'a, 'de, T> Decode<'de> for Cow<'a, T> +impl<'de, T> Decode<'de> for Cow<'_, T> where - T: Decode<'de> + ToOwned + ?Sized, + T: Decode<'de> + ToOwned, { // Clippy doesn't notice the type difference #[allow(clippy::redundant_clone)] @@ -553,7 +571,60 @@ impl_simple_decode! { (bool, read_bool) } -// TODO: Provide decode for tuples and serde json value +macro_rules! impl_tuple_decode { + () => {}; + ($h:ident $($t:ident)*) => { + #[allow(non_snake_case)] + impl<'de, $h, $($t),*> Decode<'de> for ($h, $($t),*) + where + $h: Decode<'de>, + $($t: Decode<'de>,)* + { + fn decode(r: &mut &'de [u8], context: &Context) -> Result { + let array_len = rmp::decode::read_array_len(r) + .map_err(DecodeError::from_vre::)?; + let expected_len = crate::expr_count!($h $(, $t)*); + + if array_len != expected_len { + return Err(DecodeError::new::(format!( + "Tuple length mismatch: expected: {}; actual: {}", + expected_len, array_len + ))); + } + let $h : $h = Decode::decode(r, context)?; + $( + let $t : $t = Decode::decode(r, context)?; + )* + Ok(($h, $($t),*)) + } + } + impl_tuple_decode! { $($t)* } + } +} + +impl_tuple_decode! { A B C D E F G H I J K L M N O P } + +// TODO: Provide decode for serde json value + +impl<'de> Decode<'de> for ExtStruct<'de> { + #[inline] + fn decode(r: &mut &'de [u8], _context: &Context) -> Result { + let meta = rmp::decode::read_ext_meta(r).map_err(DecodeError::from_vre::)?; + let expected = meta.size as usize; + if r.len() < expected { + let actual = r.len(); + *r = &r[actual..]; + return Err(DecodeError::new::(format!( + "unexpected end of buffer (expected: {expected}, actual: {actual})" + ))); + } + + let (a, b) = r.split_at(expected); + *r = b; + + Ok(Self::new(meta.typeid, a)) + } +} //////////////////////////////////////////////////////////////////////////////// // Encode @@ -746,7 +817,7 @@ where } } -impl<'a, T> Encode for Cow<'a, T> +impl Encode for Cow<'_, T> where T: Encode + ToOwned + ?Sized, { @@ -880,6 +951,15 @@ macro_rules! impl_tuple_encode { impl_tuple_encode! { A B C D E F G H I J K L M N O P } +impl Encode for ExtStruct<'_> { + #[inline] + fn encode(&self, w: &mut impl Write, _context: &Context) -> Result<(), EncodeError> { + rmp::encode::write_ext_meta(w, self.data.len() as u32, self.tag)?; + w.write_all(self.data).map_err(EncodeError::from)?; + Ok(()) + } +} + impl Encode for serde_json::Value { #[inline] fn encode(&self, w: &mut impl Write, _context: &Context) -> Result<(), EncodeError> { @@ -2223,4 +2303,15 @@ mod tests { assert_eq!(decode::(b"\xce\xff\xff\xff\xff").unwrap(), u32::MAX); assert_eq!(decode::(b"\xcf\xff\xff\xff\xff\xff\xff\xff\xff").unwrap(), u64::MAX); } + + #[test] + fn decode_tuple() { + let value: (Option, Vec, String) = (None, vec![0, 1, 2], "hello".to_string()); + + let encoded = rmp_serde::encode::to_vec(&value).unwrap(); + + let actual: (Option, Vec, String) = decode(&encoded).unwrap(); + + assert_eq!(actual, value); + } } diff --git a/tarantool/src/net_box/mod.rs b/tarantool/src/net_box/mod.rs index eab3b6ab..2a72ccb7 100644 --- a/tarantool/src/net_box/mod.rs +++ b/tarantool/src/net_box/mod.rs @@ -31,9 +31,9 @@ //! - The state machine starts in the `initial` state. //! - [Conn::new()](struct.Conn.html#method.new) method changes the state to `connecting` and spawns a worker fiber. //! - If authentication and schema upload are required, it’s possible later on to re-enter the `fetch_schema` state -//! from `active` if a request fails due to a schema version mismatch error, so schema reload is triggered. +//! from `active` if a request fails due to a schema version mismatch error, so schema reload is triggered. //! - [conn.close()](struct.Conn.html#method.close) method sets the state to `closed` and kills the worker. If the -//! transport is already in the `error` state, [close()](struct.Conn.html#method.close) does nothing. +//! transport is already in the `error` state, [close()](struct.Conn.html#method.close) does nothing. //! //! See also: //! - [Lua reference: Module net.box](https://www.tarantool.io/en/doc/latest/reference/reference_lua/net_box/) @@ -237,6 +237,7 @@ mod tests { ConnOptions { user: "test_user".into(), password: "password".into(), + auth_method: crate::auth::AuthMethod::ChapSha1, ..ConnOptions::default() }, None, @@ -442,7 +443,7 @@ mod tests { .to_string(); assert_eq!( err, - "server responded with error: PasswordMismatch: User not found or supplied credentials are invalid" + "server responded with error: System: Invalid credentials" ); } diff --git a/tarantool/src/network/client/mod.rs b/tarantool/src/network/client/mod.rs index 71f39e3c..a684b8af 100644 --- a/tarantool/src/network/client/mod.rs +++ b/tarantool/src/network/client/mod.rs @@ -478,6 +478,7 @@ mod tests { listen_port(), protocol::Config { creds: Some(("test_user".into(), "password".into())), + auth_method: crate::auth::AuthMethod::ChapSha1, ..Default::default() }, ) @@ -891,7 +892,7 @@ mod tests { // first request let err = client.eval("return", &()).await.unwrap_err().to_string(); #[rustfmt::skip] - assert_eq!(err, "server responded with error: PasswordMismatch: User not found or supplied credentials are invalid"); + assert_eq!(err, "server responded with error: System: Invalid credentials"); } // Wrong auth method @@ -943,7 +944,12 @@ mod tests { assert_eq!(e.error_type(), "ClientError"); assert_eq!(e.file(), Some("eval")); assert_eq!(e.line(), Some(1)); - assert_eq!(e.fields().len(), 0); + let fields_len = e.fields().len(); + // Starting from tarantool 2.11.5 it will contain `name` field + assert!(fields_len == 1 || fields_len == 0); + if fields_len == 1 { + assert_eq!(e.fields()["name"], rmpv::Value::from("UNSUPPORTED")); + } let e = e.cause().unwrap(); diff --git a/tarantool/src/network/client/reconnect.rs b/tarantool/src/network/client/reconnect.rs index 7935b92c..6aa91359 100644 --- a/tarantool/src/network/client/reconnect.rs +++ b/tarantool/src/network/client/reconnect.rs @@ -177,6 +177,7 @@ mod tests { listen_port(), protocol::Config { creds: Some(("test_user".into(), "password".into())), + auth_method: crate::auth::AuthMethod::ChapSha1, ..Default::default() }, ) diff --git a/tarantool/src/network/client/tcp.rs b/tarantool/src/network/client/tcp.rs index c9f4be7f..0c4fae75 100644 --- a/tarantool/src/network/client/tcp.rs +++ b/tarantool/src/network/client/tcp.rs @@ -1,3 +1,5 @@ +#![allow(deprecated)] + //! Contains an implementation of a custom async coio based [`TcpStream`]. //! //! ## Example @@ -16,17 +18,17 @@ //! # }; //! ``` -use std::cell::{self, Cell}; +use std::cell::Cell; use std::ffi::{CString, NulError}; use std::future::{self}; use std::mem::{self, MaybeUninit}; use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd}; use std::os::unix::io::RawFd; use std::pin::Pin; -use std::rc::{self, Rc}; +use std::rc::Rc; use std::task::{Context, Poll}; use std::time::Duration; -use std::{io, marker, vec}; +use std::{io, marker}; #[cfg(feature = "async-std")] use async_std::io::{Read as AsyncRead, Write as AsyncWrite}; @@ -104,6 +106,54 @@ impl Drop for AutoCloseFd { } } +/// A store for raw file descriptor so we can allow cloning actual `TcpStream` properly. +#[derive(Debug)] +struct TcpInner { + /// A raw tcp socket file descriptor. Replaced with `None` when the stream + /// is closed. + fd: Cell>, +} + +impl TcpInner { + #[inline(always)] + #[track_caller] + fn close(&self) -> io::Result<()> { + let Some(fd) = self.fd.take() else { + return Ok(()); + }; + // SAFETY: safe because we close the `fd` only once + let rc = unsafe { ffi::coio_close(fd) }; + if rc != 0 { + let e = io::Error::last_os_error(); + if e.raw_os_error() == Some(libc::EBADF) { + crate::say_error!("close({fd}): Bad file descriptor"); + if cfg!(debug_assertions) { + panic!("close({}): Bad file descriptor", fd); + } + } + return Err(e); + } + Ok(()) + } + + #[inline(always)] + fn fd(&self) -> io::Result { + let Some(fd) = self.fd.get() else { + let e = io::Error::new(io::ErrorKind::Other, "socket closed already"); + return Err(e); + }; + Ok(fd) + } +} + +impl Drop for TcpInner { + fn drop(&mut self) { + if let Err(e) = self.close() { + crate::say_error!("TcpInner::drop: closing tcp stream inner failed: {e}"); + } + } +} + /// Async TcpStream based on fibers and coio. /// /// Use [timeout][t] on top of read or write operations on [`TcpStream`] @@ -118,15 +168,8 @@ impl Drop for AutoCloseFd { /// [t]: crate::fiber::async::timeout::timeout #[derive(Debug, Clone)] pub struct TcpStream { - /// A raw tcp socket file descriptor. Replaced with `None` when the stream - /// is closed. - /// - /// Note that it's wrapped in a `Rc`, because the outer `TcpStream` needs to - /// be mutably borrowable (thanks to AsyncWrite & AsyncRead traits) and it - /// doesn't make sense to wrap it in a Mutex of any sort, because it's - /// perfectly safe to read & write on a tcp socket even from concurrent threads, - /// but we only use it from different fibers. - fd: Rc>>, + /// An actual fd which also stored it's open/close state. + inner: Rc, } impl TcpStream { @@ -248,32 +291,22 @@ impl TcpStream { #[inline(always)] #[track_caller] - pub fn close(&mut self) -> io::Result<()> { - let Some(fd) = self.fd.take() else { - // Already closed. - return Ok(()); - }; - - // SAFETY: safe because we close the `fd` only once - let rc = unsafe { ffi::coio_close(fd) }; - if rc != 0 { - let e = io::Error::last_os_error(); - if e.raw_os_error() == Some(libc::EBADF) { - crate::say_error!("close({fd}): Bad file descriptor"); - if cfg!(debug_assertions) { - panic!("close({}): Bad file descriptor", fd); - } - } - return Err(e); - } - Ok(()) + pub fn close(&self) -> io::Result<()> { + self.inner.close() } } +/// SAFETY: completely unsafe, but we are allowed to do this cause sending/sharing following stream to/from another thread +/// SAFETY: will take no effect due to no runtime within it +unsafe impl Send for TcpStream {} +unsafe impl Sync for TcpStream {} + impl From for TcpStream { fn from(value: RawFd) -> Self { Self { - fd: rc::Rc::new(cell::Cell::new(Some(value))), + inner: Rc::new(TcpInner { + fd: Cell::new(Some(value)), + }), } } } @@ -290,10 +323,7 @@ impl AsyncWrite for TcpStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let Some(fd) = self.fd.get() else { - let e = io::Error::new(io::ErrorKind::Other, "socket closed already"); - return Poll::Ready(Err(e)); - }; + let fd = self.inner.fd()?; let (result, err) = ( // `self.fd` must be nonblocking for this to work correctly @@ -325,11 +355,7 @@ impl AsyncWrite for TcpStream { } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - if self.fd.get().is_none() { - let e = io::Error::new(io::ErrorKind::Other, "socket closed already"); - return Poll::Ready(Err(e)); - }; - + self.inner.fd()?; // [`TcpStream`] similarily to std does not buffer anything, // so there is nothing to flush. // @@ -337,13 +363,9 @@ impl AsyncWrite for TcpStream { Poll::Ready(Ok(())) } - fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - if self.fd.get().is_none() { - let e = io::Error::new(io::ErrorKind::Other, "socket closed already"); - return Poll::Ready(Err(e)); - }; - - let res = self.close(); + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.inner.fd()?; + let res = self.inner.close(); Poll::Ready(res) } } @@ -354,13 +376,10 @@ impl AsyncRead for TcpStream { cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let Some(fd) = self.fd.get() else { - let e = io::Error::new(io::ErrorKind::Other, "socket closed already"); - return Poll::Ready(Err(e)); - }; + let fd = self.inner.fd()?; let (result, err) = ( - // `self.fd` must be nonblocking for this to work correctly + // `self.inner.fd` must be nonblocking for this to work correctly unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) }, io::Error::last_os_error(), ); @@ -389,14 +408,6 @@ impl AsyncRead for TcpStream { } } -impl Drop for TcpStream { - fn drop(&mut self) { - if let Err(e) = self.close() { - crate::say_error!("TcpStream::drop: closing tcp stream failed: {e}"); - } - } -} - /// Resolves provided url and port to a sequence of sock addrs. /// /// # Returns @@ -628,6 +639,7 @@ impl<'a> From<&'a SockAddr> for AddrInfo<'a> { /// necessary when working with our async runtime, which is single threaded. #[derive(Debug, Clone)] #[repr(transparent)] +#[deprecated = "Use `TcpStream` instead"] pub struct UnsafeSendSyncTcpStream(pub TcpStream); unsafe impl Send for UnsafeSendSyncTcpStream {} @@ -810,6 +822,16 @@ mod tests { stream.read_exact(&mut buf).timeout(_10_SEC).await.unwrap(); } + #[crate::test(tarantool = "crate")] + async fn read_clone() { + let mut stream = TcpStream::connect_timeout("localhost", listen_port(), _10_SEC).unwrap(); + let cloned = stream.clone(); + drop(cloned); + // Read greeting + let mut buf = vec![0; 128]; + stream.read_exact(&mut buf).timeout(_10_SEC).await.unwrap(); + } + #[crate::test(tarantool = "crate")] async fn read_timeout() { let mut stream = TcpStream::connect_timeout("localhost", listen_port(), _10_SEC).unwrap(); @@ -855,6 +877,39 @@ mod tests { assert_eq!(buf, vec![1, 2, 3, 4, 5]) } + #[crate::test(tarantool = "crate")] + fn write_clone() { + let (sender, receiver) = std::sync::mpsc::channel(); + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + // Spawn listener + thread::spawn(move || { + for stream in listener.incoming() { + let mut stream = stream.unwrap(); + let mut buf = vec![]; + ::read_to_end(&mut stream, &mut buf).unwrap(); + sender.send(buf).unwrap(); + } + }); + // Send data + { + fiber::block_on(async { + let mut stream = + TcpStream::connect_timeout("localhost", addr.port(), _10_SEC).unwrap(); + let cloned = stream.clone(); + drop(cloned); + timeout::timeout(_10_SEC, stream.write_all(&[1, 2, 3])) + .await + .unwrap(); + timeout::timeout(_10_SEC, stream.write_all(&[4, 5])) + .await + .unwrap(); + }); + } + let buf = receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + assert_eq!(buf, vec![1, 2, 3, 4, 5]) + } + #[crate::test(tarantool = "crate")] fn split() { let (sender, receiver) = std::sync::mpsc::channel(); diff --git a/tarantool/src/network/protocol/api.rs b/tarantool/src/network/protocol/api.rs index 2c52cffc..674a4b1b 100644 --- a/tarantool/src/network/protocol/api.rs +++ b/tarantool/src/network/protocol/api.rs @@ -54,7 +54,7 @@ pub struct Call<'a, 'b, T: ?Sized> { pub args: &'b T, } -impl<'a, 'b, T> Request for Call<'a, 'b, T> +impl Request for Call<'_, '_, T> where T: ToTupleBuffer + ?Sized, { @@ -77,7 +77,7 @@ pub struct Eval<'a, 'b, T: ?Sized> { pub args: &'b T, } -impl<'a, 'b, T> Request for Eval<'a, 'b, T> +impl Request for Eval<'_, '_, T> where T: ToTupleBuffer + ?Sized, { @@ -100,7 +100,7 @@ pub struct Execute<'a, 'b, T: ?Sized> { pub bind_params: &'b T, } -impl<'a, 'b, T> Request for Execute<'a, 'b, T> +impl Request for Execute<'_, '_, T> where T: ToTupleBuffer + ?Sized, { @@ -125,7 +125,7 @@ pub struct Auth<'u, 'p, 's> { pub method: crate::auth::AuthMethod, } -impl<'u, 'p, 's> Request for Auth<'u, 'p, 's> { +impl Request for Auth<'_, '_, '_> { const TYPE: IProtoType = IProtoType::Auth; type Response = (); @@ -149,7 +149,7 @@ pub struct Select<'a, T: ?Sized> { pub key: &'a T, } -impl<'a, T> Request for Select<'a, T> +impl Request for Select<'_, T> where T: ToTupleBuffer + ?Sized, { @@ -183,7 +183,7 @@ where pub value: &'a T, } -impl<'a, T> Request for Insert<'a, T> +impl Request for Insert<'_, T> where T: ToTupleBuffer + ?Sized, { @@ -210,7 +210,7 @@ where pub value: &'a T, } -impl<'a, T> Request for Replace<'a, T> +impl Request for Replace<'_, T> where T: ToTupleBuffer + ?Sized, { @@ -239,7 +239,7 @@ where pub ops: &'a [Op], } -impl<'a, T, Op> Request for Update<'a, T, Op> +impl Request for Update<'_, T, Op> where T: ToTupleBuffer + ?Sized, Op: Encode, @@ -269,7 +269,7 @@ where pub ops: &'a [Op], } -impl<'a, T, Op> Request for Upsert<'a, T, Op> +impl Request for Upsert<'_, T, Op> where T: ToTupleBuffer + ?Sized, Op: Encode, @@ -298,7 +298,7 @@ where pub key: &'a T, } -impl<'a, T> Request for Delete<'a, T> +impl Request for Delete<'_, T> where T: ToTupleBuffer + ?Sized, { diff --git a/tarantool/src/network/protocol/codec.rs b/tarantool/src/network/protocol/codec.rs index 25b51bfe..29cc5b2b 100644 --- a/tarantool/src/network/protocol/codec.rs +++ b/tarantool/src/network/protocol/codec.rs @@ -11,6 +11,8 @@ use crate::tuple::{ToTupleBuffer, Tuple}; use super::SyncIndex; +const MP_STR_MAX_HEADER_SIZE: usize = 5; + /// Keys of the HEADER and BODY maps in the iproto packets. /// /// See `enum iproto_key` in \/src/box/iproto_constants.h for source @@ -105,7 +107,9 @@ pub fn encode_header( helper.encode(stream) } -pub fn chap_sha1_auth_data(password: &str, salt: &[u8]) -> Vec { +/// Prepares (hashes) password with salt according to CHAP-SHA1 algorithm. +#[inline] +pub fn chap_sha1_prepare(password: impl AsRef<[u8]>, salt: &[u8]) -> Vec { // prepare 'chap-sha1' scramble: // salt = base64_decode(encoded_salt); // step_1 = sha1(password); @@ -116,7 +120,7 @@ pub fn chap_sha1_auth_data(password: &str, salt: &[u8]) -> Vec { use sha1::{Digest as Sha1Digest, Sha1}; let mut hasher = Sha1::new(); - hasher.update(password.as_bytes()); + hasher.update(password); let mut step_1_and_scramble = hasher.finalize(); let mut hasher = Sha1::new(); @@ -133,29 +137,51 @@ pub fn chap_sha1_auth_data(password: &str, salt: &[u8]) -> Vec { .zip(step_3.iter()) .for_each(|(a, b)| *a ^= *b); - let scramble_bytes = &step_1_and_scramble.as_slice(); + let scramble_bytes = step_1_and_scramble.to_vec(); debug_assert_eq!(scramble_bytes.len(), 20); + scramble_bytes +} + +/// Prepares (hashes) password with salt according to CHAP-SHA1 algorithm and encodes into MessagePack. +// TODO(kbezuglyi): password should be `impl AsRef<[u8]>`, not `&str`. +#[inline] +pub fn chap_sha1_auth_data(password: &str, salt: &[u8]) -> Vec { + let hashed_data = chap_sha1_prepare(password, salt); + let hashed_len = hashed_data.len(); - // 5 is the maximum possible MP_STR header size - let mut res = Vec::with_capacity(scramble_bytes.len() + 5); - rmp::encode::write_str_len(&mut res, scramble_bytes.len() as _).expect("Can't fail for a Vec"); - res.write_all(scramble_bytes).expect("Can't fail for a Vec"); - return res; + let mut res = Vec::with_capacity(hashed_len + MP_STR_MAX_HEADER_SIZE); + rmp::encode::write_str_len(&mut res, hashed_len as _).expect("Can't fail for a Vec"); + res.write_all(&hashed_data).expect("Can't fail for a Vec"); + res } +/// Prepares password according to LDAP. +#[cfg(feature = "picodata")] +#[inline] +pub fn ldap_prepare(password: impl AsRef<[u8]>) -> Vec { + password.as_ref().to_vec() +} + +/// Prepares password according to LDAP and encodes into MessagePack. +/// WARNING: data is sent without any encryption, it is recommended +/// to use SSH tunnel/SSL/else to make communication secure. +// TODO(kbezuglyi): password should be `impl AsRef<[u8]>`, not `&str`. #[cfg(feature = "picodata")] #[inline] pub fn ldap_auth_data(password: &str) -> Vec { - // 5 is the maximum possible MP_STR header size - let mut res = Vec::with_capacity(password.len() + 5); - // Hopefully you're using an ssh tunnel or something ¯\_(ツ)_/¯ - rmp::encode::write_str(&mut res, password).expect("Can't fail for a Vec"); - return res; + let hashed_data = ldap_prepare(password); + let hashed_len = hashed_data.len(); + + let mut res = Vec::with_capacity(hashed_len + MP_STR_MAX_HEADER_SIZE); + rmp::encode::write_str_len(&mut res, hashed_len as _).expect("Can't fail for a Vec"); + res.write_all(&hashed_data).expect("Can't fail for a Vec"); + res } +/// Prepares (hashes) password with salt according to MD5. #[cfg(feature = "picodata")] #[inline] -pub fn md5_auth_data(user: &str, password: &str, salt: [u8; 4]) -> Vec { +pub fn md5_prepare(user: &str, password: impl AsRef<[u8]>, salt: [u8; 4]) -> Vec { // recv_from_db(salt) // recv_from_user(name, password) // shadow_pass = md5(name + password), do not add "md5" prefix @@ -165,7 +191,6 @@ pub fn md5_auth_data(user: &str, password: &str, salt: [u8; 4]) -> Vec { use md5::{Digest as Md5Digest, Md5}; - let mut res = Vec::new(); let mut md5 = Md5::new(); md5.update(password); @@ -176,8 +201,21 @@ pub fn md5_auth_data(user: &str, password: &str, salt: [u8; 4]) -> Vec { md5.update(salt); let client_pass = format!("md5{:x}", md5.finalize()); - rmp::encode::write_str(&mut res, &client_pass).expect("Can't fail for a Vec"); - return res; + client_pass.into_bytes() +} + +/// Prepares (hashes) password with salt according to MD5 and encodes into MessagePack. +// TODO(kbezuglyi): password should be `impl AsRef<[u8]>`, not `&str`. +#[cfg(feature = "picodata")] +#[inline] +pub fn md5_auth_data(user: &str, password: &str, salt: [u8; 4]) -> Vec { + let hashed_data = md5_prepare(user, password, salt); + let hashed_len = hashed_data.len(); + + let mut res = Vec::with_capacity(hashed_len + MP_STR_MAX_HEADER_SIZE); + rmp::encode::write_str_len(&mut res, hashed_len as _).expect("Can't fail for a Vec"); + res.write_all(&hashed_data).expect("Can't fail for a Vec"); + res } pub fn encode_auth( @@ -405,7 +443,7 @@ impl Header { pub fn encode(&self, stream: &mut impl Write) -> Result<(), Error> { rmp::encode::write_map_len(stream, 2)?; rmp::encode::write_pfix(stream, REQUEST_TYPE)?; - rmp::encode::write_pfix(stream, self.iproto_type as u8)?; + rmp::encode::write_uint(stream, self.iproto_type as _)?; rmp::encode::write_pfix(stream, SYNC)?; rmp::encode::write_uint(stream, self.sync.0)?; Ok(()) diff --git a/tarantool/src/read_view.rs b/tarantool/src/read_view.rs index f46146d7..845f93b1 100644 --- a/tarantool/src/read_view.rs +++ b/tarantool/src/read_view.rs @@ -129,7 +129,7 @@ impl<'a> Iterator for ReadViewIterator<'a> { } } -impl<'a> Drop for ReadViewIterator<'a> { +impl Drop for ReadViewIterator<'_> { #[inline(always)] fn drop(&mut self) { unsafe { ffi::box_read_view_iterator_free(self.inner.as_ptr()) } diff --git a/tarantool/src/sequence.rs b/tarantool/src/sequence.rs index 1cbda430..35de853b 100644 --- a/tarantool/src/sequence.rs +++ b/tarantool/src/sequence.rs @@ -28,8 +28,8 @@ impl Sequence { /// The generation algorithm is simple: /// - If this is the first time, then return the `start` value. /// - If the previous value plus the `increment` value is less than the `minimum` value or greater than the - /// `maximum` value, that is "overflow", so either raise an error (if `cycle = false`) or return the `maximum` value - /// (if `cycle = true` and `step < 0`) or return the `minimum` value (if `cycle = true` and `step > 0`). + /// `maximum` value, that is "overflow", so either raise an error (if `cycle = false`) or return the `maximum` value + /// (if `cycle = true` and `step < 0`) or return the `minimum` value (if `cycle = true` and `step > 0`). /// /// If there was no error, then save the returned result, it is now the "previous value". pub fn next(&mut self) -> Result { diff --git a/tarantool/src/space.rs b/tarantool/src/space.rs index cd45eaae..45fe4308 100644 --- a/tarantool/src/space.rs +++ b/tarantool/src/space.rs @@ -6,13 +6,13 @@ //! See also: //! - [Lua reference: Submodule box.space](https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_space/) //! - [C API reference: Module box](https://www.tarantool.io/en/doc/latest/dev_guide/reference_capi/box/) -use crate::error::{Error, TarantoolError}; +use crate::error::{Error, TarantoolError, TarantoolErrorCode}; use crate::ffi::tarantool as ffi; use crate::index::{Index, IndexIterator, IteratorType}; use crate::tuple::{Encode, ToTupleBuffer, Tuple, TupleBuffer}; -use crate::unwrap_or; use crate::util::Value; use crate::{msgpack, tuple_from_box_api}; +use crate::{set_error, unwrap_or}; use serde::{Deserialize, Serialize}; use serde_json::Map; use std::borrow::Cow; @@ -445,9 +445,8 @@ impl SpaceCache { // TODO: clear the cache if box_schema_version changes. let mut cache = self.spaces.borrow_mut(); cache.get(name).cloned().or_else(|| { - Space::find(name).map(|space| { + Space::find(name).inspect(|space| { cache.insert(name.to_string(), space.clone()); - space }) }) } @@ -459,9 +458,8 @@ impl SpaceCache { .get(&(space.id, name.to_string())) .cloned() .or_else(|| { - space.index(name).map(|index| { + space.index(name).inspect(|index| { cache.insert((space.id, name.to_string()), index.clone()); - index }) }) } @@ -747,7 +745,16 @@ impl Space { /// excluding index keys. For a measure of index size, see [index.bsize()](../index/struct.Index.html#method.bsize). #[inline(always)] pub fn bsize(&self) -> Result { - self.primary_key().bsize() + let space = unsafe { ffi::space_by_id(self.id) }; + if space.is_null() { + set_error!( + TarantoolErrorCode::NoSuchSpace, + "Space {} does not exist", + self.id + ); + return Err(TarantoolError::last().into()); + } + Ok(unsafe { ffi::space_bsize(space) }) } /// Search for a tuple in the given space. @@ -1370,7 +1377,7 @@ pub fn space_id_temporary_min() -> Option { static mut VALUE: Option> = None; // SAFETY: this is safe as we only call this from tx thread. unsafe { - if VALUE.is_none() { + if (*std::ptr::addr_of!(VALUE)).is_none() { VALUE = Some( crate::lua_state() .eval("return box.schema.SPACE_ID_TEMPORARY_MIN") diff --git a/tarantool/src/sql.rs b/tarantool/src/sql.rs index f8327d45..d96685b2 100644 --- a/tarantool/src/sql.rs +++ b/tarantool/src/sql.rs @@ -2,12 +2,15 @@ use crate::error::TarantoolError; use crate::ffi; -use crate::ffi::sql::ObufWrapper; +use crate::ffi::sql::{ObufWrapper, PortC}; use serde::Serialize; +use std::borrow::Cow; use std::io::Read; use std::os::raw::c_char; use std::str; +const MP_EMPTY_ARRAY: &[u8] = &[0x90]; + /// Returns the hash, used as the statement ID, generated from the SQL query text. pub fn calculate_hash(sql: &str) -> u32 { unsafe { ffi::sql::sql_stmt_calculate_id(sql.as_ptr() as *const c_char, sql.len()) } @@ -24,10 +27,9 @@ where IN: Serialize, { let mut buf = ObufWrapper::new(1024); - // 0x90 is an empty mp array - let mut param_data = vec![0x90]; + let mut param_data = Cow::from(MP_EMPTY_ARRAY); if std::mem::size_of::() != 0 { - param_data = rmp_serde::to_vec(bind_params)?; + param_data = Cow::from(rmp_serde::to_vec(bind_params)?); debug_assert!(crate::msgpack::skip_value(&mut std::io::Cursor::new(¶m_data)).is_ok()); } let param_ptr = param_data.as_ptr() as *const u8; @@ -46,6 +48,36 @@ where Ok(buf) } +pub fn sql_execute_into_port( + query: &str, + bind_params: &IN, + vdbe_max_steps: u64, + port: &mut PortC, +) -> crate::Result<()> +where + IN: Serialize, +{ + let mut param_data = Cow::from(MP_EMPTY_ARRAY); + if std::mem::size_of::() != 0 { + param_data = Cow::from(rmp_serde::to_vec(bind_params)?); + debug_assert!(crate::msgpack::skip_value(&mut std::io::Cursor::new(¶m_data)).is_ok()); + } + let param_ptr = param_data.as_ptr() as *const u8; + let execute_result = unsafe { + ffi::sql::sql_execute_into_port( + query.as_ptr() as *const u8, + query.len() as i32, + param_ptr, + vdbe_max_steps, + port.as_mut_ptr(), + ) + }; + if execute_result < 0 { + return Err(TarantoolError::last().into()); + } + Ok(()) +} + /// Creates new SQL prepared statement and stores it in the session. /// query - SQL query. /// @@ -55,7 +87,6 @@ where /// already existing statement within the same session does not increase the /// instance cache counter. However, calling prepare on the statement in a /// different session without the statement does increase the counter. - pub fn prepare(query: String) -> crate::Result { let mut stmt_id: u32 = 0; let mut session_id: u64 = 0; @@ -121,10 +152,9 @@ impl Statement { IN: Serialize, { let mut buf = ObufWrapper::new(1024); - // 0x90 is an empty mp array - let mut param_data = vec![0x90]; + let mut param_data = Cow::from(MP_EMPTY_ARRAY); if std::mem::size_of::() != 0 { - param_data = rmp_serde::to_vec(bind_params)?; + param_data = Cow::from(rmp_serde::to_vec(bind_params)?); debug_assert!( crate::msgpack::skip_value(&mut std::io::Cursor::new(¶m_data)).is_ok() ); @@ -139,4 +169,36 @@ impl Statement { } Ok(buf) } + + pub fn execute_into_port( + &self, + bind_params: &IN, + vdbe_max_steps: u64, + port: &mut PortC, + ) -> crate::Result<()> + where + IN: Serialize, + { + let mut param_data = Cow::from(MP_EMPTY_ARRAY); + if std::mem::size_of::() != 0 { + param_data = Cow::from(rmp_serde::to_vec(bind_params)?); + debug_assert!( + crate::msgpack::skip_value(&mut std::io::Cursor::new(¶m_data)).is_ok() + ); + } + let param_ptr = param_data.as_ptr() as *const u8; + let execute_result = unsafe { + ffi::sql::stmt_execute_into_port( + self.id(), + param_ptr, + vdbe_max_steps, + port.as_mut_ptr(), + ) + }; + + if execute_result < 0 { + return Err(TarantoolError::last().into()); + } + Ok(()) + } } diff --git a/tarantool/src/test.rs b/tarantool/src/test.rs index f316474a..1c201a0a 100644 --- a/tarantool/src/test.rs +++ b/tarantool/src/test.rs @@ -308,6 +308,7 @@ pub mod util { let guard = on_scope_exit(move || { crate::say_info!("killing ldap server"); ldap_server_process.kill().unwrap(); + ldap_server_process.wait().unwrap(); // Remove the temporary directory with it's contents drop(tempdir); diff --git a/tarantool/src/tuple.rs b/tarantool/src/tuple.rs index bfc5146a..82e7816a 100644 --- a/tarantool/src/tuple.rs +++ b/tarantool/src/tuple.rs @@ -22,10 +22,11 @@ use rmp::Marker; use serde::Serialize; use crate::error::{self, Error, Result, TarantoolError}; +#[cfg(feature = "picodata")] +use crate::ffi::sql::PortC; use crate::ffi::tarantool as ffi; use crate::index; use crate::tlua; -use crate::util::NumOrStr; /// Tuple pub struct Tuple { @@ -299,7 +300,7 @@ impl Tuple { } //////////////////////////////////////////////////////////////////////////////// -/// TupleIndex +// TupleIndex //////////////////////////////////////////////////////////////////////////////// /// Types implementing this trait can be used as arguments for the @@ -330,14 +331,13 @@ impl TupleIndex for &str { { use once_cell::sync::Lazy; use std::io::{Error as IOError, ErrorKind}; - static API: Lazy> = Lazy::new(|| unsafe { - let c_str = std::ffi::CStr::from_bytes_with_nul_unchecked; - let lib = dlopen::symbor::Library::open_self()?; - let err = match lib.symbol_cstr(c_str(ffi::TUPLE_FIELD_BY_PATH_NEW_API.as_bytes())) { + static API: Lazy> = Lazy::new(|| unsafe { + let lib = libloading::os::unix::Library::this(); + let err = match lib.get(ffi::TUPLE_FIELD_BY_PATH_NEW_API.as_bytes()) { Ok(api) => return Ok(Api::New(*api)), Err(e) => e, }; - if let Ok(api) = lib.symbol_cstr(c_str(ffi::TUPLE_FIELD_BY_PATH_OLD_API.as_bytes())) { + if let Ok(api) = lib.get(ffi::TUPLE_FIELD_BY_PATH_OLD_API.as_bytes()) { return Ok(Api::Old(*api)); } Err(err) @@ -423,7 +423,7 @@ impl<'de> serde_bytes::Deserialize<'de> for Tuple { } //////////////////////////////////////////////////////////////////////////////// -/// ToTupleBuffer +// ToTupleBuffer //////////////////////////////////////////////////////////////////////////////// /// Types implementing this trait can be converted to tarantool tuple (msgpack @@ -486,7 +486,7 @@ where } //////////////////////////////////////////////////////////////////////////////// -/// Encode +// Encode //////////////////////////////////////////////////////////////////////////////// /// Types implementing this trait can be serialized into a valid tarantool tuple @@ -500,7 +500,7 @@ pub trait Encode: Serialize { } } -impl<'a, T> Encode for &'a T +impl Encode for &'_ T where T: Encode, { @@ -550,7 +550,7 @@ macro_rules! impl_tuple { impl_tuple! { A B C D E F G H I J K L M N O P } //////////////////////////////////////////////////////////////////////////////// -/// TupleBuffer +// TupleBuffer //////////////////////////////////////////////////////////////////////////////// /// Buffer containing tuple contents (MsgPack array) @@ -690,7 +690,7 @@ impl<'de> serde_bytes::Deserialize<'de> for TupleBuffer { } //////////////////////////////////////////////////////////////////////////////// -/// TupleFormat +// TupleFormat //////////////////////////////////////////////////////////////////////////////// /// Tuple format @@ -741,7 +741,7 @@ impl Debug for TupleFormat { } //////////////////////////////////////////////////////////////////////////////// -/// TupleIterator +// TupleIterator //////////////////////////////////////////////////////////////////////////////// /// Tuple iterator @@ -936,12 +936,7 @@ impl<'a> KeyDefPart<'a> { } } - pub fn try_from_index_part(p: &'a index::Part) -> Option { - let field_no = match p.field { - NumOrStr::Num(field_no) => field_no, - NumOrStr::Str(_) => return None, - }; - + pub fn from_index_part(p: &'a index::Part) -> Self { let collation = p.collation.as_deref().map(|s| { CString::new(s) .expect("it's your fault if you put '\0' in collation") @@ -952,13 +947,13 @@ impl<'a> KeyDefPart<'a> { .expect("it's your fault if you put '\0' in collation") .into() }); - Some(Self { - field_no, + Self { + field_no: p.field, field_type: p.r#type.map(From::from).unwrap_or(FieldType::Any), is_nullable: p.is_nullable.unwrap_or(false), collation, path, - }) + } } } @@ -1112,12 +1107,10 @@ impl Drop for KeyDef { } } -impl std::convert::TryFrom<&index::Metadata<'_>> for KeyDef { - type Error = index::FieldMustBeNumber; - +impl From<&index::Metadata<'_>> for KeyDef { #[inline(always)] - fn try_from(meta: &index::Metadata<'_>) -> std::result::Result { - meta.try_to_key_def() + fn from(meta: &index::Metadata<'_>) -> Self { + meta.to_key_def() } } @@ -1146,7 +1139,7 @@ where } //////////////////////////////////////////////////////////////////////////////// -/// FunctionCtx +// FunctionCtx //////////////////////////////////////////////////////////////////////////////// #[repr(C)] @@ -1208,10 +1201,21 @@ impl FunctionCtx { Ok(result) } } + + #[cfg(feature = "picodata")] + #[inline] + pub fn mut_port_c(&mut self) -> &mut PortC { + unsafe { + let mut ctx = NonNull::new_unchecked(self.inner); + NonNull::new_unchecked(ctx.as_mut().port) + .as_mut() + .as_mut_port_c() + } + } } //////////////////////////////////////////////////////////////////////////////// -/// FunctionArgs +// FunctionArgs //////////////////////////////////////////////////////////////////////////////// #[repr(C)] @@ -1370,7 +1374,7 @@ where } //////////////////////////////////////////////////////////////////////////////// -/// Decode +// Decode //////////////////////////////////////////////////////////////////////////////// /// Types implementing this trait can be decoded from msgpack. @@ -1409,7 +1413,7 @@ pub trait DecodeOwned: for<'de> Decode<'de> {} impl DecodeOwned for T where T: for<'de> Decode<'de> {} //////////////////////////////////////////////////////////////////////////////// -/// RawBytes +// RawBytes //////////////////////////////////////////////////////////////////////////////// /// A wrapper type for reading raw bytes from a tuple. @@ -1486,7 +1490,7 @@ impl std::borrow::ToOwned for RawBytes { } //////////////////////////////////////////////////////////////////////////////// -/// RawByteBuf +// RawByteBuf //////////////////////////////////////////////////////////////////////////////// /// A wrapper type for reading raw bytes from a tuple. @@ -1576,15 +1580,18 @@ impl std::borrow::Borrow for RawByteBuf { #[cfg(feature = "picodata")] mod picodata { use super::*; - use crate::Result; - use std::ffi::CStr; - use std::io::{Cursor, Write}; + use crate::static_ref; //////////////////////////////////////////////////////////////////////////// // Tuple picodata extensions //////////////////////////////////////////////////////////////////////////// impl Tuple { + /// !!! + /// WARNING: + /// NO LONGER SUPPORTED - PANICS WHEN USED! + /// !!! + /// /// Returns messagepack encoded tuple with named fields (messagepack map). /// /// Returned map has only numeric keys if tuple has default tuple format (see [TupleFormat](struct.TupleFormat.html)), @@ -1592,50 +1599,20 @@ mod picodata { /// fields in tuple format - then additional fields are presents in the map with numeric keys. /// /// This function is useful if there is no information about tuple fields in program runtime. + #[deprecated = "did not find its use"] + #[inline(always)] pub fn as_named_buffer(&self) -> Result> { - let format = self.format(); - let buff = self.to_vec(); - - let field_count = self.len(); - let mut named_buffer = Vec::with_capacity(buff.len()); - - let mut cursor = Cursor::new(&buff); - - rmp::encode::write_map_len(&mut named_buffer, field_count)?; - rmp::decode::read_array_len(&mut cursor)?; - format.names().try_for_each(|field_name| -> Result<()> { - let value_start = cursor.position() as usize; - crate::msgpack::skip_value(&mut cursor)?; - let value_end = cursor.position() as usize; - - rmp::encode::write_str(&mut named_buffer, field_name)?; - Ok(named_buffer.write_all(&buff[value_start..value_end])?) - })?; - - for i in 0..field_count - format.name_count() { - let value_start = cursor.position() as usize; - crate::msgpack::skip_value(&mut cursor)?; - let value_end = cursor.position() as usize; - - rmp::encode::write_u32(&mut named_buffer, i)?; - named_buffer.write_all(&buff[value_start..value_end])?; - } - - Ok(named_buffer) + crate::say_error!("Tuple::as_named_buffer is no longer supported"); + panic!("Tuple::as_named_buffer is no longer supported"); } /// Returns a slice of data contained in the tuple. #[inline] pub fn data(&self) -> &[u8] { - // Safety: safe because we only construct `Tuple` from valid pointers to `box_tuple_t`. - let tuple = unsafe { self.ptr.as_ref() }; - // Safety: this is how tuple data is stored in picodata's tarantool-2.11.2-137-ga0f7c15f75 unsafe { - let data_offset = tuple.data_offset(); - let data = (tuple as *const ffi::BoxTuple) - .cast::() - .offset(data_offset as _); - std::slice::from_raw_parts(data, tuple.bsize()) + // Safety: safe because we only construct `Tuple` from valid pointers to `box_tuple_t`. + let tuple = self.ptr.as_ref(); + tuple.data() } } } @@ -1677,7 +1654,7 @@ mod picodata { static mut SINGLETON: Option = None; // Safety: only makes sense to call this from tx thread - if unsafe { SINGLETON.is_none() } { + if unsafe { static_ref!(mut SINGLETON) }.is_none() { // Safety: this code is valid for picodata's tarantool-2.11.2-137-ga0f7c15f75. unsafe { let inner = ffi::box_tuple_format_new(std::ptr::null_mut(), 0); @@ -1686,26 +1663,40 @@ mod picodata { SINGLETON = Some(Self { inner }); } } - unsafe { SINGLETON.as_ref().expect("just made sure it's there") } + + unsafe { static_ref!(const SINGLETON) } + .as_ref() + .expect("just made sure it's there") } + /// !!! + /// WARNING: + /// NO LONGER SUPPORTED - PANICS WHEN USED! + /// !!! + /// /// Return tuple field names count. + #[deprecated = "did not find its use"] pub fn name_count(&self) -> u32 { - unsafe { (*(*self.inner).dict).name_count } + crate::say_error!("TupleFormat::name_count is no longer supported"); + panic!("TupleFormat::name_count is no longer supported"); } + /// !!! + /// WARNING: + /// NO LONGER SUPPORTED - PANICS WHEN USED! + /// !!! + /// /// Return tuple field names. + #[deprecated = "did not find its use"] pub fn names(&self) -> impl Iterator { - // Safety: this code is valid for picodata's tarantool-2.11.2-137-ga0f7c15f75. - let slice = unsafe { - std::slice::from_raw_parts((*(*self.inner).dict).names, self.name_count() as _) - }; - slice.iter().copied().map(|ptr| { - // Safety: this code is valid for picodata's tarantool-2.11.2-137-ga0f7c15f75. - let cstr = unsafe { CStr::from_ptr(ptr) }; - let s = cstr.to_str().expect("tuple fields should be in utf-8"); - s - }) + crate::say_error!("TupleFormat::names is no longer supported"); + // weird hack over "never type does not implement all traits". + // if we just panic, it will infer the type of a function as a never type + // and it won't compile because `!` does not implement `Iterator` trait. + // wrapping in a closure will make it behave like `unwrap` on a failure + (|| panic!("TupleFormat::names is no longer supported"))(); + // this is also part of a hack because we need correct type inference + Vec::new().into_iter() } } } @@ -2016,7 +2007,7 @@ mod test { .part("id") .part("s") .part( - index::Part::field("nested") + index::Part::::field("nested") .field_type(index::FieldType::Unsigned) .path("[2].blabla"), ) diff --git a/tarantool/src/util.rs b/tarantool/src/util.rs index 02f492e2..121a1d0a 100644 --- a/tarantool/src/util.rs +++ b/tarantool/src/util.rs @@ -3,6 +3,29 @@ use serde::{Deserialize, Serialize}; use std::borrow::Cow; use std::ffi::CString; +/// Exists to optimize code burden with access to mutable statics +/// and avoid annoying warnings by Clippy about in-place dereference. +/// It is pretty much a `addr_of*`/`&raw` wrapper without warnings, +/// so it is developer responsibility to use it correctly, but it +/// is not marked as unsafe in an expansion to prevent silly mistakes +/// and to have an opportunity for reviewers to blame someone easily. +/// +/// Expects `mut` or `const` as a second parameter, to state whether +/// you want mutable (exclusive) or constant (shared) reference to +/// the provided mutable static variable as a first parameter. +/// +/// NOTE: this macro is not that useful if you need a shared reference +/// to an immutable static, as it is considered safe by default. +#[macro_export] +macro_rules! static_ref { + (const $var:ident) => { + &*&raw const $var + }; + (mut $var:ident) => { + &mut *&raw mut $var + }; +} + pub trait IntoClones: Clone { fn into_clones(self) -> Tuple; } @@ -36,6 +59,7 @@ macro_rules! tuple_from_box_api { { let mut result = ::std::mem::MaybeUninit::uninit(); #[allow(unused_unsafe)] + #[allow(clippy::macro_metavars_in_unsafe)] unsafe { if $f($($args),*, result.as_mut_ptr()) < 0 { return Err($crate::error::TarantoolError::last().into()); diff --git a/tarantool/src/uuid.rs b/tarantool/src/uuid.rs index b64f7129..de7186ee 100644 --- a/tarantool/src/uuid.rs +++ b/tarantool/src/uuid.rs @@ -1,5 +1,9 @@ use crate::ffi::uuid as ffi; +use std::convert::{TryFrom, TryInto}; +use std::io::Write; +use crate::msgpack; +use crate::msgpack::{Context, Decode, DecodeError, Encode, EncodeError}; pub use ::uuid::{adapter, Error}; use serde::{Deserialize, Serialize}; @@ -53,7 +57,8 @@ impl Uuid { tt.tl = tt.tl.swap_bytes(); tt.tm = tt.tm.swap_bytes(); tt.th = tt.th.swap_bytes(); - Self::from_bytes(std::mem::transmute(tt)) + let bytes: [u8; 16] = std::mem::transmute(tt); + Self::from_bytes(bytes) } } @@ -164,6 +169,28 @@ impl Uuid { pub const fn to_urn_ref(&self) -> adapter::UrnRef<'_> { self.inner.to_urn_ref() } + + fn from_ext_structure(tag: i8, bytes: &[u8]) -> Result { + if tag != ffi::MP_UUID { + return Err(format!("Expected UUID, found msgpack ext #{}", tag)); + } + + Self::try_from_slice(bytes).ok_or_else(|| { + format!( + "Not enough bytes for UUID: expected 16, got {}", + bytes.len() + ) + }) + } +} + +impl<'a> TryFrom> for Uuid { + type Error = String; + + #[inline(always)] + fn try_from(value: msgpack::ExtStruct<'a>) -> Result { + Self::from_ext_structure(value.tag, value.data) + } } impl From for Uuid { @@ -207,19 +234,33 @@ impl std::str::FromStr for Uuid { } //////////////////////////////////////////////////////////////////////////////// -/// Tuple +// Tuple //////////////////////////////////////////////////////////////////////////////// +impl Encode for Uuid { + fn encode(&self, w: &mut impl Write, context: &Context) -> Result<(), EncodeError> { + msgpack::ExtStruct::new(ffi::MP_UUID, self.as_bytes()).encode(w, context) + } +} + +impl<'de> Decode<'de> for Uuid { + fn decode(r: &mut &'de [u8], context: &Context) -> Result { + msgpack::ExtStruct::decode(r, context)? + .try_into() + .map_err(DecodeError::new::) + } +} + impl serde::Serialize for Uuid { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { #[derive(Serialize)] - struct _ExtStruct((i8, serde_bytes::ByteBuf)); + struct _ExtStruct<'a>((i8, &'a serde_bytes::Bytes)); let data = self.as_bytes(); - _ExtStruct((ffi::MP_UUID, serde_bytes::ByteBuf::from(data as &[_]))).serialize(serializer) + _ExtStruct((ffi::MP_UUID, serde_bytes::Bytes::new(data))).serialize(serializer) } } @@ -231,34 +272,22 @@ impl<'de> serde::Deserialize<'de> for Uuid { #[derive(Deserialize)] struct _ExtStruct((i8, serde_bytes::ByteBuf)); - let _ExtStruct((kind, bytes)) = serde::Deserialize::deserialize(deserializer)?; + let _ExtStruct((tag, bytes)) = Deserialize::deserialize(deserializer)?; - if kind != ffi::MP_UUID { - return Err(serde::de::Error::custom(format!( - "Expected UUID, found msgpack ext #{}", - kind - ))); - } - - let data = bytes.into_vec(); - Self::try_from_slice(&data).ok_or_else(|| { - serde::de::Error::custom(format!( - "Not enough bytes for UUID: expected 16, got {}", - data.len() - )) - }) + Self::from_ext_structure(tag, bytes.as_slice()).map_err(serde::de::Error::custom) } } //////////////////////////////////////////////////////////////////////////////// -/// Lua +// Lua //////////////////////////////////////////////////////////////////////////////// static mut CTID_UUID: Option = None; fn ctid_uuid() -> u32 { + // SAFETY: only safe to call this from tx thread unsafe { - if CTID_UUID.is_none() { + if (*std::ptr::addr_of!(CTID_UUID)).is_none() { let lua = crate::global_lua(); let ctid_uuid = tlua::ffi::luaL_ctypeid(tlua::AsLua::as_lua(&lua), crate::c_ptr!("struct tt_uuid")); @@ -305,3 +334,36 @@ impl tlua::PushInto for Uuid { } impl tlua::PushOneInto for Uuid {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::msgpack; + + #[test] + fn serialize() { + let result = [ + 216, 2, 227, 6, 158, 228, 241, 69, 78, 73, 160, 56, 166, 128, 14, 42, 161, 217, + ]; + let uuid = Uuid::parse_str("e3069ee4-f145-4e49-a038-a6800e2aa1d9").unwrap(); + let serde_data = rmp_serde::to_vec(&uuid).unwrap(); + assert_eq!(serde_data, result); + + let msgpack_data = msgpack::encode(&uuid); + assert_eq!(msgpack_data, result); + } + + #[test] + fn deserialize() { + let uuid = Uuid { + inner: uuid::Uuid::NAMESPACE_DNS, + }; + let serde_data = rmp_serde::to_vec(&uuid).unwrap(); + + let msgpack_uuid = msgpack::decode(serde_data.as_slice()).unwrap(); + assert_eq!(uuid, msgpack_uuid); + + let serde_uuid = rmp_serde::from_slice(serde_data.as_slice()).unwrap(); + assert_eq!(uuid, serde_uuid); + } +} diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 1c874539..33d9349c 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -9,7 +9,7 @@ authors = [ ] edition = "2018" license = "BSD-2-Clause" -rust-version = "1.61" +rust-version = "1.82" [dependencies] log = "0.4.11" @@ -26,8 +26,8 @@ rmpv = { version = "1", features = ["with-serde"] } libc = "*" futures = "0.3.25" linkme = "0.3.0" -time = "=0.3.17" -time-macros = "=0.2.6" +time = "0.3.37" # not used directly, but referenced by macro expansions from time-macros +time-macros = "0.2.6" [dependencies.tarantool] path = "../tarantool" diff --git a/tests/src/box.rs b/tests/src/box.rs index 77f56e1d..8b0c3657 100644 --- a/tests/src/box.rs +++ b/tests/src/box.rs @@ -1,6 +1,7 @@ use rand::Rng; use std::borrow::Cow; use std::collections::BTreeMap; +use tarantool::error::{IntoBoxError, TarantoolErrorCode}; use tarantool::index::{self, IndexOptions, IteratorType}; use tarantool::sequence::Sequence; @@ -1098,3 +1099,36 @@ pub fn fully_temporary_space() { space_5.drop().unwrap(); space_6.drop().unwrap(); } + +pub fn space_bsize() { + let space_name = "space_bsize_test"; + let space = Space::builder(space_name).create().unwrap(); + space + .index_builder("pk") + .parts([(1, index::FieldType::String), (2, index::FieldType::String)]) + .create() + .unwrap(); + + let bsize = space.bsize().expect("space should exist"); + assert_eq!(bsize, 0); + + space.insert(&("Hello", "world")).unwrap(); + + let bsize = space.bsize().expect("space should exist"); + assert_eq!(bsize, 13); //? Hello + \0 + next_item_char + world + \0 + + let lua = tarantool::lua_state(); + lua.exec(&format!("box.space.{}:drop()", space_name)) + .expect("lua exec failed"); + + let bsize = space.bsize(); + assert!( + bsize.is_err(), + "space.bsize should return an error, because the space does not exist anymore" + ); + assert_eq!( + bsize.err().unwrap().error_code(), + TarantoolErrorCode::NoSuchSpace as u32, + "the error is not equal to box.error.NO_SUCH_SPACE" + ) +} diff --git a/tests/src/common.rs b/tests/src/common.rs index 641d306e..e7c58768 100644 --- a/tests/src/common.rs +++ b/tests/src/common.rs @@ -80,7 +80,7 @@ use once_cell::unsync::OnceCell; pub fn lib_name() -> String { thread_local! { - static LIB_NAME: OnceCell = OnceCell::new(); + static LIB_NAME: OnceCell = const { OnceCell::new() }; } LIB_NAME.with(|lib_name| { lib_name diff --git a/tests/src/define_str_enum.rs b/tests/src/define_str_enum.rs index da11d13c..ee7f2b52 100644 --- a/tests/src/define_str_enum.rs +++ b/tests/src/define_str_enum.rs @@ -94,7 +94,8 @@ pub fn basic() { ); // other claimed traits - impl<'de, L: tlua::AsLua> AssertImpl<'de, L> for Color {} + impl AssertImpl<'_, L> for Color {} + #[allow(unused)] trait AssertImpl<'de, L: tlua::AsLua>: AsRef + Into diff --git a/tests/src/fiber/channel.rs b/tests/src/fiber/channel.rs index c9636797..b56fca08 100644 --- a/tests/src/fiber/channel.rs +++ b/tests/src/fiber/channel.rs @@ -303,6 +303,7 @@ pub fn into_clones() { struct NonClonable(); #[derive(Clone)] + #[allow(unused)] struct MyChannel(fiber::Channel); let (_, _) = MyChannel(fiber::Channel::new(1)).into_clones(); diff --git a/tests/src/lib.rs b/tests/src/lib.rs index cc0a4ecc..b3639a3a 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -8,6 +8,7 @@ #![allow(clippy::redundant_pattern_matching)] #![allow(clippy::useless_vec)] #![allow(clippy::get_first)] +#![allow(clippy::unused_unit)] use std::io; use serde::Deserialize; @@ -439,6 +440,7 @@ fn run_tests(cfg: TestConfig) -> Result { r#box::space_drop, r#box::index_create_drop, r#box::index_parts, + r#box::space_bsize, tuple::tuple_new_from_struct, tuple::new_tuple_from_flatten_struct, tuple::tuple_field_count, @@ -536,6 +538,8 @@ fn run_tests(cfg: TestConfig) -> Result { #[cfg(feature = "picodata")] { tests.append(&mut tests![ + proc::return_port, + proc::dump_port_to_lua, sql::prepared_source_query, sql::prepared_invalid_query, sql::prepared_no_params, @@ -543,8 +547,8 @@ fn run_tests(cfg: TestConfig) -> Result { sql::prepared_with_unnamed_params, sql::prepared_with_named_params, sql::prepared_invalid_params, - tuple_picodata::tuple_format_get_names, - tuple_picodata::tuple_as_named_buffer, + sql::port_c, + sql::port_c_vtab, tuple_picodata::tuple_hash, ]) } diff --git a/tests/src/net_box.rs b/tests/src/net_box.rs index 0d6312d0..a7401a5c 100644 --- a/tests/src/net_box.rs +++ b/tests/src/net_box.rs @@ -26,6 +26,7 @@ fn test_user_conn() -> Conn { ConnOptions { user: "test_user".into(), password: "password".into(), + auth_method: tarantool::auth::AuthMethod::ChapSha1, ..ConnOptions::default() }, None, @@ -635,6 +636,7 @@ pub fn triggers_schema_sync() { ConnOptions { user: "test_user".to_string(), password: "password".to_string(), + auth_method: tarantool::auth::AuthMethod::ChapSha1, ..ConnOptions::default() }, Some(Rc::new(TriggersMock { diff --git a/tests/src/proc.rs b/tests/src/proc.rs index 9f7f8c2c..34d9273d 100644 --- a/tests/src/proc.rs +++ b/tests/src/proc.rs @@ -1,14 +1,11 @@ use crate::common::lib_name; +use ::tarantool::proc::ReturnMsgpack; +use ::tarantool::tlua::{ + self, AsTable, Call, CallError, LuaFunction, LuaRead, LuaState, LuaThread, PushGuard, PushInto, +}; +use ::tarantool::tuple::{RawByteBuf, RawBytes, Tuple, TupleBuffer}; use rmpv::Value; use std::ffi::OsStr; -use tarantool::{ - proc::ReturnMsgpack, - tlua::{ - self, AsTable, Call, CallError, LuaFunction, LuaRead, LuaState, LuaThread, PushGuard, - PushInto, - }, - tuple::{RawByteBuf, RawBytes, Tuple, TupleBuffer}, -}; fn call_proc(name: &str, args: A) -> Result> where @@ -74,6 +71,104 @@ pub fn return_tuple() { assert_eq!(data, ["hello", "sailor"]); } +#[cfg(feature = "picodata")] +pub fn return_port() { + use tarantool::error::TarantoolErrorCode; + use tarantool::set_error; + use tarantool::tuple::{FunctionArgs, FunctionCtx}; + + #[no_mangle] + unsafe extern "C" fn proc_port(mut ctx: FunctionCtx, args: FunctionArgs) -> i32 { + let (a, b) = match args.decode::<(i32, String)>() { + Ok(v) => v, + Err(e) => { + set_error!(TarantoolErrorCode::ProcC, "decode error: {}", e); + return -1; + } + }; + let tuple = Tuple::new(&(a, b)).expect("tuple creation failed"); + ctx.mut_port_c().add_tuple(&tuple); + ctx.mut_port_c().add_mp(b"\x91\xa5hello"); + ctx.mut_port_c().add_mp(b"\xa6sailor"); + 0 + } + + let data: (Tuple, [String; 1], String) = call_proc("proc_port", (42, "magic")).unwrap(); + assert_eq!( + data.0.decode::<(i32, String)>().unwrap(), + (42, "magic".to_string()) + ); + assert_eq!(data.1, ["hello"]); + assert_eq!(data.2, "sailor"); +} + +#[cfg(feature = "picodata")] +pub fn dump_port_to_lua() { + use core::ffi::c_char; + use std::ptr::NonNull; + use tarantool::ffi::sql::{Obuf, Port, PortVTable}; + use tarantool::ffi::tarantool::luaT_pushtuple; + use tarantool::tlua::ffi::{self, lua_State}; + use tarantool::tuple::{FunctionArgs, FunctionCtx}; + + const VTAB_LUA: PortVTable = PortVTable::new(dump_msgpack_with_panic, dump_lua_with_header); + + #[no_mangle] + unsafe extern "C" fn dump_msgpack_with_panic(_port: *mut Port, _out: *mut Obuf) { + unimplemented!(); + } + + #[no_mangle] + unsafe extern "C" fn dump_lua_with_header(port: *mut Port, l: *mut lua_State, _is_flat: bool) { + // Create the map with two keys. + ffi::lua_createtable(l, 0, 2); + // Push the "header" key and value ("greeting"). + ffi::lua_pushstring(l, b"greeting\0".as_ptr() as *const c_char); + ffi::lua_setfield(l, -2, b"header\0".as_ptr() as *const c_char); + // Push the "data" key and value (array of tuples from the port). + let port_c = unsafe { + let port: &mut Port = NonNull::new_unchecked(port).as_mut(); + port.as_mut_port_c() + }; + // Create the array of tuples. + ffi::lua_createtable(l, port_c.size(), 0); + for (idx, mp_bytes) in port_c.iter().enumerate() { + let tuple = Tuple::try_from_slice(mp_bytes).unwrap(); + luaT_pushtuple(l, tuple.as_ptr()); + ffi::lua_rawseti(l, -2, idx as i32 + 1); + } + ffi::lua_setfield(l, -2, b"data\0".as_ptr() as *const c_char); + } + + #[no_mangle] + unsafe extern "C" fn proc_dump_lua(mut ctx: FunctionCtx, _args: FunctionArgs) -> i32 { + ctx.mut_port_c().vtab = &VTAB_LUA; + // Pay attention that we use msgpack wrapped with array. + // It is required to build the tuple "in place" from the port msgpack + // in the dump_lua callback. + ctx.mut_port_c().add_mp(b"\x91\xa5hello"); + ctx.mut_port_c().add_mp(b"\x91\xa5world"); + 0 + } + + #[derive(tlua::LuaRead)] + struct Data { + header: String, + data: Vec, + } + + let data: (Data,) = call_proc("proc_dump_lua", ()).unwrap(); + assert_eq!(data.0.header, "greeting"); + assert_eq!( + data.0.data[0].decode::<(String,)>().unwrap(), + ("hello".into(),) + ); + assert_eq!( + data.0.data[1].decode::<(String,)>().unwrap(), + ("world".into(),) + ); +} + pub fn with_error() { #[tarantool::proc] fn proc_with_error(x: i32, y: String) -> Result<(i32, i32), String> { @@ -256,7 +351,7 @@ pub fn inject() { fn global() -> &'static GlobalData { static mut GLOBAL: Option = None; unsafe { - GLOBAL.get_or_insert_with(|| GlobalData { + (*std::ptr::addr_of_mut!(GLOBAL)).get_or_insert_with(|| GlobalData { data: vec!["some".into(), "global".into(), "data".into()], }) } diff --git a/tests/src/sql.rs b/tests/src/sql.rs index f82493f5..a5b8abe4 100644 --- a/tests/src/sql.rs +++ b/tests/src/sql.rs @@ -2,11 +2,16 @@ use serde::de::DeserializeOwned; use std::collections::HashMap; -use std::io::Read; +use std::io::{Cursor, Read}; +use std::ptr::NonNull; use tarantool::error::{Error, TarantoolError}; -use tarantool::ffi::sql::IPROTO_DATA; +use tarantool::ffi::lua::lua_State; +use tarantool::ffi::sql::{obuf_append, Obuf, ObufWrapper, Port, PortC, PortVTable, IPROTO_DATA}; use tarantool::index::IndexType; +use tarantool::msgpack::write_array_len; use tarantool::space::{Field, Space}; +use tarantool::sql::{sql_execute_into_port, unprepare}; +use tarantool::tuple::Tuple; fn create_sql_test_space(name: &str) -> tarantool::Result { let space = Space::builder(name) @@ -52,6 +57,18 @@ where rmpv::ext::from_value::(data).unwrap() } +fn decode_port(port: &PortC) -> Vec +where + OUT: DeserializeOwned, +{ + let mut result = Vec::new(); + for mp_bytes in port.iter() { + let entry: OUT = rmp_serde::from_slice(mp_bytes).unwrap(); + result.push(entry); + } + result +} + pub fn prepared_invalid_query() { let maybe_stmt = tarantool::sql::prepare("SELECT * FROM UNKNOWN_SPACE".to_string()); assert!(maybe_stmt.is_err()); @@ -67,6 +84,7 @@ pub fn prepared_source_query() { let stmt = tarantool::sql::prepare("SELECT * FROM SQL_TEST".to_string()).unwrap(); assert_eq!(stmt.source(), "SELECT * FROM SQL_TEST"); + unprepare(stmt).unwrap(); drop_sql_test_space(sp).unwrap(); } @@ -89,12 +107,25 @@ pub fn prepared_no_params() { assert_eq!((3, "three".to_string()), result[2]); assert_eq!((4, "four".to_string()), result[3]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + stmt.execute_into_port(&(), 100, &mut port_c).unwrap(); + let decoded_port: Vec<(u64, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + let sql = "SELECT * FROM SQL_TEST WHERE ID = 1"; let mut stream = tarantool::sql::prepare_and_execute_raw(sql, &(), 100).unwrap(); let result = decode_dql_result::>(&mut stream); assert_eq!(1, result.len()); assert_eq!((1, "one".to_string()), result[0]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + sql_execute_into_port(sql, &(), 100, &mut port_c).unwrap(); + let decoded_port: Vec<(u64, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + + unprepare(stmt).unwrap(); drop_sql_test_space(sp).unwrap(); } @@ -124,6 +155,13 @@ pub fn prepared_large_query() { i += 4; } + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + stmt.execute_into_port(&(), 0, &mut port_c).unwrap(); + let decoded_port: Vec<(u64, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + + unprepare(stmt).unwrap(); drop_sql_test_space(sp).unwrap(); } @@ -139,6 +177,15 @@ pub fn prepared_invalid_params() { Error::Tarantool(TarantoolError { .. }) )); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + let result = stmt.execute_into_port(&("not uint value",), 0, &mut port_c); + assert!(result.is_err()); + assert!(matches!( + result.err().unwrap(), + Error::Tarantool(TarantoolError { .. }) + )); + let result = tarantool::sql::prepare_and_execute_raw( "SELECT * FROM SQL_TEST WHERE ID = ?", &("not uint value"), @@ -150,6 +197,21 @@ pub fn prepared_invalid_params() { Error::Tarantool(TarantoolError { .. }) )); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + let result = sql_execute_into_port( + "SELECT * FROM SQL_TEST WHERE ID = ?", + &("not uint value"), + 0, + &mut port_c, + ); + assert!(result.is_err()); + assert!(matches!( + result.err().unwrap(), + Error::Tarantool(TarantoolError { .. }) + )); + + unprepare(stmt).unwrap(); drop_sql_test_space(sp).unwrap(); } @@ -169,11 +231,23 @@ pub fn prepared_with_unnamed_params() { assert_eq!((103, "three".to_string()), result[0]); assert_eq!((104, "four".to_string()), result[1]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + stmt.execute_into_port(&(102,), 0, &mut port_c).unwrap(); + let decoded_port: Vec<(u8, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + let mut stream = stmt.execute_raw(&(103,), 0).unwrap(); let result = decode_dql_result::>(&mut stream); assert_eq!(1, result.len()); assert_eq!((104, "four".to_string()), result[0]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + stmt.execute_into_port(&(103,), 0, &mut port_c).unwrap(); + let decoded_port: Vec<(u8, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + let stmt2 = tarantool::sql::prepare("SELECT * FROM SQL_TEST WHERE ID > ? AND VALUE = ?".to_string()) .unwrap(); @@ -182,6 +256,14 @@ pub fn prepared_with_unnamed_params() { assert_eq!(1, result.len()); assert_eq!((103, "three".to_string()), result[0]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + stmt2 + .execute_into_port(&(102, "three"), 0, &mut port_c) + .unwrap(); + let decoded_port: Vec<(u8, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + let mut stream = tarantool::sql::prepare_and_execute_raw( "SELECT * FROM SQL_TEST WHERE ID = ? AND VALUE = ?", &(101, "one"), @@ -192,6 +274,19 @@ pub fn prepared_with_unnamed_params() { assert_eq!(1, result.len()); assert_eq!((101, "one".to_string()), result[0]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + sql_execute_into_port( + "SELECT * FROM SQL_TEST WHERE ID = ? AND VALUE = ?", + &(101, "one"), + 0, + &mut port_c, + ) + .unwrap(); + let decoded_port: Vec<(u8, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + + unprepare(stmt).unwrap(); drop_sql_test_space(sp).unwrap(); } @@ -224,11 +319,25 @@ pub fn prepared_with_named_params() { assert_eq!((3, "three".to_string()), result[0]); assert_eq!((4, "four".to_string()), result[1]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + stmt.execute_into_port(&[bind_id(2)], 0, &mut port_c) + .unwrap(); + let decoded_port: Vec<(u8, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + let mut stream = stmt.execute_raw(&[bind_id(3)], 0).unwrap(); let result = decode_dql_result::>(&mut stream); assert_eq!(1, result.len()); assert_eq!((4, "four".to_string()), result[0]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + stmt.execute_into_port(&[bind_id(3)], 0, &mut port_c) + .unwrap(); + let decoded_port: Vec<(u8, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + let stmt2 = tarantool::sql::prepare( "SELECT * FROM SQL_TEST WHERE ID > :ID AND VALUE = :NAME".to_string(), ) @@ -240,6 +349,14 @@ pub fn prepared_with_named_params() { assert_eq!(1, result.len()); assert_eq!((3, "three".to_string()), result[0]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + stmt2 + .execute_into_port(&(bind_id(2), bind_name("three")), 0, &mut port_c) + .unwrap(); + let decoded_port: Vec<(u8, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + let mut stream = tarantool::sql::prepare_and_execute_raw( "SELECT * FROM SQL_TEST WHERE ID = :ID AND VALUE = :NAME", &(bind_id(1), bind_name("one")), @@ -250,5 +367,187 @@ pub fn prepared_with_named_params() { assert_eq!(1, result.len()); assert_eq!((1, "one".to_string()), result[0]); + let mut port = Port::new_port_c(); + let mut port_c = unsafe { port.as_mut_port_c() }; + sql_execute_into_port( + "SELECT * FROM SQL_TEST WHERE ID = :ID AND VALUE = :NAME", + &(bind_id(1), bind_name("one")), + 0, + &mut port_c, + ) + .unwrap(); + let decoded_port: Vec<(u8, String)> = decode_port(&port_c); + assert_eq!(decoded_port, result); + + unprepare(stmt).unwrap(); drop_sql_test_space(sp).unwrap(); } + +pub fn port_c() { + let tuple_refs = |tuple: &Tuple| unsafe { NonNull::new(tuple.as_ptr()).unwrap().as_ref() }.refs; + let mut port = Port::new_port_c(); + let port_c = unsafe { port.as_mut_port_c() }; + + // Check that we can iterate over an empty port. + let mut iter = port_c.iter(); + assert_eq!(iter.next(), None); + + // Let's check that the data in the port can outlive + // the original tuples after dropping them. + { + let tuple1 = Tuple::new(&("A",)).unwrap(); + port_c.add_tuple(&tuple1); + assert_eq!(port_c.size(), 1); + let mp1 = b"\x91\xa1B"; + unsafe { port_c.add_mp(mp1.as_slice()) }; + assert_eq!(port_c.size(), 2); + let tuple2 = Tuple::new(&("C", "D")).unwrap(); + port_c.add_tuple(&tuple2); + assert_eq!(port_c.size(), 3); + let mp2 = b"\x91\xa1E"; + unsafe { port_c.add_mp(mp2.as_slice()) }; + assert_eq!(port_c.size(), 4); + } + let tuple3 = Tuple::new(&("F",)).unwrap(); + // The tuple has two references and it should not surprise you. + // The first one is a long-live reference produces by the box_tuple_ref. + // The second is temporary produced by tuple_bless when the tuple is added + // to the output box_tuple_last pointer. The next tuple put to the + // box_tuple_last decreases the reference count of the previous tuple. + assert_eq!(tuple_refs(&tuple3), 2); + let _ = Tuple::new(&("G",)).unwrap(); + assert_eq!(tuple_refs(&tuple3), 1); + port_c.add_tuple(&tuple3); + assert_eq!(tuple_refs(&tuple3), 2); + + let expected: Vec> = vec![ + vec!["A".into()], + vec!["B".into()], + vec!["C".into(), "D".into()], + vec!["E".into()], + vec!["F".into()], + ]; + let mut result = Vec::new(); + for mp_bytes in port_c.iter() { + let entry: Vec = rmp_serde::from_slice(mp_bytes).unwrap(); + result.push(entry); + } + assert_eq!(result, expected); + + // Check the last msgpack in the port. + let last_mp = port_c.last_mp().unwrap(); + assert_eq!(last_mp, b"\x91\xa1F"); + + // Check the first msgpack in the port. + let first_mp = port_c.first_mp().unwrap(); + assert_eq!(first_mp, b"\x91\xa1A"); + + // Check port destruction and the amount of references + // in the tuples. + drop(port); + assert_eq!(tuple_refs(&tuple3), 1); +} + +pub fn port_c_vtab() { + #[no_mangle] + unsafe extern "C" fn dump_msgpack_with_header(port: *mut Port, out: *mut Obuf) { + // When we write data from the port to the out buffer we treat + // the first msgpack as a header. All the other ones are treated + // as an array of data. So, the algorithm: + // 1. Write the first msgpack from the port. + // 2. Write an array with the size of all other msgpacks. + // 3. Write all other msgpacks to the out buffer. + // If the port is empty, write MP_NULL. + // If the port has only a single msgpack, write the msgpack and an empty array. + + let port_c: &PortC = NonNull::new_unchecked(port as *mut PortC).as_ref(); + if port_c.size() == 0 { + obuf_append(out, b"\xC0").expect("Failed to append MP_NULL"); + return; + } + + // Write the first msgpack from the port. + let first_mp = port_c.first_mp().expect("Failed to get first msgpack"); + obuf_append(out, first_mp).expect("Failed to append first msgpack"); + + // Write an array with the size of all other msgpacks. + let size = (port_c.size() - 1) as u32; + let mut array_len_buf = [0u8; 5]; + let mut cursor = Cursor::new(&mut array_len_buf[..]); + write_array_len(&mut cursor, size).expect("Failed to write array length"); + let buf_len = cursor.position() as usize; + obuf_append(out, &array_len_buf[..buf_len]).expect("Failed to append array length"); + + for (idx, mp_bytes) in port_c.iter().enumerate() { + // Skip the first msgpack. + if idx == 0 { + continue; + } + obuf_append(out, mp_bytes).expect("Failed to append msgpack"); + } + } + + #[no_mangle] + unsafe extern "C" fn dump_lua_with_panic(_port: *mut Port, _l: *mut lua_State, _is_flat: bool) { + unimplemented!(); + } + + let vtab = PortVTable::new(dump_msgpack_with_header, dump_lua_with_panic); + let mut out = ObufWrapper::new(100); + + // Check an empty port. + let mut port = Port::new_port_c(); + let port_c = unsafe { port.as_mut_port_c() }; + port_c.vtab = &vtab as *const PortVTable; + unsafe { dump_msgpack_with_header(port_c.as_mut_ptr(), out.obuf()) }; + let mut result = [0u8; 1]; + let len = out + .read(&mut result) + .expect("Failed to read from out buffer"); + assert_eq!(len, 1); + assert_eq!(result[0], 0xC0); + out.reset(); + + // Check a port with a single msgpack. + let header_mp = b"\xd96HEADER"; + unsafe { port_c.add_mp(header_mp) }; + unsafe { ((*port_c.vtab).dump_msgpack)(port_c.as_mut_ptr(), out.obuf()) }; + let expected = b"\xd96HEADER\x90"; + let mut result = [0u8; 9]; + let len = out + .read(&mut result) + .expect("Failed to read from out buffer"); + assert_eq!(len, expected.len()); + assert_eq!(&result[..], expected); + out.reset(); + drop(port); + + // Check a port with multiple msgpacks. + let mut port = Port::new_port_c(); + let port_c = unsafe { port.as_mut_port_c() }; + port_c.vtab = &vtab as *const PortVTable; + let header_mp = b"\xd96HEADER"; + unsafe { port_c.add_mp(header_mp) }; + let mp1 = b"\xd95DATA1"; + unsafe { port_c.add_mp(mp1) }; + let mp2 = b"\xd95DATA2"; + unsafe { port_c.add_mp(mp2) }; + // Check that the C wrapper over the virtual `dump_msgpack` method works. + unsafe { dump_msgpack_with_header(port_c.as_mut_ptr(), out.obuf()) }; + let expected = b"\xd96HEADER\x92\xd95DATA1\xd95DATA2"; + let mut result = [0u8; 23]; + let len = out + .read(&mut result) + .expect("Failed to read from out buffer"); + assert_eq!(len, expected.len()); + assert_eq!(&result[..], expected); + + // Check a manual drop of the port. + let mut port = unsafe { Port::zeroed() }; + port.vtab = &vtab as *const PortVTable; + let port_c = unsafe { port.as_mut_port_c() }; + unsafe { port_c.add_mp(b"\xd94DATA") }; + unsafe { ((*port.vtab).destroy)(port.as_mut()) }; + // Avoid double free. + std::mem::forget(port); +} diff --git a/tests/src/tlua/functions_write.rs b/tests/src/tlua/functions_write.rs index 0c042fdf..c575e0e0 100644 --- a/tests/src/tlua/functions_write.rs +++ b/tests/src/tlua/functions_write.rs @@ -120,9 +120,9 @@ pub fn closures() { } pub fn closures_lifetime() { - fn t(f: F) + fn t(f: F) where - F: Fn(i32, i32) -> i32, + F: Fn(i32, i32) -> i32 + 'static, { let lua = Lua::new(); @@ -246,7 +246,10 @@ pub fn closures_must_be_static() { } let f: LuaFunction<_> = lua.get("a").unwrap(); let () = f.call().unwrap(); - assert_eq!(unsafe { &GLOBAL }, &Some(vec![1, 2, 3])); + assert_eq!( + unsafe { &*std::ptr::addr_of!(GLOBAL) }, + &Some(vec![1, 2, 3]) + ); } pub fn pcall() { @@ -263,20 +266,20 @@ pub fn pcall() { pub fn error() { let lua = tarantool::lua_state(); lua.set("error_callback", - tlua::function1(|lua: tlua::LuaState| tlua::error!(lua, "but it compiled :(")) + tlua::function1(|lua: tlua::LuaState| -> () { tlua::error!(lua, "but it compiled :(") }) ); let msg = lua.exec("return error_callback()").unwrap_err().to_string(); assert_eq!(msg, "but it compiled :("); lua.set("error_callback_2", - tlua::function2(|msg: String, lua: tlua::LuaState| tlua::error!(lua, "your message: {}", msg)) + tlua::function2(|msg: String, lua: tlua::LuaState| -> () { tlua::error!(lua, "your message: {}", msg) }) ); let msg = lua.exec("return error_callback_2('my message')").unwrap_err().to_string(); assert_eq!(msg, "your message: my message"); lua.set("error_callback_3", tlua::Function::new( - |qualifier: String, lua: tlua::StaticLua| { + |qualifier: String, lua: tlua::StaticLua| -> () { tlua::error!(lua, "this way is {qualifier}") } ) diff --git a/tests/src/tlua/lua_functions.rs b/tests/src/tlua/lua_functions.rs index 5cbe2f39..a7792156 100644 --- a/tests/src/tlua/lua_functions.rs +++ b/tests/src/tlua/lua_functions.rs @@ -344,10 +344,10 @@ pub fn non_string_error() { } match lua - .exec("error(box.error.new(box.error.SYSTEM, 'oops'))") + .exec("error(box.error.new(box.error.NO_SUCH_USER, 'John'))") .unwrap_err() { - LuaError::ExecutionError(msg) => assert_eq!(msg, "oops"), + LuaError::ExecutionError(msg) => assert_eq!(msg, "User 'John' is not found"), _ => unreachable!(), } } diff --git a/tests/src/tlua/lua_tables.rs b/tests/src/tlua/lua_tables.rs index 218ac8ac..ee43cf90 100644 --- a/tests/src/tlua/lua_tables.rs +++ b/tests/src/tlua/lua_tables.rs @@ -94,6 +94,7 @@ pub fn get_or_create_metatable() { { let table = lua.get::, _>("a").unwrap(); + #[allow(deprecated)] let metatable = table.get_or_create_metatable(); fn handler() -> i32 { 5 @@ -180,6 +181,7 @@ pub fn registry_metatable() { let lua = Lua::new(); let registry = LuaTable::registry(&lua); + #[allow(deprecated)] let metatable = registry.get_or_create_metatable(); metatable.set(3, "hello"); } diff --git a/tests/src/tlua/rust_tables.rs b/tests/src/tlua/rust_tables.rs index df194b68..c2da9b6f 100644 --- a/tests/src/tlua/rust_tables.rs +++ b/tests/src/tlua/rust_tables.rs @@ -219,7 +219,7 @@ pub fn read_array_partial() { None::<[DropCheck; 4]> ); let dropped = unsafe { - DROPPED + (*std::ptr::addr_of!(DROPPED)) .as_ref() .unwrap() .borrow() @@ -241,7 +241,11 @@ pub fn read_array_partial() { impl Drop for DropCheck { fn drop(&mut self) { unsafe { - DROPPED.as_ref().unwrap().borrow_mut().insert(self.0); + (*std::ptr::addr_of!(DROPPED)) + .as_ref() + .unwrap() + .borrow_mut() + .insert(self.0); } } } diff --git a/tests/src/tlua/values.rs b/tests/src/tlua/values.rs index 0a22a14a..d908cb88 100644 --- a/tests/src/tlua/values.rs +++ b/tests/src/tlua/values.rs @@ -808,6 +808,7 @@ pub fn push_cdata() { } pub fn as_cdata_wrong_size() { + #[allow(unused)] #[derive(Debug)] struct WrongSize(u64); unsafe impl AsCData for WrongSize { @@ -888,6 +889,7 @@ pub fn readwrite_strings() { let lua = Lua::new(); + #[allow(unused)] #[derive(tlua::Push)] struct S<'a, 'b, 'c> { os: &'a OsStr, diff --git a/tests/src/tuple_picodata.rs b/tests/src/tuple_picodata.rs index 542fca1d..ed703db2 100644 --- a/tests/src/tuple_picodata.rs +++ b/tests/src/tuple_picodata.rs @@ -2,56 +2,6 @@ use tarantool::tuple::{FieldType, KeyDef, KeyDefPart, Tuple}; -pub fn tuple_format_get_names() { - let space = tarantool::space::Space::find("test_s2").unwrap(); - let idx_1 = space.index("idx_1").unwrap(); - let tuple = idx_1.get(&("key_16",)).unwrap().unwrap(); - let format = tuple.format(); - - let names = format.names().collect::>(); - assert_eq!(vec!["id", "key", "value", "a", "b"], names); -} - -pub fn tuple_as_named_buffer() { - let space = tarantool::space::Space::find("test_s2").unwrap(); - let idx_1 = space.index("idx_1").unwrap(); - let tuple = idx_1.get(&("key_16",)).unwrap().unwrap(); - - let mp_map = tuple.as_named_buffer().unwrap(); - let map: rmpv::Value = rmp_serde::from_slice(&mp_map).unwrap(); - let map = map.as_map().unwrap(); - - assert_eq!(5, map.len()); - for (k, v) in map { - match k.as_str().unwrap() { - "id" => assert_eq!(16, v.as_u64().unwrap()), - "key" => assert_eq!("key_16", v.as_str().unwrap()), - "value" => assert_eq!("value_16", v.as_str().unwrap()), - "a" => assert_eq!(1, v.as_u64().unwrap()), - "b" => assert_eq!(3, v.as_u64().unwrap()), - _ => { - unreachable!() - } - } - } - - let tuple = Tuple::new(&(1, "foo")).unwrap(); - let mp_map = tuple.as_named_buffer().unwrap(); - let map: rmpv::Value = rmp_serde::from_slice(&mp_map).unwrap(); - let map = map.as_map().unwrap(); - - assert_eq!(2, map.len()); - for (k, v) in map { - match k.as_u64().unwrap() { - 0 => assert_eq!(1, v.as_u64().unwrap()), - 1 => assert_eq!("foo", v.as_str().unwrap()), - _ => { - unreachable!() - } - } - } -} - pub fn tuple_hash() { let tuple = Tuple::new(&(1, 2, 3)).unwrap(); let key = KeyDef::new(vec![ diff --git a/tlua-derive/Cargo.toml b/tlua-derive/Cargo.toml index 91724d45..01e0b5bd 100644 --- a/tlua-derive/Cargo.toml +++ b/tlua-derive/Cargo.toml @@ -4,12 +4,12 @@ authors = [ ] name = "tlua-derive" description = "Tlua derive macro definitions" -version = "0.2.1" +version = "1.0.1" edition = "2018" license = "BSD-2-Clause" documentation = "https://docs.rs/tlua-derive/" repository = "https://github.com/picodata/tarantool-module" -rust-version = "1.61" +rust-version = "1.82" [lib] proc-macro = true diff --git a/tlua-derive/src/lib.rs b/tlua-derive/src/lib.rs index 2d775d65..7e88623e 100644 --- a/tlua-derive/src/lib.rs +++ b/tlua-derive/src/lib.rs @@ -2,8 +2,8 @@ use std::io::Write; use proc_macro::TokenStream as TokenStream1; use proc_macro2::{Span, TokenStream}; -use quote::quote; -use syn::{parse_macro_input, DeriveInput, Ident, Lifetime, Type}; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::{parse_macro_input, AttrStyle, DeriveInput, Ident, Lifetime, Type}; #[proc_macro_attribute] pub fn test(_attr: TokenStream1, item: TokenStream1) -> TokenStream1 { @@ -34,18 +34,6 @@ fn proc_macro_derive_push_impl( let info = Info::new(&input); let ctx = Context::with_generics(&input.generics).set_is_push_into(is_push_into); let (lifetimes, types, consts) = split_generics(&input.generics); - // We skip default values for constant generics parameters - // because they are not supported in impl blocks. You may want to check - // https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/219 - let consts: Vec = consts - .into_iter() - .cloned() - .map(|mut param| { - param.eq_token = None; - param.default = None; - param - }) - .collect(); let (_, generics, where_clause) = input.generics.split_for_impl(); let type_bounds = where_clause.map(|w| &w.predicates); let as_lua_bounds = info.push_bounds(&ctx); @@ -598,7 +586,7 @@ impl<'a> VariantsInfo<'a> { } } -impl<'a> VariantInfo<'a> { +impl VariantInfo<'_> { fn push(&self) -> TokenStream { let Self { name, info } = self; if let Some(info) = info { @@ -793,7 +781,7 @@ impl<'a> Context<'a> { is_generic: bool, type_params: &'a [&'a Ident], } - impl<'a, 'ast> syn::visit::Visit<'ast> for GenericTypeVisitor<'a> { + impl<'ast> syn::visit::Visit<'ast> for GenericTypeVisitor<'_> { // These cannot actually appear in struct/enum field types, // but who cares fn visit_type_impl_trait(&mut self, _: &'ast syn::TypeImplTrait) { @@ -820,19 +808,60 @@ impl<'a> Context<'a> { } } +#[derive(Copy, Clone)] +struct ImplTypeParam<'a>(&'a syn::TypeParam); +impl ToTokens for ImplTypeParam<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + // Leave off the type parameter defaults + let param = self.0; + tokens.append_all( + param + .attrs + .iter() + .filter(|attr| matches!(attr.style, AttrStyle::Outer)), + ); + param.ident.to_tokens(tokens); + if !param.bounds.is_empty() { + if let Some(colon) = ¶m.colon_token { + colon.to_tokens(tokens); + } + param.bounds.to_tokens(tokens); + } + } +} + +#[derive(Copy, Clone)] +struct ImplConstParam<'a>(&'a syn::ConstParam); +impl ToTokens for ImplConstParam<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + // Leave off the type parameter defaults + let param = self.0; + tokens.append_all( + param + .attrs + .iter() + .filter(|attr| matches!(attr.style, AttrStyle::Outer)), + ); + param.const_token.to_tokens(tokens); + param.ident.to_tokens(tokens); + param.colon_token.to_tokens(tokens); + param.ty.to_tokens(tokens); + } +} + fn split_generics( generics: &syn::Generics, ) -> ( Vec<&syn::LifetimeDef>, - Vec<&syn::TypeParam>, - Vec<&syn::ConstParam>, + Vec, + Vec, ) { let mut res = (vec![], vec![], vec![]); for param in &generics.params { match param { syn::GenericParam::Lifetime(l) => res.0.push(l), - syn::GenericParam::Type(t) => res.1.push(t), - syn::GenericParam::Const(c) => res.2.push(c), + syn::GenericParam::Type(t) => res.1.push(ImplTypeParam(t)), + syn::GenericParam::Const(c) => res.2.push(ImplConstParam(c)), } } res diff --git a/tlua/Cargo.toml b/tlua/Cargo.toml index 6461e850..29d94b8f 100644 --- a/tlua/Cargo.toml +++ b/tlua/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tlua" -version = "4.0.0" +version = "6.0.1" edition = "2018" authors = [ "Georgy Moshkin ", @@ -12,11 +12,11 @@ keywords = ["lua"] repository = "https://github.com/picodata/tarantool-module" documentation = "http://docs.rs/tlua" license = "MIT" -rust-version = "1.61" +rust-version = "1.82" [dependencies] libc = "0.2" -tlua-derive = { path = "../tlua-derive", version = "0.2.1" } +tlua-derive = { path = "../tlua-derive", version = "1.0.1" } serde = { version = "1.0", features = ["derive"] } linkme = { version = "0.2.10", optional = true } tester = { version = "0.7.0", optional = true } diff --git a/tlua/src/cdata.rs b/tlua/src/cdata.rs index 272b3745..39c2229a 100644 --- a/tlua/src/cdata.rs +++ b/tlua/src/cdata.rs @@ -8,7 +8,7 @@ use std::num::NonZeroI32; use std::os::raw::{c_char, c_void}; //////////////////////////////////////////////////////////////////////////////// -/// CDataOnStack +// CDataOnStack //////////////////////////////////////////////////////////////////////////////// /// Represents a reference to the underlying cdata value corresponding to a @@ -33,7 +33,7 @@ enum CDataRef<'l> { /// /// // check CTypeID /// assert_eq!(cdata.ctypeid(), ffi::CTID_UINT8); - +/// /// // check raw bytes /// assert_eq!(cdata.data(), [69]); /// @@ -278,7 +278,7 @@ where } //////////////////////////////////////////////////////////////////////////////// -/// AsCData +// AsCData //////////////////////////////////////////////////////////////////////////////// /// Types implementing this trait can be represented as luajit's cdata. @@ -338,7 +338,7 @@ impl_builtin_as_cdata! { } //////////////////////////////////////////////////////////////////////////////// -/// CData +// CData //////////////////////////////////////////////////////////////////////////////// /// A wrapper type for reading/writing rust values as luajit cdata. diff --git a/tlua/src/ffi.rs b/tlua/src/ffi.rs index 49e26940..aa8a2e47 100644 --- a/tlua/src/ffi.rs +++ b/tlua/src/ffi.rs @@ -138,7 +138,7 @@ pub type lua_Writer = extern "C" fn( ud: *mut libc::c_void, ) -> libc::c_int; -extern "C" { +extern "C-unwind" { // Lua C API functions. pub fn lua_newstate(f: lua_Alloc, ud: *mut libc::c_void) -> *mut lua_State; pub fn lua_close(l: *mut lua_State); @@ -389,6 +389,7 @@ extern "C" { /// - `0`: no errors; /// - [`LUA_ERRSYNTAX`]: syntax error during pre-compilation; /// - [`LUA_ERRMEM`]: memory allocation error. + /// /// This function only loads a chunk; it does not run it. /// /// `lua_load` automatically detects whether the chunk is text or binary, @@ -657,7 +658,7 @@ pub const CTID_P_CCHAR: CTypeID = 19; pub const CTID_A_CCHAR: CTypeID = 20; pub const CTID_CTYPEID: CTypeID = 21; -extern "C" { +extern "C-unwind" { /// Push `u64` onto the stack /// *[-0, +1, -]* pub fn luaL_pushuint64(l: *mut lua_State, val: u64); @@ -685,7 +686,9 @@ extern "C" { /// uninitialized. Only numbers and pointers are supported. /// - `l`: Lua State /// - `ctypeid`: FFI's CTypeID of this cdata + /// /// See also: [`luaL_checkcdata`] + /// /// **Returns** memory associated with this cdata pub fn luaL_pushcdata(l: *mut lua_State, ctypeid: CTypeID) -> *mut c_void; @@ -693,13 +696,16 @@ extern "C" { /// * `l`: Lua State /// * `idx`: stack index /// * `ctypeid`: FFI's CTypeID of this cdata + /// /// See also: [`luaL_pushcdata`] + /// /// **Returns** memory associated with this cdata pub fn luaL_checkcdata(l: *mut lua_State, idx: c_int, ctypeid: *mut CTypeID) -> *mut c_void; /// Return CTypeID (FFI) of given CDATA type /// `ctypename` is a C type name as string (e.g. "struct request", /// "uint32_t", etc.). + /// /// See also: [`luaL_pushcdata`], [`luaL_checkcdata`] pub fn luaL_ctypeid(l: *mut lua_State, ctypename: *const c_char) -> CTypeID; @@ -724,7 +730,7 @@ pub unsafe fn luaL_hasmetafield(l: *mut lua_State, index: i32, field: *const c_c } } -extern "C" { +extern "C-unwind" { /// Convert the value at `idx` to string using `__tostring` metamethod if /// other measures didn't work and return it. Sets the `len` if it's not /// `NULL`. The newly created string is left on top of the stack. diff --git a/tlua/src/lib.rs b/tlua/src/lib.rs index fc081a65..da5915b7 100644 --- a/tlua/src/lib.rs +++ b/tlua/src/lib.rs @@ -409,10 +409,8 @@ pub trait AsLua { /// Push `v` onto the lua stack. /// /// This method is only available if - /// - `T` implements `PushOneInto`, which means that it pushes a single - /// value onto the stack - /// - `T::Err` implements `Into`, which means that no error can happen - /// during the attempt to push + /// - `T` implements `PushOneInto`, which means that it pushes a single value onto the stack + /// - `T::Err` implements `Into`, which means that no error can happen during the attempt to push /// /// Returns a `PushGuard` which captures `self` by value and stores the /// amount of values pushed onto the stack (ideally this will be 1, but it @@ -431,9 +429,9 @@ pub trait AsLua { /// /// This method is only available if /// - `I::Item` implements `PushInto`, which means that it can be - /// pushed onto the lua stack by value + /// pushed onto the lua stack by value /// - `I::Item::Err` implements `Into`, which means that no error can - /// happen during the attempt to push + /// happen during the attempt to push /// /// If `I::Item` pushes a single value onto the stack, the resulting lua /// table is a lua sequence (a table with 1-based integer keys). diff --git a/tlua/src/lua_functions.rs b/tlua/src/lua_functions.rs index 27463ba5..9b41436c 100644 --- a/tlua/src/lua_functions.rs +++ b/tlua/src/lua_functions.rs @@ -33,7 +33,7 @@ use crate::{ #[derive(Debug)] pub struct LuaCode<'a>(pub &'a str); -impl<'c, L> Push for LuaCode<'c> +impl Push for LuaCode<'_> where L: AsLua, { @@ -47,7 +47,7 @@ where } } -impl<'c, L> PushOne for LuaCode<'c> where L: AsLua {} +impl PushOne for LuaCode<'_> where L: AsLua {} /// Wrapper around a `Read` object. When pushed, the content will be parsed as Lua code and turned /// into a function. @@ -254,8 +254,7 @@ where /// Returns an error if there is an error while executing the Lua code (eg. a function call /// returns an error), or if the requested return type doesn't match the actual return type. /// - /// > **Note**: In order to pass parameters, see `into_call_with_args` - /// instead. + /// > **Note**: In order to pass parameters, see `into_call_with_args` instead. #[track_caller] #[inline] pub fn into_call(self) -> Result diff --git a/tlua/src/lua_tables.rs b/tlua/src/lua_tables.rs index 82eeb92e..300f9529 100644 --- a/tlua/src/lua_tables.rs +++ b/tlua/src/lua_tables.rs @@ -407,7 +407,7 @@ where } } -impl<'t, L, K, V> Drop for LuaTableIterator<'t, L, K, V> +impl Drop for LuaTableIterator<'_, L, K, V> where L: AsLua, { diff --git a/tlua/src/macros.rs b/tlua/src/macros.rs index b6ff4682..647c6ca0 100644 --- a/tlua/src/macros.rs +++ b/tlua/src/macros.rs @@ -114,11 +114,14 @@ macro_rules! c_ptr { macro_rules! error { ($l:expr, $($args:tt)+) => {{ let msg = ::std::format!($($args)+); + // Bind the metavariable outside unsafe block to prevent users + // accidentally doing unsafe things + let l = &$l; #[allow(unused_unsafe)] unsafe { - let lua = $crate::AsLua::as_lua(&$l); + let lua = $crate::AsLua::as_lua(l); $crate::ffi::lua_pushlstring(lua, msg.as_ptr() as _, msg.len()); - $crate::ffi::lua_error($crate::AsLua::as_lua(&$l)); + $crate::ffi::lua_error($crate::AsLua::as_lua(l)); } unreachable!("luaL_error never returns") }}; diff --git a/tlua/src/object.rs b/tlua/src/object.rs index 2474810c..4b40e848 100644 --- a/tlua/src/object.rs +++ b/tlua/src/object.rs @@ -438,10 +438,10 @@ where /// /// # Possible errors /// - Returns an error if pushing `index` or `value` failed. This can only - /// happen for a limited set of types. You are encouraged to use the - /// [`NewIndex::set`] method if pushing cannot fail. + /// happen for a limited set of types. You are encouraged to use the + /// [`NewIndex::set`] method if pushing cannot fail. /// - Returns a `LuaError::ExecutionError` in case an error happened during - /// an attempt to set value. + /// an attempt to set value. #[track_caller] #[inline(always)] fn try_checked_set( diff --git a/tlua/src/rust_tables.rs b/tlua/src/rust_tables.rs index 9f80903b..56a09ca0 100644 --- a/tlua/src/rust_tables.rs +++ b/tlua/src/rust_tables.rs @@ -166,7 +166,7 @@ where } //////////////////////////////////////////////////////////////////////////////// -/// Vec +// Vec //////////////////////////////////////////////////////////////////////////////// impl Push for Vec @@ -219,10 +219,7 @@ where // We need this as iteration order isn't guaranteed to match order of // keys, even if they're numeric // https://www.lua.org/manual/5.2/manual.html#pdf-next - let table = match LuaTable::lua_read_at_position(lua, index) { - Ok(table) => table, - Err(lua) => return Err(lua), - }; + let table = LuaTable::lua_read_at_position(lua, index)?; let mut dict: BTreeMap = BTreeMap::new(); let mut max_key = i32::MIN; @@ -286,7 +283,7 @@ where } //////////////////////////////////////////////////////////////////////////////// -/// \[T] +// \[T] //////////////////////////////////////////////////////////////////////////////// impl Push for [T] @@ -310,7 +307,7 @@ where } //////////////////////////////////////////////////////////////////////////////// -/// [T; N] +// [T; N] //////////////////////////////////////////////////////////////////////////////// impl Push for [T; N] @@ -360,11 +357,7 @@ where T: 'static, { fn lua_read_at_position(lua: L, index: NonZeroI32) -> ReadResult { - let table = match LuaTable::lua_read_at_position(lua, index) { - Ok(table) => table, - Err(lua) => return Err(lua), - }; - + let table = LuaTable::lua_read_at_position(lua, index)?; let mut res = std::mem::MaybeUninit::uninit(); let ptr = &mut res as *mut _ as *mut [T; N] as *mut T; let mut was_assigned = [false; N]; @@ -402,7 +395,7 @@ where for i in IntoIterator::into_iter(was_assigned) .enumerate() - .flat_map(|(i, was_assigned)| was_assigned.then(|| i)) + .flat_map(|(i, was_assigned)| was_assigned.then_some(i)) { unsafe { std::ptr::drop_in_place(ptr.add(i)) } } @@ -428,7 +421,7 @@ where } //////////////////////////////////////////////////////////////////////////////// -/// HashMap +// HashMap //////////////////////////////////////////////////////////////////////////////// impl LuaRead for HashMap @@ -509,7 +502,7 @@ where } //////////////////////////////////////////////////////////////////////////////// -/// HashSet +// HashSet //////////////////////////////////////////////////////////////////////////////// macro_rules! push_hashset_impl { diff --git a/tlua/src/tuples.rs b/tlua/src/tuples.rs index 2fba7b0a..0354ec02 100644 --- a/tlua/src/tuples.rs +++ b/tlua/src/tuples.rs @@ -659,8 +659,7 @@ impl From> for AsTablePushError { /// - `0` values => nothing happens /// - `1` value: `v` => `table[i] = v` /// - `2` values: `k` & `v` => `table[k] = v` -/// - any other number => nothing is inserted into table, -/// `AsTablePushError::TooManyValues(n)` is returned +/// - any other number => nothing is inserted into table, `AsTablePushError::TooManyValues(n)` is returned /// /// If an error happens during attempt to push `T`, /// `AsTablePushError::ValuePushError(e)` is returned diff --git a/tlua/src/values.rs b/tlua/src/values.rs index 02eefa0d..625eacf4 100644 --- a/tlua/src/values.rs +++ b/tlua/src/values.rs @@ -523,7 +523,7 @@ where } } -impl<'a, L> Deref for StringInLua<'a, L> { +impl Deref for StringInLua<'_, L> { type Target = str; #[inline] @@ -871,7 +871,7 @@ impl From for String { } } -impl<'a> From for Cow<'a, str> { +impl From for Cow<'_, str> { fn from(other: ToString) -> Self { Cow::Owned(other.0) }