Skip to content

chore: tight range checks in NodeIndex #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl<L: LeafData> FilledTree<L> for FilledTreeImpl<L> {
}

fn get_root_hash(&self) -> Result<HashOutput, FilledTreeError> {
match self.tree_map.get(&NodeIndex::root_index()) {
match self.tree_map.get(&NodeIndex::ROOT) {
Some(root_node) => Ok(root_node.hash),
None => Err(FilledTreeError::MissingRoot),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ impl OriginalSkeletonTreeImpl {
assert_child(*last_leaf);

let child_direction_mask = U256::ONE << (root_height.0 - 1);
(first_leaf.0 & child_direction_mask) != (last_leaf.0 & child_direction_mask)
(U256::from(first_leaf) & child_direction_mask)
!= (U256::from(*last_leaf) & child_direction_mask)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ fn empty_skeleton(height: u8) -> OriginalSkeletonTreeImpl {
#[case::large_tree_farthest_leaves(
251,
1,
vec![NodeIndex(U256::ONE << 251),NodeIndex((U256::ONE << 252) - U256::ONE)],
vec![NodeIndex::ROOT << 251, NodeIndex::MAX_INDEX],
true)]
#[case::large_tree_positive_consecutive_indices_of_different_sides(
251,
1,
vec![NodeIndex((U256::from(3u8) << 250) - U256::ONE), NodeIndex(U256::from(3u8) << 250)],
vec![NodeIndex::new((U256::from(3u8) << 250) - U256::ONE), NodeIndex::new(U256::from(3u8) << 250)],
true)]
#[case::large_tree_negative_one_shift_of_positive_case(
251,
1,
vec![NodeIndex(U256::from(3u8) << 250), NodeIndex((U256::from(3u8) << 250)+ U256::ONE)],
vec![NodeIndex::new(U256::from(3u8) << 250), NodeIndex::new((U256::from(3u8) << 250)+ U256::ONE)],
false)]
fn test_has_leaves_on_both_sides(
#[case] tree_height: u8,
Expand All @@ -41,7 +41,7 @@ fn test_has_leaves_on_both_sides(
#[case] expected: bool,
) {
let skeleton_tree = empty_skeleton(tree_height);
let root_index = NodeIndex(root_index.into());
let root_index = NodeIndex::new(root_index.into());
assert_eq!(
skeleton_tree.has_leaves_on_both_sides(&root_index, &leaf_indices),
expected
Expand All @@ -58,6 +58,6 @@ fn test_has_leaves_on_both_sides_assertions(
#[case] leaf_indices: Vec<NodeIndex>,
) {
let skeleton_tree = empty_skeleton(tree_height);
let root_index = NodeIndex(root_index.into());
let root_index = NodeIndex::new(root_index.into());
skeleton_tree.has_leaves_on_both_sides(&root_index, &leaf_indices);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use crate::storage::storage_trait::Storage;
use crate::storage::storage_trait::StorageKey;
use crate::storage::storage_trait::StoragePrefix;
use bisection::{bisect_left, bisect_right};
use ethnum::U256;
use std::collections::HashMap;
#[cfg(test)]
#[path = "create_tree_test.rs"]
Expand All @@ -42,7 +41,7 @@ impl<'a> SubTree<'a> {
) -> (&'a [NodeIndex], &'a [NodeIndex]) {
let height = self.get_height(total_tree_height);
let leftmost_index_in_right_subtree =
((self.root_index << 1) + NodeIndex(U256::ONE)) << (height.0 - 1);
((self.root_index << 1) + NodeIndex::ROOT) << (height.0 - 1);
let mid = bisect_left(self.sorted_leaf_indices, &leftmost_index_in_right_subtree);
(
&self.sorted_leaf_indices[..mid],
Expand All @@ -65,7 +64,7 @@ impl<'a> SubTree<'a> {
self.get_height(total_tree_height) - TreeHeight(path_to_bottom.length.0);
let leftmost_in_subtree = bottom_index << bottom_height.0;
let rightmost_in_subtree =
leftmost_in_subtree + (NodeIndex(U256::ONE) << bottom_height.0) - NodeIndex(U256::ONE);
leftmost_in_subtree + (NodeIndex::ROOT << bottom_height.0) - NodeIndex::ROOT;
let bottom_leaves =
&self.sorted_leaf_indices[bisect_left(self.sorted_leaf_indices, &leftmost_in_subtree)
..bisect_right(self.sorted_leaf_indices, &rightmost_in_subtree)];
Expand All @@ -84,7 +83,7 @@ impl<'a> SubTree<'a> {
total_tree_height: &TreeHeight,
) -> (Self, Self) {
let (left_leaves, right_leaves) = self.split_leaves(total_tree_height);
let left_root_index = self.root_index * 2;
let left_root_index = self.root_index * 2.into();
(
SubTree {
sorted_leaf_indices: left_leaves,
Expand All @@ -93,7 +92,7 @@ impl<'a> SubTree<'a> {
},
SubTree {
sorted_leaf_indices: right_leaves,
root_index: left_root_index + NodeIndex(U256::ONE),
root_index: left_root_index + NodeIndex::ROOT,
root_hash: right_hash,
},
)
Expand Down Expand Up @@ -216,7 +215,7 @@ impl OriginalSkeletonTreeImpl {
) -> OriginalSkeletonTreeResult<Self> {
let main_subtree = SubTree {
sorted_leaf_indices,
root_index: NodeIndex::root_index(),
root_index: NodeIndex::ROOT,
root_hash,
};
let mut skeleton_tree = Self {
Expand Down
80 changes: 56 additions & 24 deletions crates/committer/src/patricia_merkle_tree/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,36 @@ use ethnum::U256;
#[path = "types_test.rs"]
pub mod types_test;

#[allow(dead_code)]
#[derive(Clone, Copy, Debug, Eq, PartialEq, derive_more::Sub)]
pub struct TreeHeight(pub u8);

impl TreeHeight {
pub const MAX_HEIGHT: u8 = 251;
}
#[derive(
Clone,
Copy,
Debug,
PartialEq,
Eq,
Hash,
derive_more::Add,
derive_more::Mul,
derive_more::Sub,
PartialOrd,
Ord,
Clone, Copy, Debug, PartialEq, Eq, Hash, derive_more::BitAnd, derive_more::Sub, PartialOrd, Ord,
)]
pub(crate) struct NodeIndex(pub U256);
pub(crate) struct NodeIndex(U256);

#[allow(dead_code)]
// Wraps a U256. Maximal possible value is the largest index in a tree of height 251 (2 ^ 252 - 1).
impl NodeIndex {
pub(crate) fn root_index() -> NodeIndex {
NodeIndex(U256::ONE)
pub const BITS: u8 = TreeHeight::MAX_HEIGHT + 1;

/// [NodeIndex] constant that represents the root index.
pub const ROOT: Self = Self(U256::ONE);

/// [NodeIndex] constant that represents the largest index in a tree.
#[allow(clippy::as_conversions)]
pub const MAX_INDEX: Self = Self(U256::from_words(
u128::MAX >> (U256::BITS - Self::BITS as u32),
u128::MAX,
));

pub(crate) fn new(index: U256) -> Self {
if index > Self::MAX_INDEX.0 {
panic!("Index {index} is too large.");
}
Self(index)
}

// TODO(Amos, 1/5/2024): Move to EdgePath.
Expand All @@ -41,10 +50,15 @@ impl NodeIndex {
(index << length.0) + Self::from_felt_value(&path.0)
}

pub(crate) fn bit_length(&self) -> u8 {
(U256::BITS - self.0.leading_zeros())
/// Returns the number of leading zeroes when represented with Self::BITS bits.
pub(crate) fn leading_zeros(&self) -> u8 {
(self.0.leading_zeros() - (U256::BITS - u32::from(Self::BITS)))
.try_into()
.expect("Failed to convert to u8.")
.expect("Leading zeroes are unexpectedly larger than a u8.")
}

pub(crate) fn bit_length(&self) -> u8 {
Self::BITS - self.leading_zeros()
}

pub(crate) fn from_starknet_storage_key(
Expand Down Expand Up @@ -74,12 +88,28 @@ impl NodeIndex {
}
}

impl std::ops::Add for NodeIndex {
type Output = Self;

fn add(self, rhs: Self) -> Self {
Self::new(self.0 + rhs.0)
}
}

impl std::ops::Mul for NodeIndex {
type Output = Self;

fn mul(self, rhs: Self) -> Self {
Self::new(self.0 * rhs.0)
}
}

impl std::ops::Shl<u8> for NodeIndex {
type Output = Self;

/// Returns the index of the left descendant (child for rhs=1) of the node.
fn shl(self, rhs: u8) -> Self::Output {
NodeIndex(self.0 << rhs)
Self::new(self.0 << rhs)
}
}

Expand All @@ -88,7 +118,7 @@ impl std::ops::Shr<u8> for NodeIndex {

/// Returns the index of the ancestor (parent for rhs=1) of the node.
fn shr(self, rhs: u8) -> Self::Output {
NodeIndex(self.0 >> rhs)
Self::new(self.0 >> rhs)
}
}

Expand All @@ -98,6 +128,8 @@ impl From<u128> for NodeIndex {
}
}

#[allow(dead_code)]
#[derive(Clone, Copy, Debug, Eq, PartialEq, derive_more::Sub)]
pub struct TreeHeight(pub u8);
impl From<NodeIndex> for U256 {
fn from(value: NodeIndex) -> Self {
value.0
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use ethnum::U256;

use std::collections::HashMap;
use std::sync::{Arc, Mutex};

Expand Down Expand Up @@ -119,8 +117,8 @@ impl<L: LeafData + std::clone::Clone + std::marker::Sync + std::marker::Send>
let node = self.get_node(index)?;
match node {
UpdatedSkeletonNode::Binary => {
let left_index = index * 2;
let right_index = left_index + NodeIndex(U256::ONE);
let left_index = index * 2.into();
let right_index = left_index + NodeIndex::ROOT;

let (left_hash, right_hash) = tokio::join!(
self.compute_filled_tree_rec::<H, TH>(left_index, Arc::clone(&output_map)),
Expand Down Expand Up @@ -171,11 +169,8 @@ impl<L: LeafData + std::clone::Clone + std::marker::Sync + std::marker::Send> Up
// 2. Fill in the hash values.
let filled_tree_map = Arc::new(self.initialize_with_placeholders());

self.compute_filled_tree_rec::<H, TH>(
NodeIndex::root_index(),
Arc::clone(&filled_tree_map),
)
.await?;
self.compute_filled_tree_rec::<H, TH>(NodeIndex::ROOT, Arc::clone(&filled_tree_map))
.await?;

// Create and return a new FilledTreeImpl from the hashmap.
Ok(FilledTreeImpl::new(Self::remove_arc_mutex_and_option(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::patricia_merkle_tree::updated_skeleton_tree::tree::{
async fn test_filled_tree_sanity() {
let mut skeleton_tree: HashMap<NodeIndex, UpdatedSkeletonNode<LeafDataImpl>> = HashMap::new();
skeleton_tree.insert(
NodeIndex::root_index(),
NodeIndex::ROOT,
UpdatedSkeletonNode::Leaf(LeafDataImpl::CompiledClassHash(ClassHash(Felt::ONE))),
);
let updated_skeleton_tree = UpdatedSkeletonTreeImpl { skeleton_tree };
Expand Down
Loading