From 3572ce8d28696bcd8fbc884ffc4711e7c393a200 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Fri, 6 Jun 2025 14:11:46 +0200 Subject: [PATCH 1/4] Add support for PodList --- Cargo.lock | 32 +-- pod/Cargo.toml | 2 +- pod/src/lib.rs | 1 + pod/src/list.rs | 334 ++++++++++++++++++++++++++++ pod/src/slice.rs | 9 +- tlv-account-resolution/Cargo.toml | 2 +- tlv-account-resolution/src/state.rs | 7 +- type-length-value/Cargo.toml | 2 +- 8 files changed, 366 insertions(+), 23 deletions(-) create mode 100644 pod/src/list.rs diff --git a/Cargo.lock b/Cargo.lock index 83fc27ed..49bbf209 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6667,7 +6667,7 @@ dependencies = [ "solana-system-interface", "solana-sysvar", "solana-zk-sdk", - "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-pod 0.5.1", "spl-token-confidential-transfer-proof-extraction", ] @@ -6708,15 +6708,14 @@ dependencies = [ [[package]] name = "spl-pod" version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d994afaf86b779104b4a95ba9ca75b8ced3fdb17ee934e38cb69e72afbe17799" dependencies = [ - "base64 0.22.1", "borsh 1.5.7", "bytemuck", "bytemuck_derive", "num-derive", "num-traits", - "serde", - "serde_json", "solana-decode-error", "solana-msg", "solana-program-error", @@ -6728,15 +6727,16 @@ dependencies = [ [[package]] name = "spl-pod" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d994afaf86b779104b4a95ba9ca75b8ced3fdb17ee934e38cb69e72afbe17799" +version = "0.6.0" dependencies = [ + "base64 0.22.1", "borsh 1.5.7", "bytemuck", "bytemuck_derive", "num-derive", "num-traits", + "serde", + "serde_json", "solana-decode-error", "solana-msg", "solana-program-error", @@ -6820,7 +6820,7 @@ dependencies = [ "solana-pubkey", "solana-sdk", "spl-discriminator 0.4.1", - "spl-pod 0.5.1", + "spl-pod 0.6.0", "spl-program-error 0.7.0", "spl-type-length-value 0.8.0", "thiserror 2.0.12", @@ -6842,7 +6842,7 @@ dependencies = [ "solana-program-error", "solana-pubkey", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-pod 0.5.1", "spl-program-error 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", "spl-type-length-value 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "thiserror 2.0.12", @@ -6908,7 +6908,7 @@ dependencies = [ "solana-zk-sdk", "spl-elgamal-registry", "spl-memo", - "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-pod 0.5.1", "spl-token", "spl-token-confidential-transfer-ciphertext-arithmetic", "spl-token-confidential-transfer-proof-extraction", @@ -6948,7 +6948,7 @@ dependencies = [ "solana-pubkey", "solana-sdk-ids", "solana-zk-sdk", - "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-pod 0.5.1", "thiserror 2.0.12", ] @@ -6978,7 +6978,7 @@ dependencies = [ "solana-program-error", "solana-pubkey", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-pod 0.5.1", "thiserror 2.0.12", ] @@ -6998,7 +6998,7 @@ dependencies = [ "solana-program-error", "solana-pubkey", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-pod 0.5.1", "spl-type-length-value 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "thiserror 2.0.12", ] @@ -7021,7 +7021,7 @@ dependencies = [ "solana-program-error", "solana-pubkey", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-pod 0.5.1", "spl-program-error 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", "spl-tlv-account-resolution 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)", "spl-type-length-value 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -7040,7 +7040,7 @@ dependencies = [ "solana-msg", "solana-program-error", "spl-discriminator 0.4.1", - "spl-pod 0.5.1", + "spl-pod 0.6.0", "spl-type-length-value-derive", "thiserror 2.0.12", ] @@ -7059,7 +7059,7 @@ dependencies = [ "solana-msg", "solana-program-error", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-pod 0.5.1", "thiserror 2.0.12", ] diff --git a/pod/Cargo.toml b/pod/Cargo.toml index 5fdf70aa..388d1d66 100644 --- a/pod/Cargo.toml +++ b/pod/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "spl-pod" -version = "0.5.1" +version = "0.6.0" description = "Solana Program Library Plain Old Data (Pod)" authors = ["Anza Maintainers "] repository = "https://github.com/solana-program/libraries" diff --git a/pod/src/lib.rs b/pod/src/lib.rs index 292f1c09..ab6ec113 100644 --- a/pod/src/lib.rs +++ b/pod/src/lib.rs @@ -2,6 +2,7 @@ pub mod bytemuck; pub mod error; +pub mod list; pub mod option; pub mod optional_keys; pub mod primitives; diff --git a/pod/src/list.rs b/pod/src/list.rs new file mode 100644 index 00000000..eae3e1d2 --- /dev/null +++ b/pod/src/list.rs @@ -0,0 +1,334 @@ +use crate::bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}; +use crate::error::PodSliceError; +use crate::primitives::PodU32; +use crate::slice::max_len_for_type; +use bytemuck::Pod; +use solana_program_error::ProgramError; + +const LENGTH_SIZE: usize = std::mem::size_of::(); + +/// A mutable, variable-length collection of `Pod` types backed by a byte buffer. +/// +/// `PodList` provides a safe, zero-copy, `Vec`-like interface for a slice of +/// `Pod` data that resides in an external, pre-allocated `&mut [u8]` buffer. +/// It does not own the buffer itself, but acts as a mutable view over it. +/// +/// This is useful in environments where allocations are restricted or expensive, +/// such as Solana programs, allowing for dynamic-length data structures within a +/// fixed-size account. +/// +/// ## Memory Layout +/// +/// The structure assumes the underlying byte buffer is formatted as follows: +/// 1. **Length**: A `u32` value (`PodU32`) at the beginning of the buffer, +/// indicating the number of currently active elements in the collection. +/// 2. **Data**: The remaining part of the buffer, which is treated as a slice +/// of `T` elements. The capacity of the collection is the number of `T` +/// elements that can fit into this data portion. +pub struct PodList<'data, T: Pod> { + length: &'data mut PodU32, + data: &'data mut [T], + max_length: usize, +} + +impl<'data, T: Pod> PodList<'data, T> { + /// Unpack the mutable buffer into a mutable slice, with the option to + /// initialize the data + fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result + where + 'a: 'data, + { + if data.len() < LENGTH_SIZE { + return Err(PodSliceError::BufferTooSmall.into()); + } + let (length, data) = data.split_at_mut(LENGTH_SIZE); + let length = pod_from_bytes_mut::(length)?; + if init { + *length = 0.into(); + } + let max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; + let data = pod_slice_from_bytes_mut(data)?; + Ok(Self { + length, + data, + max_length, + }) + } + + /// Unpack the mutable buffer into a mutable slice + pub fn unpack<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + Self::unpack_internal(data, /* init */ false) + } + + /// Unpack the mutable buffer into a mutable slice, and initialize the + /// slice to 0-length + pub fn init<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + Self::unpack_internal(data, /* init */ true) + } + + /// Add another item to the slice + pub fn push(&mut self, t: T) -> Result<(), ProgramError> { + let length = u32::from(*self.length); + if length as usize == self.max_length { + Err(PodSliceError::BufferTooSmall.into()) + } else { + self.data[length as usize] = t; + *self.length = length.saturating_add(1).into(); + Ok(()) + } + } + + /// Remove and return the element at `index`, shifting all later + /// elements one position to the left. + pub fn remove_at(&mut self, index: usize) -> Result { + let len = u32::from(*self.length) as usize; + if index >= len { + return Err(ProgramError::InvalidArgument); + } + + let removed_item = self.data[index]; + + // Move the tail left by one + let tail_start = index + .checked_add(1) + .ok_or(ProgramError::ArithmeticOverflow)?; + self.data.copy_within(tail_start..len, index); + + // Zero-fill the now-unused slot at the end + let last = len.checked_sub(1).ok_or(ProgramError::ArithmeticOverflow)?; + self.data[last] = T::zeroed(); + + // Store the new length (len - 1) + *self.length = (last as u32).into(); + + Ok(removed_item) + } + + /// Find the first element that satisfies `predicate` and remove it, + /// returning the element. + pub fn remove_first_where

(&mut self, mut predicate: P) -> Result + where + P: FnMut(&T) -> bool, + { + if let Some(index) = self.data.iter().position(&mut predicate) { + self.remove_at(index) + } else { + Err(ProgramError::InvalidArgument) + } + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + bytemuck_derive::{Pod, Zeroable}, + }; + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestStruct { + test_field: u8, + test_pubkey: [u8; 32], + } + + #[test] + fn test_pod_collection() { + // slice can fit 2 `TestStruct` + let mut pod_slice_bytes = [0; 70]; + // set length to 1, so we have room to push 1 more item + let len_bytes = [1, 0, 0, 0]; + pod_slice_bytes[0..4].copy_from_slice(&len_bytes); + + let mut pod_slice = PodList::::unpack(&mut pod_slice_bytes).unwrap(); + + assert_eq!(*pod_slice.length, PodU32::from(1)); + pod_slice.push(TestStruct::default()).unwrap(); + assert_eq!(*pod_slice.length, PodU32::from(2)); + let err = pod_slice + .push(TestStruct::default()) + .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + fn make_buffer(capacity: usize, items: &[u8]) -> Vec { + let buff_len = LENGTH_SIZE.checked_add(capacity).unwrap(); + let mut buf = vec![0u8; buff_len]; + buf[..LENGTH_SIZE].copy_from_slice(&(items.len() as u32).to_le_bytes()); + let end = LENGTH_SIZE.checked_add(items.len()).unwrap(); + buf[LENGTH_SIZE..end].copy_from_slice(items); + buf + } + + #[test] + fn remove_at_first_item() { + let mut buff = make_buffer(15, &[10, 20, 30, 40]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_at(0).unwrap(); + assert_eq!(removed, 10); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[20, 30, 40]); + assert_eq!(pod_list.data[3], 0); + } + + #[test] + fn remove_at_middle_item() { + let mut buff = make_buffer(15, &[10, 20, 30, 40]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_at(2).unwrap(); + assert_eq!(removed, 30); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 20, 40]); + assert_eq!(pod_list.data[3], 0); + } + + #[test] + fn remove_at_last_item() { + let mut buff = make_buffer(15, &[10, 20, 30, 40]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_at(3).unwrap(); + assert_eq!(removed, 40); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 20, 30]); + assert_eq!(pod_list.data[3], 0); + } + + #[test] + fn remove_at_out_of_bounds() { + let mut buff = make_buffer(3, &[1, 2, 3]); + let original_buff = buff.clone(); + + { + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let err = pod_list.remove_at(3).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + + // pod_list should be unchanged + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), vec![1, 2, 3]); + } + + assert_eq!(buff, original_buff); + } + + #[test] + fn remove_at_single_element() { + let mut buff = make_buffer(1, &[10]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_at(0).unwrap(); + assert_eq!(removed, 10); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 0); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[] as &[u8]); + assert_eq!(pod_list.data[0], 0); + } + + #[test] + fn remove_at_empty_slice() { + let mut buff = make_buffer(0, &[]); + let original_buff = buff.clone(); + + { + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let err = pod_list.remove_at(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + + // Assert list state is unchanged + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 0); + } + + assert_eq!(buff, original_buff); + } + + #[test] + fn remove_first_where_first_item() { + let mut buff = make_buffer(3, &[5, 10, 15]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_first_where(|&x| x == 5).unwrap(); + assert_eq!(removed, 5); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 2); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 15]); + assert_eq!(pod_list.data[2], 0); + } + + #[test] + fn remove_first_where_middle_item() { + let mut buff = make_buffer(4, &[1, 2, 3, 4]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_first_where(|v| *v == 3).unwrap(); + assert_eq!(removed, 3); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[1, 2, 4]); + assert_eq!(pod_list.data[3], 0); + } + + #[test] + fn remove_first_where_last_item() { + let mut buff = make_buffer(3, &[5, 10, 15]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_first_where(|&x| x == 15).unwrap(); + assert_eq!(removed, 15); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 2); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[5, 10]); + assert_eq!(pod_list.data[2], 0); + } + + #[test] + fn remove_first_where_multiple_matches() { + let mut buff = make_buffer(5, &[7, 8, 8, 9, 10]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_first_where(|v| *v == 8).unwrap(); + assert_eq!(removed, 8); // Removed *first* 8 + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 4); + // Should remove only the *first* match. + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[7, 8, 9, 10]); + assert_eq!(pod_list.data[4], 0); + } + + #[test] + fn remove_first_where_not_found() { + let mut buff = make_buffer(3, &[5, 6, 7]); + let original_buff = buff.clone(); + + { + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let err = pod_list.remove_first_where(|v| *v == 42).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + // Assert list state is unchanged + assert_eq!(u32::from(*pod_list.length) as usize, 3); + } + + assert_eq!(buff, original_buff); + } + + #[test] + fn remove_first_where_empty_slice() { + let mut buff = make_buffer(0, &[]); + let original_buff = buff.clone(); + + { + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let err = pod_list.remove_first_where(|_| true).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + // Assert list state is unchanged + assert_eq!(u32::from(*pod_list.length) as usize, 0); + } + + assert_eq!(buff, original_buff); + } +} diff --git a/pod/src/slice.rs b/pod/src/slice.rs index ca885aba..0082c7a4 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -49,12 +49,18 @@ impl<'data, T: Pod> PodSlice<'data, T> { } } +#[deprecated( + since = "0.6.0", + note = "This struct will be removed in the next major release (1.0.0). Please use `PodList` instead." +)] /// Special type for using a slice of mutable `Pod`s in a zero-copy way pub struct PodSliceMut<'data, T: Pod> { length: &'data mut PodU32, data: &'data mut [T], max_length: usize, } + +#[allow(deprecated)] impl<'data, T: Pod> PodSliceMut<'data, T> { /// Unpack the mutable buffer into a mutable slice, with the option to /// initialize the data @@ -109,7 +115,7 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { } } -fn max_len_for_type(data_len: usize, length_val: usize) -> Result { +pub fn max_len_for_type(data_len: usize, length_val: usize) -> Result { let item_size = std::mem::size_of::(); let max_len = data_len .checked_div(item_size) @@ -136,6 +142,7 @@ fn max_len_for_type(data_len: usize, length_val: usize) -> Result::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; - let mut validation_data = PodSliceMut::init(bytes)?; + let mut validation_data = PodList::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -188,7 +189,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; - let mut validation_data = PodSliceMut::init(bytes)?; + let mut validation_data = PodList::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } diff --git a/type-length-value/Cargo.toml b/type-length-value/Cargo.toml index 9f8fc535..39b930d6 100644 --- a/type-length-value/Cargo.toml +++ b/type-length-value/Cargo.toml @@ -21,7 +21,7 @@ solana-msg = "2.2.1" solana-program-error = "2.2.1" spl-discriminator = { version = "0.4.0", path = "../discriminator" } spl-type-length-value-derive = { version = "0.2", path = "./derive", optional = true } -spl-pod = { version = "0.5.1", path = "../pod" } +spl-pod = { version = "0.6.0", path = "../pod" } thiserror = "2.0" [lib] From 56057f1a0cea8da680c8525c581ae866074e02db Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Wed, 9 Jul 2025 13:53:33 +0200 Subject: [PATCH 2/4] Review updates --- Cargo.lock | 32 +- pod/Cargo.toml | 2 +- pod/src/list.rs | 451 +++++++++++++++++++--------- pod/src/primitives.rs | 41 +++ pod/src/slice.rs | 82 +++-- tlv-account-resolution/Cargo.toml | 2 +- tlv-account-resolution/src/state.rs | 7 +- type-length-value/Cargo.toml | 2 +- 8 files changed, 407 insertions(+), 212 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 49bbf209..83fc27ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6667,7 +6667,7 @@ dependencies = [ "solana-system-interface", "solana-sysvar", "solana-zk-sdk", - "spl-pod 0.5.1", + "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "spl-token-confidential-transfer-proof-extraction", ] @@ -6708,14 +6708,15 @@ dependencies = [ [[package]] name = "spl-pod" version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d994afaf86b779104b4a95ba9ca75b8ced3fdb17ee934e38cb69e72afbe17799" dependencies = [ + "base64 0.22.1", "borsh 1.5.7", "bytemuck", "bytemuck_derive", "num-derive", "num-traits", + "serde", + "serde_json", "solana-decode-error", "solana-msg", "solana-program-error", @@ -6727,16 +6728,15 @@ dependencies = [ [[package]] name = "spl-pod" -version = "0.6.0" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d994afaf86b779104b4a95ba9ca75b8ced3fdb17ee934e38cb69e72afbe17799" dependencies = [ - "base64 0.22.1", "borsh 1.5.7", "bytemuck", "bytemuck_derive", "num-derive", "num-traits", - "serde", - "serde_json", "solana-decode-error", "solana-msg", "solana-program-error", @@ -6820,7 +6820,7 @@ dependencies = [ "solana-pubkey", "solana-sdk", "spl-discriminator 0.4.1", - "spl-pod 0.6.0", + "spl-pod 0.5.1", "spl-program-error 0.7.0", "spl-type-length-value 0.8.0", "thiserror 2.0.12", @@ -6842,7 +6842,7 @@ dependencies = [ "solana-program-error", "solana-pubkey", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1", + "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "spl-program-error 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", "spl-type-length-value 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "thiserror 2.0.12", @@ -6908,7 +6908,7 @@ dependencies = [ "solana-zk-sdk", "spl-elgamal-registry", "spl-memo", - "spl-pod 0.5.1", + "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "spl-token", "spl-token-confidential-transfer-ciphertext-arithmetic", "spl-token-confidential-transfer-proof-extraction", @@ -6948,7 +6948,7 @@ dependencies = [ "solana-pubkey", "solana-sdk-ids", "solana-zk-sdk", - "spl-pod 0.5.1", + "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "thiserror 2.0.12", ] @@ -6978,7 +6978,7 @@ dependencies = [ "solana-program-error", "solana-pubkey", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1", + "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "thiserror 2.0.12", ] @@ -6998,7 +6998,7 @@ dependencies = [ "solana-program-error", "solana-pubkey", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1", + "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "spl-type-length-value 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "thiserror 2.0.12", ] @@ -7021,7 +7021,7 @@ dependencies = [ "solana-program-error", "solana-pubkey", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1", + "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "spl-program-error 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", "spl-tlv-account-resolution 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)", "spl-type-length-value 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -7040,7 +7040,7 @@ dependencies = [ "solana-msg", "solana-program-error", "spl-discriminator 0.4.1", - "spl-pod 0.6.0", + "spl-pod 0.5.1", "spl-type-length-value-derive", "thiserror 2.0.12", ] @@ -7059,7 +7059,7 @@ dependencies = [ "solana-msg", "solana-program-error", "spl-discriminator 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "spl-pod 0.5.1", + "spl-pod 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "thiserror 2.0.12", ] diff --git a/pod/Cargo.toml b/pod/Cargo.toml index 388d1d66..5fdf70aa 100644 --- a/pod/Cargo.toml +++ b/pod/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "spl-pod" -version = "0.6.0" +version = "0.5.1" description = "Solana Program Library Plain Old Data (Pod)" authors = ["Anza Maintainers "] repository = "https://github.com/solana-program/libraries" diff --git a/pod/src/list.rs b/pod/src/list.rs index eae3e1d2..480761bd 100644 --- a/pod/src/list.rs +++ b/pod/src/list.rs @@ -1,15 +1,37 @@ -use crate::bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}; -use crate::error::PodSliceError; -use crate::primitives::PodU32; -use crate::slice::max_len_for_type; -use bytemuck::Pod; -use solana_program_error::ProgramError; - -const LENGTH_SIZE: usize = std::mem::size_of::(); +use { + crate::{ + bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}, + error::PodSliceError, + primitives::{PodLength, PodU64}, + }, + bytemuck::Pod, + core::mem::{align_of, size_of}, + solana_program_error::ProgramError, +}; + +/// Calculate padding needed between types for alignment +#[inline] +fn calculate_padding() -> Result { + let length_size = size_of::(); + let data_align = align_of::(); + + // Calculate how many bytes we need to add to length_size + // to make it a multiple of data_align + let remainder = length_size + .checked_rem(data_align) + .ok_or(ProgramError::ArithmeticOverflow)?; + if remainder == 0 { + Ok(0) + } else { + data_align + .checked_sub(remainder) + .ok_or(ProgramError::ArithmeticOverflow) + } +} /// A mutable, variable-length collection of `Pod` types backed by a byte buffer. /// -/// `PodList` provides a safe, zero-copy, `Vec`-like interface for a slice of +/// `ListView` provides a safe, zero-copy, `Vec`-like interface for a slice of /// `Pod` data that resides in an external, pre-allocated `&mut [u8]` buffer. /// It does not own the buffer itself, but acts as a mutable view over it. /// @@ -20,34 +42,54 @@ const LENGTH_SIZE: usize = std::mem::size_of::(); /// ## Memory Layout /// /// The structure assumes the underlying byte buffer is formatted as follows: -/// 1. **Length**: A `u32` value (`PodU32`) at the beginning of the buffer, -/// indicating the number of currently active elements in the collection. -/// 2. **Data**: The remaining part of the buffer, which is treated as a slice +/// 1. **Length**: A length field of type `L` at the beginning of the buffer, +/// indicating the number of currently active elements in the collection. Defaults to `PodU64` so the offset is then compatible with 1, 2, 4 and 8 bytes. +/// 2. **Padding**: Optional padding bytes to ensure proper alignment of the data. +/// 3. **Data**: The remaining part of the buffer, which is treated as a slice /// of `T` elements. The capacity of the collection is the number of `T` /// elements that can fit into this data portion. -pub struct PodList<'data, T: Pod> { - length: &'data mut PodU32, +pub struct ListView<'data, T: Pod, L: PodLength = PodU64> { + length: &'data mut L, data: &'data mut [T], max_length: usize, } -impl<'data, T: Pod> PodList<'data, T> { +impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { /// Unpack the mutable buffer into a mutable slice, with the option to /// initialize the data - fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result - where - 'a: 'data, - { - if data.len() < LENGTH_SIZE { + #[inline(always)] + fn unpack_internal(buf: &'data mut [u8], init: bool) -> Result { + // Split the buffer to get the length prefix. + // buf: [ L L L L | P P D D D D D D D D ...] + // <-------> <----------------------> + // len_bytes tail + let length_size = size_of::(); + if buf.len() < length_size { return Err(PodSliceError::BufferTooSmall.into()); } - let (length, data) = data.split_at_mut(LENGTH_SIZE); - let length = pod_from_bytes_mut::(length)?; + let (len_bytes, tail) = buf.split_at_mut(length_size); + + // Skip alignment padding to find the start of the data. + // tail: [P P | D D D D D D D D ...] + // <-> <-------------------> + // padding data_bytes + let padding = calculate_padding::()?; + let data_bytes = tail + .get_mut(padding..) + .ok_or(PodSliceError::BufferTooSmall)?; + + // Cast the bytes to typed data + let length = pod_from_bytes_mut::(len_bytes)?; + let data = pod_slice_from_bytes_mut::(data_bytes)?; + let max_length = data.len(); + + // Initialize the list or validate its current length. if init { - *length = 0.into(); + *length = L::from_usize(0)?; + } else if length.as_usize() > max_length { + return Err(PodSliceError::BufferTooSmall.into()); } - let max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; - let data = pod_slice_from_bytes_mut(data)?; + Ok(Self { length, data, @@ -60,7 +102,7 @@ impl<'data, T: Pod> PodList<'data, T> { where 'a: 'data, { - Self::unpack_internal(data, /* init */ false) + Self::unpack_internal(data, false) } /// Unpack the mutable buffer into a mutable slice, and initialize the @@ -69,17 +111,17 @@ impl<'data, T: Pod> PodList<'data, T> { where 'a: 'data, { - Self::unpack_internal(data, /* init */ true) + Self::unpack_internal(data, true) } /// Add another item to the slice pub fn push(&mut self, t: T) -> Result<(), ProgramError> { - let length = u32::from(*self.length); - if length as usize == self.max_length { + let length = self.length.as_usize(); + if length == self.max_length { Err(PodSliceError::BufferTooSmall.into()) } else { - self.data[length as usize] = t; - *self.length = length.saturating_add(1).into(); + self.data[length] = t; + *self.length = L::from_usize(length.saturating_add(1))?; Ok(()) } } @@ -87,7 +129,7 @@ impl<'data, T: Pod> PodList<'data, T> { /// Remove and return the element at `index`, shifting all later /// elements one position to the left. pub fn remove_at(&mut self, index: usize) -> Result { - let len = u32::from(*self.length) as usize; + let len = self.length.as_usize(); if index >= len { return Err(ProgramError::InvalidArgument); } @@ -101,33 +143,58 @@ impl<'data, T: Pod> PodList<'data, T> { self.data.copy_within(tail_start..len, index); // Zero-fill the now-unused slot at the end - let last = len.checked_sub(1).ok_or(ProgramError::ArithmeticOverflow)?; + let last = len.saturating_sub(1); self.data[last] = T::zeroed(); // Store the new length (len - 1) - *self.length = (last as u32).into(); + *self.length = L::from_usize(last)?; Ok(removed_item) } /// Find the first element that satisfies `predicate` and remove it, /// returning the element. - pub fn remove_first_where

(&mut self, mut predicate: P) -> Result + pub fn remove_first_where

(&mut self, predicate: P) -> Result where P: FnMut(&T) -> bool, { - if let Some(index) = self.data.iter().position(&mut predicate) { + if let Some(index) = self.data.iter().position(predicate) { self.remove_at(index) } else { Err(ProgramError::InvalidArgument) } } + + /// Get the amount of bytes used by `num_items` + pub fn size_of(num_items: usize) -> Result { + let padding_size = calculate_padding::()?; + let header_size = size_of::().saturating_add(padding_size); + + let data_size = size_of::() + .checked_mul(num_items) + .ok_or(PodSliceError::CalculationFailure)?; + + header_size + .checked_add(data_size) + .ok_or(PodSliceError::CalculationFailure.into()) + } + + /// Get the current number of items in collection + pub fn len(&self) -> usize { + self.length.as_usize() + } + + /// Returns true if the collection is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } #[cfg(test)] mod tests { use { super::*, + crate::primitives::{PodU16, PodU32, PodU64}, bytemuck_derive::{Pod, Zeroable}, }; @@ -139,196 +206,292 @@ mod tests { } #[test] - fn test_pod_collection() { - // slice can fit 2 `TestStruct` - let mut pod_slice_bytes = [0; 70]; - // set length to 1, so we have room to push 1 more item - let len_bytes = [1, 0, 0, 0]; - pod_slice_bytes[0..4].copy_from_slice(&len_bytes); + fn init_and_push() { + let size = ListView::::size_of(2).unwrap(); + let mut buffer = vec![0u8; size]; - let mut pod_slice = PodList::::unpack(&mut pod_slice_bytes).unwrap(); + let mut pod_slice = ListView::::init(&mut buffer).unwrap(); - assert_eq!(*pod_slice.length, PodU32::from(1)); pod_slice.push(TestStruct::default()).unwrap(); - assert_eq!(*pod_slice.length, PodU32::from(2)); - let err = pod_slice - .push(TestStruct::default()) - .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); + assert_eq!(*pod_slice.length, PodU64::from(1)); + assert_eq!(pod_slice.len(), 1); + + pod_slice.push(TestStruct::default()).unwrap(); + assert_eq!(*pod_slice.length, PodU64::from(2)); + assert_eq!(pod_slice.len(), 2); + + // Buffer should be full now + let err = pod_slice.push(TestStruct::default()).unwrap_err(); assert_eq!(err, PodSliceError::BufferTooSmall.into()); } - fn make_buffer(capacity: usize, items: &[u8]) -> Vec { - let buff_len = LENGTH_SIZE.checked_add(capacity).unwrap(); + fn make_buffer(capacity: usize, items: &[u8]) -> Vec { + let length_size = size_of::(); + let padding_size = calculate_padding::().unwrap(); + let header_size = length_size.saturating_add(padding_size); + let buff_len = header_size.checked_add(capacity).unwrap(); let mut buf = vec![0u8; buff_len]; - buf[..LENGTH_SIZE].copy_from_slice(&(items.len() as u32).to_le_bytes()); - let end = LENGTH_SIZE.checked_add(items.len()).unwrap(); - buf[LENGTH_SIZE..end].copy_from_slice(items); + + // Write the length + let length = L::from_usize(items.len()).unwrap(); + let length_bytes = bytemuck::bytes_of(&length); + buf[..length_size].copy_from_slice(length_bytes); + + // Copy the data after the header + let data_end = header_size.checked_add(items.len()).unwrap(); + buf[header_size..data_end].copy_from_slice(items); buf } #[test] fn remove_at_first_item() { - let mut buff = make_buffer(15, &[10, 20, 30, 40]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_at(0).unwrap(); + let mut buff = make_buffer::(15, &[10, 20, 30, 40]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_at(0).unwrap(); assert_eq!(removed, 10); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[20, 30, 40]); - assert_eq!(pod_list.data[3], 0); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[20, 30, 40]); + assert_eq!(list_view.data[3], 0); } #[test] fn remove_at_middle_item() { - let mut buff = make_buffer(15, &[10, 20, 30, 40]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_at(2).unwrap(); + let mut buff = make_buffer::(15, &[10, 20, 30, 40]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_at(2).unwrap(); assert_eq!(removed, 30); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 20, 40]); - assert_eq!(pod_list.data[3], 0); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 40]); + assert_eq!(list_view.data[3], 0); } #[test] fn remove_at_last_item() { - let mut buff = make_buffer(15, &[10, 20, 30, 40]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_at(3).unwrap(); + let mut buff = make_buffer::(15, &[10, 20, 30, 40]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_at(3).unwrap(); assert_eq!(removed, 40); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 20, 30]); - assert_eq!(pod_list.data[3], 0); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 30]); + assert_eq!(list_view.data[3], 0); } #[test] fn remove_at_out_of_bounds() { - let mut buff = make_buffer(3, &[1, 2, 3]); + let mut buff = make_buffer::(3, &[1, 2, 3]); let original_buff = buff.clone(); - { - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let err = pod_list.remove_at(3).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove_at(3).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); - // pod_list should be unchanged - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), vec![1, 2, 3]); - } + // list_view should be unchanged + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), vec![1, 2, 3]); assert_eq!(buff, original_buff); } #[test] fn remove_at_single_element() { - let mut buff = make_buffer(1, &[10]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_at(0).unwrap(); + let mut buff = make_buffer::(1, &[10]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_at(0).unwrap(); assert_eq!(removed, 10); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 0); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[] as &[u8]); - assert_eq!(pod_list.data[0], 0); + assert_eq!(list_view.len(), 0); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[] as &[u8]); + assert_eq!(list_view.data[0], 0); } #[test] fn remove_at_empty_slice() { - let mut buff = make_buffer(0, &[]); + let mut buff = make_buffer::(0, &[]); let original_buff = buff.clone(); - { - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let err = pod_list.remove_at(0).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove_at(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); - // Assert list state is unchanged - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 0); - } + // Assert list state is unchanged + assert_eq!(list_view.len(), 0); assert_eq!(buff, original_buff); } #[test] fn remove_first_where_first_item() { - let mut buff = make_buffer(3, &[5, 10, 15]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_first_where(|&x| x == 5).unwrap(); + let mut buff = make_buffer::(3, &[5, 10, 15]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_first_where(|&x| x == 5).unwrap(); assert_eq!(removed, 5); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 2); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 15]); - assert_eq!(pod_list.data[2], 0); + assert_eq!(list_view.len(), 2); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 15]); + assert_eq!(list_view.data[2], 0); } #[test] fn remove_first_where_middle_item() { - let mut buff = make_buffer(4, &[1, 2, 3, 4]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_first_where(|v| *v == 3).unwrap(); + let mut buff = make_buffer::(4, &[1, 2, 3, 4]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_first_where(|v| *v == 3).unwrap(); assert_eq!(removed, 3); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[1, 2, 4]); - assert_eq!(pod_list.data[3], 0); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[1, 2, 4]); + assert_eq!(list_view.data[3], 0); } #[test] fn remove_first_where_last_item() { - let mut buff = make_buffer(3, &[5, 10, 15]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_first_where(|&x| x == 15).unwrap(); + let mut buff = make_buffer::(3, &[5, 10, 15]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_first_where(|&x| x == 15).unwrap(); assert_eq!(removed, 15); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 2); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[5, 10]); - assert_eq!(pod_list.data[2], 0); + assert_eq!(list_view.len(), 2); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[5, 10]); + assert_eq!(list_view.data[2], 0); } #[test] fn remove_first_where_multiple_matches() { - let mut buff = make_buffer(5, &[7, 8, 8, 9, 10]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_first_where(|v| *v == 8).unwrap(); + let mut buff = make_buffer::(5, &[7, 8, 8, 9, 10]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_first_where(|v| *v == 8).unwrap(); assert_eq!(removed, 8); // Removed *first* 8 - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 4); + assert_eq!(list_view.len(), 4); // Should remove only the *first* match. - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[7, 8, 9, 10]); - assert_eq!(pod_list.data[4], 0); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[7, 8, 9, 10]); + assert_eq!(list_view.data[4], 0); } #[test] fn remove_first_where_not_found() { - let mut buff = make_buffer(3, &[5, 6, 7]); + let mut buff = make_buffer::(3, &[5, 6, 7]); let original_buff = buff.clone(); - { - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let err = pod_list.remove_first_where(|v| *v == 42).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - // Assert list state is unchanged - assert_eq!(u32::from(*pod_list.length) as usize, 3); - } + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove_first_where(|v| *v == 42).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + // Assert list state is unchanged + assert_eq!(list_view.len(), 3); assert_eq!(buff, original_buff); } #[test] fn remove_first_where_empty_slice() { - let mut buff = make_buffer(0, &[]); + let mut buff = make_buffer::(0, &[]); let original_buff = buff.clone(); - { - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let err = pod_list.remove_first_where(|_| true).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - // Assert list state is unchanged - assert_eq!(u32::from(*pod_list.length) as usize, 0); - } + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove_first_where(|_| true).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + // Assert list state is unchanged + assert_eq!(list_view.len(), 0); assert_eq!(buff, original_buff); } + + #[test] + fn test_different_length_types() { + // Test with u16 length + let mut buff16 = make_buffer::(5, &[1, 2, 3]); + let list16 = ListView::::unpack(&mut buff16).unwrap(); + assert_eq!(list16.length.as_usize(), 3); + assert_eq!(list16.len(), 3); + + // Test with u32 length + let mut buff32 = make_buffer::(5, &[4, 5, 6]); + let list32 = ListView::::unpack(&mut buff32).unwrap(); + assert_eq!(list32.length.as_usize(), 3); + assert_eq!(list32.len(), 3); + + // Test with u64 length + let mut buff64 = make_buffer::(5, &[7, 8, 9]); + let list64 = ListView::::unpack(&mut buff64).unwrap(); + assert_eq!(list64.length.as_usize(), 3); + assert_eq!(list64.len(), 3); + } + + #[test] + fn test_calculate_padding() { + // When length and data have same alignment, no padding needed + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + + // When data alignment is smaller than or divides length size + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + + // When padding is needed + assert_eq!(calculate_padding::().unwrap(), 2); // 2 + 2 = 4 (align to 4) + assert_eq!(calculate_padding::().unwrap(), 6); // 2 + 6 = 8 (align to 8) + assert_eq!(calculate_padding::().unwrap(), 4); // 4 + 4 = 8 (align to 8) + + // Test with custom aligned structs + #[repr(C, align(8))] + #[derive(Pod, Zeroable, Copy, Clone)] + struct Align8 { + _data: [u8; 8], + } + + #[repr(C, align(16))] + #[derive(Pod, Zeroable, Copy, Clone)] + struct Align16 { + _data: [u8; 16], + } + + assert_eq!(calculate_padding::().unwrap(), 6); // 2 + 6 = 8 + assert_eq!(calculate_padding::().unwrap(), 4); // 4 + 4 = 8 + assert_eq!(calculate_padding::().unwrap(), 0); // 8 % 8 = 0 + + assert_eq!(calculate_padding::().unwrap(), 14); // 2 + 14 = 16 + assert_eq!(calculate_padding::().unwrap(), 12); // 4 + 12 = 16 + assert_eq!(calculate_padding::().unwrap(), 8); // 8 + 8 = 16 + } + + #[test] + fn test_alignment_in_practice() { + // u32 length with u64 data - needs 4 bytes padding + let size = ListView::::size_of(2).unwrap(); + let mut buffer = vec![0u8; size]; + let list = ListView::::init(&mut buffer).unwrap(); + + // Check that data pointer is 8-byte aligned + let data_ptr = list.data.as_ptr() as usize; + assert_eq!(data_ptr % 8, 0); + + // u16 length with u64 data - needs 6 bytes padding + let size = ListView::::size_of(2).unwrap(); + let mut buffer = vec![0u8; size]; + let list = ListView::::init(&mut buffer).unwrap(); + + let data_ptr = list.data.as_ptr() as usize; + assert_eq!(data_ptr % 8, 0); + } + + #[test] + fn test_length_too_large() { + // Create a buffer with capacity for 2 items + let capacity = 2; + let length_size = size_of::(); + let padding_size = calculate_padding::().unwrap(); + let header_size = length_size.saturating_add(padding_size); + let buff_len = header_size.checked_add(capacity).unwrap(); + let mut buffer = vec![0u8; buff_len]; + + // Manually write a length value that exceeds the capacity + let invalid_length = PodU32::from_usize(capacity + 1).unwrap(); + let length_bytes = bytemuck::bytes_of(&invalid_length); + buffer[..length_size].copy_from_slice(length_bytes); + + // Attempting to unpack should return BufferTooSmall error + match ListView::::unpack(&mut buffer) { + Err(err) => assert_eq!(err, PodSliceError::BufferTooSmall.into()), + Ok(_) => panic!("Expected BufferTooSmall error, but unpack succeeded"), + } + } } diff --git a/pod/src/primitives.rs b/pod/src/primitives.rs index 5eb694ed..4ae28957 100644 --- a/pod/src/primitives.rs +++ b/pod/src/primitives.rs @@ -127,6 +127,47 @@ impl_int_conversion!(PodI64, i64); pub struct PodU128(pub [u8; 16]); impl_int_conversion!(PodU128, u128); +/// Trait for types that can be used as length fields in Pod data structures +pub trait PodLength: bytemuck::Pod + Copy { + fn as_usize(&self) -> usize; + + fn from_usize(val: usize) -> Result; +} + +impl PodLength for PodU16 { + fn as_usize(&self) -> usize { + u16::from(*self) as usize + } + + fn from_usize(val: usize) -> Result { + u16::try_from(val) + .map(Into::into) + .map_err(|_| crate::error::PodSliceError::CalculationFailure) + } +} + +impl PodLength for PodU32 { + fn as_usize(&self) -> usize { + u32::from(*self) as usize + } + + fn from_usize(val: usize) -> Result { + u32::try_from(val) + .map(Into::into) + .map_err(|_| crate::error::PodSliceError::CalculationFailure) + } +} + +impl PodLength for PodU64 { + fn as_usize(&self) -> usize { + u64::from(*self) as usize + } + + fn from_usize(val: usize) -> Result { + Ok((val as u64).into()) + } +} + #[cfg(test)] mod tests { use {super::*, crate::bytemuck::pod_from_bytes}; diff --git a/pod/src/slice.rs b/pod/src/slice.rs index 0082c7a4..8ee4826d 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -2,10 +2,9 @@ use { crate::{ - bytemuck::{ - pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut, - }, + bytemuck::{pod_from_bytes, pod_slice_from_bytes}, error::PodSliceError, + list::ListView, primitives::PodU32, }, bytemuck::Pod, @@ -51,46 +50,23 @@ impl<'data, T: Pod> PodSlice<'data, T> { #[deprecated( since = "0.6.0", - note = "This struct will be removed in the next major release (1.0.0). Please use `PodList` instead." + note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." )] -/// Special type for using a slice of mutable `Pod`s in a zero-copy way +/// Special type for using a slice of mutable `Pod`s in a zero-copy way. +/// Uses `ListView` under the hood. pub struct PodSliceMut<'data, T: Pod> { - length: &'data mut PodU32, - data: &'data mut [T], - max_length: usize, + inner: ListView<'data, T, PodU32>, } #[allow(deprecated)] impl<'data, T: Pod> PodSliceMut<'data, T> { - /// Unpack the mutable buffer into a mutable slice, with the option to - /// initialize the data - fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result - where - 'a: 'data, - { - if data.len() < LENGTH_SIZE { - return Err(PodSliceError::BufferTooSmall.into()); - } - let (length, data) = data.split_at_mut(LENGTH_SIZE); - let length = pod_from_bytes_mut::(length)?; - if init { - *length = 0.into(); - } - let max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; - let data = pod_slice_from_bytes_mut(data)?; - Ok(Self { - length, - data, - max_length, - }) - } - /// Unpack the mutable buffer into a mutable slice pub fn unpack<'a>(data: &'a mut [u8]) -> Result where 'a: 'data, { - Self::unpack_internal(data, /* init */ false) + let inner = ListView::::unpack(data)?; + Ok(Self { inner }) } /// Unpack the mutable buffer into a mutable slice, and initialize the @@ -99,23 +75,17 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { where 'a: 'data, { - Self::unpack_internal(data, /* init */ true) + let inner = ListView::::init(data)?; + Ok(Self { inner }) } /// Add another item to the slice pub fn push(&mut self, t: T) -> Result<(), ProgramError> { - let length = u32::from(*self.length); - if length as usize == self.max_length { - Err(PodSliceError::BufferTooSmall.into()) - } else { - self.data[length as usize] = t; - *self.length = length.saturating_add(1).into(); - Ok(()) - } + self.inner.push(t) } } -pub fn max_len_for_type(data_len: usize, length_val: usize) -> Result { +fn max_len_for_type(data_len: usize, length_val: usize) -> Result { let item_size = std::mem::size_of::(); let max_len = data_len .checked_div(item_size) @@ -275,11 +245,33 @@ mod tests { let len_bytes = [1, 0, 0, 0]; pod_slice_bytes[0..4].copy_from_slice(&len_bytes); - let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); + // Verify initial length + assert_eq!( + u32::from_le_bytes([ + pod_slice_bytes[0], + pod_slice_bytes[1], + pod_slice_bytes[2], + pod_slice_bytes[3] + ]), + 1 + ); - assert_eq!(*pod_slice.length, PodU32::from(1)); + let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); pod_slice.push(TestStruct::default()).unwrap(); - assert_eq!(*pod_slice.length, PodU32::from(2)); + + // Check length after push + assert_eq!( + u32::from_le_bytes([ + pod_slice_bytes[0], + pod_slice_bytes[1], + pod_slice_bytes[2], + pod_slice_bytes[3] + ]), + 2 + ); + + // Test that buffer is full + let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); let err = pod_slice .push(TestStruct::default()) .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); diff --git a/tlv-account-resolution/Cargo.toml b/tlv-account-resolution/Cargo.toml index 63e06a30..589ca9d3 100644 --- a/tlv-account-resolution/Cargo.toml +++ b/tlv-account-resolution/Cargo.toml @@ -24,7 +24,7 @@ solana-msg = "2.2.1" solana-pubkey = { version = "2.2.1", features = ["curve25519"] } spl-discriminator = { version = "0.4.0", path = "../discriminator" } spl-program-error = { version = "0.7.0", path = "../program-error" } -spl-pod = { version = "0.6.0", path = "../pod" } +spl-pod = { version = "0.5.1", path = "../pod" } spl-type-length-value = { version = "0.8.0", path = "../type-length-value" } thiserror = "2.0" diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs index e901fff3..74ea8e36 100644 --- a/tlv-account-resolution/src/state.rs +++ b/tlv-account-resolution/src/state.rs @@ -1,6 +1,5 @@ //! State transition types -use spl_pod::list::PodList; use { crate::{account::ExtraAccountMeta, error::AccountResolutionError}, solana_account_info::AccountInfo, @@ -8,7 +7,7 @@ use { solana_program_error::ProgramError, solana_pubkey::Pubkey, spl_discriminator::SplDiscriminate, - spl_pod::slice::PodSlice, + spl_pod::{list::ListView, slice::PodSlice}, spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, std::future::Future, }; @@ -173,7 +172,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; - let mut validation_data = PodList::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -189,7 +188,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; - let mut validation_data = PodList::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } diff --git a/type-length-value/Cargo.toml b/type-length-value/Cargo.toml index 39b930d6..9f8fc535 100644 --- a/type-length-value/Cargo.toml +++ b/type-length-value/Cargo.toml @@ -21,7 +21,7 @@ solana-msg = "2.2.1" solana-program-error = "2.2.1" spl-discriminator = { version = "0.4.0", path = "../discriminator" } spl-type-length-value-derive = { version = "0.2", path = "./derive", optional = true } -spl-pod = { version = "0.6.0", path = "../pod" } +spl-pod = { version = "0.5.1", path = "../pod" } thiserror = "2.0" [lib] From 80c67eb5faa92c02f04e6c5b331fd9512351c8b4 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Wed, 9 Jul 2025 14:18:07 +0200 Subject: [PATCH 3/4] Fix test + add more specific doc string --- pod/src/slice.rs | 4 +++- tlv-account-resolution/src/state.rs | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pod/src/slice.rs b/pod/src/slice.rs index 8ee4826d..34433466 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -50,7 +50,9 @@ impl<'data, T: Pod> PodSlice<'data, T> { #[deprecated( since = "0.6.0", - note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." + note = "This struct will be removed in the next major release (1.0.0). \ + Please use `ListView` instead. If using with existing data initialized by PodSliceMut, \ + you need to specifiy PodU32 length (e.g. ListView::::init(bytes))" )] /// Special type for using a slice of mutable `Pod`s in a zero-copy way. /// Uses `ListView` under the hood. diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs index 74ea8e36..6bed11d6 100644 --- a/tlv-account-resolution/src/state.rs +++ b/tlv-account-resolution/src/state.rs @@ -7,7 +7,7 @@ use { solana_program_error::ProgramError, solana_pubkey::Pubkey, spl_discriminator::SplDiscriminate, - spl_pod::{list::ListView, slice::PodSlice}, + spl_pod::{list::ListView, primitives::PodU32, slice::PodSlice}, spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, std::future::Future, }; @@ -172,7 +172,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; - let mut validation_data = ListView::::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -188,7 +188,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; - let mut validation_data = ListView::::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } From b3982b80d3a4087de9dd9885a7e4e961c9d61cef Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Tue, 15 Jul 2025 15:08:01 +0200 Subject: [PATCH 4/4] Review updates --- pod/src/error.rs | 7 ++ pod/src/lib.rs | 1 + pod/src/list.rs | 157 ++++++++++++------------------------------ pod/src/pod_length.rs | 54 +++++++++++++++ pod/src/primitives.rs | 41 ----------- 5 files changed, 106 insertions(+), 154 deletions(-) create mode 100644 pod/src/pod_length.rs diff --git a/pod/src/error.rs b/pod/src/error.rs index c8325f4b..f4fd6ca5 100644 --- a/pod/src/error.rs +++ b/pod/src/error.rs @@ -2,6 +2,7 @@ use { solana_msg::msg, solana_program_error::{PrintProgramError, ProgramError}, + std::num::TryFromIntError, }; /// Errors that may be returned by the spl-pod library. @@ -49,3 +50,9 @@ impl PrintProgramError for PodSliceError { } } } + +impl From for PodSliceError { + fn from(_: TryFromIntError) -> Self { + PodSliceError::CalculationFailure + } +} diff --git a/pod/src/lib.rs b/pod/src/lib.rs index ab6ec113..8a00c512 100644 --- a/pod/src/lib.rs +++ b/pod/src/lib.rs @@ -5,6 +5,7 @@ pub mod error; pub mod list; pub mod option; pub mod optional_keys; +pub mod pod_length; pub mod primitives; pub mod slice; diff --git a/pod/src/list.rs b/pod/src/list.rs index 480761bd..44ae527a 100644 --- a/pod/src/list.rs +++ b/pod/src/list.rs @@ -2,7 +2,8 @@ use { crate::{ bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}, error::PodSliceError, - primitives::{PodLength, PodU64}, + pod_length::PodLength, + primitives::PodU64, }, bytemuck::Pod, core::mem::{align_of, size_of}, @@ -29,15 +30,15 @@ fn calculate_padding() -> Result { } } -/// A mutable, variable-length collection of `Pod` types backed by a byte buffer. +/// An API for interpreting a raw buffer (`&[u8]`) as a mutable, variable-length collection of Pod elements. /// /// `ListView` provides a safe, zero-copy, `Vec`-like interface for a slice of /// `Pod` data that resides in an external, pre-allocated `&mut [u8]` buffer. /// It does not own the buffer itself, but acts as a mutable view over it. /// /// This is useful in environments where allocations are restricted or expensive, -/// such as Solana programs, allowing for dynamic-length data structures within a -/// fixed-size account. +/// such as Solana programs, allowing for efficient reads and manipulation of +/// dynamic-length data structures. /// /// ## Memory Layout /// @@ -85,8 +86,8 @@ impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { // Initialize the list or validate its current length. if init { - *length = L::from_usize(0)?; - } else if length.as_usize() > max_length { + *length = L::try_from(0)?; + } else if (*length).into() > max_length { return Err(PodSliceError::BufferTooSmall.into()); } @@ -115,21 +116,21 @@ impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { } /// Add another item to the slice - pub fn push(&mut self, t: T) -> Result<(), ProgramError> { - let length = self.length.as_usize(); - if length == self.max_length { + pub fn push(&mut self, item: T) -> Result<(), ProgramError> { + let length = (*self.length).into(); + if length >= self.max_length { Err(PodSliceError::BufferTooSmall.into()) } else { - self.data[length] = t; - *self.length = L::from_usize(length.saturating_add(1))?; + self.data[length] = item; + *self.length = L::try_from(length.saturating_add(1))?; Ok(()) } } /// Remove and return the element at `index`, shifting all later /// elements one position to the left. - pub fn remove_at(&mut self, index: usize) -> Result { - let len = self.length.as_usize(); + pub fn remove(&mut self, index: usize) -> Result { + let len = (*self.length).into(); if index >= len { return Err(ProgramError::InvalidArgument); } @@ -147,24 +148,11 @@ impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { self.data[last] = T::zeroed(); // Store the new length (len - 1) - *self.length = L::from_usize(last)?; + *self.length = L::try_from(last)?; Ok(removed_item) } - /// Find the first element that satisfies `predicate` and remove it, - /// returning the element. - pub fn remove_first_where

(&mut self, predicate: P) -> Result - where - P: FnMut(&T) -> bool, - { - if let Some(index) = self.data.iter().position(predicate) { - self.remove_at(index) - } else { - Err(ProgramError::InvalidArgument) - } - } - /// Get the amount of bytes used by `num_items` pub fn size_of(num_items: usize) -> Result { let padding_size = calculate_padding::()?; @@ -181,13 +169,25 @@ impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { /// Get the current number of items in collection pub fn len(&self) -> usize { - self.length.as_usize() + (*self.length).into() } /// Returns true if the collection is empty pub fn is_empty(&self) -> bool { self.len() == 0 } + + /// Returns an iterator over the current elements + pub fn iter(&self) -> std::slice::Iter { + let len = (*self.length).into(); + self.data[..len].iter() + } + + /// Returns a mutable iterator over the current elements + pub fn iter_mut(&mut self) -> std::slice::IterMut { + let len = (*self.length).into(); + self.data[..len].iter_mut() + } } #[cfg(test)] @@ -225,7 +225,11 @@ mod tests { assert_eq!(err, PodSliceError::BufferTooSmall.into()); } - fn make_buffer(capacity: usize, items: &[u8]) -> Vec { + fn make_buffer + TryFrom>(capacity: usize, items: &[u8]) -> Vec + where + PodSliceError: From<>::Error>, + >::Error: std::fmt::Debug, + { let length_size = size_of::(); let padding_size = calculate_padding::().unwrap(); let header_size = length_size.saturating_add(padding_size); @@ -233,7 +237,7 @@ mod tests { let mut buf = vec![0u8; buff_len]; // Write the length - let length = L::from_usize(items.len()).unwrap(); + let length = L::try_from(items.len()).unwrap(); let length_bytes = bytemuck::bytes_of(&length); buf[..length_size].copy_from_slice(length_bytes); @@ -247,7 +251,7 @@ mod tests { fn remove_at_first_item() { let mut buff = make_buffer::(15, &[10, 20, 30, 40]); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_at(0).unwrap(); + let removed = list_view.remove(0).unwrap(); assert_eq!(removed, 10); assert_eq!(list_view.len(), 3); assert_eq!(list_view.data[..list_view.len()].to_vec(), &[20, 30, 40]); @@ -258,7 +262,7 @@ mod tests { fn remove_at_middle_item() { let mut buff = make_buffer::(15, &[10, 20, 30, 40]); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_at(2).unwrap(); + let removed = list_view.remove(2).unwrap(); assert_eq!(removed, 30); assert_eq!(list_view.len(), 3); assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 40]); @@ -269,7 +273,7 @@ mod tests { fn remove_at_last_item() { let mut buff = make_buffer::(15, &[10, 20, 30, 40]); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_at(3).unwrap(); + let removed = list_view.remove(3).unwrap(); assert_eq!(removed, 40); assert_eq!(list_view.len(), 3); assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 30]); @@ -282,7 +286,7 @@ mod tests { let original_buff = buff.clone(); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove_at(3).unwrap_err(); + let err = list_view.remove(3).unwrap_err(); assert_eq!(err, ProgramError::InvalidArgument); // list_view should be unchanged @@ -296,7 +300,7 @@ mod tests { fn remove_at_single_element() { let mut buff = make_buffer::(1, &[10]); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_at(0).unwrap(); + let removed = list_view.remove(0).unwrap(); assert_eq!(removed, 10); assert_eq!(list_view.len(), 0); assert_eq!(list_view.data[..list_view.len()].to_vec(), &[] as &[u8]); @@ -309,82 +313,9 @@ mod tests { let original_buff = buff.clone(); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove_at(0).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - - // Assert list state is unchanged - assert_eq!(list_view.len(), 0); - - assert_eq!(buff, original_buff); - } - - #[test] - fn remove_first_where_first_item() { - let mut buff = make_buffer::(3, &[5, 10, 15]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_first_where(|&x| x == 5).unwrap(); - assert_eq!(removed, 5); - assert_eq!(list_view.len(), 2); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 15]); - assert_eq!(list_view.data[2], 0); - } - - #[test] - fn remove_first_where_middle_item() { - let mut buff = make_buffer::(4, &[1, 2, 3, 4]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_first_where(|v| *v == 3).unwrap(); - assert_eq!(removed, 3); - assert_eq!(list_view.len(), 3); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[1, 2, 4]); - assert_eq!(list_view.data[3], 0); - } - - #[test] - fn remove_first_where_last_item() { - let mut buff = make_buffer::(3, &[5, 10, 15]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_first_where(|&x| x == 15).unwrap(); - assert_eq!(removed, 15); - assert_eq!(list_view.len(), 2); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[5, 10]); - assert_eq!(list_view.data[2], 0); - } - - #[test] - fn remove_first_where_multiple_matches() { - let mut buff = make_buffer::(5, &[7, 8, 8, 9, 10]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_first_where(|v| *v == 8).unwrap(); - assert_eq!(removed, 8); // Removed *first* 8 - assert_eq!(list_view.len(), 4); - // Should remove only the *first* match. - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[7, 8, 9, 10]); - assert_eq!(list_view.data[4], 0); - } - - #[test] - fn remove_first_where_not_found() { - let mut buff = make_buffer::(3, &[5, 6, 7]); - let original_buff = buff.clone(); - - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove_first_where(|v| *v == 42).unwrap_err(); + let err = list_view.remove(0).unwrap_err(); assert_eq!(err, ProgramError::InvalidArgument); - // Assert list state is unchanged - assert_eq!(list_view.len(), 3); - - assert_eq!(buff, original_buff); - } - - #[test] - fn remove_first_where_empty_slice() { - let mut buff = make_buffer::(0, &[]); - let original_buff = buff.clone(); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove_first_where(|_| true).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); // Assert list state is unchanged assert_eq!(list_view.len(), 0); @@ -396,19 +327,19 @@ mod tests { // Test with u16 length let mut buff16 = make_buffer::(5, &[1, 2, 3]); let list16 = ListView::::unpack(&mut buff16).unwrap(); - assert_eq!(list16.length.as_usize(), 3); + assert_eq!(list16.len(), 3); assert_eq!(list16.len(), 3); // Test with u32 length let mut buff32 = make_buffer::(5, &[4, 5, 6]); let list32 = ListView::::unpack(&mut buff32).unwrap(); - assert_eq!(list32.length.as_usize(), 3); + assert_eq!(list32.len(), 3); assert_eq!(list32.len(), 3); // Test with u64 length let mut buff64 = make_buffer::(5, &[7, 8, 9]); let list64 = ListView::::unpack(&mut buff64).unwrap(); - assert_eq!(list64.length.as_usize(), 3); + assert_eq!(list64.len(), 3); assert_eq!(list64.len(), 3); } @@ -484,7 +415,7 @@ mod tests { let mut buffer = vec![0u8; buff_len]; // Manually write a length value that exceeds the capacity - let invalid_length = PodU32::from_usize(capacity + 1).unwrap(); + let invalid_length = PodU32::try_from(capacity + 1).unwrap(); let length_bytes = bytemuck::bytes_of(&invalid_length); buffer[..length_size].copy_from_slice(length_bytes); diff --git a/pod/src/pod_length.rs b/pod/src/pod_length.rs new file mode 100644 index 00000000..52e4a7eb --- /dev/null +++ b/pod/src/pod_length.rs @@ -0,0 +1,54 @@ +use { + crate::{ + error::PodSliceError, + primitives::{PodU16, PodU32, PodU64}, + }, + bytemuck::Pod, +}; + +/// Marker trait for converting to/from Pod `uint`'s and `usize` +pub trait PodLength: Pod + Into + TryFrom {} + +impl PodLength for T where T: Pod + Into + TryFrom {} + +impl TryFrom for PodU16 { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + Ok(u16::try_from(val)?.into()) + } +} + +impl From for usize { + fn from(pod: PodU16) -> Self { + u16::from(pod) as usize + } +} + +impl TryFrom for PodU32 { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + Ok(u32::try_from(val)?.into()) + } +} + +impl From for usize { + fn from(pod: PodU32) -> Self { + u32::from(pod) as usize + } +} + +impl TryFrom for PodU64 { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + Ok(u64::try_from(val)?.into()) + } +} + +impl From for usize { + fn from(pod: PodU64) -> Self { + u64::from(pod) as usize + } +} diff --git a/pod/src/primitives.rs b/pod/src/primitives.rs index 4ae28957..5eb694ed 100644 --- a/pod/src/primitives.rs +++ b/pod/src/primitives.rs @@ -127,47 +127,6 @@ impl_int_conversion!(PodI64, i64); pub struct PodU128(pub [u8; 16]); impl_int_conversion!(PodU128, u128); -/// Trait for types that can be used as length fields in Pod data structures -pub trait PodLength: bytemuck::Pod + Copy { - fn as_usize(&self) -> usize; - - fn from_usize(val: usize) -> Result; -} - -impl PodLength for PodU16 { - fn as_usize(&self) -> usize { - u16::from(*self) as usize - } - - fn from_usize(val: usize) -> Result { - u16::try_from(val) - .map(Into::into) - .map_err(|_| crate::error::PodSliceError::CalculationFailure) - } -} - -impl PodLength for PodU32 { - fn as_usize(&self) -> usize { - u32::from(*self) as usize - } - - fn from_usize(val: usize) -> Result { - u32::try_from(val) - .map(Into::into) - .map_err(|_| crate::error::PodSliceError::CalculationFailure) - } -} - -impl PodLength for PodU64 { - fn as_usize(&self) -> usize { - u64::from(*self) as usize - } - - fn from_usize(val: usize) -> Result { - Ok((val as u64).into()) - } -} - #[cfg(test)] mod tests { use {super::*, crate::bytemuck::pod_from_bytes};