From a4ab6549faed9289a44e059e8d809f1e20479383 Mon Sep 17 00:00:00 2001 From: Pawan Dhananjay Date: Thu, 6 Feb 2025 17:52:46 -0800 Subject: [PATCH] Check length after serde deserializing --- src/fixed_vector.rs | 45 +++++++++++++++++++++++++++++++++++++++++-- src/variable_list.rs | 46 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 4 deletions(-) diff --git a/src/fixed_vector.rs b/src/fixed_vector.rs index f4a2029..53bf8ad 100644 --- a/src/fixed_vector.rs +++ b/src/fixed_vector.rs @@ -1,6 +1,7 @@ use crate::tree_hash::vec_tree_hash_root; use crate::Error; -use serde_derive::{Deserialize, Serialize}; +use serde::Deserialize; +use serde_derive::Serialize; use std::marker::PhantomData; use std::ops::{Deref, DerefMut, Index, IndexMut}; use std::slice::SliceIndex; @@ -44,7 +45,7 @@ pub use typenum; /// let long: FixedVector<_, typenum::U5> = FixedVector::from(base); /// assert_eq!(&long[..], &[1, 2, 3, 4, 0]); /// ``` -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize)] #[serde(transparent)] pub struct FixedVector { vec: Vec, @@ -340,6 +341,31 @@ where } } +impl<'de, T, N> Deserialize<'de> for FixedVector +where + T: Deserialize<'de>, + N: Unsigned, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let vec = Vec::::deserialize(deserializer)?; + if vec.len() == N::to_usize() { + Ok(FixedVector { + vec, + _phantom: PhantomData, + }) + } else { + Err(serde::de::Error::custom(format!( + "Wrong number of FixedVector elements. Expected {}, actual {}", + N::to_usize(), + vec.len(), + ))) + } + } +} + #[cfg(feature = "arbitrary")] impl<'a, T: arbitrary::Arbitrary<'a>, N: 'static + Unsigned> arbitrary::Arbitrary<'a> for FixedVector @@ -528,4 +554,19 @@ mod test { } assert_eq!(hashset.len(), 2); } + #[test] + fn serde_invalid_length() { + use typenum::U4; + let json = serde_json::json!([1, 2, 3, 4, 5]); + let result: Result, _> = serde_json::from_value(json); + assert!(result.is_err()); + + let json = serde_json::json!([1, 2, 3]); + let result: Result, _> = serde_json::from_value(json); + assert!(result.is_err()); + + let json = serde_json::json!([1, 2, 3, 4]); + let result: Result, _> = serde_json::from_value(json); + assert!(result.is_ok()); + } } diff --git a/src/variable_list.rs b/src/variable_list.rs index 7c1140b..5dffad1 100644 --- a/src/variable_list.rs +++ b/src/variable_list.rs @@ -1,6 +1,7 @@ use crate::tree_hash::vec_tree_hash_root; use crate::Error; -use serde_derive::{Deserialize, Serialize}; +use serde::Deserialize; +use serde_derive::Serialize; use std::marker::PhantomData; use std::ops::{Deref, DerefMut, Index, IndexMut}; use std::slice::SliceIndex; @@ -46,7 +47,7 @@ pub use typenum; /// // Push a value to if it _does_ exceed the maximum. /// assert!(long.push(6).is_err()); /// ``` -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize)] #[serde(transparent)] pub struct VariableList { vec: Vec, @@ -312,6 +313,31 @@ where } } +impl<'de, T, N> Deserialize<'de> for VariableList +where + T: Deserialize<'de>, + N: Unsigned, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let vec = Vec::::deserialize(deserializer)?; + if vec.len() <= N::to_usize() { + Ok(VariableList { + vec, + _phantom: PhantomData, + }) + } else { + Err(serde::de::Error::custom(format!( + "VariableList length {} exceeds maximum length {}", + vec.len(), + N::to_usize() + ))) + } + } +} + #[cfg(feature = "arbitrary")] impl<'a, T: arbitrary::Arbitrary<'a>, N: 'static + Unsigned> arbitrary::Arbitrary<'a> for VariableList @@ -574,4 +600,20 @@ mod test { } assert_eq!(hashset.len(), 2); } + + #[test] + fn serde_invalid_length() { + use typenum::U4; + let json = serde_json::json!([1, 2, 3, 4, 5]); + let result: Result, _> = serde_json::from_value(json); + assert!(result.is_err()); + + let json = serde_json::json!([1, 2, 3]); + let result: Result, _> = serde_json::from_value(json); + assert!(result.is_ok()); + + let json = serde_json::json!([1, 2, 3, 4]); + let result: Result, _> = serde_json::from_value(json); + assert!(result.is_ok()); + } }