diff --git a/pod/src/error.rs b/pod/src/error.rs index c8325f4..f4fd6ca 100644 --- a/pod/src/error.rs +++ b/pod/src/error.rs @@ -2,6 +2,7 @@ use { solana_msg::msg, solana_program_error::{PrintProgramError, ProgramError}, + std::num::TryFromIntError, }; /// Errors that may be returned by the spl-pod library. @@ -49,3 +50,9 @@ impl PrintProgramError for PodSliceError { } } } + +impl From for PodSliceError { + fn from(_: TryFromIntError) -> Self { + PodSliceError::CalculationFailure + } +} diff --git a/pod/src/lib.rs b/pod/src/lib.rs index 292f1c0..8a00c51 100644 --- a/pod/src/lib.rs +++ b/pod/src/lib.rs @@ -2,8 +2,10 @@ pub mod bytemuck; pub mod error; +pub mod list; pub mod option; pub mod optional_keys; +pub mod pod_length; pub mod primitives; pub mod slice; diff --git a/pod/src/list.rs b/pod/src/list.rs new file mode 100644 index 0000000..44ae527 --- /dev/null +++ b/pod/src/list.rs @@ -0,0 +1,428 @@ +use { + crate::{ + bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}, + error::PodSliceError, + pod_length::PodLength, + primitives::PodU64, + }, + bytemuck::Pod, + core::mem::{align_of, size_of}, + solana_program_error::ProgramError, +}; + +/// Calculate padding needed between types for alignment +#[inline] +fn calculate_padding() -> Result { + let length_size = size_of::(); + let data_align = align_of::(); + + // Calculate how many bytes we need to add to length_size + // to make it a multiple of data_align + let remainder = length_size + .checked_rem(data_align) + .ok_or(ProgramError::ArithmeticOverflow)?; + if remainder == 0 { + Ok(0) + } else { + data_align + .checked_sub(remainder) + .ok_or(ProgramError::ArithmeticOverflow) + } +} + +/// An API for interpreting a raw buffer (`&[u8]`) as a mutable, variable-length collection of Pod elements. +/// +/// `ListView` provides a safe, zero-copy, `Vec`-like interface for a slice of +/// `Pod` data that resides in an external, pre-allocated `&mut [u8]` buffer. +/// It does not own the buffer itself, but acts as a mutable view over it. +/// +/// This is useful in environments where allocations are restricted or expensive, +/// such as Solana programs, allowing for efficient reads and manipulation of +/// dynamic-length data structures. +/// +/// ## Memory Layout +/// +/// The structure assumes the underlying byte buffer is formatted as follows: +/// 1. **Length**: A length field of type `L` at the beginning of the buffer, +/// indicating the number of currently active elements in the collection. Defaults to `PodU64` so the offset is then compatible with 1, 2, 4 and 8 bytes. +/// 2. **Padding**: Optional padding bytes to ensure proper alignment of the data. +/// 3. **Data**: The remaining part of the buffer, which is treated as a slice +/// of `T` elements. The capacity of the collection is the number of `T` +/// elements that can fit into this data portion. +pub struct ListView<'data, T: Pod, L: PodLength = PodU64> { + length: &'data mut L, + data: &'data mut [T], + max_length: usize, +} + +impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { + /// Unpack the mutable buffer into a mutable slice, with the option to + /// initialize the data + #[inline(always)] + fn unpack_internal(buf: &'data mut [u8], init: bool) -> Result { + // Split the buffer to get the length prefix. + // buf: [ L L L L | P P D D D D D D D D ...] + // <-------> <----------------------> + // len_bytes tail + let length_size = size_of::(); + if buf.len() < length_size { + return Err(PodSliceError::BufferTooSmall.into()); + } + let (len_bytes, tail) = buf.split_at_mut(length_size); + + // Skip alignment padding to find the start of the data. + // tail: [P P | D D D D D D D D ...] + // <-> <-------------------> + // padding data_bytes + let padding = calculate_padding::()?; + let data_bytes = tail + .get_mut(padding..) + .ok_or(PodSliceError::BufferTooSmall)?; + + // Cast the bytes to typed data + let length = pod_from_bytes_mut::(len_bytes)?; + let data = pod_slice_from_bytes_mut::(data_bytes)?; + let max_length = data.len(); + + // Initialize the list or validate its current length. + if init { + *length = L::try_from(0)?; + } else if (*length).into() > max_length { + return Err(PodSliceError::BufferTooSmall.into()); + } + + Ok(Self { + length, + data, + max_length, + }) + } + + /// Unpack the mutable buffer into a mutable slice + pub fn unpack<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + Self::unpack_internal(data, false) + } + + /// Unpack the mutable buffer into a mutable slice, and initialize the + /// slice to 0-length + pub fn init<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + Self::unpack_internal(data, true) + } + + /// Add another item to the slice + pub fn push(&mut self, item: T) -> Result<(), ProgramError> { + let length = (*self.length).into(); + if length >= self.max_length { + Err(PodSliceError::BufferTooSmall.into()) + } else { + self.data[length] = item; + *self.length = L::try_from(length.saturating_add(1))?; + Ok(()) + } + } + + /// Remove and return the element at `index`, shifting all later + /// elements one position to the left. + pub fn remove(&mut self, index: usize) -> Result { + let len = (*self.length).into(); + if index >= len { + return Err(ProgramError::InvalidArgument); + } + + let removed_item = self.data[index]; + + // Move the tail left by one + let tail_start = index + .checked_add(1) + .ok_or(ProgramError::ArithmeticOverflow)?; + self.data.copy_within(tail_start..len, index); + + // Zero-fill the now-unused slot at the end + let last = len.saturating_sub(1); + self.data[last] = T::zeroed(); + + // Store the new length (len - 1) + *self.length = L::try_from(last)?; + + Ok(removed_item) + } + + /// Get the amount of bytes used by `num_items` + pub fn size_of(num_items: usize) -> Result { + let padding_size = calculate_padding::()?; + let header_size = size_of::().saturating_add(padding_size); + + let data_size = size_of::() + .checked_mul(num_items) + .ok_or(PodSliceError::CalculationFailure)?; + + header_size + .checked_add(data_size) + .ok_or(PodSliceError::CalculationFailure.into()) + } + + /// Get the current number of items in collection + pub fn len(&self) -> usize { + (*self.length).into() + } + + /// Returns true if the collection is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the current elements + pub fn iter(&self) -> std::slice::Iter { + let len = (*self.length).into(); + self.data[..len].iter() + } + + /// Returns a mutable iterator over the current elements + pub fn iter_mut(&mut self) -> std::slice::IterMut { + let len = (*self.length).into(); + self.data[..len].iter_mut() + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::primitives::{PodU16, PodU32, PodU64}, + bytemuck_derive::{Pod, Zeroable}, + }; + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestStruct { + test_field: u8, + test_pubkey: [u8; 32], + } + + #[test] + fn init_and_push() { + let size = ListView::::size_of(2).unwrap(); + let mut buffer = vec![0u8; size]; + + let mut pod_slice = ListView::::init(&mut buffer).unwrap(); + + pod_slice.push(TestStruct::default()).unwrap(); + assert_eq!(*pod_slice.length, PodU64::from(1)); + assert_eq!(pod_slice.len(), 1); + + pod_slice.push(TestStruct::default()).unwrap(); + assert_eq!(*pod_slice.length, PodU64::from(2)); + assert_eq!(pod_slice.len(), 2); + + // Buffer should be full now + let err = pod_slice.push(TestStruct::default()).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + fn make_buffer + TryFrom>(capacity: usize, items: &[u8]) -> Vec + where + PodSliceError: From<>::Error>, + >::Error: std::fmt::Debug, + { + let length_size = size_of::(); + let padding_size = calculate_padding::().unwrap(); + let header_size = length_size.saturating_add(padding_size); + let buff_len = header_size.checked_add(capacity).unwrap(); + let mut buf = vec![0u8; buff_len]; + + // Write the length + let length = L::try_from(items.len()).unwrap(); + let length_bytes = bytemuck::bytes_of(&length); + buf[..length_size].copy_from_slice(length_bytes); + + // Copy the data after the header + let data_end = header_size.checked_add(items.len()).unwrap(); + buf[header_size..data_end].copy_from_slice(items); + buf + } + + #[test] + fn remove_at_first_item() { + let mut buff = make_buffer::(15, &[10, 20, 30, 40]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove(0).unwrap(); + assert_eq!(removed, 10); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[20, 30, 40]); + assert_eq!(list_view.data[3], 0); + } + + #[test] + fn remove_at_middle_item() { + let mut buff = make_buffer::(15, &[10, 20, 30, 40]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove(2).unwrap(); + assert_eq!(removed, 30); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 40]); + assert_eq!(list_view.data[3], 0); + } + + #[test] + fn remove_at_last_item() { + let mut buff = make_buffer::(15, &[10, 20, 30, 40]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove(3).unwrap(); + assert_eq!(removed, 40); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 30]); + assert_eq!(list_view.data[3], 0); + } + + #[test] + fn remove_at_out_of_bounds() { + let mut buff = make_buffer::(3, &[1, 2, 3]); + let original_buff = buff.clone(); + + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove(3).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + + // list_view should be unchanged + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), vec![1, 2, 3]); + + assert_eq!(buff, original_buff); + } + + #[test] + fn remove_at_single_element() { + let mut buff = make_buffer::(1, &[10]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove(0).unwrap(); + assert_eq!(removed, 10); + assert_eq!(list_view.len(), 0); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[] as &[u8]); + assert_eq!(list_view.data[0], 0); + } + + #[test] + fn remove_at_empty_slice() { + let mut buff = make_buffer::(0, &[]); + let original_buff = buff.clone(); + + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + + // Assert list state is unchanged + assert_eq!(list_view.len(), 0); + + assert_eq!(buff, original_buff); + } + + #[test] + fn test_different_length_types() { + // Test with u16 length + let mut buff16 = make_buffer::(5, &[1, 2, 3]); + let list16 = ListView::::unpack(&mut buff16).unwrap(); + assert_eq!(list16.len(), 3); + assert_eq!(list16.len(), 3); + + // Test with u32 length + let mut buff32 = make_buffer::(5, &[4, 5, 6]); + let list32 = ListView::::unpack(&mut buff32).unwrap(); + assert_eq!(list32.len(), 3); + assert_eq!(list32.len(), 3); + + // Test with u64 length + let mut buff64 = make_buffer::(5, &[7, 8, 9]); + let list64 = ListView::::unpack(&mut buff64).unwrap(); + assert_eq!(list64.len(), 3); + assert_eq!(list64.len(), 3); + } + + #[test] + fn test_calculate_padding() { + // When length and data have same alignment, no padding needed + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + + // When data alignment is smaller than or divides length size + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + + // When padding is needed + assert_eq!(calculate_padding::().unwrap(), 2); // 2 + 2 = 4 (align to 4) + assert_eq!(calculate_padding::().unwrap(), 6); // 2 + 6 = 8 (align to 8) + assert_eq!(calculate_padding::().unwrap(), 4); // 4 + 4 = 8 (align to 8) + + // Test with custom aligned structs + #[repr(C, align(8))] + #[derive(Pod, Zeroable, Copy, Clone)] + struct Align8 { + _data: [u8; 8], + } + + #[repr(C, align(16))] + #[derive(Pod, Zeroable, Copy, Clone)] + struct Align16 { + _data: [u8; 16], + } + + assert_eq!(calculate_padding::().unwrap(), 6); // 2 + 6 = 8 + assert_eq!(calculate_padding::().unwrap(), 4); // 4 + 4 = 8 + assert_eq!(calculate_padding::().unwrap(), 0); // 8 % 8 = 0 + + assert_eq!(calculate_padding::().unwrap(), 14); // 2 + 14 = 16 + assert_eq!(calculate_padding::().unwrap(), 12); // 4 + 12 = 16 + assert_eq!(calculate_padding::().unwrap(), 8); // 8 + 8 = 16 + } + + #[test] + fn test_alignment_in_practice() { + // u32 length with u64 data - needs 4 bytes padding + let size = ListView::::size_of(2).unwrap(); + let mut buffer = vec![0u8; size]; + let list = ListView::::init(&mut buffer).unwrap(); + + // Check that data pointer is 8-byte aligned + let data_ptr = list.data.as_ptr() as usize; + assert_eq!(data_ptr % 8, 0); + + // u16 length with u64 data - needs 6 bytes padding + let size = ListView::::size_of(2).unwrap(); + let mut buffer = vec![0u8; size]; + let list = ListView::::init(&mut buffer).unwrap(); + + let data_ptr = list.data.as_ptr() as usize; + assert_eq!(data_ptr % 8, 0); + } + + #[test] + fn test_length_too_large() { + // Create a buffer with capacity for 2 items + let capacity = 2; + let length_size = size_of::(); + let padding_size = calculate_padding::().unwrap(); + let header_size = length_size.saturating_add(padding_size); + let buff_len = header_size.checked_add(capacity).unwrap(); + let mut buffer = vec![0u8; buff_len]; + + // Manually write a length value that exceeds the capacity + let invalid_length = PodU32::try_from(capacity + 1).unwrap(); + let length_bytes = bytemuck::bytes_of(&invalid_length); + buffer[..length_size].copy_from_slice(length_bytes); + + // Attempting to unpack should return BufferTooSmall error + match ListView::::unpack(&mut buffer) { + Err(err) => assert_eq!(err, PodSliceError::BufferTooSmall.into()), + Ok(_) => panic!("Expected BufferTooSmall error, but unpack succeeded"), + } + } +} diff --git a/pod/src/pod_length.rs b/pod/src/pod_length.rs new file mode 100644 index 0000000..52e4a7e --- /dev/null +++ b/pod/src/pod_length.rs @@ -0,0 +1,54 @@ +use { + crate::{ + error::PodSliceError, + primitives::{PodU16, PodU32, PodU64}, + }, + bytemuck::Pod, +}; + +/// Marker trait for converting to/from Pod `uint`'s and `usize` +pub trait PodLength: Pod + Into + TryFrom {} + +impl PodLength for T where T: Pod + Into + TryFrom {} + +impl TryFrom for PodU16 { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + Ok(u16::try_from(val)?.into()) + } +} + +impl From for usize { + fn from(pod: PodU16) -> Self { + u16::from(pod) as usize + } +} + +impl TryFrom for PodU32 { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + Ok(u32::try_from(val)?.into()) + } +} + +impl From for usize { + fn from(pod: PodU32) -> Self { + u32::from(pod) as usize + } +} + +impl TryFrom for PodU64 { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + Ok(u64::try_from(val)?.into()) + } +} + +impl From for usize { + fn from(pod: PodU64) -> Self { + u64::from(pod) as usize + } +} diff --git a/pod/src/slice.rs b/pod/src/slice.rs index ca885ab..3443346 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -2,10 +2,9 @@ use { crate::{ - bytemuck::{ - pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut, - }, + bytemuck::{pod_from_bytes, pod_slice_from_bytes}, error::PodSliceError, + list::ListView, primitives::PodU32, }, bytemuck::Pod, @@ -49,42 +48,27 @@ impl<'data, T: Pod> PodSlice<'data, T> { } } -/// Special type for using a slice of mutable `Pod`s in a zero-copy way +#[deprecated( + since = "0.6.0", + note = "This struct will be removed in the next major release (1.0.0). \ + Please use `ListView` instead. If using with existing data initialized by PodSliceMut, \ + you need to specifiy PodU32 length (e.g. ListView::::init(bytes))" +)] +/// Special type for using a slice of mutable `Pod`s in a zero-copy way. +/// Uses `ListView` under the hood. pub struct PodSliceMut<'data, T: Pod> { - length: &'data mut PodU32, - data: &'data mut [T], - max_length: usize, + inner: ListView<'data, T, PodU32>, } -impl<'data, T: Pod> PodSliceMut<'data, T> { - /// Unpack the mutable buffer into a mutable slice, with the option to - /// initialize the data - fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result - where - 'a: 'data, - { - if data.len() < LENGTH_SIZE { - return Err(PodSliceError::BufferTooSmall.into()); - } - let (length, data) = data.split_at_mut(LENGTH_SIZE); - let length = pod_from_bytes_mut::(length)?; - if init { - *length = 0.into(); - } - let max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; - let data = pod_slice_from_bytes_mut(data)?; - Ok(Self { - length, - data, - max_length, - }) - } +#[allow(deprecated)] +impl<'data, T: Pod> PodSliceMut<'data, T> { /// Unpack the mutable buffer into a mutable slice pub fn unpack<'a>(data: &'a mut [u8]) -> Result where 'a: 'data, { - Self::unpack_internal(data, /* init */ false) + let inner = ListView::::unpack(data)?; + Ok(Self { inner }) } /// Unpack the mutable buffer into a mutable slice, and initialize the @@ -93,19 +77,13 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { where 'a: 'data, { - Self::unpack_internal(data, /* init */ true) + let inner = ListView::::init(data)?; + Ok(Self { inner }) } /// Add another item to the slice pub fn push(&mut self, t: T) -> Result<(), ProgramError> { - let length = u32::from(*self.length); - if length as usize == self.max_length { - Err(PodSliceError::BufferTooSmall.into()) - } else { - self.data[length as usize] = t; - *self.length = length.saturating_add(1).into(); - Ok(()) - } + self.inner.push(t) } } @@ -136,6 +114,7 @@ fn max_len_for_type(data_len: usize, length_val: usize) -> Result::unpack(&mut pod_slice_bytes).unwrap(); + // Verify initial length + assert_eq!( + u32::from_le_bytes([ + pod_slice_bytes[0], + pod_slice_bytes[1], + pod_slice_bytes[2], + pod_slice_bytes[3] + ]), + 1 + ); - assert_eq!(*pod_slice.length, PodU32::from(1)); + let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); pod_slice.push(TestStruct::default()).unwrap(); - assert_eq!(*pod_slice.length, PodU32::from(2)); + + // Check length after push + assert_eq!( + u32::from_le_bytes([ + pod_slice_bytes[0], + pod_slice_bytes[1], + pod_slice_bytes[2], + pod_slice_bytes[3] + ]), + 2 + ); + + // Test that buffer is full + let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); let err = pod_slice .push(TestStruct::default()) .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs index f17ee3d..6bed11d 100644 --- a/tlv-account-resolution/src/state.rs +++ b/tlv-account-resolution/src/state.rs @@ -7,7 +7,7 @@ use { solana_program_error::ProgramError, solana_pubkey::Pubkey, spl_discriminator::SplDiscriminate, - spl_pod::slice::{PodSlice, PodSliceMut}, + spl_pod::{list::ListView, primitives::PodU32, slice::PodSlice}, spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, std::future::Future, }; @@ -172,7 +172,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; - let mut validation_data = PodSliceMut::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -188,7 +188,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; - let mut validation_data = PodSliceMut::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; }