diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index efa24bff78..2054edec76 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -53,7 +53,8 @@ jobs: cargo test -p light-account-checks --all-features cargo test -p light-verifier --all-features cargo test -p light-merkle-tree-metadata --all-features - cargo test -p light-zero-copy --features std + cargo test -p light-zero-copy --features "std, mut, derive" + cargo test -p light-zero-copy-derive --features "mut" cargo test -p light-hash-set --all-features - name: program-libs-slow packages: light-bloom-filter light-indexed-merkle-tree light-batched-merkle-tree diff --git a/.gitignore b/.gitignore index 9943415324..8fa4a89ce0 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,5 @@ output1.txt **/.claude/**/* **/~/ + +expand.rs diff --git a/Cargo.lock b/Cargo.lock index 301eae03fd..94d24af349 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2389,6 +2389,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "governor" version = "0.6.3" @@ -3810,6 +3816,8 @@ dependencies = [ name = "light-zero-copy" version = "0.2.0" dependencies = [ + "borsh 0.10.4", + "light-zero-copy-derive", "pinocchio", "rand 0.8.5", "solana-program-error", @@ -3817,6 +3825,21 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "light-zero-copy-derive" +version = "0.1.0" +dependencies = [ + "borsh 0.10.4", + "lazy_static", + "light-zero-copy", + "proc-macro2", + "quote", + "rand 0.8.5", + "syn 2.0.103", + "trybuild", + "zerocopy", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -9078,6 +9101,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-triple" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790" + [[package]] name = "tarpc" version = "0.29.0" @@ -9650,6 +9679,21 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "trybuild" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c9bf9513a2f4aeef5fdac8677d7d349c79fdbcc03b9c86da6e9d254f1e43be2" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "target-triple", + "termcolor", + "toml 0.8.23", +] + [[package]] name = "tungstenite" version = "0.20.1" diff --git a/Cargo.toml b/Cargo.toml index b58870a631..34b2d10bc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "program-libs/hash-set", "program-libs/indexed-merkle-tree", "program-libs/indexed-array", + "program-libs/zero-copy-derive", "programs/account-compression", "programs/system", "programs/compressed-token", @@ -167,6 +168,7 @@ light-compressed-account = { path = "program-libs/compressed-account", version = light-account-checks = { path = "program-libs/account-checks", version = "0.3.0" } light-verifier = { path = "program-libs/verifier", version = "2.1.0" } light-zero-copy = { path = "program-libs/zero-copy", version = "0.2.0" } +light-zero-copy-derive = { path = "program-libs/zero-copy-derive", version = "0.1.0" } photon-api = { path = "sdk-libs/photon-api", version = "0.51.0" } forester-utils = { path = "forester-utils", version = "2.0.0" } account-compression = { path = "programs/account-compression", version = "2.0.0", features = [ diff --git a/program-libs/compressed-account/src/instruction_data/with_account_info.rs b/program-libs/compressed-account/src/instruction_data/with_account_info.rs index 599ad9cd0b..57b49e5e78 100644 --- a/program-libs/compressed-account/src/instruction_data/with_account_info.rs +++ b/program-libs/compressed-account/src/instruction_data/with_account_info.rs @@ -399,9 +399,13 @@ impl<'a> Deserialize<'a> for InstructionDataInvokeCpiWithAccountInfo { let (account_infos, bytes) = { let (num_slices, mut bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; let num_slices = u32::from(*num_slices) as usize; - // TODO: add check that remaining data is enough to read num_slices - // This prevents agains invalid data allocating a lot of heap memory let mut slices = Vec::with_capacity(num_slices); + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } for _ in 0..num_slices { let (slice, _bytes) = CompressedAccountInfo::zero_copy_at_with_owner( bytes, diff --git a/program-libs/compressed-account/src/instruction_data/with_readonly.rs b/program-libs/compressed-account/src/instruction_data/with_readonly.rs index 59b9c27bd7..e591f45444 100644 --- a/program-libs/compressed-account/src/instruction_data/with_readonly.rs +++ b/program-libs/compressed-account/src/instruction_data/with_readonly.rs @@ -347,8 +347,14 @@ impl<'a> Deserialize<'a> for InstructionDataInvokeCpiWithReadOnly { let (input_compressed_accounts, bytes) = { let (num_slices, mut bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; let num_slices = u32::from(*num_slices) as usize; - // TODO: add check that remaining data is enough to read num_slices - // This prevents agains invalid data allocating a lot of heap memory + // Prevent heap exhaustion attacks by checking if num_slices is reasonable + // Each element needs at least 1 byte when serialized + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } let mut slices = Vec::with_capacity(num_slices); for _ in 0..num_slices { let (slice, _bytes) = diff --git a/program-libs/zero-copy-derive/Cargo.toml b/program-libs/zero-copy-derive/Cargo.toml new file mode 100644 index 0000000000..1cdc8254e8 --- /dev/null +++ b/program-libs/zero-copy-derive/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "light-zero-copy-derive" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" +description = "Proc macro for zero-copy deserialization" + +[features] +default = [] +mut = [] + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "2.0", features = ["full", "extra-traits"] } +lazy_static = "1.4" + +[dev-dependencies] +trybuild = "1.0" +rand = "0.8" +borsh = { workspace = true } +light-zero-copy = { workspace = true, features = ["std", "derive"] } +zerocopy = { workspace = true, features = ["derive"] } diff --git a/program-libs/zero-copy-derive/README.md b/program-libs/zero-copy-derive/README.md new file mode 100644 index 0000000000..8e17fbbb25 --- /dev/null +++ b/program-libs/zero-copy-derive/README.md @@ -0,0 +1,103 @@ +# Light-Zero-Copy-Derive + +A procedural macro for deriving zero-copy deserialization for Rust structs used with Solana programs. + +## Features + +This crate provides two key derive macros: + +1. `#[derive(ZeroCopy)]` - Implements zero-copy deserialization with: + - The `zero_copy_at` and `zero_copy_at_mut` methods for deserialization + - Full Borsh compatibility for serialization/deserialization + - Efficient memory representation with no copying of data + - `From>` and `FromMut>` implementations for easy conversion back to the original struct + +2. `#[derive(ZeroCopyEq)]` - Adds equality comparison support: + - Compare zero-copy instances with regular struct instances + - Can be used alongside `ZeroCopy` for complete functionality + - Derivation for Options is not robust and may not compile. + +## Rules for Zero-Copy Deserialization + +The macro follows these rules when generating code: + +1. Creates a `ZStruct` for your struct that follows zero-copy principles + 1. Fields are extracted into a meta struct until reaching a `Vec`, `Option` or non-`Copy` type + 2. Vectors are represented as `ZeroCopySlice` and not included in the meta struct + 3. Integer types are replaced with their zerocopy equivalents (e.g., `u16` → `U16`) + 4. Fields after the first vector are directly included in the `ZStruct` and deserialized one by one + 5. If a vector contains a nested vector (non-`Copy` type), it must implement `Deserialize` + 6. Elements in an `Option` must implement `Deserialize` + 7. Types that don't implement `Copy` must implement `Deserialize` and are deserialized one by one + +## Usage + +### Basic Usage + +```rust +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::ZeroCopy; +use light_zero_copy::{borsh::Deserialize, borsh_mut::DeserializeMut}; + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} +let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, +}; +// Use the struct with zero-copy deserialization +let mut bytes = my_struct.try_to_vec().unwrap(); + +// Immutable zero-copy deserialization +let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); + +// Convert back to original struct using From implementation +let converted: MyStruct = zero_copy.clone().into(); +assert_eq!(converted, my_struct); + +// Mutable zero-copy deserialization with modification +let (mut zero_copy_mut, _remaining) = MyStruct::zero_copy_at_mut(&mut bytes).unwrap(); +zero_copy_mut.a = 42; + +// The change is reflected when we convert back to the original struct +let modified: MyStruct = zero_copy_mut.into(); +assert_eq!(modified.a, 42); + +// And also when we deserialize directly from the modified bytes +let borsh = MyStruct::try_from_slice(&bytes).unwrap(); +assert_eq!(borsh.a, 42u8); +``` + +### With Equality Comparison + +```rust +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::ZeroCopy; + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} +let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, +}; +// Use the struct with zero-copy deserialization +let mut bytes = my_struct.try_to_vec().unwrap(); +let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); +assert_eq!(zero_copy, my_struct); +``` diff --git a/program-libs/zero-copy-derive/src/lib.rs b/program-libs/zero-copy-derive/src/lib.rs new file mode 100644 index 0000000000..becac18087 --- /dev/null +++ b/program-libs/zero-copy-derive/src/lib.rs @@ -0,0 +1,166 @@ +//! Procedural macros for zero-copy deserialization. +//! +//! This crate provides derive macros that generate efficient zero-copy data structures +//! and deserialization code, eliminating the need for data copying during parsing. +//! +//! ## Main Macros +//! +//! - `ZeroCopy`: Generates zero-copy structs and deserialization traits +//! - `ZeroCopyMut`: Adds mutable zero-copy support +//! - `ZeroCopyEq`: Adds PartialEq implementation for comparing with original structs +//! - `ZeroCopyNew`: Generates configuration structs for initialization + +use proc_macro::TokenStream; + +mod shared; +mod zero_copy; +mod zero_copy_eq; +#[cfg(feature = "mut")] +mod zero_copy_mut; + +/// ZeroCopy derivation macro for zero-copy deserialization +/// +/// # Usage +/// +/// Basic usage: +/// ```rust +/// use light_zero_copy_derive::ZeroCopy; +/// #[derive(ZeroCopy)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +/// +/// To derive PartialEq as well, use ZeroCopyEq in addition to ZeroCopy: +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +/// #[derive(ZeroCopy, ZeroCopyEq)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +/// +/// # Macro Rules +/// 1. Create zero copy structs Z and ZMut for the struct +/// 1.1. The first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy +/// 1.2. Represent vectors to ZeroCopySlice & don't include these into the meta struct +/// 1.3. Replace u16 with U16, u32 with U32, etc +/// 1.4. Every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 +/// 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize +/// 1.6. Elements in an Option must implement Deserialize +/// 1.7. A type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 +/// 1.8. is u8 deserialized as u8::zero_copy_at instead of Ref<&'a [u8], u8> for non mut, for mut it is Ref<&'a mut [u8], u8> +/// 2. Implement Deserialize and DeserializeMut which return Z and ZMut +/// 3. Implement From> for StructName and FromMut> for StructName +/// +/// Note: Options are not supported in ZeroCopyEq +#[proc_macro_derive(ZeroCopy)] +pub fn derive_zero_copy(input: TokenStream) -> TokenStream { + let res = zero_copy::derive_zero_copy_impl(input); + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +/// ZeroCopyEq implementation to add PartialEq for zero-copy structs. +/// +/// Use this in addition to ZeroCopy when you want the generated struct to implement PartialEq: +/// +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +/// #[derive(ZeroCopy, ZeroCopyEq)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +#[proc_macro_derive(ZeroCopyEq)] +pub fn derive_zero_copy_eq(input: TokenStream) -> TokenStream { + let res = zero_copy_eq::derive_zero_copy_eq_impl(input); + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +/// ZeroCopyMut derivation macro for mutable zero-copy deserialization +/// +/// This macro generates mutable zero-copy implementations including: +/// - DeserializeMut trait implementation +/// - Mutable Z-struct with `Mut` suffix +/// - byte_len() method implementation +/// - Mutable ZeroCopyStructInner implementation +/// +/// # Usage +/// +/// ```rust +/// use light_zero_copy_derive::ZeroCopyMut; +/// +/// #[derive(ZeroCopyMut)] +/// pub struct MyStruct { +/// pub a: u8, +/// pub vec: Vec, +/// } +/// ``` +/// +/// This will generate: +/// - `MyStruct::zero_copy_at_mut()` method +/// - `ZMyStructMut<'a>` type for mutable zero-copy access +/// - `MyStruct::byte_len()` method +/// +/// For both immutable and mutable functionality, use both derives: +/// ```rust +/// use light_zero_copy_derive::{ZeroCopy, ZeroCopyMut}; +/// +/// #[derive(ZeroCopy, ZeroCopyMut)] +/// pub struct MyStruct { +/// pub a: u8, +/// } +/// ``` +#[cfg(feature = "mut")] +#[proc_macro_derive(ZeroCopyMut)] +pub fn derive_zero_copy_mut(input: TokenStream) -> TokenStream { + let res = zero_copy_mut::derive_zero_copy_mut_impl(input); + TokenStream::from(match res { + Ok(res) => res, + Err(err) => err.to_compile_error(), + }) +} + +// /// ZeroCopyNew derivation macro for configuration-based zero-copy initialization +// /// +// /// This macro generates configuration structs and initialization methods for structs +// /// with Vec and Option fields that need to be initialized with specific configurations. +// /// +// /// # Usage +// /// +// /// ```ignore +// /// use light_zero_copy_derive::ZeroCopyNew; +// /// +// /// #[derive(ZeroCopyNew)] +// /// pub struct MyStruct { +// /// pub a: u8, +// /// pub vec: Vec, +// /// pub option: Option, +// /// } +// /// ``` +// /// +// /// This will generate: +// /// - `MyStructConfig` struct with configuration fields +// /// - `ZeroCopyNew` implementation for `MyStruct` +// /// - `new_zero_copy(bytes, config)` method for initialization +// /// +// /// The configuration struct will have fields based on the complexity of the original fields: +// /// - `Vec` → `field_name: u32` (length) +// /// - `Option` → `field_name: bool` (is_some) +// /// - `Vec` → `field_name: Vec` (config per element) +// /// - `Option` → `field_name: Option` (config if some) +// #[cfg(feature = "mut")] +// #[proc_macro_derive(ZeroCopyNew)] +// pub fn derive_zero_copy_config(input: TokenStream) -> TokenStream { +// let res = zero_copy_new::derive_zero_copy_config_impl(input); +// TokenStream::from(match res { +// Ok(res) => res, +// Err(err) => err.to_compile_error(), +// }) +// } diff --git a/program-libs/zero-copy-derive/src/shared/from_impl.rs b/program-libs/zero-copy-derive/src/shared/from_impl.rs new file mode 100644 index 0000000000..1aab0eb9b3 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/from_impl.rs @@ -0,0 +1,242 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Field, Ident}; + +use super::{ + utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generates code for the From> for StructName implementation +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_from_impl( + name: &Ident, + z_struct_name: &Ident, + meta_fields: &[&Field], + struct_fields: &[&Field], +) -> syn::Result { + let z_struct_name = if MUT { + format_ident!("{}Mut", z_struct_name) + } else { + z_struct_name.clone() + }; + + // Generate the conversion code for meta fields + let meta_field_conversions = if !meta_fields.is_empty() { + let field_types = analyze_struct_fields(meta_fields)?; + let conversions = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => { + quote! { #field_name: value.__meta.#field_name, } + } + _ if utils::is_specific_primitive_type(field_type, "bool") => { + quote! { #field_name: value.__meta.#field_name > 0, } + } + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { #field_name: #field_type::from(value.__meta.#field_name), } + } + } + } + FieldType::Array(field_name, _) => { + // For arrays, just copy the value + quote! { #field_name: value.__meta.#field_name, } + } + FieldType::Pubkey(field_name) => { + quote! { #field_name: value.__meta.#field_name, } + } + _ => { + let field_name = field_type.name(); + quote! { #field_name: value.__meta.#field_name.into(), } + } + } + }); + conversions.collect::>() + } else { + vec![] + }; + + // Generate the conversion code for struct fields + let struct_field_conversions = if !struct_fields.is_empty() { + let field_types = analyze_struct_fields(struct_fields)?; + let conversions = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::VecU8(field_name) => { + quote! { #field_name: value.#field_name.to_vec(), } + } + FieldType::VecCopy(field_name, _) => { + quote! { #field_name: value.#field_name.to_vec(), } + } + FieldType::VecDynamicZeroCopy(field_name, _) => { + // For non-copy vectors, clone each element directly + // We need to convert into() for Zstructs + quote! { + #field_name: { + value.#field_name.iter().map(|item| (*item).clone().into()).collect() + }, + } + } + FieldType::Array(field_name, _) => { + // For arrays, just copy the value + quote! { #field_name: *value.#field_name, } + } + FieldType::Option(field_name, field_type) => { + // Extract inner type from Option + let inner_type = utils::get_option_inner_type(field_type).expect( + "Failed to extract inner type from Option - expected Option format", + ); + let field_type = inner_type; + // For Option types, use a direct copy of the value when possible + quote! { + #field_name: if value.#field_name.is_some() { + // Create a clone of the Some value - for compressed proofs and other structs + // For instruction_data.rs, we just need to clone the value directly + Some((#field_type::from(*value.#field_name.as_ref().unwrap()).clone())) + } else { + None + }, + } + } + FieldType::Pubkey(field_name) => { + quote! { #field_name: *value.#field_name, } + } + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => { + if MUT { + quote! { #field_name: *value.#field_name, } + } else { + quote! { #field_name: value.#field_name, } + } + } + _ if utils::is_specific_primitive_type(field_type, "bool") => { + if MUT { + quote! { #field_name: *value.#field_name > 0, } + } else { + quote! { #field_name: value.#field_name > 0, } + } + } + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { #field_name: #field_type::from(*value.#field_name), } + } + } + } + FieldType::Copy(field_name, _) => { + quote! { #field_name: value.#field_name, } + } + FieldType::OptionU64(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u64::from(**x)), } + } + FieldType::OptionU32(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u32::from(**x)), } + } + FieldType::OptionU16(field_name) => { + quote! { #field_name: value.#field_name.as_ref().map(|x| u16::from(**x)), } + } + FieldType::DynamicZeroCopy(field_name, field_type) => { + // For complex non-copy types, dereference and clone directly + quote! { #field_name: #field_type::from(&value.#field_name), } + } + } + }); + conversions.collect::>() + } else { + vec![] + }; + + // Combine all the field conversions + let all_field_conversions = [meta_field_conversions, struct_field_conversions].concat(); + + // Return the final From implementation without generic From implementations + let result = quote! { + impl<'a> From<#z_struct_name<'a>> for #name { + fn from(value: #z_struct_name<'a>) -> Self { + Self { + #(#all_field_conversions)* + } + } + } + + impl<'a> From<&#z_struct_name<'a>> for #name { + fn from(value: &#z_struct_name<'a>) -> Self { + Self { + #(#all_field_conversions)* + } + } + } + }; + Ok(result) +} + +#[cfg(test)] +mod tests { + use quote::format_ident; + use syn::{parse_quote, Field}; + + use super::*; + + #[test] + fn test_generate_from_impl() { + // Create a struct for testing + let name = format_ident!("TestStruct"); + let z_struct_name = format_ident!("ZTestStruct"); + + // Create some test fields + let field_a: Field = parse_quote!(pub a: u8); + let field_b: Field = parse_quote!(pub b: u16); + let field_c: Field = parse_quote!(pub c: Vec); + + // Split into meta and struct fields + let meta_fields = vec![&field_a, &field_b]; + let struct_fields = vec![&field_c]; + + // Generate the implementation + let result = + generate_from_impl::(&name, &z_struct_name, &meta_fields, &struct_fields); + + // Convert to string for testing + let result_str = result.unwrap().to_string(); + + // Check that the implementation contains required elements + assert!(result_str.contains("impl < 'a > From < ZTestStruct < 'a >> for TestStruct")); + + // Check field handling + assert!(result_str.contains("a :")); // For u8 fields + assert!(result_str.contains("b :")); // For u16 fields + assert!(result_str.contains("c :")); // For Vec fields + } + + #[test] + fn test_generate_from_impl_mut() { + // Create a struct for testing + let name = format_ident!("TestStruct"); + let z_struct_name = format_ident!("ZTestStruct"); + + // Create some test fields + let field_a: Field = parse_quote!(pub a: u8); + let field_b: Field = parse_quote!(pub b: bool); + let field_c: Field = parse_quote!(pub c: Option); + + // Split into meta and struct fields + let meta_fields = vec![&field_a, &field_b]; + let struct_fields = vec![&field_c]; + + // Generate the implementation for mutable version + let result = + generate_from_impl::(&name, &z_struct_name, &meta_fields, &struct_fields); + + // Convert to string for testing + let result_str = result.unwrap().to_string(); + + // Check that the implementation contains required elements + assert!(result_str.contains("impl < 'a > From < ZTestStructMut < 'a >> for TestStruct")); + + // Check field handling + assert!(result_str.contains("a :")); // For u8 fields + assert!(result_str.contains("b :")); // For bool fields + assert!(result_str.contains("c :")); // For Option fields + } +} diff --git a/program-libs/zero-copy-derive/src/shared/meta_struct.rs b/program-libs/zero-copy-derive/src/shared/meta_struct.rs new file mode 100644 index 0000000000..0dbf9cda3a --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/meta_struct.rs @@ -0,0 +1,57 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::Field; + +use super::utils::convert_to_zerocopy_type; + +/// Generates the meta struct definition as a TokenStream +/// The `MUT` parameter determines if the struct should be generated for mutable access +pub fn generate_meta_struct( + z_struct_meta_name: &syn::Ident, + meta_fields: &[&Field], + hasher: bool, +) -> syn::Result { + let z_struct_meta_name = if MUT { + format_ident!("{}Mut", z_struct_meta_name) + } else { + z_struct_meta_name.clone() + }; + + // Generate the meta struct fields with converted types + let meta_fields_with_converted_types = meta_fields.iter().map(|field| { + let field_name = &field.ident; + let attributes = if hasher { + field + .attrs + .iter() + .map(|attr| { + quote! { #attr } + }) + .collect::>() + } else { + vec![quote! {}] + }; + let field_type = convert_to_zerocopy_type(&field.ty); + quote! { + #(#attributes)* + pub #field_name: #field_type + } + }); + let hasher = if hasher { + quote! { + , LightHasher + } + } else { + quote! {} + }; + + // Return the complete meta struct definition + let result = quote! { + #[repr(C)] + #[derive(Debug, PartialEq, light_zero_copy::KnownLayout, light_zero_copy::Immutable, light_zero_copy::Unaligned, light_zero_copy::FromBytes, light_zero_copy::IntoBytes #hasher)] + pub struct #z_struct_meta_name { + #(#meta_fields_with_converted_types,)* + } + }; + Ok(result) +} diff --git a/program-libs/zero-copy-derive/src/shared/mod.rs b/program-libs/zero-copy-derive/src/shared/mod.rs new file mode 100644 index 0000000000..c7b406b530 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/mod.rs @@ -0,0 +1,6 @@ +pub mod from_impl; +pub mod meta_struct; +pub mod utils; +pub mod z_struct; +#[cfg(feature = "mut")] +pub mod zero_copy_new; diff --git a/program-libs/zero-copy-derive/src/shared/utils.rs b/program-libs/zero-copy-derive/src/shared/utils.rs new file mode 100644 index 0000000000..e92e56bb29 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/utils.rs @@ -0,0 +1,437 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Attribute, Data, DeriveInput, Field, Fields, FieldsNamed, Ident, Type, TypePath}; + +// Global cache for storing whether a struct implements Copy +lazy_static::lazy_static! { + static ref COPY_IMPL_CACHE: Arc>> = Arc::new(Mutex::new(HashMap::new())); +} + +/// Creates a unique cache key for a type using span information to avoid collisions +/// between types with the same name from different modules/locations +fn create_unique_type_key(ident: &Ident) -> String { + format!("{}:{:?}", ident, ident.span()) +} + +/// Process the derive input to extract the struct information +pub fn process_input( + input: &DeriveInput, +) -> syn::Result<( + &Ident, // Original struct name + proc_macro2::Ident, // Z-struct name + proc_macro2::Ident, // Z-struct meta name + &FieldsNamed, // Struct fields +)> { + let name = &input.ident; + let z_struct_name = format_ident!("Z{}", name); + let z_struct_meta_name = format_ident!("Z{}Meta", name); + + // Populate the cache by checking if this struct implements Copy + let _ = struct_implements_copy(input); + + let fields = match &input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => fields, + _ => { + return Err(syn::Error::new_spanned( + &data.fields, + "ZeroCopy only supports structs with named fields", + )) + } + }, + _ => { + return Err(syn::Error::new_spanned( + input, + "ZeroCopy only supports structs", + )) + } + }; + + Ok((name, z_struct_name, z_struct_meta_name, fields)) +} + +pub fn process_fields(fields: &FieldsNamed) -> (Vec<&Field>, Vec<&Field>) { + let mut meta_fields = Vec::new(); + let mut struct_fields = Vec::new(); + let mut reached_vec_or_option = false; + + for field in fields.named.iter() { + if !reached_vec_or_option { + if is_vec_or_option(&field.ty) || !is_copy_type(&field.ty) { + reached_vec_or_option = true; + struct_fields.push(field); + } else { + meta_fields.push(field); + } + } else { + struct_fields.push(field); + } + } + + (meta_fields, struct_fields) +} + +pub fn is_vec_or_option(ty: &Type) -> bool { + is_vec_type(ty) || is_option_type(ty) +} + +pub fn is_vec_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "Vec"; + } + } + false +} + +pub fn is_option_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "Option"; + } + } + false +} + +pub fn get_vec_inner_type(ty: &Type) -> Option<&Type> { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + if segment.ident == "Vec" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { + return Some(inner_ty); + } + } + } + } + } + None +} + +pub fn get_option_inner_type(ty: &Type) -> Option<&Type> { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + if segment.ident == "Option" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { + return Some(inner_ty); + } + } + } + } + } + None +} + +pub fn is_primitive_integer(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + let ident = &segment.ident; + return ident == "u16" + || ident == "u32" + || ident == "u64" + || ident == "i16" + || ident == "i32" + || ident == "i64" + || ident == "u8" + || ident == "i8"; + } + } + false +} + +pub fn is_bool_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "bool"; + } + } + false +} + +/// Check if a type is a specific primitive type (u8, u16, u32, u64, bool, etc.) +pub fn is_specific_primitive_type(ty: &Type, type_name: &str) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == type_name; + } + } + false +} + +pub fn is_pubkey_type(ty: &Type) -> bool { + if let Type::Path(TypePath { path, .. }) = ty { + if let Some(segment) = path.segments.last() { + return segment.ident == "Pubkey"; + } + } + false +} + +pub fn convert_to_zerocopy_type(ty: &Type) -> TokenStream { + match ty { + Type::Path(TypePath { path, .. }) => { + if let Some(segment) = path.segments.last() { + let ident = &segment.ident; + + // Handle primitive types first + match ident.to_string().as_str() { + "u16" => quote! { light_zero_copy::little_endian::U16 }, + "u32" => quote! { light_zero_copy::little_endian::U32 }, + "u64" => quote! { light_zero_copy::little_endian::U64 }, + "bool" => quote! { u8 }, + _ => { + // Handle container types recursively + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + let transformed_args: Vec = args + .args + .iter() + .map(|arg| { + if let syn::GenericArgument::Type(inner_type) = arg { + convert_to_zerocopy_type(inner_type) + } else { + quote! { #arg } + } + }) + .collect(); + + quote! { #ident<#(#transformed_args),*> } + } else { + quote! { #ty } + } + } + } + } else { + quote! { #ty } + } + } + _ => { + quote! { #ty } + } + } +} + +/// Checks if a struct has a derive(Copy) attribute +fn struct_has_copy_derive(attrs: &[Attribute]) -> bool { + attrs.iter().any(|attr| { + attr.path().is_ident("derive") && { + let mut found_copy = false; + // Use parse_nested_meta as the primary and only approach - it's the syn 2.0 standard + // for parsing comma-separated derive items like #[derive(Copy, Clone, Debug)] + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("Copy") { + found_copy = true; + } + Ok(()) // Continue parsing other derive items + }) + .is_ok() + && found_copy + } + }) +} + +/// Determines whether a struct implements Copy by checking for the #[derive(Copy)] attribute. +/// Results are cached for performance. +/// +/// In Rust, a struct can only implement Copy if: +/// 1. It explicitly has a #[derive(Copy)] attribute, AND +/// 2. All of its fields implement Copy +/// +/// The Rust compiler will enforce the second condition at compile time, so we only need to check +/// for the derive attribute here. +pub fn struct_implements_copy(input: &DeriveInput) -> bool { + let cache_key = create_unique_type_key(&input.ident); + + // Check the cache first + if let Some(implements_copy) = COPY_IMPL_CACHE.lock().unwrap().get(&cache_key) { + return *implements_copy; + } + + // Check if the struct has a derive(Copy) attribute + let implements_copy = struct_has_copy_derive(&input.attrs); + + // Cache the result + COPY_IMPL_CACHE + .lock() + .unwrap() + .insert(cache_key, implements_copy); + + implements_copy +} + +/// Determines whether a type implements Copy +/// 1. check whether type is a primitive type that implements Copy +/// 2. check whether type is an array type (which is always Copy if the element type is Copy) +/// 3. check whether type is struct -> check in the COPY_IMPL_CACHE if we know whether it has a #[derive(Copy)] attribute +/// +/// For struct types, this relies on the cache populated by struct_implements_copy. If we don't have cached +/// information, it assumes the type does not implement Copy. This is a limitation of our approach, but it +/// works well in practice because process_input will call struct_implements_copy for all structs before +/// they might be referenced by other structs. +pub fn is_copy_type(ty: &Type) -> bool { + match ty { + Type::Path(TypePath { path, .. }) => { + if let Some(segment) = path.segments.last() { + let ident = &segment.ident; + + // Check if it's a primitive type that implements Copy + if ident == "u8" + || ident == "u16" + || ident == "u32" + || ident == "u64" + || ident == "i8" + || ident == "i16" + || ident == "i32" + || ident == "i64" + || ident == "bool" // bool is a Copy type + || ident == "char" + || ident == "Pubkey" + // Pubkey is hardcoded as copy type for now. + { + return true; + } + + // Check if we have cached information about this type + let cache_key = create_unique_type_key(ident); + if let Some(implements_copy) = COPY_IMPL_CACHE.lock().unwrap().get(&cache_key) { + return *implements_copy; + } + } + } + // Handle array types (which are always Copy if the element type is Copy) + Type::Array(array) => { + // Arrays are Copy if their element type is Copy + return is_copy_type(&array.elem); + } + // For struct types not in cache, we'd need the derive input to check attributes + _ => {} + } + false +} + +#[cfg(test)] +mod tests { + use syn::parse_quote; + + use super::*; + + // Helper function to check if a struct implements Copy + fn check_struct_implements_copy(input: syn::DeriveInput) -> bool { + struct_implements_copy(&input) + } + + #[test] + fn test_struct_implements_copy() { + // Ensure the cache is cleared and the lock is released immediately + COPY_IMPL_CACHE.lock().unwrap().clear(); + // Test case 1: Empty struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct EmptyStruct {} + }; + assert!( + check_struct_implements_copy(input), + "EmptyStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 2: Simple struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct SimpleStruct { + a: u8, + b: u16, + } + }; + assert!( + check_struct_implements_copy(input), + "SimpleStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 3: Struct with #[derive(Clone)] but not Copy + let input: syn::DeriveInput = parse_quote! { + #[derive(Clone)] + struct StructWithoutCopy { + a: u8, + b: u16, + } + }; + assert!( + !check_struct_implements_copy(input), + "StructWithoutCopy should not implement Copy without #[derive(Copy)]" + ); + + // Test case 4: Struct with a non-Copy field but with derive(Copy) + // Note: In real Rust code, this would not compile, but for our test we only check attributes + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct StructWithVec { + a: u8, + b: Vec, + } + }; + assert!( + check_struct_implements_copy(input), + "StructWithVec has #[derive(Copy)] so our function returns true" + ); + + // Test case 5: Struct with all Copy fields but without #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + struct StructWithCopyFields { + a: u8, + b: u16, + c: i32, + d: bool, + } + }; + assert!( + !check_struct_implements_copy(input), + "StructWithCopyFields should not implement Copy without #[derive(Copy)]" + ); + + // Test case 6: Unit struct without #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + struct UnitStructWithoutCopy; + }; + assert!( + !check_struct_implements_copy(input), + "UnitStructWithoutCopy should not implement Copy without #[derive(Copy)]" + ); + + // Test case 7: Unit struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct UnitStructWithCopy; + }; + assert!( + check_struct_implements_copy(input), + "UnitStructWithCopy should implement Copy with #[derive(Copy)]" + ); + + // Test case 8: Tuple struct with #[derive(Copy)] + let input: syn::DeriveInput = parse_quote! { + #[derive(Copy, Clone)] + struct TupleStruct(u32, bool, char); + }; + assert!( + check_struct_implements_copy(input), + "TupleStruct should implement Copy with #[derive(Copy)]" + ); + + // Test case 9: Multiple derives including Copy + let input: syn::DeriveInput = parse_quote! { + #[derive(Debug, PartialEq, Copy, Clone)] + struct MultipleDerivesStruct { + a: u8, + } + }; + assert!( + check_struct_implements_copy(input), + "MultipleDerivesStruct should implement Copy with #[derive(Copy)]" + ); + } +} diff --git a/program-libs/zero-copy-derive/src/shared/z_struct.rs b/program-libs/zero-copy-derive/src/shared/z_struct.rs new file mode 100644 index 0000000000..e7cf42b4e1 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/z_struct.rs @@ -0,0 +1,624 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, TokenStreamExt}; +use syn::{parse_quote, Field, Ident, Type}; + +use super::utils; + +/// Enum representing the different field types for zero-copy struct +/// (Name, Type) +/// Note: Arrays with Option elements are not currently supported +#[derive(Debug)] +pub enum FieldType<'a> { + VecU8(&'a Ident), + VecCopy(&'a Ident, &'a Type), + VecDynamicZeroCopy(&'a Ident, &'a Type), + Array(&'a Ident, &'a Type), // Static arrays only - no Option elements supported + Option(&'a Ident, &'a Type), + OptionU64(&'a Ident), + OptionU32(&'a Ident), + OptionU16(&'a Ident), + Pubkey(&'a Ident), + Primitive(&'a Ident, &'a Type), + Copy(&'a Ident, &'a Type), + DynamicZeroCopy(&'a Ident, &'a Type), +} + +impl<'a> FieldType<'a> { + /// Get the name of the field + pub fn name(&self) -> &'a Ident { + match self { + FieldType::VecU8(name) => name, + FieldType::VecCopy(name, _) => name, + FieldType::VecDynamicZeroCopy(name, _) => name, + FieldType::Array(name, _) => name, + FieldType::Option(name, _) => name, + FieldType::OptionU64(name) => name, + FieldType::OptionU32(name) => name, + FieldType::OptionU16(name) => name, + FieldType::Pubkey(name) => name, + FieldType::Primitive(name, _) => name, + FieldType::Copy(name, _) => name, + FieldType::DynamicZeroCopy(name, _) => name, + } + } +} + +/// Classify a Vec type based on its inner type +fn classify_vec_type<'a>( + field_name: &'a Ident, + field_type: &'a Type, + inner_type: &'a Type, +) -> FieldType<'a> { + if utils::is_specific_primitive_type(inner_type, "u8") { + FieldType::VecU8(field_name) + } else if utils::is_copy_type(inner_type) { + FieldType::VecCopy(field_name, inner_type) + } else { + FieldType::VecDynamicZeroCopy(field_name, field_type) + } +} + +/// Classify an Option type based on its inner type +fn classify_option_type<'a>( + field_name: &'a Ident, + field_type: &'a Type, + inner_type: &'a Type, +) -> FieldType<'a> { + if utils::is_primitive_integer(inner_type) { + match () { + _ if utils::is_specific_primitive_type(inner_type, "u64") => { + FieldType::OptionU64(field_name) + } + _ if utils::is_specific_primitive_type(inner_type, "u32") => { + FieldType::OptionU32(field_name) + } + _ if utils::is_specific_primitive_type(inner_type, "u16") => { + FieldType::OptionU16(field_name) + } + _ => FieldType::Option(field_name, field_type), + } + } else { + FieldType::Option(field_name, field_type) + } +} + +/// Classify a primitive integer type +fn classify_integer_type<'a>( + field_name: &'a Ident, + field_type: &'a Type, +) -> syn::Result> { + match () { + _ if utils::is_specific_primitive_type(field_type, "u64") + | utils::is_specific_primitive_type(field_type, "u32") + | utils::is_specific_primitive_type(field_type, "u16") + | utils::is_specific_primitive_type(field_type, "u8") => + { + Ok(FieldType::Primitive(field_name, field_type)) + } + _ => Err(syn::Error::new_spanned( + field_type, + "Unsupported integer type. Only u8, u16, u32, and u64 are supported", + )), + } +} + +/// Classify a Copy type +fn classify_copy_type<'a>(field_name: &'a Ident, field_type: &'a Type) -> FieldType<'a> { + if utils::is_specific_primitive_type(field_type, "u8") + || utils::is_specific_primitive_type(field_type, "bool") + { + FieldType::Primitive(field_name, field_type) + } else { + FieldType::Copy(field_name, field_type) + } +} + +/// Classify a single field into its FieldType +fn classify_field<'a>(field_name: &'a Ident, field_type: &'a Type) -> syn::Result> { + // Vec types + if utils::is_vec_type(field_type) { + return match utils::get_vec_inner_type(field_type) { + Some(inner_type) => Ok(classify_vec_type(field_name, field_type, inner_type)), + None => Err(syn::Error::new_spanned( + field_type, + "Could not determine inner type of Vec", + )), + }; + } + + // Array types + if let Type::Array(_) = field_type { + return Ok(FieldType::Array(field_name, field_type)); + } + + // Option types + if utils::is_option_type(field_type) { + return match utils::get_option_inner_type(field_type) { + Some(inner_type) => Ok(classify_option_type(field_name, field_type, inner_type)), + None => Ok(FieldType::Option(field_name, field_type)), + }; + } + + // Simple type dispatch + match () { + _ if utils::is_pubkey_type(field_type) => Ok(FieldType::Pubkey(field_name)), + _ if utils::is_bool_type(field_type) => Ok(FieldType::Primitive(field_name, field_type)), + _ if utils::is_primitive_integer(field_type) => { + classify_integer_type(field_name, field_type) + } + _ if utils::is_copy_type(field_type) => Ok(classify_copy_type(field_name, field_type)), + _ => Ok(FieldType::DynamicZeroCopy(field_name, field_type)), + } +} + +/// Analyze struct fields and return vector of FieldType enums +pub fn analyze_struct_fields<'a>( + struct_fields: &'a [&'a Field], +) -> syn::Result>> { + struct_fields + .iter() + .map(|field| { + let field_name = field + .ident + .as_ref() + .ok_or_else(|| syn::Error::new_spanned(field, "Field must have a name"))?; + classify_field(field_name, &field.ty) + }) + .collect() +} + +/// Generate struct fields with zerocopy types based on field type enum +fn generate_struct_fields_with_zerocopy_types<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], + hasher: &'a bool, +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(struct_fields)?; + let iterator = field_types + .into_iter() + .zip(struct_fields.iter()) + .map(|(field_type, field)| { + let attributes = if *hasher { + field + .attrs + .iter() + .map(|attr| { + quote! { #attr } + }) + .collect::>() + } else { + vec![quote! {}] + }; + let (mutability, import_path, import_slice, camel_case_suffix): ( + syn::Type, + syn::Ident, + syn::Ident, + String, + ) = if MUT { + ( + parse_quote!(&'a mut [u8]), + format_ident!("borsh_mut"), + format_ident!("slice_mut"), + String::from("Mut"), + ) + } else { + ( + parse_quote!(&'a [u8]), + format_ident!("borsh"), + format_ident!("slice"), + String::new(), + ) + }; + let deserialize_ident = format_ident!("Deserialize{}", camel_case_suffix); + let trait_name: syn::Type = parse_quote!(light_zero_copy::#import_path::#deserialize_ident); + let slice_ident = format_ident!("ZeroCopySlice{}Borsh", camel_case_suffix); + let slice_name: syn::Type = parse_quote!(light_zero_copy::#import_slice::#slice_ident); + let struct_inner_ident = format_ident!("ZeroCopyStructInner{}", camel_case_suffix); + let inner_ident = format_ident!("ZeroCopyInner{}", camel_case_suffix); + let struct_inner_trait_name: syn::Type = parse_quote!(light_zero_copy::#import_path::#struct_inner_ident::#inner_ident); + match field_type { + FieldType::VecU8(field_name) => { + quote! { + #(#attributes)* + pub #field_name: #mutability + } + } + FieldType::VecCopy(field_name, inner_type) => { + // For primitive Copy types, use the zerocopy converted type directly + // For complex Copy types, use the ZeroCopyStructInner trait + if utils::is_primitive_integer(inner_type) || utils::is_bool_type(inner_type) || utils::is_pubkey_type(inner_type) { + let zerocopy_type = utils::convert_to_zerocopy_type(inner_type); + quote! { + #(#attributes)* + pub #field_name: #slice_name<'a, #zerocopy_type> + } + } else { + let inner_type = utils::convert_to_zerocopy_type(inner_type); + quote! { + #(#attributes)* + pub #field_name: #slice_name<'a, <#inner_type as #struct_inner_trait_name>> + } + } + } + FieldType::VecDynamicZeroCopy(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + FieldType::Array(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability , #field_type> + } + } + FieldType::Option(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + FieldType::OptionU64(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::OptionU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::OptionU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + quote! { + #(#attributes)* + pub #field_name: Option> + } + } + FieldType::Pubkey(field_name) => { + quote! { + #(#attributes)* + pub #field_name: >::Output + } + } + FieldType::Primitive(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + FieldType::Copy(field_name, field_type) => { + let zerocopy_type = utils::convert_to_zerocopy_type(field_type); + quote! { + #(#attributes)* + pub #field_name: light_zero_copy::Ref<#mutability , #zerocopy_type> + } + } + FieldType::DynamicZeroCopy(field_name, field_type) => { + quote! { + #(#attributes)* + pub #field_name: <#field_type as #trait_name<'a>>::Output + } + } + } + }); + Ok(iterator) +} + +/// Generate accessor methods for boolean fields in struct_fields. +/// We need accessors because booleans are stored as u8. +fn generate_bool_accessor_methods<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + struct_fields.iter().filter_map(|field| { + let field_name = &field.ident; + let field_type = &field.ty; + + if utils::is_bool_type(field_type) { + let comparison = if MUT { + quote! { *self.#field_name > 0 } + } else { + quote! { self.#field_name > 0 } + }; + + Some(quote! { + pub fn #field_name(&self) -> bool { + #comparison + } + }) + } else { + None + } + }) +} + +/// Generates the ZStruct definition as a TokenStream +pub fn generate_z_struct( + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + struct_fields: &[&Field], + meta_fields: &[&Field], + hasher: bool, +) -> syn::Result { + let z_struct_name = if MUT { + format_ident!("{}Mut", z_struct_name) + } else { + z_struct_name.clone() + }; + let z_struct_meta_name = if MUT { + format_ident!("{}Mut", z_struct_meta_name) + } else { + z_struct_meta_name.clone() + }; + let mutability: syn::Type = if MUT { + parse_quote!(&'a mut [u8]) + } else { + parse_quote!(&'a [u8]) + }; + + let derive_clone = if MUT { + quote! {} + } else { + quote! {, Clone } + }; + let struct_fields_with_zerocopy_types: Vec = + generate_struct_fields_with_zerocopy_types::(struct_fields, &hasher)?.collect(); + + let derive_hasher = if hasher { + quote! { + , LightHasher + } + } else { + quote! {} + }; + let hasher_flatten = if hasher { + quote! { + #[flatten] + } + } else { + quote! {} + }; + + let partial_eq_derive = if MUT { quote!() } else { quote!(, PartialEq) }; + + let mut z_struct = if meta_fields.is_empty() { + quote! { + // ZStruct + #[derive(Debug #partial_eq_derive #derive_clone #derive_hasher)] + pub struct #z_struct_name<'a> { + #(#struct_fields_with_zerocopy_types,)* + } + } + } else { + let mut tokens = quote! { + // ZStruct + #[derive(Debug #partial_eq_derive #derive_clone #derive_hasher)] + pub struct #z_struct_name<'a> { + #hasher_flatten + __meta: light_zero_copy::Ref<#mutability, #z_struct_meta_name>, + #(#struct_fields_with_zerocopy_types,)* + } + impl<'a> core::ops::Deref for #z_struct_name<'a> { + type Target = light_zero_copy::Ref<#mutability , #z_struct_meta_name>; + + fn deref(&self) -> &Self::Target { + &self.__meta + } + } + }; + + if MUT { + tokens.append_all(quote! { + impl<'a> core::ops::DerefMut for #z_struct_name<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.__meta + } + } + }); + } + tokens + }; + + if !meta_fields.is_empty() { + let meta_bool_accessor_methods = generate_bool_accessor_methods::(meta_fields); + z_struct.append_all(quote! { + // Implement methods for ZStruct + impl<'a> #z_struct_name<'a> { + #(#meta_bool_accessor_methods)* + } + }) + }; + + if !struct_fields.is_empty() { + let bool_accessor_methods = generate_bool_accessor_methods::(struct_fields); + z_struct.append_all(quote! { + // Implement methods for ZStruct + impl<'a> #z_struct_name<'a> { + #(#bool_accessor_methods)* + } + + }); + } + Ok(z_struct) +} + +#[cfg(test)] +mod tests { + use quote::format_ident; + use rand::{prelude::SliceRandom, rngs::StdRng, thread_rng, Rng, SeedableRng}; + use syn::parse_quote; + + use super::*; + + /// Generate a safe field name for testing + fn random_ident(rng: &mut StdRng) -> String { + // Use predetermined safe field names + const FIELD_NAMES: &[&str] = &[ + "field1", "field2", "field3", "field4", "field5", "value", "data", "count", "size", + "flag", "name", "id", "code", "index", "key", "amount", "balance", "total", "result", + "status", + ]; + + FIELD_NAMES.choose(rng).unwrap().to_string() + } + + /// Generate a random Rust type + fn random_type(rng: &mut StdRng, _depth: usize) -> syn::Type { + // Define our available types + let types = [0, 1, 2, 3, 4, 5, 6, 7]; + + // Randomly select a type index + let selected = *types.choose(rng).unwrap(); + + // Return the corresponding type + match selected { + 0 => parse_quote!(u8), + 1 => parse_quote!(u16), + 2 => parse_quote!(u32), + 3 => parse_quote!(u64), + 4 => parse_quote!(bool), + 5 => parse_quote!(Vec), + 6 => parse_quote!(Vec), + 7 => parse_quote!(Vec), + _ => unreachable!(), + } + } + + /// Generate a random field + fn random_field(rng: &mut StdRng) -> Field { + let name = random_ident(rng); + let ty = random_type(rng, 0); + + // Use a safer approach to create the field + let name_ident = format_ident!("{}", name); + parse_quote!(pub #name_ident: #ty) + } + + /// Generate a list of random fields + fn random_fields(rng: &mut StdRng, count: usize) -> Vec { + (0..count).map(|_| random_field(rng)).collect() + } + + #[test] + fn test_fuzz_generate_z_struct() { + // Set up RNG with a seed for reproducibility + let seed = thread_rng().gen(); + println!("seed {}", seed); + let mut rng = StdRng::seed_from_u64(seed); + + // Now that the test is working, run with 10,000 iterations + let num_iters = 10000; + + for i in 0..num_iters { + // Generate a random struct name + let struct_name = format_ident!("{}", random_ident(&mut rng)); + let z_struct_name = format_ident!("Z{}", struct_name); + let z_struct_meta_name = format_ident!("Z{}Meta", struct_name); + + // Generate random number of fields (1-10) + let field_count = rng.gen_range(1..11); + let fields = random_fields(&mut rng, field_count); + + // Create a named fields collection that lives longer than the process_fields call + let syn_fields = syn::punctuated::Punctuated::from_iter(fields.iter().cloned()); + let fields_named = syn::FieldsNamed { + brace_token: syn::token::Brace::default(), + named: syn_fields, + }; + + // Split into meta fields and struct fields + let (meta_fields, struct_fields) = crate::shared::utils::process_fields(&fields_named); + + // Call the function we're testing + let result = generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + false, + ); + + // Get the generated code as a string for validation + let result_str = result.unwrap().to_string(); + + // Validate the generated code + + // Verify the result contains expected struct elements + // Basic validation - must be non-empty + assert!( + !result_str.is_empty(), + "Failed to generate TokenStream for iteration {}", + i + ); + + // Validate that the generated code contains the expected struct definition + let struct_pattern = format!("struct {} < 'a >", z_struct_name); + assert!( + result_str.contains(&struct_pattern), + "Generated code missing struct definition for iteration {}. Expected: {}", + i, + struct_pattern + ); + + if meta_fields.is_empty() { + // Validate the meta field is present + assert!( + !result_str.contains("meta :"), + "Generated code had meta field for iteration {}", + i + ); + // Validate Deref implementation + assert!( + !result_str.contains("impl < 'a > core :: ops :: Deref"), + "Generated code missing Deref implementation for iteration {}", + i + ); + } else { + // Validate the meta field is present + assert!( + result_str.contains("meta :"), + "Generated code missing meta field for iteration {}", + i + ); + // Validate Deref implementation + assert!( + result_str.contains("impl < 'a > core :: ops :: Deref"), + "Generated code missing Deref implementation for iteration {}", + i + ); + // Validate Target type + assert!( + result_str.contains("type Target"), + "Generated code missing Target type for iteration {}", + i + ); + // Check that the deref method is implemented + assert!( + result_str.contains("fn deref (& self)"), + "Generated code missing deref method for iteration {}", + i + ); + + // Check for light_zero_copy::Ref reference + assert!( + result_str.contains("light_zero_copy :: Ref"), + "Generated code missing light_zero_copy::Ref for iteration {}", + i + ); + } + + // Make sure derive attributes are present + assert!( + result_str.contains("# [derive (Debug , PartialEq , Clone)]"), + "Generated code missing derive attributes for iteration {}", + i + ); + } + } +} diff --git a/program-libs/zero-copy-derive/src/shared/zero_copy_new.rs b/program-libs/zero-copy-derive/src/shared/zero_copy_new.rs new file mode 100644 index 0000000000..495977cbf0 --- /dev/null +++ b/program-libs/zero-copy-derive/src/shared/zero_copy_new.rs @@ -0,0 +1,391 @@ +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::Ident; + +use crate::shared::{ + utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generate ZeroCopyNew implementation with new_at method for a struct +pub fn generate_init_mut_impl( + struct_name: &syn::Ident, + meta_fields: &[&syn::Field], + struct_fields: &[&syn::Field], +) -> syn::Result { + let config_name = quote::format_ident!("{}Config", struct_name); + let z_meta_name = quote::format_ident!("Z{}MetaMut", struct_name); + let z_struct_mut_name = quote::format_ident!("Z{}Mut", struct_name); + + // Use the pre-separated fields from utils::process_fields (consistent with other derives) + let struct_field_types = analyze_struct_fields(struct_fields)?; + + // Generate field initialization code for struct fields only (meta fields are part of __meta) + let field_initializations: Result, syn::Error> = + struct_field_types + .iter() + .map(|field_type| generate_field_initialization(field_type)) + .collect(); + let field_initializations = field_initializations?; + + // Generate struct construction - only include struct fields that were initialized + // Meta fields are accessed via __meta.field_name in the generated ZStruct + let struct_field_names: Vec = struct_field_types + .iter() + .map(|field_type| { + let field_name = field_type.name(); + quote! { #field_name, } + }) + .collect(); + + // Check if there are meta fields to determine whether to include __meta + let has_meta_fields = !meta_fields.is_empty(); + + let meta_initialization = if has_meta_fields { + quote! { + // Handle the meta struct (fixed-size fields at the beginning) + let (__meta, bytes) = Ref::<&mut [u8], #z_meta_name>::from_prefix(bytes)?; + } + } else { + quote! { + // No meta fields, skip meta struct initialization + } + }; + + let struct_construction = if has_meta_fields { + quote! { + let result = #z_struct_mut_name { + __meta, + #(#struct_field_names)* + }; + } + } else { + quote! { + let result = #z_struct_mut_name { + #(#struct_field_names)* + }; + } + }; + + // Generate byte_len calculation for each field type + let byte_len_calculations: Result, syn::Error> = + struct_field_types + .iter() + .map(|field_type| generate_byte_len_calculation(field_type)) + .collect(); + let byte_len_calculations = byte_len_calculations?; + + // Calculate meta size if there are meta fields + let meta_size_calculation = if has_meta_fields { + quote! { + core::mem::size_of::<#z_meta_name>() + } + } else { + quote! { 0 } + }; + + let result = quote! { + impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for #struct_name { + type Config = #config_name; + type Output = >::Output; + + fn byte_len(config: &Self::Config) -> usize { + #meta_size_calculation #(+ #byte_len_calculations)* + } + + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), light_zero_copy::errors::ZeroCopyError> { + use zerocopy::Ref; + + #meta_initialization + + #(#field_initializations)* + + #struct_construction + + Ok((result, bytes)) + } + } + }; + Ok(result) +} + +// Configuration system functions moved from config.rs + +/// Determine if this field type requires configuration for initialization +pub fn requires_config(field_type: &FieldType) -> bool { + match field_type { + // Vec types always need length configuration + FieldType::VecU8(_) | FieldType::VecCopy(_, _) | FieldType::VecDynamicZeroCopy(_, _) => { + true + } + // Option types need Some/None configuration + FieldType::Option(_, _) => true, + // Fixed-size types don't need configuration + FieldType::Array(_, _) + | FieldType::Pubkey(_) + | FieldType::Primitive(_, _) + | FieldType::Copy(_, _) => false, + // DynamicZeroCopy types might need configuration if they contain Vec/Option + FieldType::DynamicZeroCopy(_, _) => true, // Conservative: assume they need config + // Option integer types need config to determine if they're enabled + FieldType::OptionU64(_) | FieldType::OptionU32(_) | FieldType::OptionU16(_) => true, + } +} + +/// Generate the config type for this field +pub fn config_type(field_type: &FieldType) -> syn::Result { + let result = match field_type { + // Simple Vec types: just need length + FieldType::VecU8(_) => quote! { u32 }, + FieldType::VecCopy(_, _) => quote! { u32 }, + + // Complex Vec types: need config for each element + FieldType::VecDynamicZeroCopy(_, vec_type) => { + if let Some(inner_type) = utils::get_vec_inner_type(vec_type) { + quote! { Vec<<#inner_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::Config> } + } else { + return Err(syn::Error::new_spanned( + vec_type, + "Could not determine inner type for VecDynamicZeroCopy config", + )); + } + } + + // Option types: delegate to the Option's Config type + FieldType::Option(_, option_type) => { + quote! { <#option_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::Config } + } + + // Fixed-size types don't need configuration + FieldType::Array(_, _) + | FieldType::Pubkey(_) + | FieldType::Primitive(_, _) + | FieldType::Copy(_, _) => quote! { () }, + + // Option integer types: use bool config to determine if enabled + FieldType::OptionU64(_) | FieldType::OptionU32(_) | FieldType::OptionU16(_) => { + quote! { bool } + } + + // DynamicZeroCopy types: delegate to their Config type (Config is typically 'static) + FieldType::DynamicZeroCopy(_, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { <#field_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::Config } + } + }; + Ok(result) +} + +/// Generate a configuration struct for a given struct +pub fn generate_config_struct( + struct_name: &Ident, + field_types: &[FieldType], +) -> syn::Result { + let config_name = quote::format_ident!("{}Config", struct_name); + + // Generate config fields only for fields that require configuration + let config_fields: Result, syn::Error> = field_types + .iter() + .filter(|field_type| requires_config(field_type)) + .map(|field_type| -> syn::Result { + let field_name = field_type.name(); + let config_type = config_type(field_type)?; + Ok(quote! { + pub #field_name: #config_type, + }) + }) + .collect(); + let config_fields = config_fields?; + + let result = if config_fields.is_empty() { + // If no fields require configuration, create an empty config struct + quote! { + #[derive(Debug, Clone, PartialEq)] + pub struct #config_name; + } + } else { + quote! { + #[derive(Debug, Clone, PartialEq)] + pub struct #config_name { + #(#config_fields)* + } + } + }; + Ok(result) +} + +/// Generate initialization logic for a field based on its configuration +pub fn generate_field_initialization(field_type: &FieldType) -> syn::Result { + let result = match field_type { + FieldType::VecU8(field_name) => { + quote! { + // Initialize the length prefix but don't use the returned ZeroCopySliceMut + { + light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::::new_at( + config.#field_name.into(), + bytes + )?; + } + // Split off the length prefix (4 bytes) and get the slice + let (_, bytes) = bytes.split_at_mut(4); + let (#field_name, bytes) = bytes.split_at_mut(config.#field_name as usize); + } + } + + FieldType::VecCopy(field_name, inner_type) => { + quote! { + let (#field_name, bytes) = light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::<#inner_type>::new_at( + config.#field_name.into(), + bytes + )?; + } + } + + FieldType::VecDynamicZeroCopy(field_name, vec_type) + | FieldType::DynamicZeroCopy(field_name, vec_type) + | FieldType::Option(field_name, vec_type) => { + quote! { + let (#field_name, bytes) = <#vec_type as light_zero_copy::init_mut::ZeroCopyNew<'a>>::new_zero_copy( + bytes, + config.#field_name + )?; + } + } + + FieldType::OptionU64(field_name) + | FieldType::OptionU32(field_name) + | FieldType::OptionU16(field_name) => { + let option_type = match field_type { + FieldType::OptionU64(_) => quote! { Option }, + FieldType::OptionU32(_) => quote! { Option }, + FieldType::OptionU16(_) => quote! { Option }, + _ => unreachable!(), + }; + quote! { + let (#field_name, bytes) = <#option_type as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( + bytes, + (config.#field_name, ()) + )?; + } + } + + // Fixed-size types that are struct fields (not meta fields) need initialization with () config + FieldType::Primitive(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut(bytes)?; + } + } + + // Array fields that are struct fields (come after Vec/Option) + FieldType::Array(field_name, array_type) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + #array_type + >::from_prefix(bytes)?; + } + } + + FieldType::Pubkey(field_name) => { + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::< + &'a mut [u8], + Pubkey + >::from_prefix(bytes)?; + } + } + + FieldType::Copy(field_name, field_type) => { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy(bytes)?; + } + } + }; + Ok(result) +} + +/// Generate byte length calculation for a field based on its configuration +pub fn generate_byte_len_calculation(field_type: &FieldType) -> syn::Result { + let result = match field_type { + // Vec types that require configuration + FieldType::VecU8(field_name) => { + quote! { + (4 + config.#field_name as usize) // 4 bytes for length + actual data + } + } + + FieldType::VecCopy(field_name, inner_type) => { + quote! { + (4 + (config.#field_name as usize * core::mem::size_of::<#inner_type>())) + } + } + + FieldType::VecDynamicZeroCopy(field_name, vec_type) => { + quote! { + <#vec_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + + // Option types + FieldType::Option(field_name, option_type) => { + quote! { + <#option_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + + FieldType::OptionU64(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + FieldType::OptionU32(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + FieldType::OptionU16(field_name) => { + quote! { + as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&(config.#field_name, ())) + } + } + + // Fixed-size types don't need configuration and have known sizes + FieldType::Primitive(_, field_type) => { + let zerocopy_type = utils::convert_to_zerocopy_type(field_type); + quote! { + core::mem::size_of::<#zerocopy_type>() + } + } + + FieldType::Array(_, array_type) => { + quote! { + core::mem::size_of::<#array_type>() + } + } + + FieldType::Pubkey(_) => { + quote! { + 32 // Pubkey is always 32 bytes + } + } + + // Meta field types (should not appear in struct fields, but handle gracefully) + FieldType::Copy(_, field_type) => { + quote! { + core::mem::size_of::<#field_type>() + } + } + + FieldType::DynamicZeroCopy(field_name, field_type) => { + quote! { + <#field_type as light_zero_copy::init_mut::ZeroCopyNew<'static>>::byte_len(&config.#field_name) + } + } + }; + Ok(result) +} diff --git a/program-libs/zero-copy-derive/src/zero_copy.rs b/program-libs/zero-copy-derive/src/zero_copy.rs new file mode 100644 index 0000000000..7e89094a6f --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy.rs @@ -0,0 +1,466 @@ +use proc_macro::TokenStream as ProcTokenStream; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_quote, DeriveInput, Field, Ident}; + +use crate::shared::{ + meta_struct, utils, + z_struct::{analyze_struct_fields, generate_z_struct, FieldType}, +}; + +/// Helper function to generate deserialize call pattern for a given type +fn generate_deserialize_call( + field_name: &syn::Ident, + field_type: &syn::Type, +) -> TokenStream { + let field_type = utils::convert_to_zerocopy_type(field_type); + let trait_path = if MUT { + quote!( as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut) + } else { + quote!( as light_zero_copy::borsh::Deserialize>::zero_copy_at) + }; + + quote! { + let (#field_name, bytes) = <#field_type #trait_path(bytes)?; + } +} + +/// Generates field deserialization code for the Deserialize implementation +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_deserialize_fields<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(struct_fields)?; + + let iterator = field_types.into_iter().map(move |field_type| { + let mutability_tokens = if MUT { + quote!(&'a mut [u8]) + } else { + quote!(&'a [u8]) + }; + match field_type { + FieldType::VecU8(field_name) => { + if MUT { + quote! { + let (#field_name, bytes) = light_zero_copy::borsh_mut::borsh_vec_u8_as_slice_mut(bytes)?; + } + } else { + quote! { + let (#field_name, bytes) = light_zero_copy::borsh::borsh_vec_u8_as_slice(bytes)?; + } + } + }, + FieldType::VecCopy(field_name, inner_type) => { + let inner_type = utils::convert_to_zerocopy_type(inner_type); + + let trait_path = if MUT { + quote!(light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::<'a, <#inner_type as light_zero_copy::borsh_mut::ZeroCopyStructInnerMut>::ZeroCopyInnerMut>) + } else { + quote!(light_zero_copy::slice::ZeroCopySliceBorsh::<'a, <#inner_type as light_zero_copy::borsh::ZeroCopyStructInner>::ZeroCopyInner>) + }; + quote! { + let (#field_name, bytes) = #trait_path::from_bytes_at(bytes)?; + } + }, + FieldType::VecDynamicZeroCopy(field_name, field_type) => { + generate_deserialize_call::(field_name, field_type) + }, + FieldType::Array(field_name, field_type) => { + let field_type = utils::convert_to_zerocopy_type(field_type); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_type>::from_prefix(bytes)?; + } + }, + FieldType::Option(field_name, field_type) => { + generate_deserialize_call::(field_name, field_type) + }, + FieldType::Pubkey(field_name) => { + generate_deserialize_call::(field_name, &parse_quote!(Pubkey)) + }, + FieldType::Primitive(field_name, field_type) => { + if MUT { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::borsh_mut::DeserializeMut>::zero_copy_at_mut(bytes)?; + } + } else { + quote! { + let (#field_name, bytes) = <#field_type as light_zero_copy::borsh::Deserialize>::zero_copy_at(bytes)?; + } + } + }, + FieldType::Copy(field_name, field_type) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(field_type); + quote! { + let (#field_name, bytes) = light_zero_copy::Ref::<#mutability_tokens, #field_ty_zerocopy>::from_prefix(bytes)?; + } + }, + FieldType::DynamicZeroCopy(field_name, field_type) => { + generate_deserialize_call::(field_name, field_type) + }, + FieldType::OptionU64(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u64)); + generate_deserialize_call::(field_name, &parse_quote!(Option<#field_ty_zerocopy>)) + }, + FieldType::OptionU32(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u32)); + generate_deserialize_call::(field_name, &parse_quote!(Option<#field_ty_zerocopy>)) + }, + FieldType::OptionU16(field_name) => { + let field_ty_zerocopy = utils::convert_to_zerocopy_type(&parse_quote!(u16)); + generate_deserialize_call::(field_name, &parse_quote!(Option<#field_ty_zerocopy>)) + } + } + }); + Ok(iterator) +} + +/// Generates field initialization code for the Deserialize implementation +pub fn generate_init_fields<'a>( + struct_fields: &'a [&'a Field], +) -> impl Iterator + 'a { + struct_fields.iter().map(|field| { + let field_name = &field.ident; + quote! { #field_name } + }) +} + +/// Generates the Deserialize implementation as a TokenStream +/// The `MUT` parameter controls whether to generate code for mutable or immutable references +pub fn generate_deserialize_impl( + name: &Ident, + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + struct_fields: &[&Field], + meta_is_empty: bool, + byte_len_impl: TokenStream, +) -> syn::Result { + let z_struct_name = if MUT { + format_ident!("{}Mut", z_struct_name) + } else { + z_struct_name.clone() + }; + let z_struct_meta_name = if MUT { + format_ident!("{}Mut", z_struct_meta_name) + } else { + z_struct_meta_name.clone() + }; + + // Define trait and types based on mutability + let (trait_name, mutability, method_name) = if MUT { + ( + quote!(light_zero_copy::borsh_mut::DeserializeMut), + quote!(mut), + quote!(zero_copy_at_mut), + ) + } else { + ( + quote!(light_zero_copy::borsh::Deserialize), + quote!(), + quote!(zero_copy_at), + ) + }; + let (meta_des, meta) = if meta_is_empty { + (quote!(), quote!()) + } else { + ( + quote! { + let (__meta, bytes) = light_zero_copy::Ref::< &'a #mutability [u8], #z_struct_meta_name>::from_prefix(bytes)?; + }, + quote!(__meta,), + ) + }; + let deserialize_fields = generate_deserialize_fields::(struct_fields)?; + let init_fields = generate_init_fields(struct_fields); + + let result = quote! { + impl<'a> #trait_name<'a> for #name { + type Output = #z_struct_name<'a>; + + fn #method_name(bytes: &'a #mutability [u8]) -> Result<(Self::Output, &'a #mutability [u8]), light_zero_copy::errors::ZeroCopyError> { + #meta_des + #(#deserialize_fields)* + Ok(( + #z_struct_name { + #meta + #(#init_fields,)* + }, + bytes + )) + } + + #byte_len_impl + } + }; + Ok(result) +} + +/// Generates the ZeroCopyStructInner implementation as a TokenStream +pub fn generate_zero_copy_struct_inner( + name: &Ident, + z_struct_name: &Ident, +) -> syn::Result { + let result = if MUT { + quote! { + // ZeroCopyStructInner implementation + impl light_zero_copy::borsh_mut::ZeroCopyStructInnerMut for #name { + type ZeroCopyInnerMut = #z_struct_name<'static>; + } + } + } else { + quote! { + // ZeroCopyStructInner implementation + impl light_zero_copy::borsh::ZeroCopyStructInner for #name { + type ZeroCopyInner = #z_struct_name<'static>; + } + } + }; + Ok(result) +} + +pub fn derive_zero_copy_impl(input: ProcTokenStream) -> syn::Result { + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(input)?; + + let hasher = false; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + let meta_struct_def = if !meta_fields.is_empty() { + meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, hasher)? + } else { + quote! {} + }; + + let z_struct_def = generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + hasher, + )?; + + let zero_copy_struct_inner_impl = + generate_zero_copy_struct_inner::(name, &z_struct_name)?; + + let deserialize_impl = generate_deserialize_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + meta_fields.is_empty(), + quote! {}, + )?; + + // Combine all implementations + let expanded = quote! { + #meta_struct_def + #z_struct_def + #zero_copy_struct_inner_impl + #deserialize_impl + }; + + Ok(expanded) +} + +#[cfg(test)] +mod tests { + use quote::format_ident; + use rand::{prelude::SliceRandom, rngs::StdRng, thread_rng, Rng, SeedableRng}; + use syn::parse_quote; + + use super::*; + + /// Generate a safe field name for testing + fn random_ident(rng: &mut StdRng) -> String { + // Use predetermined safe field names + const FIELD_NAMES: &[&str] = &[ + "field1", "field2", "field3", "field4", "field5", "value", "data", "count", "size", + "flag", "name", "id", "code", "index", "key", "amount", "balance", "total", "result", + "status", + ]; + + FIELD_NAMES.choose(rng).unwrap().to_string() + } + + /// Generate a random Rust type + fn random_type(rng: &mut StdRng, _depth: usize) -> syn::Type { + // Define our available types + let types = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + + // Randomly select a type index + let selected = *types.choose(rng).unwrap(); + + // Return the corresponding type + match selected { + 0 => parse_quote!(u8), + 1 => parse_quote!(u16), + 2 => parse_quote!(u32), + 3 => parse_quote!(u64), + 4 => parse_quote!(bool), + 5 => parse_quote!(Vec), + 6 => parse_quote!(Vec), + 7 => parse_quote!(Vec), + 8 => parse_quote!([u32; 12]), + 9 => parse_quote!([Vec; 12]), + 10 => parse_quote!([Vec; 20]), + _ => unreachable!(), + } + } + + /// Generate a random field + fn random_field(rng: &mut StdRng) -> Field { + let name = random_ident(rng); + let ty = random_type(rng, 0); + + // Use a safer approach to create the field + let name_ident = format_ident!("{}", name); + parse_quote!(pub #name_ident: #ty) + } + + /// Generate a list of random fields + fn random_fields(rng: &mut StdRng, count: usize) -> Vec { + (0..count).map(|_| random_field(rng)).collect() + } + + // Test for field initialization code generation - behavioral test + #[test] + fn test_init_fields() { + let field1: Field = parse_quote!(pub id: u32); + let field2: Field = parse_quote!(pub name: String); + let struct_fields = vec![&field1, &field2]; + + let result = generate_init_fields(&struct_fields).collect::>(); + assert_eq!( + result.len(), + 2, + "Should generate exactly 2 field initializations" + ); + + let result_str = format!("{} {}", result[0], result[1]); + assert!(result_str.contains("id"), "Should contain 'id' field"); + assert!(result_str.contains("name"), "Should contain 'name' field"); + } + + #[test] + fn test_fuzz_generate_deserialize_impl() { + // Set up RNG with a seed for reproducibility + let seed = thread_rng().gen(); + println!("seed {}", seed); + let mut rng = StdRng::seed_from_u64(seed); + + // Number of iterations for the test + let num_iters = 10000; + + for i in 0..num_iters { + // Generate a random struct name + let struct_name = format_ident!("{}", random_ident(&mut rng)); + let z_struct_name = format_ident!("Z{}", struct_name); + let z_struct_meta_name = format_ident!("Z{}Meta", struct_name); + + // Generate random number of fields (1-10) + let field_count = rng.gen_range(1..11); + let fields = random_fields(&mut rng, field_count); + + // Create a named fields collection + let syn_fields = syn::punctuated::Punctuated::from_iter(fields.iter().cloned()); + let fields_named = syn::FieldsNamed { + brace_token: syn::token::Brace::default(), + named: syn_fields, + }; + + // Split into meta fields and struct fields + let (_, struct_fields) = crate::shared::utils::process_fields(&fields_named); + + // Call the function we're testing + let result = generate_deserialize_impl::( + &struct_name, + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + false, + quote! {}, + ); + + // Get the generated code as a string for validation + let result_str = result.unwrap().to_string(); + + // Print the first result for debugging + if i == 0 { + println!("Generated deserialize_impl code format:\n{}", result_str); + } + + // Verify the result contains expected elements + // Basic validation - must be non-empty + assert!( + !result_str.is_empty(), + "Failed to generate TokenStream for iteration {}", + i + ); + + // Validate that the generated code contains the expected impl definition + let impl_pattern = format!( + "impl < 'a > light_zero_copy :: borsh :: Deserialize < 'a > for {}", + struct_name + ); + assert!( + result_str.contains(&impl_pattern), + "Generated code missing impl definition for iteration {}. Expected: {}", + i, + impl_pattern + ); + + // Validate type Output is defined + let output_pattern = format!("type Output = {} < 'a >", z_struct_name); + assert!( + result_str.contains(&output_pattern), + "Generated code missing Output type for iteration {}. Expected: {}", + i, + output_pattern + ); + + // Validate the zero_copy_at method is present + assert!( + result_str.contains("fn zero_copy_at (bytes : & 'a [u8])"), + "Generated code missing zero_copy_at method for iteration {}", + i + ); + + // Check for meta field extraction + let meta_extraction_pattern = format!( + "let (__meta , bytes) = light_zero_copy :: Ref :: < & 'a [u8] , {} > :: from_prefix (bytes) ?", + z_struct_meta_name + ); + assert!( + result_str.contains(&meta_extraction_pattern), + "Generated code missing meta field extraction for iteration {}", + i + ); + + // Check for return with Ok pattern + assert!( + result_str.contains("Ok (("), + "Generated code missing Ok return statement for iteration {}", + i + ); + + // Check for the struct initialization + let struct_init_pattern = format!("{} {{", z_struct_name); + assert!( + result_str.contains(&struct_init_pattern), + "Generated code missing struct initialization for iteration {}", + i + ); + + // Check for meta field in the returned struct + assert!( + result_str.contains("__meta ,"), + "Generated code missing meta field in struct initialization for iteration {}", + i + ); + } + } +} diff --git a/program-libs/zero-copy-derive/src/zero_copy_eq.rs b/program-libs/zero-copy-derive/src/zero_copy_eq.rs new file mode 100644 index 0000000000..94b06b51a6 --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy_eq.rs @@ -0,0 +1,265 @@ +use proc_macro::TokenStream as ProcTokenStream; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{DeriveInput, Field, Ident}; + +use crate::shared::{ + from_impl, utils, + z_struct::{analyze_struct_fields, FieldType}, +}; + +/// Generates meta field comparisons for PartialEq implementation +pub fn generate_meta_field_comparisons<'a>( + meta_fields: &'a [&'a Field], +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(meta_fields)?; + + let iterator = field_types.into_iter().map(|field_type| match field_type { + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => quote! { + if other.#field_name != meta.#field_name { + return false; + } + }, + _ if utils::is_specific_primitive_type(field_type, "bool") => quote! { + if other.#field_name != (meta.#field_name > 0) { + return false; + } + }, + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { + if other.#field_name != #field_type::from(meta.#field_name) { + return false; + } + } + } + } + } + _ => { + let field_name = field_type.name(); + quote! { + if other.#field_name != meta.#field_name { + return false; + } + } + } + }); + Ok(iterator) +} + +/// Generates struct field comparisons for PartialEq implementation +pub fn generate_struct_field_comparisons<'a, const MUT: bool>( + struct_fields: &'a [&'a Field], +) -> syn::Result + 'a> { + let field_types = analyze_struct_fields(struct_fields)?; + if field_types + .iter() + .any(|x| matches!(x, FieldType::Option(_, _))) + { + return Err(syn::Error::new_spanned( + struct_fields[0], + "Options are not supported in ZeroCopyEq", + )); + } + + let iterator = field_types.into_iter().map(|field_type| { + match field_type { + FieldType::VecU8(field_name) => { + quote! { + if self.#field_name != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::VecCopy(field_name, _) => { + quote! { + if self.#field_name.as_slice() != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::VecDynamicZeroCopy(field_name, _) => { + quote! { + if self.#field_name.as_slice() != other.#field_name.as_slice() { + return false; + } + } + } + FieldType::Array(field_name, _) => { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } + FieldType::Option(field_name, field_type) => { + if utils::is_specific_primitive_type(field_type, "u8") { + quote! { + if self.#field_name.is_some() && other.#field_name.is_some() { + if self.#field_name.as_ref().unwrap() != other.#field_name.as_ref().unwrap() { + return false; + } + } else if self.#field_name.is_some() || other.#field_name.is_some() { + return false; + } + } + } + // TODO: handle issue that structs need * == *, arrays need ** == * + // else if crate::utils::is_copy_type(field_type) { + // quote! { + // if self.#field_name.is_some() && other.#field_name.is_some() { + // if **self.#field_name.as_ref().unwrap() != *other.#field_name.as_ref().unwrap() { + // return false; + // } + // } else if self.#field_name.is_some() || other.#field_name.is_some() { + // return false; + // } + // } + // } + else { + quote! { + if self.#field_name.is_some() && other.#field_name.is_some() { + if **self.#field_name.as_ref().unwrap() != *other.#field_name.as_ref().unwrap() { + return false; + } + } else if self.#field_name.is_some() || other.#field_name.is_some() { + return false; + } + } + } + + } + FieldType::Pubkey(field_name) => { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } + FieldType::Primitive(field_name, field_type) => { + match () { + _ if utils::is_specific_primitive_type(field_type, "u8") => + if MUT { + quote! { + if *self.#field_name != other.#field_name { + return false; + } + } + } else { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + }, + _ if utils::is_specific_primitive_type(field_type, "bool") => + if MUT { + quote! { + if (*self.#field_name > 0) != other.#field_name { + return false; + } + } + } else { + quote! { + if (self.#field_name > 0) != other.#field_name { + return false; + } + } + }, + _ => { + // For u64, u32, u16 - use the type's from() method + quote! { + if #field_type::from(*self.#field_name) != other.#field_name { + return false; + } + } + } + } + } + FieldType::Copy(field_name, _) + | FieldType::DynamicZeroCopy(field_name, _) => { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + }, + FieldType::OptionU64(field_name) + | FieldType::OptionU32(field_name) + | FieldType::OptionU16(field_name) => { + quote! { + if self.#field_name != other.#field_name { + return false; + } + } + } + } + }); + Ok(iterator) +} + +/// Generates the PartialEq implementation as a TokenStream +pub fn generate_partial_eq_impl( + name: &Ident, + z_struct_name: &Ident, + z_struct_meta_name: &Ident, + meta_fields: &[&Field], + struct_fields: &[&Field], +) -> syn::Result { + let struct_field_comparisons = generate_struct_field_comparisons::(struct_fields)?; + let result = if !meta_fields.is_empty() { + let meta_field_comparisons = generate_meta_field_comparisons(meta_fields)?; + quote! { + impl<'a> PartialEq<#name> for #z_struct_name<'a> { + fn eq(&self, other: &#name) -> bool { + let meta: &#z_struct_meta_name = &self.__meta; + #(#meta_field_comparisons)* + #(#struct_field_comparisons)* + true + } + } + } + } else { + quote! { + impl<'a> PartialEq<#name> for #z_struct_name<'a> { + fn eq(&self, other: &#name) -> bool { + #(#struct_field_comparisons)* + true + } + } + + } + }; + Ok(result) +} + +pub fn derive_zero_copy_eq_impl(input: ProcTokenStream) -> syn::Result { + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(input)?; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + // Generate the PartialEq implementation + let partial_eq_impl = generate_partial_eq_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &meta_fields, + &struct_fields, + )?; + + // Generate From implementations + let from_impl = + from_impl::generate_from_impl::(name, &z_struct_name, &meta_fields, &struct_fields)?; + + Ok(quote! { + #partial_eq_impl + #from_impl + }) +} diff --git a/program-libs/zero-copy-derive/src/zero_copy_mut.rs b/program-libs/zero-copy-derive/src/zero_copy_mut.rs new file mode 100644 index 0000000000..ad52bba4d5 --- /dev/null +++ b/program-libs/zero-copy-derive/src/zero_copy_mut.rs @@ -0,0 +1,93 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::DeriveInput; + +use crate::{ + shared::{ + meta_struct, utils, + z_struct::{self, analyze_struct_fields}, + zero_copy_new::{generate_config_struct, generate_init_mut_impl}, + }, + zero_copy, +}; + +pub fn derive_zero_copy_mut_impl(fn_input: TokenStream) -> syn::Result { + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(fn_input.clone())?; + + let hasher = false; + + // Process the input to extract struct information + let (name, z_struct_name, z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Process the fields to separate meta fields and struct fields + let (meta_fields, struct_fields) = utils::process_fields(fields); + + let meta_struct_def_mut = if !meta_fields.is_empty() { + meta_struct::generate_meta_struct::(&z_struct_meta_name, &meta_fields, hasher)? + } else { + quote! {} + }; + + let z_struct_def_mut = z_struct::generate_z_struct::( + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + &meta_fields, + hasher, + )?; + + let zero_copy_struct_inner_impl_mut = zero_copy::generate_zero_copy_struct_inner::( + name, + &format_ident!("{}Mut", z_struct_name), + )?; + + let deserialize_impl_mut = zero_copy::generate_deserialize_impl::( + name, + &z_struct_name, + &z_struct_meta_name, + &struct_fields, + meta_fields.is_empty(), + quote! {}, + )?; + + // Parse the input DeriveInput + let input: DeriveInput = syn::parse(fn_input)?; + + // Process the input to extract struct information + let (name, _z_struct_name, _z_struct_meta_name, fields) = utils::process_input(&input)?; + + // Use the same field processing logic as other derive macros for consistency + let (meta_fields, struct_fields) = utils::process_fields(fields); + + // Process ALL fields uniformly by type (no position dependency for config generation) + let all_fields: Vec<&syn::Field> = meta_fields + .iter() + .chain(struct_fields.iter()) + .cloned() + .collect(); + let all_field_types = analyze_struct_fields(&all_fields)?; + + // Generate configuration struct based on all fields that need config (type-based) + let config_struct = generate_config_struct(name, &all_field_types)?; + + // Generate ZeroCopyNew implementation using the existing field separation + let init_mut_impl = generate_init_mut_impl(name, &meta_fields, &struct_fields)?; + + // Combine all mutable implementations + let expanded = quote! { + #config_struct + + #init_mut_impl + + #meta_struct_def_mut + + #z_struct_def_mut + + #zero_copy_struct_inner_impl_mut + + #deserialize_impl_mut + }; + + Ok(expanded) +} diff --git a/program-libs/zero-copy-derive/tests/config_test.rs b/program-libs/zero-copy-derive/tests/config_test.rs new file mode 100644 index 0000000000..990b2f6a18 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/config_test.rs @@ -0,0 +1,430 @@ +#![cfg(feature = "mut")] + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::borsh_mut::DeserializeMut; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + +/// Simple struct with just a Vec field to test basic config functionality +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct SimpleVecStruct { + pub a: u8, + pub vec: Vec, + pub b: u16, +} + +#[test] +fn test_simple_config_generation() { + // This test verifies that the ZeroCopyNew derive macro generates the expected config struct + // and ZeroCopyNew implementation + + // The config should have been generated as SimpleVecStructConfig + let config = SimpleVecStructConfig { + vec: 10, // Vec should have u32 config (length) + }; + + // Test that we can create a configuration + assert_eq!(config.vec, 10); + + println!("Config generation test passed!"); +} + +#[test] +fn test_simple_vec_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test the new_zero_copy method generated by ZeroCopyNew + let config = SimpleVecStructConfig { + vec: 5, // Vec with capacity 5 + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleVecStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // Use the generated new_zero_copy method + let result = SimpleVecStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set meta fields + simple_struct.__meta.a = 42; + + // Test that we can write to the vec slice + simple_struct.vec[0] = 10; + simple_struct.vec[1] = 20; + simple_struct.vec[2] = 30; + + // Test that we can set the b field + *simple_struct.b = 12345u16.into(); + + // Verify the values we set + assert_eq!(simple_struct.__meta.a, 42); + assert_eq!(simple_struct.vec[0], 10); + assert_eq!(simple_struct.vec[1], 20); + assert_eq!(simple_struct.vec[2], 30); + assert_eq!(u16::from(*simple_struct.b), 12345); + + // Test deserializing the initialized bytes with zero_copy_at_mut + let deserialize_result = SimpleVecStruct::zero_copy_at_mut(&mut bytes); + assert!(deserialize_result.is_ok()); + let (deserialized, _remaining) = deserialize_result.unwrap(); + + // Verify the deserialized data matches what we set + assert_eq!(deserialized.__meta.a, 42); + assert_eq!(deserialized.vec[0], 10); + assert_eq!(deserialized.vec[1], 20); + assert_eq!(deserialized.vec[2], 30); + assert_eq!(u16::from(*deserialized.b), 12345); + + println!("new_zero_copy initialization test passed!"); +} + +/// Struct with Option field to test Option config +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct SimpleOptionStruct { + pub a: u8, + pub option: Option, +} + +#[test] +fn test_simple_option_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with option enabled + let config = SimpleOptionStructConfig { + option: true, // Option should have bool config (enabled/disabled) + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleOptionStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = SimpleOptionStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set meta field + simple_struct.__meta.a = 123; + + // Test that option is Some and we can set its value + assert!(simple_struct.option.is_some()); + if let Some(ref mut opt_val) = simple_struct.option { + **opt_val = 98765u64.into(); + } + + // Verify the values + assert_eq!(simple_struct.__meta.a, 123); + if let Some(ref opt_val) = simple_struct.option { + assert_eq!(u64::from(**opt_val), 98765); + } + + // Test deserializing + let (deserialized, _) = SimpleOptionStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 123); + assert!(deserialized.option.is_some()); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 98765); + } + + println!("Option new_zero_copy test passed!"); +} + +#[test] +fn test_simple_option_struct_disabled() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with option disabled + let config = SimpleOptionStructConfig { + option: false, // Option disabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = SimpleOptionStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = SimpleOptionStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut simple_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set meta field + simple_struct.__meta.a = 200; + + // Test that option is None + assert!(simple_struct.option.is_none()); + + // Test deserializing + let (deserialized, _) = SimpleOptionStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 200); + assert!(deserialized.option.is_none()); + + println!("Option disabled new_zero_copy test passed!"); +} + +/// Test both Vec and Option in one struct +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct MixedStruct { + pub a: u8, + pub vec: Vec, + pub option: Option, + pub b: u16, +} + +#[test] +fn test_mixed_struct_new_zero_copy() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with both vec and option enabled + let config = MixedStructConfig { + vec: 8, // Vec -> u32 length + option: true, // Option -> bool enabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = MixedStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set meta field + mixed_struct.__meta.a = 77; + + // Set vec data + mixed_struct.vec[0] = 11; + mixed_struct.vec[3] = 44; + mixed_struct.vec[7] = 88; + + // Set option value + assert!(mixed_struct.option.is_some()); + if let Some(ref mut opt_val) = mixed_struct.option { + **opt_val = 123456789u64.into(); + } + + // Set b field + *mixed_struct.b = 54321u16.into(); + + // Verify all values + assert_eq!(mixed_struct.__meta.a, 77); + assert_eq!(mixed_struct.vec[0], 11); + assert_eq!(mixed_struct.vec[3], 44); + assert_eq!(mixed_struct.vec[7], 88); + if let Some(ref opt_val) = mixed_struct.option { + assert_eq!(u64::from(**opt_val), 123456789); + } + assert_eq!(u16::from(*mixed_struct.b), 54321); + + // Test deserializing + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 77); + assert_eq!(deserialized.vec[0], 11); + assert_eq!(deserialized.vec[3], 44); + assert_eq!(deserialized.vec[7], 88); + assert!(deserialized.option.is_some()); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 123456789); + } + assert_eq!(u16::from(*deserialized.b), 54321); + + println!("Mixed struct new_zero_copy test passed!"); +} + +#[test] +fn test_mixed_struct_option_disabled() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test with vec enabled but option disabled + let config = MixedStructConfig { + vec: 3, // Vec -> u32 length + option: false, // Option -> bool disabled + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = MixedStruct::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set values + mixed_struct.__meta.a = 99; + mixed_struct.vec[0] = 255; + mixed_struct.vec[2] = 128; + *mixed_struct.b = 9999u16.into(); + + // Verify option is None + assert!(mixed_struct.option.is_none()); + + // Test deserializing + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 99); + assert_eq!(deserialized.vec[0], 255); + assert_eq!(deserialized.vec[2], 128); + assert!(deserialized.option.is_none()); + assert_eq!(u16::from(*deserialized.b), 9999); + + println!("Mixed struct option disabled test passed!"); +} + +#[test] +fn test_byte_len_calculation() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Test SimpleVecStruct byte_len calculation + let config = SimpleVecStructConfig { + vec: 10, // Vec with capacity 10 + }; + + let expected_size = 1 + // a: u8 (meta field) + 4 + 10 + // vec: 4 bytes length + 10 bytes data + 2; // b: u16 + + let calculated_size = SimpleVecStruct::byte_len(&config); + assert_eq!(calculated_size, expected_size); + println!( + "SimpleVecStruct byte_len: calculated={}, expected={}", + calculated_size, expected_size + ); + + // Test SimpleOptionStruct byte_len calculation + let config_some = SimpleOptionStructConfig { + option: true, // Option enabled + }; + + let expected_size_some = 1 + // a: u8 (meta field) + 1 + 8; // option: 1 byte discriminant + 8 bytes u64 + + let calculated_size_some = SimpleOptionStruct::byte_len(&config_some); + assert_eq!(calculated_size_some, expected_size_some); + println!( + "SimpleOptionStruct (Some) byte_len: calculated={}, expected={}", + calculated_size_some, expected_size_some + ); + + let config_none = SimpleOptionStructConfig { + option: false, // Option disabled + }; + + let expected_size_none = 1 + // a: u8 (meta field) + 1; // option: 1 byte discriminant for None + + let calculated_size_none = SimpleOptionStruct::byte_len(&config_none); + assert_eq!(calculated_size_none, expected_size_none); + println!( + "SimpleOptionStruct (None) byte_len: calculated={}, expected={}", + calculated_size_none, expected_size_none + ); + + // Test MixedStruct byte_len calculation + let config_mixed = MixedStructConfig { + vec: 5, // Vec with capacity 5 + option: true, // Option enabled + }; + + let expected_size_mixed = 1 + // a: u8 (meta field) + 4 + 5 + // vec: 4 bytes length + 5 bytes data + 1 + 8 + // option: 1 byte discriminant + 8 bytes u64 + 2; // b: u16 + + let calculated_size_mixed = MixedStruct::byte_len(&config_mixed); + assert_eq!(calculated_size_mixed, expected_size_mixed); + println!( + "MixedStruct byte_len: calculated={}, expected={}", + calculated_size_mixed, expected_size_mixed + ); + + println!("All byte_len calculation tests passed!"); +} + +#[test] +fn test_dynamic_buffer_allocation_with_byte_len() { + use light_zero_copy::init_mut::ZeroCopyNew; + + // Example of how to use byte_len for dynamic buffer allocation + let config = MixedStructConfig { + vec: 12, // Vec with capacity 12 + option: true, // Option enabled + }; + + // Calculate the exact buffer size needed + let required_size = MixedStruct::byte_len(&config); + println!("Required buffer size: {} bytes", required_size); + + // Allocate exactly the right amount of memory + let mut bytes = vec![0u8; required_size]; + + // Initialize the structure + let result = MixedStruct::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mixed_struct, remaining) = result.unwrap(); + + // Verify we used exactly the right amount of bytes (no remaining bytes) + assert_eq!( + remaining.len(), + 0, + "Should have used exactly the calculated number of bytes" + ); + + // Set some values to verify it works + mixed_struct.__meta.a = 42; + mixed_struct.vec[5] = 123; + if let Some(ref mut opt_val) = mixed_struct.option { + **opt_val = 9999u64.into(); + } + *mixed_struct.b = 7777u16.into(); + + // Verify round-trip works + let (deserialized, _) = MixedStruct::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.a, 42); + assert_eq!(deserialized.vec[5], 123); + if let Some(ref opt_val) = deserialized.option { + assert_eq!(u64::from(**opt_val), 9999); + } + assert_eq!(u16::from(*deserialized.b), 7777); + + println!("Dynamic buffer allocation test passed!"); +} diff --git a/program-libs/zero-copy-derive/tests/cross_crate_copy.rs b/program-libs/zero-copy-derive/tests/cross_crate_copy.rs new file mode 100644 index 0000000000..dad348eee2 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/cross_crate_copy.rs @@ -0,0 +1,295 @@ +#![cfg(feature = "mut")] +//! Test cross-crate Copy identification functionality +//! +//! This test validates that the zero-copy derive macro correctly identifies +//! which types implement Copy, both for built-in types and user-defined types. + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + +// Test struct with primitive Copy types that should be in meta fields +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct PrimitiveCopyStruct { + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, + pub f: Vec, // Split point - this and following fields go to struct_fields + pub g: u32, // Should be in struct_fields due to field ordering rules +} + +// Test struct with primitive Copy types that should be in meta fields +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyEq, ZeroCopyMut)] +pub struct PrimitiveCopyStruct2 { + pub f: Vec, // Split point - this and following fields go to struct_fields + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, + pub g: u32, +} + +// Test struct with arrays that use u8 (which supports Unaligned) +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct ArrayCopyStruct { + pub fixed_u8: [u8; 4], + pub another_u8: [u8; 8], + pub data: Vec, // Split point + pub more_data: [u8; 3], // Should be in struct_fields due to field ordering +} + +// Test struct with Vec of primitive Copy types +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy)] +pub struct VecPrimitiveStruct { + pub header: u32, + pub data: Vec, // Vec - special case + pub numbers: Vec, // Vec of Copy type + pub footer: u64, +} + +#[cfg(test)] +mod tests { + use light_zero_copy::borsh::Deserialize; + + use super::*; + + #[test] + fn test_primitive_copy_field_splitting() { + // This test validates that primitive Copy types are correctly + // identified and placed in meta_fields until we hit a Vec + + let data = PrimitiveCopyStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + f: vec![5, 6, 7], + g: 8, + }; + + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = PrimitiveCopyStruct::zero_copy_at(&serialized).unwrap(); + + // Verify we can access meta fields (should be zero-copy references) + assert_eq!(deserialized.a, 1); + assert_eq!(deserialized.b.get(), 2); // U16 type, use .get() + assert_eq!(deserialized.c.get(), 3); // U32 type, use .get() + assert_eq!(deserialized.d.get(), 4); // U64 type, use .get() + assert!(deserialized.e()); // bool accessor method + + // Verify we can access struct fields + assert_eq!(deserialized.f, &[5, 6, 7]); + assert_eq!(deserialized.g.get(), 8); // U32 type in struct fields + } + + #[test] + fn test_array_copy_field_splitting() { + // Arrays should be treated as Copy types + let data = ArrayCopyStruct { + fixed_u8: [1, 2, 3, 4], + another_u8: [10, 20, 30, 40, 50, 60, 70, 80], + data: vec![5, 6], + more_data: [30, 40, 50], + }; + + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = ArrayCopyStruct::zero_copy_at(&serialized).unwrap(); + + // Arrays should be accessible (in meta_fields before Vec split) + assert_eq!(deserialized.fixed_u8.as_ref(), &[1, 2, 3, 4]); + assert_eq!( + deserialized.another_u8.as_ref(), + &[10, 20, 30, 40, 50, 60, 70, 80] + ); + + // After Vec split + assert_eq!(deserialized.data, &[5, 6]); + assert_eq!(deserialized.more_data.as_ref(), &[30, 40, 50]); + } + + #[test] + fn test_vec_primitive_types() { + // Test Vec with various primitive Copy element types + let data = VecPrimitiveStruct { + header: 1, + data: vec![10, 20, 30], + numbers: vec![100, 200, 300], + footer: 999, + }; + + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = VecPrimitiveStruct::zero_copy_at(&serialized).unwrap(); + + assert_eq!(deserialized.header.get(), 1); + + // Vec is special case - stored as slice + assert_eq!(deserialized.data, &[10, 20, 30]); + + // Vec should use ZeroCopySliceBorsh + assert_eq!(deserialized.numbers.len(), 3); + assert_eq!(deserialized.numbers[0].get(), 100); + assert_eq!(deserialized.numbers[1].get(), 200); + assert_eq!(deserialized.numbers[2].get(), 300); + + assert_eq!(deserialized.footer.get(), 999); + } + + #[test] + fn test_all_derives_with_vec_first() { + // This test validates PrimitiveCopyStruct2 which has Vec as the first field + // This means NO meta fields (all fields go to struct_fields due to field ordering) + // Also tests all derive macros: ZeroCopy, ZeroCopyEq, ZeroCopyMut + + use light_zero_copy::{borsh_mut::DeserializeMut, init_mut::ZeroCopyNew}; + + let data = PrimitiveCopyStruct2 { + f: vec![1, 2, 3], // Vec first - causes all fields to be in struct_fields + a: 10, + b: 20, + c: 30, + d: 40, + e: true, + g: 50, + }; + + // Test ZeroCopy (immutable) + let serialized = borsh::to_vec(&data).unwrap(); + let (deserialized, _) = PrimitiveCopyStruct2::zero_copy_at(&serialized).unwrap(); + + // Since Vec is first, ALL fields should be in struct_fields (no meta fields) + assert_eq!(deserialized.f, &[1, 2, 3]); + assert_eq!(deserialized.a, 10); // u8 direct access + assert_eq!(deserialized.b.get(), 20); // U16 via .get() + assert_eq!(deserialized.c.get(), 30); // U32 via .get() + assert_eq!(deserialized.d.get(), 40); // U64 via .get() + assert_eq!(deserialized.g.get(), 50); // U32 via .get() + assert!(deserialized.e()); // bool accessor method + + // Test ZeroCopyEq (PartialEq implementation) + let original = PrimitiveCopyStruct2 { + f: vec![1, 2, 3], + a: 10, + b: 20, + c: 30, + d: 40, + e: true, + g: 50, + }; + + // Should be equal to original + assert_eq!(deserialized, original); + + // Test inequality + let different = PrimitiveCopyStruct2 { + f: vec![1, 2, 3], + a: 11, + b: 20, + c: 30, + d: 40, + e: true, + g: 50, // Different 'a' + }; + assert_ne!(deserialized, different); + + // Test ZeroCopyMut (mutable zero-copy) + #[cfg(feature = "mut")] + { + let mut serialized_mut = borsh::to_vec(&data).unwrap(); + let (deserialized_mut, _) = + PrimitiveCopyStruct2::zero_copy_at_mut(&mut serialized_mut).unwrap(); + + // Test mutable access + assert_eq!(deserialized_mut.f, &[1, 2, 3]); + assert_eq!(*deserialized_mut.a, 10); // Mutable u8 field + assert_eq!(deserialized_mut.b.get(), 20); + let (deserialized_mut, _) = + PrimitiveCopyStruct2::zero_copy_at(&serialized_mut).unwrap(); + + // Test From implementation (ZeroCopyEq generates this for immutable version) + let converted: PrimitiveCopyStruct2 = deserialized_mut.into(); + assert_eq!(converted.a, 10); + assert_eq!(converted.b, 20); + assert_eq!(converted.c, 30); + assert_eq!(converted.d, 40); + assert!(converted.e); + assert_eq!(converted.f, vec![1, 2, 3]); + assert_eq!(converted.g, 50); + } + + // Test ZeroCopyNew (configuration-based initialization) + let config = super::PrimitiveCopyStruct2Config { + f: 3, // Vec length + // Other fields don't need config (they're primitives) + }; + + // Calculate required buffer size + let buffer_size = PrimitiveCopyStruct2::byte_len(&config); + let mut buffer = vec![0u8; buffer_size]; + + // Initialize the zero-copy struct + let (mut initialized, _) = + PrimitiveCopyStruct2::new_zero_copy(&mut buffer, config).unwrap(); + + // Verify we can access the initialized fields + assert_eq!(initialized.f.len(), 3); // Vec should have correct length + + // Set some values in the Vec + initialized.f[0] = 100; + initialized.f[1] = 101; + initialized.f[2] = 102; + *initialized.a = 200; + + // Verify the values were set correctly + assert_eq!(initialized.f, &[100, 101, 102]); + assert_eq!(*initialized.a, 200); + + println!("All derive macros (ZeroCopy, ZeroCopyEq, ZeroCopyMut) work correctly with Vec-first struct!"); + } + + #[test] + fn test_copy_identification_compilation() { + // The primary test is that our macro successfully processes all struct definitions + // above without panicking or generating invalid code. The fact that compilation + // succeeds demonstrates that our Copy identification logic works correctly. + + // Test basic functionality to ensure the generated code is sound + let primitive_data = PrimitiveCopyStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + f: vec![1, 2], + g: 5, + }; + + let array_data = ArrayCopyStruct { + fixed_u8: [1, 2, 3, 4], + another_u8: [5, 6, 7, 8, 9, 10, 11, 12], + data: vec![13, 14], + more_data: [15, 16, 17], + }; + + let vec_data = VecPrimitiveStruct { + header: 42, + data: vec![1, 2, 3], + numbers: vec![10, 20], + footer: 99, + }; + + // Serialize and deserialize to verify the generated code works + let serialized = borsh::to_vec(&primitive_data).unwrap(); + let (_, _) = PrimitiveCopyStruct::zero_copy_at(&serialized).unwrap(); + + let serialized = borsh::to_vec(&array_data).unwrap(); + let (_, _) = ArrayCopyStruct::zero_copy_at(&serialized).unwrap(); + + let serialized = borsh::to_vec(&vec_data).unwrap(); + let (_, _) = VecPrimitiveStruct::zero_copy_at(&serialized).unwrap(); + + println!("Cross-crate Copy identification test passed - all structs compiled and work correctly!"); + } +} diff --git a/program-libs/zero-copy-derive/tests/from_test.rs b/program-libs/zero-copy-derive/tests/from_test.rs new file mode 100644 index 0000000000..20391c36dd --- /dev/null +++ b/program-libs/zero-copy-derive/tests/from_test.rs @@ -0,0 +1,77 @@ +#![cfg(feature = "mut")] +use std::vec::Vec; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{borsh::Deserialize, ZeroCopyEq}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyMut}; + +// Simple struct with a primitive field and a vector +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct SimpleStruct { + pub a: u8, + pub b: Vec, +} + +// Basic struct with all basic numeric types +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct NumericStruct { + pub a: u8, + pub b: u16, + pub c: u32, + pub d: u64, + pub e: bool, +} + +// use light_zero_copy::borsh_mut::DeserializeMut; // Not needed for non-mut derivations + +#[test] +fn test_simple_from_implementation() { + // Create an instance of our struct + let original = SimpleStruct { + a: 42, + b: vec![1, 2, 3, 4, 5], + }; + + // Serialize it + let bytes = original.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), original.byte_len()); + + // Test From implementation for immutable struct + let (zero_copy, _) = SimpleStruct::zero_copy_at(&bytes).unwrap(); + let converted: SimpleStruct = zero_copy.into(); + assert_eq!(converted.a, 42); + assert_eq!(converted.b, vec![1, 2, 3, 4, 5]); + assert_eq!(converted, original); +} + +#[test] +fn test_numeric_from_implementation() { + // Create a struct with different primitive types + let original = NumericStruct { + a: 1, + b: 2, + c: 3, + d: 4, + e: true, + }; + + // Serialize it + let bytes = original.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), original.byte_len()); + + // Test From implementation for immutable struct + let (zero_copy, _) = NumericStruct::zero_copy_at(&bytes).unwrap(); + let converted: NumericStruct = zero_copy.clone().into(); + + // Verify all fields + assert_eq!(converted.a, 1); + assert_eq!(converted.b, 2); + assert_eq!(converted.c, 3); + assert_eq!(converted.d, 4); + assert!(converted.e); + + // Verify complete struct + assert_eq!(converted, original); +} diff --git a/program-libs/zero-copy-derive/tests/instruction_data.rs b/program-libs/zero-copy-derive/tests/instruction_data.rs new file mode 100644 index 0000000000..094248e4c8 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/instruction_data.rs @@ -0,0 +1,1401 @@ +#![cfg(feature = "mut")] +use std::vec::Vec; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{borsh::Deserialize, borsh_mut::DeserializeMut, errors::ZeroCopyError}; +use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned}; + +#[derive( + Debug, + Copy, + PartialEq, + Clone, + Immutable, + FromBytes, + IntoBytes, + KnownLayout, + BorshDeserialize, + BorshSerialize, + Default, + Unaligned, +)] +#[repr(C)] +pub struct Pubkey(pub(crate) [u8; 32]); + +impl Pubkey { + pub fn new_unique() -> Self { + use rand::Rng; + let mut rng = rand::thread_rng(); + let bytes = rng.gen::<[u8; 32]>(); + Pubkey(bytes) + } + + pub fn to_bytes(self) -> [u8; 32] { + self.0 + } +} + +impl<'a> Deserialize<'a> for Pubkey { + type Output = Ref<&'a [u8], Pubkey>; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + Ok(Ref::<&'a [u8], Pubkey>::from_prefix(bytes)?) + } +} + +impl<'a> DeserializeMut<'a> for Pubkey { + type Output = Ref<&'a mut [u8], Pubkey>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(Ref::<&'a mut [u8], Pubkey>::from_prefix(bytes)?) + } +} + +// We should not implement DeserializeMut for primitive types directly +// The implementation should be in the zero-copy crate + +impl PartialEq<>::Output> for Pubkey { + fn eq(&self, other: &>::Output) -> bool { + self.0 == other.0 + } +} + +impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for Pubkey { + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + 32 // Pubkey is always 32 bytes + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Self::zero_copy_at_mut(bytes) + } +} + +#[derive( + ZeroCopy, ZeroCopyMut, BorshDeserialize, BorshSerialize, Debug, PartialEq, Default, Clone, +)] +pub struct InstructionDataInvoke { + pub proof: Option, + pub input_compressed_accounts_with_merkle_context: + Vec, + pub output_compressed_accounts: Vec, + pub relay_fee: Option, + pub new_address_params: Vec, + pub compress_or_decompress_lamports: Option, + pub is_compress: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for InstructionDataInvoke { +// type Config = InstructionDataInvokeConfig; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), light_zero_copy::errors::ZeroCopyError> { +// use zerocopy::Ref; +// +// // First handle the meta struct (empty for InstructionDataInvoke) +// let (__meta, bytes) = Ref::<&mut [u8], ZInstructionDataInvokeMetaMut>::from_prefix(bytes)?; +// +// // Initialize each field using the corresponding config, following DeserializeMut order +// let (proof, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.proof_config.is_some(), CompressedProofConfig {}) +// )?; +// +// let input_configs: Vec = config.input_accounts_configs +// .into_iter() +// .map(|compressed_account_config| PackedCompressedAccountWithMerkleContextConfig { +// compressed_account: CompressedAccountConfig { +// address: (compressed_account_config.address_enabled, ()), +// data: (compressed_account_config.data_enabled, CompressedAccountDataConfig { data: compressed_account_config.data_capacity }), +// }, +// merkle_context: PackedMerkleContextConfig {}, +// }) +// .collect(); +// let (input_compressed_accounts_with_merkle_context, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// input_configs +// )?; +// +// let output_configs: Vec = config.output_accounts_configs +// .into_iter() +// .map(|compressed_account_config| OutputCompressedAccountWithPackedContextConfig { +// compressed_account: CompressedAccountConfig { +// address: (compressed_account_config.address_enabled, ()), +// data: (compressed_account_config.data_enabled, CompressedAccountDataConfig { data: compressed_account_config.data_capacity }), +// }, +// }) +// .collect(); +// let (output_compressed_accounts, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// output_configs +// )?; +// +// let (relay_fee, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.relay_fee_config.is_some(), ()) +// )?; +// +// let new_address_configs: Vec = config.new_address_configs +// .into_iter() +// .map(|_| NewAddressParamsPackedConfig {}) +// .collect(); +// let (new_address_params, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// new_address_configs +// )?; +// +// let (compress_or_decompress_lamports, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.decompress_lamports_config.is_some(), ()) +// )?; +// +// let (is_compress, bytes) = ::new_zero_copy( +// bytes, +// () +// )?; +// +// Ok(( +// ZInstructionDataInvokeMut { +// proof, +// input_compressed_accounts_with_merkle_context, +// output_compressed_accounts, +// relay_fee, +// new_address_params, +// compress_or_decompress_lamports, +// is_compress, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct OutputCompressedAccountWithContext { + pub compressed_account: CompressedAccount, + pub merkle_tree: Pubkey, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct OutputCompressedAccountWithPackedContext { + pub compressed_account: CompressedAccount, + pub merkle_tree_index: u8, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for OutputCompressedAccountWithPackedContext { +// type Config = CompressedAccountZeroCopyNew; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZOutputCompressedAccountWithPackedContextMetaMut>::from_prefix(bytes)?; +// let (compressed_account, bytes) = ::new_zero_copy(bytes, config)?; +// let (merkle_tree_index, bytes) = ::new_zero_copy(bytes, ())?; +// +// Ok(( +// ZOutputCompressedAccountWithPackedContextMut { +// compressed_account, +// merkle_tree_index, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + Copy, +)] +pub struct NewAddressParamsPacked { + pub seed: [u8; 32], + pub address_queue_account_index: u8, + pub address_merkle_tree_account_index: u8, + pub address_merkle_tree_root_index: u16, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for NewAddressParamsPacked { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZNewAddressParamsPackedMetaMut>::from_prefix(bytes)?; +// Ok((ZNewAddressParamsPackedMut { __meta }, bytes)) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct NewAddressParams { + pub seed: [u8; 32], + pub address_queue_pubkey: Pubkey, + pub address_merkle_tree_pubkey: Pubkey, + pub address_merkle_tree_root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, + Copy, +)] +pub struct PackedReadOnlyAddress { + pub address: [u8; 32], + pub address_merkle_tree_root_index: u16, + pub address_merkle_tree_account_index: u8, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct ReadOnlyAddress { + pub address: [u8; 32], + pub address_merkle_tree_pubkey: Pubkey, + pub address_merkle_tree_root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Clone, + Copy, +)] +pub struct CompressedProof { + pub a: [u8; 32], + pub b: [u8; 64], + pub c: [u8; 32], +} + +impl Default for CompressedProof { + fn default() -> Self { + Self { + a: [0; 32], + b: [0; 64], + c: [0; 32], + } + } +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedProof { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedProofMetaMut>::from_prefix(bytes)?; +// Ok((ZCompressedProofMut { __meta }, bytes)) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Default, +)] +pub struct CompressedCpiContext { + /// Is set by the program that is invoking the CPI to signal that is should + /// set the cpi context. + pub set_context: bool, + /// Is set to clear the cpi context since someone could have set it before + /// with unrelated data. + pub first_set_context: bool, + /// Index of cpi context account in remaining accounts. + pub cpi_context_account_index: u8, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct PackedCompressedAccountWithMerkleContext { + pub compressed_account: CompressedAccount, + pub merkle_context: PackedMerkleContext, + /// Index of root used in inclusion validity proof. + pub root_index: u16, + /// Placeholder to mark accounts read-only unimplemented set to false. + pub read_only: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for PackedCompressedAccountWithMerkleContext { +// type Config = CompressedAccountZeroCopyNew; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZPackedCompressedAccountWithMerkleContextMetaMut>::from_prefix(bytes)?; +// let (compressed_account, bytes) = ::new_zero_copy(bytes, config)?; +// let (merkle_context, bytes) = ::new_zero_copy(bytes, ())?; +// let (root_index, bytes) = ::new_zero_copy(bytes, ())?; +// let (read_only, bytes) = ::new_zero_copy(bytes, ())?; +// +// Ok(( +// ZPackedCompressedAccountWithMerkleContextMut { +// compressed_account, +// merkle_context, +// root_index, +// read_only, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Default, +)] +pub struct MerkleContext { + pub merkle_tree_pubkey: Pubkey, + pub nullifier_queue_pubkey: Pubkey, + pub leaf_index: u32, + pub prove_by_index: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for MerkleContext { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZMerkleContextMetaMut>::from_prefix(bytes)?; +// +// Ok(( +// ZMerkleContextMut { +// __meta, +// }, +// bytes, +// )) +// } +// } + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct CompressedAccountWithMerkleContext { + pub compressed_account: CompressedAccount, + pub merkle_context: MerkleContext, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct ReadOnlyCompressedAccount { + pub account_hash: [u8; 32], + pub merkle_context: MerkleContext, + pub root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct PackedReadOnlyCompressedAccount { + pub account_hash: [u8; 32], + pub merkle_context: PackedMerkleContext, + pub root_index: u16, +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + Clone, + Copy, + PartialEq, + Default, +)] +pub struct PackedMerkleContext { + pub merkle_tree_pubkey_index: u8, + pub nullifier_queue_pubkey_index: u8, + pub leaf_index: u32, + pub prove_by_index: bool, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for PackedMerkleContext { +// type Config = (); +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// _config: Self::Config +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZPackedMerkleContextMetaMut>::from_prefix(bytes)?; +// Ok((ZPackedMerkleContextMut { __meta }, bytes)) +// } +// } + +#[derive(Debug, PartialEq, Default, Clone, Copy)] +pub struct CompressedAccountZeroCopyNew { + pub address_enabled: bool, + pub data_enabled: bool, + pub data_capacity: u32, +} + +// Manual InstructionDataInvokeConfig removed - now using generated config from ZeroCopyNew derive + +#[derive( + ZeroCopy, ZeroCopyMut, BorshDeserialize, BorshSerialize, Debug, PartialEq, Default, Clone, +)] +pub struct CompressedAccount { + pub owner: [u8; 32], + pub lamports: u64, + pub address: Option<[u8; 32]>, + pub data: Option, +} + +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedAccount { +// type Config = CompressedAccountZeroCopyNew; +// type Output = >::Output; +// +// fn new_zero_copy( +// bytes: &'a mut [u8], +// config: Self::Config, +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedAccountMetaMut>::from_prefix(bytes)?; +// +// // Use generic Option implementation for address field +// let (address, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.address_enabled, ()) +// )?; +// +// // Use generic Option implementation for data field +// let (data, bytes) = as light_zero_copy::init_mut::ZeroCopyNew>::new_zero_copy( +// bytes, +// (config.data_enabled, CompressedAccountDataConfig { data: config.data_capacity }) +// )?; +// +// Ok(( +// ZCompressedAccountMut { +// __meta, +// address, +// data, +// }, +// bytes, +// )) +// } +// } + +impl<'a> From> for CompressedAccount { + fn from(value: ZCompressedAccount<'a>) -> Self { + Self { + owner: value.__meta.owner, + lamports: u64::from(value.__meta.lamports), + address: value.address.map(|x| *x), + data: value.data.as_ref().map(|x| x.into()), + } + } +} + +impl<'a> From<&ZCompressedAccount<'a>> for CompressedAccount { + fn from(value: &ZCompressedAccount<'a>) -> Self { + Self { + owner: value.__meta.owner, + lamports: u64::from(value.__meta.lamports), + address: value.address.as_ref().map(|x| **x), + data: value.data.as_ref().map(|x| x.into()), + } + } +} + +impl PartialEq for ZCompressedAccount<'_> { + fn eq(&self, other: &CompressedAccount) -> bool { + // Check address: if both Some and unequal, return false + if self.address.is_some() + && other.address.is_some() + && *self.address.unwrap() != other.address.unwrap() + { + return false; + } + // Check address: if exactly one is Some, return false + if self.address.is_some() != other.address.is_some() { + return false; + } + + // Check data: if both Some and unequal, return false + if self.data.is_some() + && other.data.is_some() + && self.data.as_ref().unwrap() != other.data.as_ref().unwrap() + { + return false; + } + // Check data: if exactly one is Some, return false + if self.data.is_some() != other.data.is_some() { + return false; + } + + self.owner == other.owner && self.lamports == other.lamports + } +} + +// Commented out because mutable derivation is disabled +// impl PartialEq for ZCompressedAccountMut<'_> { +// fn eq(&self, other: &CompressedAccount) -> bool { +// if self.address.is_some() +// && other.address.is_some() +// && **self.address.as_ref().unwrap() != *other.address.as_ref().unwrap() +// { +// return false; +// } +// if self.address.is_some() || other.address.is_some() { +// return false; +// } +// if self.data.is_some() +// && other.data.is_some() +// && self.data.as_ref().unwrap() != other.data.as_ref().unwrap() +// { +// return false; +// } +// if self.data.is_some() || other.data.is_some() { +// return false; +// } + +// self.owner == other.owner && self.lamports == other.lamports +// } +// } +impl PartialEq> for CompressedAccount { + fn eq(&self, other: &ZCompressedAccount) -> bool { + // Check address: if both Some and unequal, return false + if self.address.is_some() + && other.address.is_some() + && self.address.unwrap() != *other.address.unwrap() + { + return false; + } + // Check address: if exactly one is Some, return false + if self.address.is_some() != other.address.is_some() { + return false; + } + + // Check data: if both Some and unequal, return false + if self.data.is_some() + && other.data.is_some() + && other.data.as_ref().unwrap() != self.data.as_ref().unwrap() + { + return false; + } + // Check data: if exactly one is Some, return false + if self.data.is_some() != other.data.is_some() { + return false; + } + + self.owner == other.owner && self.lamports == u64::from(other.lamports) + } +} + +#[derive( + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, + BorshDeserialize, + BorshSerialize, + Debug, + PartialEq, + Default, + Clone, +)] +pub struct CompressedAccountData { + pub discriminator: [u8; 8], + pub data: Vec, + pub data_hash: [u8; 32], +} + +// COMMENTED OUT: Now using ZeroCopyNew derive macro instead +// impl<'a> light_zero_copy::init_mut::ZeroCopyNew<'a> for CompressedAccountData { +// type Config = u32; // data_capacity +// type Output = >::Output; + +// fn new_zero_copy( +// bytes: &'a mut [u8], +// data_capacity: Self::Config, +// ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { +// let (__meta, bytes) = Ref::<&mut [u8], ZCompressedAccountDataMetaMut>::from_prefix(bytes)?; +// // For u8 slices we just use &mut [u8] so we init the len and the split mut separately. +// { +// light_zero_copy::slice_mut::ZeroCopySliceMutBorsh::::new_at( +// data_capacity.into(), +// bytes, +// )?; +// } +// // Split off len for +// let (_, bytes) = bytes.split_at_mut(4); +// let (data, bytes) = bytes.split_at_mut(data_capacity as usize); +// let (data_hash, bytes) = Ref::<&mut [u8], [u8; 32]>::from_prefix(bytes)?; +// Ok(( +// ZCompressedAccountDataMut { +// __meta, +// data, +// data_hash, +// }, +// bytes, +// )) +// } +// } + +#[test] +fn test_compressed_account_data_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + let config = CompressedAccountDataConfig { data: 10 }; + + // Calculate exact buffer size needed and allocate + let buffer_size = CompressedAccountData::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + let result = CompressedAccountData::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mut_account, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test that we can set discriminator + mut_account.__meta.discriminator = [1, 2, 3, 4, 5, 6, 7, 8]; + + // Test that we can write to data + mut_account.data[0] = 42; + mut_account.data[1] = 43; + + // Test that we can set data_hash + mut_account.data_hash[0] = 99; + mut_account.data_hash[1] = 100; + + assert_eq!(mut_account.__meta.discriminator, [1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!(mut_account.data[0], 42); + assert_eq!(mut_account.data[1], 43); + assert_eq!(mut_account.data_hash[0], 99); + assert_eq!(mut_account.data_hash[1], 100); + + // Test deserializing the initialized bytes with zero_copy_at_mut + let deserialize_result = CompressedAccountData::zero_copy_at_mut(&mut bytes); + assert!(deserialize_result.is_ok()); + let (deserialized_account, _remaining) = deserialize_result.unwrap(); + + // Verify the deserialized data matches what we set + assert_eq!( + deserialized_account.__meta.discriminator, + [1, 2, 3, 4, 5, 6, 7, 8] + ); + assert_eq!(deserialized_account.data.len(), 10); + assert_eq!(deserialized_account.data[0], 42); + assert_eq!(deserialized_account.data[1], 43); + assert_eq!(deserialized_account.data_hash[0], 99); + assert_eq!(deserialized_account.data_hash[1], 100); +} + +#[test] +fn test_compressed_account_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + let config = CompressedAccountConfig { + address: (true, ()), + data: (true, CompressedAccountDataConfig { data: 10 }), + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = CompressedAccount::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + let result = CompressedAccount::new_zero_copy(&mut bytes, config); + assert!(result.is_ok()); + let (mut mut_account, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Set values + mut_account.__meta.owner = [1u8; 32]; + mut_account.__meta.lamports = 12345u64.into(); + mut_account.address.as_mut().unwrap()[0] = 42; + mut_account.data.as_mut().unwrap().data[0] = 99; + + // Test deserialize + let (deserialized, _) = CompressedAccount::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(deserialized.__meta.owner, [1u8; 32]); + assert_eq!(u64::from(deserialized.__meta.lamports), 12345u64); + assert_eq!(deserialized.address.as_ref().unwrap()[0], 42); + assert_eq!(deserialized.data.as_ref().unwrap().data[0], 99); +} + +#[test] +fn test_instruction_data_invoke_new_at() { + use light_zero_copy::init_mut::ZeroCopyNew; + // Create different configs to test various combinations + let compressed_account_config1 = CompressedAccountZeroCopyNew { + address_enabled: true, + data_enabled: true, + data_capacity: 10, + }; + + let compressed_account_config2 = CompressedAccountZeroCopyNew { + address_enabled: false, + data_enabled: true, + data_capacity: 5, + }; + + let compressed_account_config3 = CompressedAccountZeroCopyNew { + address_enabled: true, + data_enabled: false, + data_capacity: 0, + }; + + let compressed_account_config4 = CompressedAccountZeroCopyNew { + address_enabled: false, + data_enabled: false, + data_capacity: 0, + }; + + let config = InstructionDataInvokeConfig { + proof: (true, CompressedProofConfig {}), // Enable proof + input_compressed_accounts_with_merkle_context: vec![ + PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config1.address_enabled, ()), + data: ( + compressed_account_config1.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config1.data_capacity, + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }, + PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config2.address_enabled, ()), + data: ( + compressed_account_config2.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config2.data_capacity, + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }, + ], + output_compressed_accounts: vec![ + OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config3.address_enabled, ()), + data: ( + compressed_account_config3.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config3.data_capacity, + }, + ), + }, + }, + OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (compressed_account_config4.address_enabled, ()), + data: ( + compressed_account_config4.data_enabled, + CompressedAccountDataConfig { + data: compressed_account_config4.data_capacity, + }, + ), + }, + }, + ], + relay_fee: true, // Enable relay fee + new_address_params: vec![ + NewAddressParamsPackedConfig {}, + NewAddressParamsPackedConfig {}, + ], // Length 2 + compress_or_decompress_lamports: true, // Enable decompress lamports + }; + + // Calculate exact buffer size needed and allocate + let buffer_size = InstructionDataInvoke::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + let result = InstructionDataInvoke::new_zero_copy(&mut bytes, config); + if let Err(ref e) = result { + eprintln!("Error: {:?}", e); + } + assert!(result.is_ok()); + let (_instruction_data, remaining) = result.unwrap(); + + // Verify we used exactly the calculated number of bytes + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // Test deserialization round-trip first + let (mut deserialized, _) = InstructionDataInvoke::zero_copy_at_mut(&mut bytes).unwrap(); + + // Now set values and test again + *deserialized.is_compress = 1; + + // Set proof values + if let Some(proof) = &mut deserialized.proof { + proof.a[0] = 42; + proof.b[0] = 43; + proof.c[0] = 44; + } + + // Set relay fee value + if let Some(relay_fee) = &mut deserialized.relay_fee { + **relay_fee = 12345u64.into(); + } + + // Set decompress lamports value + if let Some(decompress_lamports) = &mut deserialized.compress_or_decompress_lamports { + **decompress_lamports = 67890u64.into(); + } + + // Set first input account values + let first_input = &mut deserialized.input_compressed_accounts_with_merkle_context[0]; + first_input.compressed_account.__meta.owner[0] = 11; + first_input.compressed_account.__meta.lamports = 1000u64.into(); + if let Some(address) = &mut first_input.compressed_account.address { + address[0] = 22; + } + if let Some(data) = &mut first_input.compressed_account.data { + data.__meta.discriminator[0] = 33; + data.data[0] = 99; + data.data_hash[0] = 55; + } + + // Set first output account values + let first_output = &mut deserialized.output_compressed_accounts[0]; + first_output.compressed_account.__meta.owner[0] = 77; + first_output.compressed_account.__meta.lamports = 2000u64.into(); + if let Some(address) = &mut first_output.compressed_account.address { + address[0] = 88; + } + + // Verify basic structure with vectors of length 2 + assert_eq!( + deserialized + .input_compressed_accounts_with_merkle_context + .len(), + 2 + ); // Length 2 + assert_eq!(deserialized.output_compressed_accounts.len(), 2); // Length 2 + assert_eq!(deserialized.new_address_params.len(), 2); // Length 2 + assert!(deserialized.proof.is_some()); // Enabled + assert!(deserialized.relay_fee.is_some()); // Enabled + assert!(deserialized.compress_or_decompress_lamports.is_some()); // Enabled + assert_eq!(*deserialized.is_compress, 1); + + // Test data access and modification + if let Some(proof) = &deserialized.proof { + // Verify we can access proof fields and our written values + assert_eq!(proof.a[0], 42); + assert_eq!(proof.b[0], 43); + assert_eq!(proof.c[0], 44); + } + + // Verify option integer values + if let Some(relay_fee) = &deserialized.relay_fee { + assert_eq!(u64::from(**relay_fee), 12345); + } + + if let Some(decompress_lamports) = &deserialized.compress_or_decompress_lamports { + assert_eq!(u64::from(**decompress_lamports), 67890); + } + + // Test accessing first input account (config1: address=true, data=true, capacity=10) + let first_input = &deserialized.input_compressed_accounts_with_merkle_context[0]; + assert_eq!(first_input.compressed_account.__meta.owner[0], 11); // Our written value + assert_eq!( + u64::from(first_input.compressed_account.__meta.lamports), + 1000 + ); // Our written value + assert!(first_input.compressed_account.address.is_some()); // Should be enabled + assert!(first_input.compressed_account.data.is_some()); // Should be enabled + if let Some(address) = &first_input.compressed_account.address { + assert_eq!(address[0], 22); // Our written value + } + if let Some(data) = &first_input.compressed_account.data { + assert_eq!(data.data.len(), 10); // Should have capacity 10 + assert_eq!(data.__meta.discriminator[0], 33); // Our written value + assert_eq!(data.data[0], 99); // Our written value + assert_eq!(data.data_hash[0], 55); // Our written value + } + + // Test accessing second input account (config2: address=false, data=true, capacity=5) + let second_input = &deserialized.input_compressed_accounts_with_merkle_context[1]; + assert_eq!(second_input.compressed_account.__meta.owner[0], 0); // Should be zero (not written) + assert!(second_input.compressed_account.address.is_none()); // Should be disabled + assert!(second_input.compressed_account.data.is_some()); // Should be enabled + if let Some(data) = &second_input.compressed_account.data { + assert_eq!(data.data.len(), 5); // Should have capacity 5 + } + + // Test accessing first output account (config3: address=true, data=false, capacity=0) + let first_output = &deserialized.output_compressed_accounts[0]; + assert_eq!(first_output.compressed_account.__meta.owner[0], 77); // Our written value + assert_eq!( + u64::from(first_output.compressed_account.__meta.lamports), + 2000 + ); // Our written value + assert!(first_output.compressed_account.address.is_some()); // Should be enabled + assert!(first_output.compressed_account.data.is_none()); // Should be disabled + if let Some(address) = &first_output.compressed_account.address { + assert_eq!(address[0], 88); // Our written value + } + + // Test accessing second output account (config4: address=false, data=false, capacity=0) + let second_output = &deserialized.output_compressed_accounts[1]; + assert_eq!(second_output.compressed_account.__meta.owner[0], 0); // Should be zero (not written) + assert!(second_output.compressed_account.address.is_none()); // Should be disabled + assert!(second_output.compressed_account.data.is_none()); // Should be disabled +} + +#[test] +fn readme() { + use borsh::{BorshDeserialize, BorshSerialize}; + use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq, ZeroCopyMut}; + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] + pub struct MyStructOption { + pub a: u8, + pub b: u16, + pub vec: Vec>, + pub c: Option, + } + + #[repr(C)] + #[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, + )] + pub struct MyStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + // Test the new ZeroCopyNew functionality + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] + pub struct TestConfigStruct { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub option: Option, + } + + let my_struct = MyStruct { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + // Use the struct with zero-copy deserialization + let bytes = my_struct.try_to_vec().unwrap(); + // byte_len not available for non-mut derivations + // assert_eq!(bytes.len(), my_struct.byte_len()); + let (zero_copy, _remaining) = MyStruct::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1); + let org_struct: MyStruct = zero_copy.into(); + assert_eq!(org_struct, my_struct); + // { + // let (mut zero_copy_mut, _remaining) = MyStruct::zero_copy_at_mut(&mut bytes).unwrap(); + // zero_copy_mut.a = 42; + // } + // let borsh = MyStruct::try_from_slice(&bytes).unwrap(); + // assert_eq!(borsh.a, 42u8); +} + +#[derive( + ZeroCopy, ZeroCopyMut, BorshDeserialize, BorshSerialize, Debug, PartialEq, Default, Clone, +)] +pub struct InstructionDataInvokeCpi { + pub proof: Option, + pub new_address_params: Vec, + pub input_compressed_accounts_with_merkle_context: + Vec, + pub output_compressed_accounts: Vec, + pub relay_fee: Option, + pub compress_or_decompress_lamports: Option, + pub is_compress: bool, + pub cpi_context: Option, +} + +impl PartialEq> for InstructionDataInvokeCpi { + fn eq(&self, other: &ZInstructionDataInvokeCpi) -> bool { + // Compare proof + match (&self.proof, &other.proof) { + (Some(ref self_proof), Some(ref other_proof)) => { + if self_proof.a != other_proof.a + || self_proof.b != other_proof.b + || self_proof.c != other_proof.c + { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare vectors lengths first + if self.new_address_params.len() != other.new_address_params.len() + || self.input_compressed_accounts_with_merkle_context.len() + != other.input_compressed_accounts_with_merkle_context.len() + || self.output_compressed_accounts.len() != other.output_compressed_accounts.len() + { + return false; + } + + // Compare new_address_params + for (self_param, other_param) in self + .new_address_params + .iter() + .zip(other.new_address_params.iter()) + { + if self_param.seed != other_param.seed + || self_param.address_queue_account_index != other_param.address_queue_account_index + || self_param.address_merkle_tree_account_index + != other_param.address_merkle_tree_account_index + || self_param.address_merkle_tree_root_index + != u16::from(other_param.address_merkle_tree_root_index) + { + return false; + } + } + + // Compare input accounts + for (self_input, other_input) in self + .input_compressed_accounts_with_merkle_context + .iter() + .zip(other.input_compressed_accounts_with_merkle_context.iter()) + { + if self_input != other_input { + return false; + } + } + + // Compare output accounts + for (self_output, other_output) in self + .output_compressed_accounts + .iter() + .zip(other.output_compressed_accounts.iter()) + { + if self_output != other_output { + return false; + } + } + + // Compare relay_fee + match (&self.relay_fee, &other.relay_fee) { + (Some(self_fee), Some(other_fee)) => { + if *self_fee != u64::from(**other_fee) { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare compress_or_decompress_lamports + match ( + &self.compress_or_decompress_lamports, + &other.compress_or_decompress_lamports, + ) { + (Some(self_lamports), Some(other_lamports)) => { + if *self_lamports != u64::from(**other_lamports) { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare is_compress (bool vs u8) + if self.is_compress != (other.is_compress != 0) { + return false; + } + + // Compare cpi_context + match (&self.cpi_context, &other.cpi_context) { + (Some(self_ctx), Some(other_ctx)) => { + if self_ctx.set_context != (other_ctx.set_context != 0) + || self_ctx.first_set_context != (other_ctx.first_set_context != 0) + || self_ctx.cpi_context_account_index != other_ctx.cpi_context_account_index + { + return false; + } + } + (None, None) => {} + _ => return false, + } + + true + } +} + +impl PartialEq for ZInstructionDataInvokeCpi<'_> { + fn eq(&self, other: &InstructionDataInvokeCpi) -> bool { + other.eq(self) + } +} + +impl PartialEq> + for PackedCompressedAccountWithMerkleContext +{ + fn eq(&self, other: &ZPackedCompressedAccountWithMerkleContext) -> bool { + // Compare compressed_account + if self.compressed_account.owner != other.compressed_account.__meta.owner + || self.compressed_account.lamports + != u64::from(other.compressed_account.__meta.lamports) + { + return false; + } + + // Compare optional address + match ( + &self.compressed_account.address, + &other.compressed_account.address, + ) { + (Some(self_addr), Some(other_addr)) => { + if *self_addr != **other_addr { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare optional data + match ( + &self.compressed_account.data, + &other.compressed_account.data, + ) { + (Some(self_data), Some(other_data)) => { + if self_data.discriminator != other_data.__meta.discriminator + || self_data.data_hash != *other_data.data_hash + || self_data.data.len() != other_data.data.len() + { + return false; + } + // Compare data contents + for (self_byte, other_byte) in self_data.data.iter().zip(other_data.data.iter()) { + if *self_byte != *other_byte { + return false; + } + } + } + (None, None) => {} + _ => return false, + } + + // Compare merkle_context + if self.merkle_context.merkle_tree_pubkey_index + != other.merkle_context.__meta.merkle_tree_pubkey_index + || self.merkle_context.nullifier_queue_pubkey_index + != other.merkle_context.__meta.nullifier_queue_pubkey_index + || self.merkle_context.leaf_index != u32::from(other.merkle_context.__meta.leaf_index) + || self.merkle_context.prove_by_index != other.merkle_context.prove_by_index() + { + return false; + } + + // Compare root_index and read_only + if self.root_index != u16::from(*other.root_index) + || self.read_only != (other.read_only != 0) + { + return false; + } + + true + } +} + +impl PartialEq> + for OutputCompressedAccountWithPackedContext +{ + fn eq(&self, other: &ZOutputCompressedAccountWithPackedContext) -> bool { + // Compare compressed_account + if self.compressed_account.owner != other.compressed_account.__meta.owner + || self.compressed_account.lamports + != u64::from(other.compressed_account.__meta.lamports) + { + return false; + } + + // Compare optional address + match ( + &self.compressed_account.address, + &other.compressed_account.address, + ) { + (Some(self_addr), Some(other_addr)) => { + if *self_addr != **other_addr { + return false; + } + } + (None, None) => {} + _ => return false, + } + + // Compare optional data + match ( + &self.compressed_account.data, + &other.compressed_account.data, + ) { + (Some(self_data), Some(other_data)) => { + if self_data.discriminator != other_data.__meta.discriminator + || self_data.data_hash != *other_data.data_hash + || self_data.data.len() != other_data.data.len() + { + return false; + } + // Compare data contents + for (self_byte, other_byte) in self_data.data.iter().zip(other_data.data.iter()) { + if *self_byte != *other_byte { + return false; + } + } + } + (None, None) => {} + _ => return false, + } + + // Compare merkle_tree_index + if self.merkle_tree_index != other.merkle_tree_index { + return false; + } + + true + } +} diff --git a/program-libs/zero-copy-derive/tests/random.rs b/program-libs/zero-copy-derive/tests/random.rs new file mode 100644 index 0000000000..993adef704 --- /dev/null +++ b/program-libs/zero-copy-derive/tests/random.rs @@ -0,0 +1,651 @@ +#![cfg(feature = "mut")] +use std::assert_eq; + +use borsh::BorshDeserialize; +use light_zero_copy::{borsh::Deserialize, init_mut::ZeroCopyNew}; +use rand::{ + rngs::{StdRng, ThreadRng}, + Rng, +}; + +mod instruction_data; +use instruction_data::{ + CompressedAccount, + CompressedAccountConfig, + CompressedAccountData, + CompressedAccountDataConfig, + CompressedCpiContext, + CompressedCpiContextConfig, + CompressedProof, + CompressedProofConfig, + InstructionDataInvoke, + // Config types (generated by ZeroCopyNew derive) + InstructionDataInvokeConfig, + InstructionDataInvokeCpi, + InstructionDataInvokeCpiConfig, + NewAddressParamsPacked, + NewAddressParamsPackedConfig, + OutputCompressedAccountWithPackedContext, + OutputCompressedAccountWithPackedContextConfig, + PackedCompressedAccountWithMerkleContext, + PackedCompressedAccountWithMerkleContextConfig, + PackedMerkleContext, + PackedMerkleContextConfig, + Pubkey, + // Zero-copy mutable types + ZInstructionDataInvokeCpiMut, + ZInstructionDataInvokeMut, +}; + +// Function to populate mutable zero-copy structure with data from InstructionDataInvokeCpi +fn populate_invoke_cpi_zero_copy( + src: &InstructionDataInvokeCpi, + dst: &mut ZInstructionDataInvokeCpiMut, +) { + *dst.is_compress = if src.is_compress { 1 } else { 0 }; + + // Copy proof if present + if let (Some(src_proof), Some(dst_proof)) = (&src.proof, &mut dst.proof) { + dst_proof.a.copy_from_slice(&src_proof.a); + dst_proof.b.copy_from_slice(&src_proof.b); + dst_proof.c.copy_from_slice(&src_proof.c); + } + + // Copy new_address_params + for (src_param, dst_param) in src + .new_address_params + .iter() + .zip(dst.new_address_params.iter_mut()) + { + dst_param.seed.copy_from_slice(&src_param.seed); + dst_param.address_queue_account_index = src_param.address_queue_account_index; + dst_param.address_merkle_tree_account_index = src_param.address_merkle_tree_account_index; + dst_param.address_merkle_tree_root_index = src_param.address_merkle_tree_root_index.into(); + } + + // Copy input_compressed_accounts_with_merkle_context + for (src_input, dst_input) in src + .input_compressed_accounts_with_merkle_context + .iter() + .zip(dst.input_compressed_accounts_with_merkle_context.iter_mut()) + { + // Copy compressed account + dst_input + .compressed_account + .owner + .copy_from_slice(&src_input.compressed_account.owner); + dst_input.compressed_account.lamports = src_input.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_input.compressed_account.address, + &mut dst_input.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_input.compressed_account.data, + &mut dst_input.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + // Copy merkle context + dst_input.merkle_context.merkle_tree_pubkey_index = + src_input.merkle_context.merkle_tree_pubkey_index; + dst_input.merkle_context.nullifier_queue_pubkey_index = + src_input.merkle_context.nullifier_queue_pubkey_index; + dst_input.merkle_context.leaf_index = src_input.merkle_context.leaf_index.into(); + dst_input.merkle_context.prove_by_index = if src_input.merkle_context.prove_by_index { + 1 + } else { + 0 + }; + + *dst_input.root_index = src_input.root_index.into(); + *dst_input.read_only = if src_input.read_only { 1 } else { 0 }; + } + + // Copy output_compressed_accounts + for (src_output, dst_output) in src + .output_compressed_accounts + .iter() + .zip(dst.output_compressed_accounts.iter_mut()) + { + // Copy compressed account + dst_output + .compressed_account + .owner + .copy_from_slice(&src_output.compressed_account.owner); + dst_output.compressed_account.lamports = src_output.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_output.compressed_account.address, + &mut dst_output.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_output.compressed_account.data, + &mut dst_output.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + *dst_output.merkle_tree_index = src_output.merkle_tree_index; + } + + // Copy relay_fee if present + if let (Some(src_fee), Some(dst_fee)) = (&src.relay_fee, &mut dst.relay_fee) { + **dst_fee = (*src_fee).into(); + } + + // Copy compress_or_decompress_lamports if present + if let (Some(src_lamports), Some(dst_lamports)) = ( + &src.compress_or_decompress_lamports, + &mut dst.compress_or_decompress_lamports, + ) { + **dst_lamports = (*src_lamports).into(); + } + + // Copy cpi_context if present + if let (Some(src_ctx), Some(dst_ctx)) = (&src.cpi_context, &mut dst.cpi_context) { + dst_ctx.set_context = if src_ctx.set_context { 1 } else { 0 }; + dst_ctx.first_set_context = if src_ctx.first_set_context { 1 } else { 0 }; + dst_ctx.cpi_context_account_index = src_ctx.cpi_context_account_index; + } +} + +// Function to populate mutable zero-copy structure with data from InstructionDataInvoke +fn populate_invoke_zero_copy(src: &InstructionDataInvoke, dst: &mut ZInstructionDataInvokeMut) { + *dst.is_compress = if src.is_compress { 1 } else { 0 }; + + // Copy proof if present + if let (Some(src_proof), Some(dst_proof)) = (&src.proof, &mut dst.proof) { + dst_proof.a.copy_from_slice(&src_proof.a); + dst_proof.b.copy_from_slice(&src_proof.b); + dst_proof.c.copy_from_slice(&src_proof.c); + } + + // Copy new_address_params + for (src_param, dst_param) in src + .new_address_params + .iter() + .zip(dst.new_address_params.iter_mut()) + { + dst_param.seed.copy_from_slice(&src_param.seed); + dst_param.address_queue_account_index = src_param.address_queue_account_index; + dst_param.address_merkle_tree_account_index = src_param.address_merkle_tree_account_index; + dst_param.address_merkle_tree_root_index = src_param.address_merkle_tree_root_index.into(); + } + + // Copy input_compressed_accounts_with_merkle_context + for (src_input, dst_input) in src + .input_compressed_accounts_with_merkle_context + .iter() + .zip(dst.input_compressed_accounts_with_merkle_context.iter_mut()) + { + // Copy compressed account + dst_input + .compressed_account + .owner + .copy_from_slice(&src_input.compressed_account.owner); + dst_input.compressed_account.lamports = src_input.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_input.compressed_account.address, + &mut dst_input.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_input.compressed_account.data, + &mut dst_input.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + // Copy merkle context + dst_input.merkle_context.merkle_tree_pubkey_index = + src_input.merkle_context.merkle_tree_pubkey_index; + dst_input.merkle_context.nullifier_queue_pubkey_index = + src_input.merkle_context.nullifier_queue_pubkey_index; + dst_input.merkle_context.leaf_index = src_input.merkle_context.leaf_index.into(); + dst_input.merkle_context.prove_by_index = if src_input.merkle_context.prove_by_index { + 1 + } else { + 0 + }; + + *dst_input.root_index = src_input.root_index.into(); + *dst_input.read_only = if src_input.read_only { 1 } else { 0 }; + } + + // Copy output_compressed_accounts + for (src_output, dst_output) in src + .output_compressed_accounts + .iter() + .zip(dst.output_compressed_accounts.iter_mut()) + { + // Copy compressed account + dst_output + .compressed_account + .owner + .copy_from_slice(&src_output.compressed_account.owner); + dst_output.compressed_account.lamports = src_output.compressed_account.lamports.into(); + + // Copy address if present + if let (Some(src_addr), Some(dst_addr)) = ( + &src_output.compressed_account.address, + &mut dst_output.compressed_account.address, + ) { + dst_addr.copy_from_slice(src_addr); + } + + // Copy data if present + if let (Some(src_data), Some(dst_data)) = ( + &src_output.compressed_account.data, + &mut dst_output.compressed_account.data, + ) { + dst_data + .discriminator + .copy_from_slice(&src_data.discriminator); + dst_data.data_hash.copy_from_slice(&src_data.data_hash); + for (src_byte, dst_byte) in src_data.data.iter().zip(dst_data.data.iter_mut()) { + *dst_byte = *src_byte; + } + } + + *dst_output.merkle_tree_index = src_output.merkle_tree_index; + } + + // Copy relay_fee if present + if let (Some(src_fee), Some(dst_fee)) = (&src.relay_fee, &mut dst.relay_fee) { + **dst_fee = (*src_fee).into(); + } + + // Copy compress_or_decompress_lamports if present + if let (Some(src_lamports), Some(dst_lamports)) = ( + &src.compress_or_decompress_lamports, + &mut dst.compress_or_decompress_lamports, + ) { + **dst_lamports = (*src_lamports).into(); + } +} + +fn get_rnd_instruction_data_invoke_cpi(rng: &mut StdRng) -> InstructionDataInvokeCpi { + InstructionDataInvokeCpi { + proof: Some(CompressedProof { + a: rng.gen(), + b: (0..64) + .map(|_| rng.gen()) + .collect::>() + .try_into() + .unwrap(), + c: rng.gen(), + }), + new_address_params: vec![get_rnd_new_address_params(rng); rng.gen_range(0..10)], + input_compressed_accounts_with_merkle_context: vec![ + get_rnd_test_input_account(rng); + rng.gen_range(0..10) + ], + output_compressed_accounts: vec![get_rnd_test_output_account(rng); rng.gen_range(0..10)], + relay_fee: None, + compress_or_decompress_lamports: rng.gen(), + is_compress: rng.gen(), + cpi_context: Some(get_rnd_cpi_context(rng)), + } +} + +fn get_rnd_cpi_context(rng: &mut StdRng) -> CompressedCpiContext { + CompressedCpiContext { + first_set_context: rng.gen(), + set_context: rng.gen(), + cpi_context_account_index: rng.gen(), + } +} + +fn get_rnd_test_account_data(rng: &mut StdRng) -> CompressedAccountData { + CompressedAccountData { + discriminator: rng.gen(), + data: (0..100).map(|_| rng.gen()).collect::>(), + data_hash: rng.gen(), + } +} + +fn get_rnd_test_account(rng: &mut StdRng) -> CompressedAccount { + CompressedAccount { + owner: Pubkey::new_unique().to_bytes(), + lamports: rng.gen(), + address: Some(Pubkey::new_unique().to_bytes()), + data: Some(get_rnd_test_account_data(rng)), + } +} + +fn get_rnd_test_output_account(rng: &mut StdRng) -> OutputCompressedAccountWithPackedContext { + OutputCompressedAccountWithPackedContext { + compressed_account: get_rnd_test_account(rng), + merkle_tree_index: rng.gen(), + } +} + +fn get_rnd_test_input_account(rng: &mut StdRng) -> PackedCompressedAccountWithMerkleContext { + PackedCompressedAccountWithMerkleContext { + compressed_account: CompressedAccount { + owner: Pubkey::new_unique().to_bytes(), + lamports: 100, + address: Some(Pubkey::new_unique().to_bytes()), + data: Some(get_rnd_test_account_data(rng)), + }, + merkle_context: PackedMerkleContext { + merkle_tree_pubkey_index: rng.gen(), + nullifier_queue_pubkey_index: rng.gen(), + leaf_index: rng.gen(), + prove_by_index: rng.gen(), + }, + root_index: rng.gen(), + read_only: false, + } +} + +fn get_rnd_new_address_params(rng: &mut StdRng) -> NewAddressParamsPacked { + NewAddressParamsPacked { + seed: rng.gen(), + address_queue_account_index: rng.gen(), + address_merkle_tree_account_index: rng.gen(), + address_merkle_tree_root_index: rng.gen(), + } +} + +// Generate config for InstructionDataInvoke based on the actual data +fn generate_random_invoke_config( + invoke_ref: &InstructionDataInvoke, +) -> InstructionDataInvokeConfig { + InstructionDataInvokeConfig { + proof: (invoke_ref.proof.is_some(), CompressedProofConfig {}), + input_compressed_accounts_with_merkle_context: invoke_ref + .input_compressed_accounts_with_merkle_context + .iter() + .map(|account| PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }) + .collect(), + output_compressed_accounts: invoke_ref + .output_compressed_accounts + .iter() + .map(|account| OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + }) + .collect(), + relay_fee: invoke_ref.relay_fee.is_some(), + new_address_params: invoke_ref + .new_address_params + .iter() + .map(|_| NewAddressParamsPackedConfig {}) + .collect(), + compress_or_decompress_lamports: invoke_ref.compress_or_decompress_lamports.is_some(), + } +} + +// Generate config for InstructionDataInvokeCpi based on the actual data +fn generate_random_invoke_cpi_config( + invoke_cpi_ref: &InstructionDataInvokeCpi, +) -> InstructionDataInvokeCpiConfig { + InstructionDataInvokeCpiConfig { + proof: (invoke_cpi_ref.proof.is_some(), CompressedProofConfig {}), + new_address_params: invoke_cpi_ref + .new_address_params + .iter() + .map(|_| NewAddressParamsPackedConfig {}) + .collect(), + input_compressed_accounts_with_merkle_context: invoke_cpi_ref + .input_compressed_accounts_with_merkle_context + .iter() + .map(|account| PackedCompressedAccountWithMerkleContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + merkle_context: PackedMerkleContextConfig {}, + }) + .collect(), + output_compressed_accounts: invoke_cpi_ref + .output_compressed_accounts + .iter() + .map(|account| OutputCompressedAccountWithPackedContextConfig { + compressed_account: CompressedAccountConfig { + address: (account.compressed_account.address.is_some(), ()), + data: ( + account.compressed_account.data.is_some(), + CompressedAccountDataConfig { + data: account + .compressed_account + .data + .as_ref() + .map(|d| d.data.len() as u32) + .unwrap_or(0), + }, + ), + }, + }) + .collect(), + relay_fee: invoke_cpi_ref.relay_fee.is_some(), + compress_or_decompress_lamports: invoke_cpi_ref.compress_or_decompress_lamports.is_some(), + cpi_context: ( + invoke_cpi_ref.cpi_context.is_some(), + CompressedCpiContextConfig {}, + ), + } +} + +#[test] +fn test_invoke_ix_data_deserialize_rnd() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + let mut thread_rng = ThreadRng::default(); + let seed = thread_rng.gen(); + // Keep this print so that in case the test fails + // we can use the seed to reproduce the error. + println!("\n\ne2e test seed for invoke_ix_data {}\n\n", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let num_iters = 1000; + for i in 0..num_iters { + // Create randomized instruction data + let invoke_ref = InstructionDataInvoke { + proof: if rng.gen() { + Some(CompressedProof { + a: rng.gen(), + b: (0..64) + .map(|_| rng.gen()) + .collect::>() + .try_into() + .unwrap(), + c: rng.gen(), + }) + } else { + None + }, + input_compressed_accounts_with_merkle_context: if i % 5 == 0 { + // Only add inputs occasionally to keep test manageable + vec![get_rnd_test_input_account(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + output_compressed_accounts: if i % 4 == 0 { + vec![get_rnd_test_output_account(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + relay_fee: None, // Relay fee is currently not supported + new_address_params: if i % 3 == 0 { + vec![get_rnd_new_address_params(&mut rng); rng.gen_range(1..3)] + } else { + vec![] + }, + compress_or_decompress_lamports: if rng.gen() { Some(rng.gen()) } else { None }, + is_compress: rng.gen(), + }; + + // 1. Generate config based on the random data + let config = generate_random_invoke_config(&invoke_ref); + + // 2. Calculate exact buffer size and allocate + let buffer_size = InstructionDataInvoke::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // 3. Create mutable zero-copy structure and verify exact allocation + { + let result = InstructionDataInvoke::new_zero_copy(&mut bytes, config); + assert!(result.is_ok(), "Failed to create zero-copy structure"); + let (mut zero_copy_mut, remaining) = result.unwrap(); + + // 4. Verify exact buffer allocation + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // 5. Populate the mutable zero-copy structure with random data + populate_invoke_zero_copy(&invoke_ref, &mut zero_copy_mut); + }; // Mutable borrow ends here + + let borsh_ref = InstructionDataInvoke::deserialize(&mut bytes.as_slice()).unwrap(); + // 6. Test immutable deserialization to verify round-trip functionality + let result_immut = InstructionDataInvoke::zero_copy_at(&bytes); + assert!( + result_immut.is_ok(), + "Immutable deserialization should succeed" + ); + assert_eq!(invoke_ref, borsh_ref); + + // 7. Test that basic zero-copy deserialization works without crashing + // The main goal is to verify the zero-copy derive macro functionality + println!("✓ Successfully tested InstructionDataInvoke with {} inputs, {} outputs, {} new_addresses", + invoke_ref.input_compressed_accounts_with_merkle_context.len(), + invoke_ref.output_compressed_accounts.len(), + invoke_ref.new_address_params.len()); + } +} + +#[test] +fn test_instruction_data_invoke_cpi_rnd() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + let mut thread_rng = ThreadRng::default(); + let seed = thread_rng.gen(); + // Keep this print so that in case the test fails + // we can use the seed to reproduce the error. + println!("\n\ne2e test seed {}\n\n", seed); + let mut rng = StdRng::seed_from_u64(seed); + + let num_iters = 10_000; + for _ in 0..num_iters { + // 1. Generate random CPI instruction data + let invoke_cpi_ref = get_rnd_instruction_data_invoke_cpi(&mut rng); + + // 2. Generate config based on the random data + let config = generate_random_invoke_cpi_config(&invoke_cpi_ref); + + // 3. Calculate exact buffer size and allocate + let buffer_size = InstructionDataInvokeCpi::byte_len(&config); + let mut bytes = vec![0u8; buffer_size]; + + // 4. Create mutable zero-copy structure and verify exact allocation + { + let result = InstructionDataInvokeCpi::new_zero_copy(&mut bytes, config); + assert!(result.is_ok(), "Failed to create CPI zero-copy structure"); + let (mut zero_copy_mut, remaining) = result.unwrap(); + + // 5. Verify exact buffer allocation + assert_eq!( + remaining.len(), + 0, + "Should have used exactly {} bytes", + buffer_size + ); + + // 6. Populate the mutable zero-copy structure with random data + populate_invoke_cpi_zero_copy(&invoke_cpi_ref, &mut zero_copy_mut); + }; // Mutable borrow ends here + + let borsh_ref = InstructionDataInvokeCpi::deserialize(&mut bytes.as_slice()).unwrap(); + // 7. Test immutable deserialization to verify round-trip functionality + let result_immut = InstructionDataInvokeCpi::zero_copy_at(&bytes); + assert!( + result_immut.is_ok(), + "Immutable deserialization should succeed" + ); + assert_eq!(invoke_cpi_ref, borsh_ref); + + // 8. Test that basic zero-copy deserialization works without crashing + // The main goal is to verify the zero-copy derive macro functionality + println!("✓ Successfully tested InstructionDataInvokeCpi with {} inputs, {} outputs, {} new_addresses", + invoke_cpi_ref.input_compressed_accounts_with_merkle_context.len(), + invoke_cpi_ref.output_compressed_accounts.len(), + invoke_cpi_ref.new_address_params.len()); + } +} diff --git a/program-libs/zero-copy/Cargo.toml b/program-libs/zero-copy/Cargo.toml index ea683e5e48..7b67262bb2 100644 --- a/program-libs/zero-copy/Cargo.toml +++ b/program-libs/zero-copy/Cargo.toml @@ -11,13 +11,17 @@ default = [] solana = ["solana-program-error"] pinocchio = ["dep:pinocchio"] std = [] +derive = ["light-zero-copy-derive"] +mut = ["light-zero-copy-derive/mut"] [dependencies] solana-program-error = { workspace = true, optional = true } pinocchio = { workspace = true, optional = true } thiserror = { workspace = true } zerocopy = { workspace = true } +light-zero-copy-derive = { workspace = true, optional = true } [dev-dependencies] rand = { workspace = true } zerocopy = { workspace = true, features = ["derive"] } +borsh = { workspace = true } diff --git a/program-libs/zero-copy/README.md b/program-libs/zero-copy/README.md index d82ee39232..e28f535469 100644 --- a/program-libs/zero-copy/README.md +++ b/program-libs/zero-copy/README.md @@ -37,6 +37,3 @@ light-zero-copy = { version = "0.1.0", features = ["anchor"] } ### Security Considerations - do not use on a 32 bit target with length greater than u32 - only length until u64 is supported - -### Tests -- `cargo test --features std` diff --git a/program-libs/zero-copy/src/borsh.rs b/program-libs/zero-copy/src/borsh.rs index c7e4fbe4db..a60d73f2ba 100644 --- a/program-libs/zero-copy/src/borsh.rs +++ b/program-libs/zero-copy/src/borsh.rs @@ -5,7 +5,7 @@ use core::{ use std::vec::Vec; use zerocopy::{ - little_endian::{U16, U32, U64}, + little_endian::{I16, I32, I64, U16, U32, U64}, FromBytes, Immutable, KnownLayout, Ref, }; @@ -52,8 +52,6 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for Option { impl Deserialize<'_> for u8 { type Output = Self; - /// Not a zero copy but cheaper. - /// A u8 should not be deserialized on it's own but as part of a struct. #[inline] fn zero_copy_at(bytes: &[u8]) -> Result<(u8, &[u8]), ZeroCopyError> { if bytes.len() < size_of::() { @@ -64,23 +62,59 @@ impl Deserialize<'_> for u8 { } } +impl<'a> Deserialize<'a> for bool { + type Output = u8; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + if bytes.len() < size_of::() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + let (bytes, remaining_bytes) = bytes.split_at(size_of::()); + Ok((bytes[0], remaining_bytes)) + } +} + macro_rules! impl_deserialize_for_primitive { - ($($t:ty),*) => { + ($(($native:ty, $zerocopy:ty)),*) => { $( - impl<'a> Deserialize<'a> for $t { - type Output = Ref<&'a [u8], $t>; + impl<'a> Deserialize<'a> for $native { + type Output = Ref<&'a [u8], $zerocopy>; #[inline] fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { - Self::Output::zero_copy_at(bytes) + Ref::<&'a [u8], $zerocopy>::from_prefix(bytes).map_err(ZeroCopyError::from) } } )* }; } -impl_deserialize_for_primitive!(u16, i16, u32, i32, u64, i64); -impl_deserialize_for_primitive!(U16, U32, U64); +impl_deserialize_for_primitive!( + (u16, U16), + (u32, U32), + (u64, U64), + (i16, I16), + (i32, I32), + (i64, I64), + (U16, U16), + (U32, U32), + (U64, U64), + (I16, I16), + (I32, I32), + (I64, I64) +); + +// Implement Deserialize for fixed-size array types +impl<'a, T: KnownLayout + Immutable + FromBytes, const N: usize> Deserialize<'a> for [T; N] { + type Output = Ref<&'a [u8], [T; N]>; + + #[inline] + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&'a [u8], [T; N]>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} impl<'a, T: Deserialize<'a>> Deserialize<'a> for Vec { type Output = Vec; @@ -88,8 +122,14 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for Vec { fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { let (num_slices, mut bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; let num_slices = u32::from(*num_slices) as usize; - // TODO: add check that remaining data is enough to read num_slices - // This prevents agains invalid data allocating a lot of heap memory + // Prevent heap exhaustion attacks by checking if num_slices is reasonable + // Each element needs at least 1 byte when serialized + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } let mut slices = Vec::with_capacity(num_slices); for _ in 0..num_slices { let (slice, _bytes) = T::zero_copy_at(bytes)?; @@ -138,6 +178,55 @@ impl<'a, T: Deserialize<'a>> Deserialize<'a> for VecU8 { } } +pub trait ZeroCopyStructInner { + type ZeroCopyInner; +} + +impl ZeroCopyStructInner for u64 { + type ZeroCopyInner = U64; +} +impl ZeroCopyStructInner for u32 { + type ZeroCopyInner = U32; +} +impl ZeroCopyStructInner for u16 { + type ZeroCopyInner = U16; +} +impl ZeroCopyStructInner for u8 { + type ZeroCopyInner = u8; +} + +impl ZeroCopyStructInner for U16 { + type ZeroCopyInner = U16; +} +impl ZeroCopyStructInner for U32 { + type ZeroCopyInner = U32; +} +impl ZeroCopyStructInner for U64 { + type ZeroCopyInner = U64; +} + +impl ZeroCopyStructInner for Vec { + type ZeroCopyInner = Vec; +} + +impl ZeroCopyStructInner for Option { + type ZeroCopyInner = Option; +} + +// Add ZeroCopyStructInner for array types +impl ZeroCopyStructInner for [u8; N] { + type ZeroCopyInner = Ref<&'static [u8], [u8; N]>; +} + +pub fn borsh_vec_u8_as_slice(bytes: &[u8]) -> Result<(&[u8], &[u8]), ZeroCopyError> { + let (num_slices, bytes) = Ref::<&[u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + if num_slices > bytes.len() { + return Err(ZeroCopyError::ArraySize(num_slices, bytes.len())); + } + Ok(bytes.split_at(num_slices)) +} + #[test] fn test_vecu8() { use std::vec; @@ -224,3 +313,561 @@ fn test_deserialize_vecu8() { assert_eq!(vec, std::vec![4, 5, 6]); assert_eq!(remaining, &[]); } + +#[cfg(test)] +pub mod test { + use std::vec; + + use borsh::{BorshDeserialize, BorshSerialize}; + use zerocopy::{ + little_endian::{U16, U64}, + IntoBytes, Ref, Unaligned, + }; + + use super::{ZeroCopyStructInner, *}; + use crate::slice::ZeroCopySliceBorsh; + + // Rules: + // 1. create ZStruct for the struct + // 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct + // 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct + // 1.3. replace u16 with U16, u32 with U32, etc + // 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 + // 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize + // 1.6. Elements in an Option must implement Deserialize + // 1.7. a type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 + + // Derive Macro needs to derive: + // 1. ZeroCopyStructInner + // 2. Deserialize + // 3. PartialEq for ZStruct<'_> + // + // For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + + // Tests for manually implemented structures (without derive macro) + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct1 { + pub a: u8, + pub b: u16, + } + + // pub fn data_hash_struct_1(a: u8, b: u16) -> [u8; 32] { + + // } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct1<'a> { + meta: Ref<&'a [u8], ZStruct1Meta>, + } + impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } + } + + #[test] + fn test_struct_1() { + let bytes = Struct1 { a: 1, b: 2 }.try_to_vec().unwrap(); + let (struct1, remaining) = Struct1::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct2<'a> { + meta: Ref<&'a [u8], ZStruct2Meta>, + pub vec: as ZeroCopyStructInner>::ZeroCopyInner, + } + + impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec.as_slice() == other.vec.as_slice() + } + } + + impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct2Meta>::from_prefix(bytes)?; + let (vec, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct2 { meta, vec }, bytes)) + } + } + + #[test] + fn test_struct_2() { + let bytes = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + } + .try_to_vec() + .unwrap(); + let (struct2, remaining) = Struct2::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct3<'a> { + meta: Ref<&'a [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceBorsh<'a, u8>, + pub c: Ref<&'a [u8], U64>, + } + + impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::zero_copy_at(bytes)?; + let (c, bytes) = Ref::<&[u8], U64>::from_prefix(bytes)?; + Ok((Self::Output { meta, vec, c }, bytes)) + } + } + + #[test] + fn test_struct_3() { + let bytes = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct3::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] + pub struct Struct4Nested { + a: u8, + b: u16, + } + + #[repr(C)] + #[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, + )] + pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, + } + + impl ZeroCopyStructInner for Struct4Nested { + type ZeroCopyInner = ZStruct4Nested; + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] + pub struct ZStruct4Meta { + pub a: ::ZeroCopyInner, + pub b: ::ZeroCopyInner, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct4<'a> { + meta: Ref<&'a [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + pub c: Ref<&'a [u8], ::ZeroCopyInner>, + pub vec_2: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + } + + impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&[u8], ::ZeroCopyInner>::from_prefix(bytes)?; + let (vec_2, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + Ok(( + Self::Output { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } + } + + #[test] + fn test_struct_4() { + let bytes = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct4::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct5 { + pub a: Vec>, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInner>>, + } + + impl<'a> Deserialize<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::ZeroCopyInner>>::zero_copy_at(bytes)?; + Ok((ZStruct5 { a }, bytes)) + } + } + + #[test] + fn test_struct_5() { + let bytes = Struct5 { + a: vec![vec![1u8; 32]; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct5::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &[]); + } + + // If a struct inside a vector contains a vector it must implement Deserialize. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct6 { + pub a: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> Deserialize<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } + } + + #[test] + fn test_struct_6() { + let bytes = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct6::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &[]); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct7<'a> { + meta: Ref<&'a [u8], ZStruct7Meta>, + pub option: as ZeroCopyStructInner>::ZeroCopyInner, + } + + impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option == other.option + } + } + + impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> Deserialize<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } + } + + #[test] + fn test_struct_7() { + let bytes = Struct7 { + a: 1, + b: 2, + option: Some(3), + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7 { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); + } + + // If a struct inside a vector contains a vector it must implement Deserialize. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct8 { + pub a: Vec, + } + + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct NestedStruct { + pub a: u8, + pub b: Struct2, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZNestedStruct<'a> { + pub a: ::ZeroCopyInner, + pub b: >::Output, + } + + impl<'a> Deserialize<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = ::ZeroCopyInner::zero_copy_at(bytes)?; + let (b, bytes) = ::zero_copy_at(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } + } + + impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + self.a == other.a && self.b == other.b + } + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> Deserialize<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } + } + + #[test] + fn test_struct_8() { + let bytes = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &[]); + } +} diff --git a/program-libs/zero-copy/src/borsh_mut.rs b/program-libs/zero-copy/src/borsh_mut.rs new file mode 100644 index 0000000000..38d5df2b65 --- /dev/null +++ b/program-libs/zero-copy/src/borsh_mut.rs @@ -0,0 +1,965 @@ +use core::{ + mem::size_of, + ops::{Deref, DerefMut}, +}; +use std::vec::Vec; + +use zerocopy::{ + little_endian::{I16, I32, I64, U16, U32, U64}, + FromBytes, Immutable, KnownLayout, Ref, +}; + +use crate::errors::ZeroCopyError; + +pub trait DeserializeMut<'a> +where + Self: Sized, +{ + // TODO: rename to ZeroCopy, can be used as ::ZeroCopy + type Output; + fn zero_copy_at_mut(bytes: &'a mut [u8]) + -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError>; +} + +// Implement DeserializeMut for fixed-size array types +impl<'a, T: KnownLayout + Immutable + FromBytes, const N: usize> DeserializeMut<'a> for [T; N] { + type Output = Ref<&'a mut [u8], [T; N]>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&'a mut [u8], [T; N]>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for Option { + type Output = Option; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + if bytes.len() < size_of::() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + let (option_byte, bytes) = bytes.split_at_mut(1); + Ok(match option_byte[0] { + 0u8 => (None, bytes), + 1u8 => { + let (value, bytes) = T::zero_copy_at_mut(bytes)?; + (Some(value), bytes) + } + _ => return Err(ZeroCopyError::InvalidOptionByte(option_byte[0])), + }) + } +} + +impl<'a> DeserializeMut<'a> for u8 { + type Output = Ref<&'a mut [u8], u8>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ref::<&'a mut [u8], u8>::from_prefix(bytes).map_err(ZeroCopyError::from) + } +} + +impl<'a> DeserializeMut<'a> for bool { + type Output = Ref<&'a mut [u8], u8>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ref::<&'a mut [u8], u8>::from_prefix(bytes).map_err(ZeroCopyError::from) + } +} + +// Implementation for specific zerocopy little-endian types +impl<'a, T: KnownLayout + Immutable + FromBytes> DeserializeMut<'a> for Ref<&'a mut [u8], T> { + type Output = Ref<&'a mut [u8], T>; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&mut [u8], T>::from_prefix(bytes)?; + Ok((bytes, remaining_bytes)) + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for Vec { + type Output = Vec; + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (num_slices, mut bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + // Prevent heap exhaustion attacks by checking if num_slices is reasonable + // Each element needs at least 1 byte when serialized + if bytes.len() < num_slices { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + num_slices, + )); + } + let mut slices = Vec::with_capacity(num_slices); + for _ in 0..num_slices { + let (slice, _bytes) = T::zero_copy_at_mut(bytes)?; + bytes = _bytes; + slices.push(slice); + } + Ok((slices, bytes)) + } +} + +macro_rules! impl_deserialize_for_primitive { + ($(($native:ty, $zerocopy:ty)),*) => { + $( + impl<'a> DeserializeMut<'a> for $native { + type Output = Ref<&'a mut [u8], $zerocopy>; + + #[inline] + fn zero_copy_at_mut(bytes: &'a mut [u8]) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ref::<&'a mut [u8], $zerocopy>::from_prefix(bytes).map_err(ZeroCopyError::from) + } + } + )* + }; +} + +impl_deserialize_for_primitive!( + (u16, U16), + (u32, U32), + (u64, U64), + (i16, I16), + (i32, I32), + (i64, I64), + (U16, U16), + (U32, U32), + (U64, U64), + (I16, I16), + (I32, I32), + (I64, I64) +); + +pub fn borsh_vec_u8_as_slice_mut( + bytes: &mut [u8], +) -> Result<(&mut [u8], &mut [u8]), ZeroCopyError> { + let (num_slices, bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + let num_slices = u32::from(*num_slices) as usize; + if num_slices > bytes.len() { + return Err(ZeroCopyError::ArraySize(num_slices, bytes.len())); + } + Ok(bytes.split_at_mut(num_slices)) +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct VecU8(Vec); +impl VecU8 { + pub fn new() -> Self { + Self(Vec::new()) + } +} + +impl Deref for VecU8 { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for VecU8 { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl<'a, T: DeserializeMut<'a>> DeserializeMut<'a> for VecU8 { + type Output = Vec; + + #[inline] + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (num_slices, mut bytes) = Ref::<&mut [u8], u8>::from_prefix(bytes)?; + let mut slices = Vec::with_capacity(*num_slices as usize); + for _ in 0..(*num_slices as usize) { + let (slice, _bytes) = T::zero_copy_at_mut(bytes)?; + bytes = _bytes; + slices.push(slice); + } + Ok((slices, bytes)) + } +} + +pub trait ZeroCopyStructInnerMut { + type ZeroCopyInnerMut; +} + +impl ZeroCopyStructInnerMut for u64 { + type ZeroCopyInnerMut = U64; +} +impl ZeroCopyStructInnerMut for u32 { + type ZeroCopyInnerMut = U32; +} +impl ZeroCopyStructInnerMut for u16 { + type ZeroCopyInnerMut = U16; +} +impl ZeroCopyStructInnerMut for u8 { + type ZeroCopyInnerMut = u8; +} +impl ZeroCopyStructInnerMut for U64 { + type ZeroCopyInnerMut = U64; +} +impl ZeroCopyStructInnerMut for U32 { + type ZeroCopyInnerMut = U32; +} +impl ZeroCopyStructInnerMut for U16 { + type ZeroCopyInnerMut = U16; +} + +impl ZeroCopyStructInnerMut for Vec { + type ZeroCopyInnerMut = Vec; +} + +impl ZeroCopyStructInnerMut for Option { + type ZeroCopyInnerMut = Option; +} + +impl ZeroCopyStructInnerMut for [u8; N] { + type ZeroCopyInnerMut = Ref<&'static mut [u8], [u8; N]>; +} + +#[test] +fn test_vecu8() { + use std::vec; + let mut bytes = vec![8, 1u8, 2, 3, 4, 5, 6, 7, 8]; + let (vec, remaining_bytes) = VecU8::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + vec.iter().map(|x| **x).collect::>(), + vec![1u8, 2, 3, 4, 5, 6, 7, 8] + ); + assert_eq!(remaining_bytes, &mut []); +} + +#[test] +fn test_deserialize_mut_ref() { + let mut bytes = [1, 0, 0, 0]; // Little-endian representation of 1 + let (ref_data, remaining) = Ref::<&mut [u8], U32>::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(u32::from(*ref_data), 1); + assert_eq!(remaining, &mut []); + let res = Ref::<&mut [u8], U32>::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_option_some() { + let mut bytes = [1, 2]; // 1 indicates Some, followed by the value 2 + let (option_value, remaining) = Option::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(option_value.map(|x| *x), Some(2)); + assert_eq!(remaining, &mut []); + let res = Option::::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::ArraySize(1, 0))); + let mut bytes = [2, 0]; // 2 indicates invalid option byte + let res = Option::::zero_copy_at_mut(&mut bytes); + assert_eq!(res, Err(ZeroCopyError::InvalidOptionByte(2))); +} + +#[test] +fn test_deserialize_mut_option_none() { + let mut bytes = [0]; // 0 indicates None + let (option_value, remaining) = Option::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(option_value, None); + assert_eq!(remaining, &mut []); +} + +#[test] +fn test_deserialize_mut_u8() { + let mut bytes = [0xFF]; // Value 255 + let (value, remaining) = u8::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(*value, 255); + assert_eq!(remaining, &mut []); + let res = u8::zero_copy_at_mut(&mut []); + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_u16() { + let mut bytes = 2323u16.to_le_bytes(); + let (value, remaining) = u16::zero_copy_at_mut(bytes.as_mut_slice()).unwrap(); + assert_eq!(*value, 2323u16); + assert_eq!(remaining, &mut []); + let mut value = [0u8]; + let res = u16::zero_copy_at_mut(&mut value); + + assert_eq!(res, Err(ZeroCopyError::Size)); +} + +#[test] +fn test_deserialize_mut_vec() { + let mut bytes = [2, 0, 0, 0, 1, 2]; // Length 2, followed by values 1 and 2 + let (vec, remaining) = Vec::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + vec.iter().map(|x| **x).collect::>(), + std::vec![1u8, 2] + ); + assert_eq!(remaining, &mut []); +} + +#[test] +fn test_vecu8_deref() { + let data = std::vec![1, 2, 3]; + let vec_u8 = VecU8(data.clone()); + assert_eq!(&*vec_u8, &data); + + let mut vec = VecU8::new(); + vec.push(1u8); + assert_eq!(*vec, std::vec![1u8]); +} + +#[test] +fn test_deserialize_mut_vecu8() { + let mut bytes = [3, 4, 5, 6]; // Length 3, followed by values 4, 5, 6 + let (vec, remaining) = VecU8::::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + vec.iter().map(|x| **x).collect::>(), + std::vec![4, 5, 6] + ); + assert_eq!(remaining, &mut []); +} + +#[cfg(test)] +pub mod test { + use std::vec; + + use borsh::{BorshDeserialize, BorshSerialize}; + use zerocopy::{ + little_endian::{U16, U64}, + IntoBytes, Ref, Unaligned, + }; + + use super::*; + use crate::slice_mut::ZeroCopySliceMutBorsh; + + // Rules: + // 1. create ZStruct for the struct + // 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct + // 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct + // 1.3. replace u16 with U16, u32 with U32, etc + // 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 + // 1.5. If a vector contains a nested vector (does not implement Copy) it must implement DeserializeMut + // 1.6. Elements in an Option must implement DeserializeMut + // 1.7. a type that does not implement Copy must implement DeserializeMut, and is deserialized 1 by 1 + + // Derive Macro needs to derive: + // 1. ZeroCopyStructInnerMut + // 2. DeserializeMut + // 3. PartialEq for ZStruct<'_> + // + // For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + + // Tests for manually implemented structures (without derive macro) + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct1 { + pub a: u8, + pub b: u16, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes, IntoBytes)] + pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct1<'a> { + pub meta: Ref<&'a mut [u8], ZStruct1Meta>, + } + impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a mut [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl DerefMut for ZStruct1<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } + } + + #[test] + fn test_struct_1() { + let ref_struct = Struct1 { a: 1, b: 2 }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (mut struct1, remaining) = Struct1::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &mut []); + struct1.meta.a = 2; + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct2<'a> { + meta: Ref<&'a mut [u8], ZStruct2Meta>, + pub vec: &'a mut [u8], + } + + impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec == other.vec.as_slice() + } + } + + impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a mut [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct2Meta>::from_prefix(bytes)?; + let (len, bytes) = bytes.split_at_mut(4); + let len = U32::from_bytes( + len.try_into() + .map_err(|_| ZeroCopyError::ArraySize(4, len.len()))?, + ); + let (vec, bytes) = bytes.split_at_mut(u32::from(len) as usize); + Ok((ZStruct2 { meta, vec }, bytes)) + } + } + + #[test] + fn test_struct_2() { + let ref_struct = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (struct2, remaining) = Struct2::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct3<'a> { + meta: Ref<&'a mut [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceMutBorsh<'a, u8>, + pub c: Ref<&'a mut [u8], U64>, + } + + impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a mut [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceMutBorsh::zero_copy_at_mut(bytes)?; + let (c, bytes) = Ref::<&mut [u8], U64>::from_prefix(bytes)?; + Ok((Self::Output { meta, vec, c }, bytes)) + } + } + + #[test] + fn test_struct_3() { + let ref_struct = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct3::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] + pub struct Struct4Nested { + a: u8, + b: u16, + } + + impl<'a> DeserializeMut<'a> for Struct4Nested { + type Output = ZStruct4Nested; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (bytes, remaining_bytes) = Ref::<&mut [u8], ZStruct4Nested>::from_prefix(bytes)?; + Ok((*bytes, remaining_bytes)) + } + } + + #[repr(C)] + #[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, + )] + pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, + } + + impl ZeroCopyStructInnerMut for Struct4Nested { + type ZeroCopyInnerMut = ZStruct4Nested; + } + + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] + pub struct ZStruct4Meta { + pub a: ::ZeroCopyInnerMut, + pub b: ::ZeroCopyInnerMut, + } + + #[derive(Debug, PartialEq)] + pub struct ZStruct4<'a> { + meta: Ref<&'a mut [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceMutBorsh<'a, ::ZeroCopyInnerMut>, + pub c: Ref<&'a mut [u8], ::ZeroCopyInnerMut>, + pub vec_2: + ZeroCopySliceMutBorsh<'a, ::ZeroCopyInnerMut>, + } + + impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a mut [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceMutBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&mut [u8], ::ZeroCopyInnerMut>::from_prefix( + bytes, + )?; + let (vec_2, bytes) = ZeroCopySliceMutBorsh::from_bytes_at(bytes)?; + Ok(( + Self::Output { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } + } + + /// TODO: + /// - add SIZE const generic DeserializeMut trait + /// - add new with data function + impl Struct4 { + // pub fn byte_len(&self) -> usize { + // size_of::() + // + size_of::() + // + size_of::() * self.vec.len() + // + size_of::() + // + size_of::() * self.vec_2.len() + // } + + pub fn new_with_data<'a>( + bytes: &'a mut [u8], + data: &Struct4, + ) -> (ZStruct4<'a>, &'a mut [u8]) { + let (mut zero_copy, bytes) = + ::zero_copy_at_mut(bytes).unwrap(); + zero_copy.meta.a = data.a; + zero_copy.meta.b = data.b.into(); + zero_copy + .vec + .iter_mut() + .zip(data.vec.iter()) + .for_each(|(x, y)| *x = *y); + (zero_copy, bytes) + } + } + + #[test] + fn test_struct_4() { + let ref_struct = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct4::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct5 { + pub a: Vec>, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInnerMut>>, + } + + impl<'a> DeserializeMut<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::< + ZeroCopySliceMutBorsh<::ZeroCopyInnerMut>, + >::zero_copy_at_mut(bytes)?; + Ok((ZStruct5 { a }, bytes)) + } + } + + #[test] + fn test_struct_5() { + let ref_struct = Struct5 { + a: vec![vec![1u8; 32]; 32], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct5::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &mut []); + } + + // If a struct inside a vector contains a vector it must implement DeserializeMut. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct6 { + pub a: Vec, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> DeserializeMut<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at_mut(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } + } + + #[test] + fn test_struct_6() { + let ref_struct = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct6::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &mut []); + } + + #[repr(C)] + #[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] + pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, + } + + #[repr(C)] + #[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] + pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct7<'a> { + meta: Ref<&'a mut [u8], ZStruct7Meta>, + pub option: Option<>::Output>, + } + + impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option.as_ref().map(|x| **x) == other.option + } + } + + impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a mut [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } + } + + impl<'a> DeserializeMut<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&mut [u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as DeserializeMut<'a>>::zero_copy_at_mut(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } + } + + #[test] + fn test_struct_7() { + let ref_struct = Struct7 { + a: 1, + b: 2, + option: Some(3), + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option.map(|x| *x), Some(3)); + assert_eq!(remaining, &mut []); + + let ref_struct = Struct7 { + a: 1, + b: 2, + option: None, + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &mut []); + } + + // If a struct inside a vector contains a vector it must implement DeserializeMut. + #[repr(C)] + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct Struct8 { + pub a: Vec, + } + + #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] + pub struct NestedStruct { + pub a: u8, + pub b: Struct2, + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZNestedStruct<'a> { + pub a: >::Output, + pub b: >::Output, + } + + impl<'a> DeserializeMut<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = + ::ZeroCopyInnerMut::zero_copy_at_mut(bytes)?; + let (b, bytes) = >::zero_copy_at_mut(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } + } + + impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + *self.a == other.a && self.b == other.b + } + } + + #[repr(C)] + #[derive(Debug, PartialEq)] + pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, + } + + impl<'a> DeserializeMut<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at_mut(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } + } + + #[test] + fn test_struct_8() { + let ref_struct = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at_mut(&mut bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &mut []); + } +} diff --git a/program-libs/zero-copy/src/init_mut.rs b/program-libs/zero-copy/src/init_mut.rs new file mode 100644 index 0000000000..c16d371176 --- /dev/null +++ b/program-libs/zero-copy/src/init_mut.rs @@ -0,0 +1,268 @@ +use core::mem::size_of; +use std::vec::Vec; + +use crate::{borsh_mut::DeserializeMut, errors::ZeroCopyError}; + +/// Trait for types that can be initialized in mutable byte slices with configuration +/// +/// This trait provides a way to initialize structures in pre-allocated byte buffers +/// with specific configuration parameters that determine Vec lengths, Option states, etc. +pub trait ZeroCopyNew<'a> +where + Self: Sized, +{ + /// Configuration type needed to initialize this type + type Config; + + /// Output type - the mutable zero-copy view of this type + type Output; + + /// Calculate the byte length needed for this type with the given configuration + /// + /// This is essential for allocating the correct buffer size before calling new_zero_copy + fn byte_len(config: &Self::Config) -> usize; + + /// Initialize this type in a mutable byte slice with the given configuration + /// + /// Returns the initialized mutable view and remaining bytes + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError>; +} + +// Generic implementation for Option +impl<'a, T> ZeroCopyNew<'a> for Option +where + T: ZeroCopyNew<'a>, +{ + type Config = (bool, T::Config); // (enabled, inner_config) + type Output = Option; + + fn byte_len(config: &Self::Config) -> usize { + let (enabled, inner_config) = config; + if *enabled { + // 1 byte for Some discriminant + inner type's byte_len + 1 + T::byte_len(inner_config) + } else { + // Just 1 byte for None discriminant + 1 + } + } + + fn new_zero_copy( + bytes: &'a mut [u8], + config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + if bytes.is_empty() { + return Err(ZeroCopyError::ArraySize(1, bytes.len())); + } + + let (enabled, inner_config) = config; + + if enabled { + bytes[0] = 1; // Some discriminant + let (_, bytes) = bytes.split_at_mut(1); + let (value, bytes) = T::new_zero_copy(bytes, inner_config)?; + Ok((Some(value), bytes)) + } else { + bytes[0] = 0; // None discriminant + let (_, bytes) = bytes.split_at_mut(1); + Ok((None, bytes)) + } + } +} + +// Implementation for primitive types (no configuration needed) +impl<'a> ZeroCopyNew<'a> for u64 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U64>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U64 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U64>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u32 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U32>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U32 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U32>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u16 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U16>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Return U16 little-endian type for generated structs + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U16>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for u8 { + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Use the DeserializeMut trait to create the proper output + Self::zero_copy_at_mut(bytes) + } +} + +impl<'a> ZeroCopyNew<'a> for bool { + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() // bool is serialized as u8 + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Treat bool as u8 + u8::zero_copy_at_mut(bytes) + } +} + +// Implementation for fixed-size arrays +impl< + 'a, + T: Copy + Default + zerocopy::KnownLayout + zerocopy::Immutable + zerocopy::FromBytes, + const N: usize, + > ZeroCopyNew<'a> for [T; N] +{ + type Config = (); + type Output = >::Output; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + // Use the DeserializeMut trait to create the proper output + Self::zero_copy_at_mut(bytes) + } +} + +// Implementation for zerocopy little-endian types +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U16 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U16>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U16>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U32 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U32>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U32>::from_prefix(bytes)?) + } +} + +impl<'a> ZeroCopyNew<'a> for zerocopy::little_endian::U64 { + type Config = (); + type Output = zerocopy::Ref<&'a mut [u8], zerocopy::little_endian::U64>; + + fn byte_len(_config: &Self::Config) -> usize { + size_of::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + _config: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + Ok(zerocopy::Ref::<&mut [u8], zerocopy::little_endian::U64>::from_prefix(bytes)?) + } +} + +// Implementation for Vec +impl<'a, T: ZeroCopyNew<'a>> ZeroCopyNew<'a> for Vec { + type Config = Vec; // Vector of configs for each item + type Output = Vec; + + fn byte_len(config: &Self::Config) -> usize { + // 4 bytes for length prefix + sum of byte_len for each element config + 4 + config + .iter() + .map(|config| T::byte_len(config)) + .sum::() + } + + fn new_zero_copy( + bytes: &'a mut [u8], + configs: Self::Config, + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + use zerocopy::{little_endian::U32, Ref}; + + // Write length as U32 + let len = configs.len() as u32; + let (mut len_ref, mut bytes) = Ref::<&mut [u8], U32>::from_prefix(bytes)?; + *len_ref = U32::new(len); + + // Initialize each item with its config + let mut items = Vec::with_capacity(configs.len()); + for config in configs { + let (item, remaining_bytes) = T::new_zero_copy(bytes, config)?; + bytes = remaining_bytes; + items.push(item); + } + + Ok((items, bytes)) + } +} diff --git a/program-libs/zero-copy/src/lib.rs b/program-libs/zero-copy/src/lib.rs index 297c849d53..3ac6a38948 100644 --- a/program-libs/zero-copy/src/lib.rs +++ b/program-libs/zero-copy/src/lib.rs @@ -10,8 +10,24 @@ pub mod vec; use core::mem::{align_of, size_of}; #[cfg(feature = "std")] pub mod borsh; - -use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; +#[cfg(feature = "std")] +pub mod borsh_mut; +#[cfg(feature = "std")] +pub mod init_mut; +#[cfg(feature = "std")] +pub use borsh::ZeroCopyStructInner; +#[cfg(feature = "std")] +pub use init_mut::ZeroCopyNew; +#[cfg(all(feature = "derive", feature = "mut"))] +pub use light_zero_copy_derive::ZeroCopyMut; +#[cfg(feature = "derive")] +pub use light_zero_copy_derive::{ZeroCopy, ZeroCopyEq}; +#[cfg(feature = "derive")] +pub use zerocopy::{ + little_endian::{self, U16, U32, U64}, + Ref, Unaligned, +}; +pub use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; #[cfg(feature = "std")] extern crate std; diff --git a/program-libs/zero-copy/src/slice_mut.rs b/program-libs/zero-copy/src/slice_mut.rs index 27cd2f776a..7a50b7e44d 100644 --- a/program-libs/zero-copy/src/slice_mut.rs +++ b/program-libs/zero-copy/src/slice_mut.rs @@ -276,3 +276,16 @@ where write!(f, "{:?}", self.as_slice()) } } + +#[cfg(feature = "std")] +impl<'a, T: ZeroCopyTraits + crate::borsh_mut::DeserializeMut<'a>> + crate::borsh_mut::DeserializeMut<'a> for ZeroCopySliceMutBorsh<'_, T> +{ + type Output = ZeroCopySliceMutBorsh<'a, T>; + + fn zero_copy_at_mut( + bytes: &'a mut [u8], + ) -> Result<(Self::Output, &'a mut [u8]), ZeroCopyError> { + ZeroCopySliceMutBorsh::from_bytes_at(bytes) + } +} diff --git a/program-libs/zero-copy/tests/borsh.rs b/program-libs/zero-copy/tests/borsh.rs new file mode 100644 index 0000000000..071b4e8df2 --- /dev/null +++ b/program-libs/zero-copy/tests/borsh.rs @@ -0,0 +1,335 @@ +#![cfg(all(feature = "std", feature = "derive", feature = "mut"))] +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{ + borsh::Deserialize, borsh_mut::DeserializeMut, ZeroCopy, ZeroCopyEq, ZeroCopyMut, +}; + +#[repr(C)] +#[derive(Debug, PartialEq, ZeroCopy, ZeroCopyMut, ZeroCopyEq, BorshDeserialize, BorshSerialize)] +pub struct Struct1Derived { + pub a: u8, + pub b: u16, +} + +#[test] +fn test_struct_1_derived() { + let ref_struct = Struct1Derived { a: 1, b: 2 }; + let mut bytes = ref_struct.try_to_vec().unwrap(); + + { + let (struct1, remaining) = Struct1Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(struct1, ref_struct); + assert_eq!(remaining, &[]); + } + { + let (mut struct1, _) = Struct1Derived::zero_copy_at_mut(&mut bytes).unwrap(); + struct1.a = 2; + struct1.b = 3.into(); + } + let borsh = Struct1Derived::deserialize(&mut &bytes[..]).unwrap(); + let (struct_1, _) = Struct1Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct_1.a, 2); // Modified value from mutable operations + assert_eq!(struct_1.b, 3); // Modified value from mutable operations + assert_eq!(struct_1, borsh); +} + +// Struct2 equivalent: Manual implementation that should match Struct2 +#[repr(C)] +#[derive( + Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct2Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, +} + +#[test] +fn test_struct_2_derived() { + let ref_struct = Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (struct2, remaining) = Struct2Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); + assert_eq!(struct2, ref_struct); +} + +// Struct3 equivalent: fields should match Struct3 +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct3Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} + +#[test] +fn test_struct_3_derived() { + let ref_struct = Struct3Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct3Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(zero_copy, ref_struct); + + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct4NestedDerived { + a: u8, + b: u16, +} + +#[repr(C)] +#[derive( + Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct4Derived { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, +} + +#[test] +fn test_struct_4_derived() { + let ref_struct = Struct4Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4NestedDerived { a: 1, b: 2 }; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct4Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + // Check vec_2 length is correct + assert_eq!(zero_copy.vec_2.len(), 32); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive( + Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct Struct5Derived { + pub a: Vec>, +} + +#[test] +fn test_struct_5_derived() { + let ref_struct = Struct5Derived { + a: vec![vec![1u8; 32]; 32], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct5Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct6Derived { + pub a: Vec, +} + +#[test] +fn test_struct_6_derived() { + let ref_struct = Struct6Derived { + a: vec![ + Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct6Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(zero_copy, ref_struct); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut)] +pub struct Struct7Derived { + pub a: u8, + pub b: u16, + pub option: Option, +} + +#[test] +fn test_struct_7_derived() { + let ref_struct = Struct7Derived { + a: 1, + b: 2, + option: Some(3), + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct7Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7Derived { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7Derived::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq)] +pub struct Struct8Derived { + pub a: Vec, +} + +#[derive( + Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize, ZeroCopy, ZeroCopyMut, ZeroCopyEq, +)] +pub struct NestedStructDerived { + pub a: u8, + pub b: Struct2Derived, +} + +#[test] +fn test_struct_8_derived() { + let ref_struct = Struct8Derived { + a: vec![ + NestedStructDerived { + a: 1, + b: Struct2Derived { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + }; + let bytes = ref_struct.try_to_vec().unwrap(); + + let (zero_copy, remaining) = Struct8Derived::zero_copy_at(&bytes).unwrap(); + // Check length of vec matches + assert_eq!(zero_copy.a.len(), 32); + assert_eq!(zero_copy, ref_struct); + + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(ZeroCopy, ZeroCopyMut, ZeroCopyEq, BorshSerialize, BorshDeserialize, PartialEq, Debug)] +pub struct ArrayStruct { + pub a: [u8; 32], + pub b: [u8; 64], + pub c: [u8; 32], +} + +#[test] +fn test_array_struct() -> Result<(), Box> { + let array_struct = ArrayStruct { + a: [1u8; 32], + b: [2u8; 64], + c: [3u8; 32], + }; + let bytes = array_struct.try_to_vec()?; + + let (zero_copy, remaining) = ArrayStruct::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, [1u8; 32]); + assert_eq!(zero_copy.b, [2u8; 64]); + assert_eq!(zero_copy.c, [3u8; 32]); + assert_eq!(zero_copy, array_struct); + assert_eq!(remaining, &[]); + Ok(()) +} + +#[derive( + Debug, + PartialEq, + Default, + Clone, + BorshSerialize, + BorshDeserialize, + ZeroCopy, + ZeroCopyMut, + ZeroCopyEq, +)] +pub struct CompressedAccountData { + pub discriminator: [u8; 8], + pub data: Vec, + pub data_hash: [u8; 32], +} + +#[test] +fn test_compressed_account_data() { + let compressed_account_data = CompressedAccountData { + discriminator: [1u8; 8], + data: vec![2u8; 32], + data_hash: [3u8; 32], + }; + let bytes = compressed_account_data.try_to_vec().unwrap(); + + let (zero_copy, remaining) = CompressedAccountData::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.discriminator, [1u8; 8]); + // assert_eq!(zero_copy.data, compressed_account_data.data.as_slice()); + assert_eq!(*zero_copy.data_hash, [3u8; 32]); + assert_eq!(zero_copy, compressed_account_data); + assert_eq!(remaining, &[]); +} diff --git a/program-libs/zero-copy/tests/borsh_2.rs b/program-libs/zero-copy/tests/borsh_2.rs new file mode 100644 index 0000000000..aece86bb1c --- /dev/null +++ b/program-libs/zero-copy/tests/borsh_2.rs @@ -0,0 +1,559 @@ +#![cfg(all(feature = "std", feature = "derive"))] + +use std::{ops::Deref, vec}; + +use borsh::{BorshDeserialize, BorshSerialize}; +use light_zero_copy::{ + borsh::Deserialize, errors::ZeroCopyError, slice::ZeroCopySliceBorsh, ZeroCopyStructInner, +}; +use zerocopy::{ + little_endian::{U16, U64}, + FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned, +}; + +// Rules: +// 1. create ZStruct for the struct +// 1.1. the first fields are extracted into a meta struct until we reach a Vec, Option or type that does not implement Copy, and we implement deref for the meta struct +// 1.2. represent vectors to ZeroCopySlice & don't include these into the meta struct +// 1.3. replace u16 with U16, u32 with U32, etc +// 1.4. every field after the first vector is directly included in the ZStruct and deserialized 1 by 1 +// 1.5. If a vector contains a nested vector (does not implement Copy) it must implement Deserialize +// 1.6. Elements in an Option must implement Deserialize +// 1.7. a type that does not implement Copy must implement Deserialize, and is deserialized 1 by 1 + +// Derive Macro needs to derive: +// 1. ZeroCopyStructInner +// 2. Deserialize +// 3. PartialEq for ZStruct<'_> +// +// For every struct1 - struct7 create struct_derived1 - struct_derived7 and replicate the tests for the new structs. + +// Tests for manually implemented structures (without derive macro) + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct1 { + pub a: u8, + pub b: u16, +} + +// pub fn data_hash_struct_1(a: u8, b: u16) -> [u8; 32] { + +// } + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct1Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct1<'a> { + meta: Ref<&'a [u8], ZStruct1Meta>, +} +impl<'a> Deref for ZStruct1<'a> { + type Target = Ref<&'a [u8], ZStruct1Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct1 { + type Output = ZStruct1<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct1Meta>::from_prefix(bytes)?; + Ok((ZStruct1 { meta }, bytes)) + } +} + +#[test] +fn test_struct_1() { + let bytes = Struct1 { a: 1, b: 2 }.try_to_vec().unwrap(); + let (struct1, remaining) = Struct1::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct1.a, 1u8); + assert_eq!(struct1.b, 2u16); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] +pub struct Struct2 { + pub a: u8, + pub b: u16, + pub vec: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct2Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct2<'a> { + meta: Ref<&'a [u8], ZStruct2Meta>, + pub vec: as ZeroCopyStructInner>::ZeroCopyInner, +} + +impl PartialEq for ZStruct2<'_> { + fn eq(&self, other: &Struct2) -> bool { + let meta: &ZStruct2Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.vec.as_slice() == other.vec.as_slice() + } +} + +impl<'a> Deref for ZStruct2<'a> { + type Target = Ref<&'a [u8], ZStruct2Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct2 { + type Output = ZStruct2<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct2Meta>::from_prefix(bytes)?; + let (vec, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct2 { meta, vec }, bytes)) + } +} + +#[test] +fn test_struct_2() { + let bytes = Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + } + .try_to_vec() + .unwrap(); + let (struct2, remaining) = Struct2::zero_copy_at(&bytes).unwrap(); + assert_eq!(struct2.a, 1u8); + assert_eq!(struct2.b, 2u16); + assert_eq!(struct2.vec.to_vec(), vec![1u8; 32]); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct3 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct3Meta { + pub a: u8, + pub b: U16, +} + +#[derive(Debug, PartialEq)] +pub struct ZStruct3<'a> { + meta: Ref<&'a [u8], ZStruct3Meta>, + pub vec: ZeroCopySliceBorsh<'a, u8>, + pub c: Ref<&'a [u8], U64>, +} + +impl<'a> Deref for ZStruct3<'a> { + type Target = Ref<&'a [u8], ZStruct3Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct3 { + type Output = ZStruct3<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct3Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::zero_copy_at(bytes)?; + let (c, bytes) = Ref::<&[u8], U64>::from_prefix(bytes)?; + Ok((ZStruct3 { meta, vec, c }, bytes)) + } +} + +#[test] +fn test_struct_3() { + let bytes = Struct3 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct3::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, Clone)] +pub struct Struct4Nested { + a: u8, + b: u16, +} + +#[repr(C)] +#[derive( + Debug, PartialEq, Copy, Clone, KnownLayout, Immutable, IntoBytes, Unaligned, FromBytes, +)] +pub struct ZStruct4Nested { + pub a: u8, + pub b: U16, +} + +impl ZeroCopyStructInner for Struct4Nested { + type ZeroCopyInner = ZStruct4Nested; +} + +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct4 { + pub a: u8, + pub b: u16, + pub vec: Vec, + pub c: u64, + pub vec_2: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, IntoBytes, FromBytes)] +pub struct ZStruct4Meta { + pub a: ::ZeroCopyInner, + pub b: ::ZeroCopyInner, +} + +#[derive(Debug, PartialEq)] +pub struct ZStruct4<'a> { + meta: Ref<&'a [u8], ZStruct4Meta>, + pub vec: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, + pub c: Ref<&'a [u8], ::ZeroCopyInner>, + pub vec_2: ZeroCopySliceBorsh<'a, ::ZeroCopyInner>, +} + +impl<'a> Deref for ZStruct4<'a> { + type Target = Ref<&'a [u8], ZStruct4Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct4 { + type Output = ZStruct4<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct4Meta>::from_prefix(bytes)?; + let (vec, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + let (c, bytes) = + Ref::<&[u8], ::ZeroCopyInner>::from_prefix(bytes)?; + let (vec_2, bytes) = ZeroCopySliceBorsh::from_bytes_at(bytes)?; + Ok(( + ZStruct4 { + meta, + vec, + c, + vec_2, + }, + bytes, + )) + } +} + +#[test] +fn test_struct_4() { + let bytes = Struct4 { + a: 1, + b: 2, + vec: vec![1u8; 32], + c: 3, + vec_2: vec![Struct4Nested { a: 1, b: 2 }; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct4::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.vec.to_vec(), vec![1u8; 32]); + assert_eq!(u64::from(*zero_copy.c), 3); + assert_eq!( + zero_copy.vec_2.to_vec(), + vec![ZStruct4Nested { a: 1, b: 2.into() }; 32] + ); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct5 { + pub a: Vec>, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct5<'a> { + pub a: Vec::ZeroCopyInner>>, +} + +impl<'a> Deserialize<'a> for Struct5 { + type Output = ZStruct5<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = + Vec::::ZeroCopyInner>>::zero_copy_at( + bytes, + )?; + Ok((ZStruct5 { a }, bytes)) + } +} + +#[test] +fn test_struct_5() { + let bytes = Struct5 { + a: vec![vec![1u8; 32]; 32], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct5::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().map(|x| x.to_vec()).collect::>(), + vec![vec![1u8; 32]; 32] + ); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct6 { + pub a: Vec, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct6<'a> { + pub a: Vec<>::Output>, +} + +impl<'a> Deserialize<'a> for Struct6 { + type Output = ZStruct6<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct6 { a }, bytes)) + } +} + +#[test] +fn test_struct_6() { + let bytes = Struct6 { + a: vec![ + Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct6::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }; + 32 + ] + ); + assert_eq!(remaining, &[]); +} + +#[repr(C)] +#[derive(Debug, PartialEq, Clone, BorshSerialize, BorshDeserialize)] +pub struct Struct7 { + pub a: u8, + pub b: u16, + pub option: Option, +} + +#[repr(C)] +#[derive(Debug, PartialEq, KnownLayout, Immutable, Unaligned, FromBytes)] +pub struct ZStruct7Meta { + pub a: u8, + pub b: U16, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct7<'a> { + meta: Ref<&'a [u8], ZStruct7Meta>, + pub option: as ZeroCopyStructInner>::ZeroCopyInner, +} + +impl PartialEq for ZStruct7<'_> { + fn eq(&self, other: &Struct7) -> bool { + let meta: &ZStruct7Meta = &self.meta; + if meta.a != other.a || other.b != meta.b.into() { + return false; + } + self.option == other.option + } +} + +impl<'a> Deref for ZStruct7<'a> { + type Target = Ref<&'a [u8], ZStruct7Meta>; + + fn deref(&self) -> &Self::Target { + &self.meta + } +} + +impl<'a> Deserialize<'a> for Struct7 { + type Output = ZStruct7<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (meta, bytes) = Ref::<&[u8], ZStruct7Meta>::from_prefix(bytes)?; + let (option, bytes) = as Deserialize>::zero_copy_at(bytes)?; + Ok((ZStruct7 { meta, option }, bytes)) + } +} + +#[test] +fn test_struct_7() { + let bytes = Struct7 { + a: 1, + b: 2, + option: Some(3), + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, Some(3)); + assert_eq!(remaining, &[]); + + let bytes = Struct7 { + a: 1, + b: 2, + option: None, + } + .try_to_vec() + .unwrap(); + let (zero_copy, remaining) = Struct7::zero_copy_at(&bytes).unwrap(); + assert_eq!(zero_copy.a, 1u8); + assert_eq!(zero_copy.b, 2u16); + assert_eq!(zero_copy.option, None); + assert_eq!(remaining, &[]); +} + +// If a struct inside a vector contains a vector it must implement Deserialize. +#[repr(C)] +#[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct Struct8 { + pub a: Vec, +} + +#[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)] +pub struct NestedStruct { + pub a: u8, + pub b: Struct2, +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZNestedStruct<'a> { + pub a: ::ZeroCopyInner, + pub b: >::Output, +} + +impl<'a> Deserialize<'a> for NestedStruct { + type Output = ZNestedStruct<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = ::ZeroCopyInner::zero_copy_at(bytes)?; + let (b, bytes) = ::zero_copy_at(bytes)?; + Ok((ZNestedStruct { a, b }, bytes)) + } +} + +impl PartialEq for ZNestedStruct<'_> { + fn eq(&self, other: &NestedStruct) -> bool { + self.a == other.a && self.b == other.b + } +} + +#[repr(C)] +#[derive(Debug, PartialEq)] +pub struct ZStruct8<'a> { + pub a: Vec<>::Output>, +} + +impl<'a> Deserialize<'a> for Struct8 { + type Output = ZStruct8<'a>; + + fn zero_copy_at(bytes: &'a [u8]) -> Result<(Self::Output, &'a [u8]), ZeroCopyError> { + let (a, bytes) = Vec::::zero_copy_at(bytes)?; + Ok((ZStruct8 { a }, bytes)) + } +} + +#[test] +fn test_struct_8() { + let bytes = Struct8 { + a: vec![ + NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ], + } + .try_to_vec() + .unwrap(); + + let (zero_copy, remaining) = Struct8::zero_copy_at(&bytes).unwrap(); + assert_eq!( + zero_copy.a.iter().collect::>(), + vec![ + &NestedStruct { + a: 1, + b: Struct2 { + a: 1, + b: 2, + vec: vec![1u8; 32], + }, + }; + 32 + ] + ); + assert_eq!(remaining, &[]); +}