From 9a68432153b3e77c750895010d46d86dbe8aaf6d Mon Sep 17 00:00:00 2001 From: Gary Rong Date: Mon, 19 May 2025 11:37:41 +0800 Subject: [PATCH 1/4] trie: rework trie hasher --- trie/hasher.go | 131 ++++++++++++++++++++-------------------------- trie/iterator.go | 7 ++- trie/node.go | 4 -- trie/node_enc.go | 14 +++-- trie/proof.go | 17 +++--- trie/trie.go | 6 +-- trie/trie_test.go | 2 +- 7 files changed, 81 insertions(+), 100 deletions(-) diff --git a/trie/hasher.go b/trie/hasher.go index 393cb0bd4d9e..cd65115777c8 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -17,6 +17,7 @@ package trie import ( + "fmt" "sync" "github.com/ethereum/go-ethereum/crypto" @@ -54,7 +55,7 @@ func returnHasherToPool(h *hasher) { } // hash collapses a node down into a hash node. -func (h *hasher) hash(n node, force bool) node { +func (h *hasher) hash(n node, force bool) []byte { // Return the cached hash if it's available if hash, _ := n.cache(); hash != nil { return hash @@ -62,57 +63,68 @@ func (h *hasher) hash(n node, force bool) node { // Trie not processed yet, walk the children switch n := n.(type) { case *shortNode: - collapsed := h.hashShortNodeChildren(n) - hashed := h.shortnodeToHash(collapsed, force) - if hn, ok := hashed.(hashNode); ok { - n.flags.hash = hn - } else { - n.flags.hash = nil + enc := h.encodeShortNode(n) + if len(enc) < 32 && !force { + buf := make([]byte, len(enc)) + copy(buf, enc) + return buf // Nodes smaller than 32 bytes are stored inside their parent } - return hashed + hash := h.hashData(enc) + n.flags.hash = hash + return hash + case *fullNode: - collapsed := h.hashFullNodeChildren(n) - hashed := h.fullnodeToHash(collapsed, force) - if hn, ok := hashed.(hashNode); ok { - n.flags.hash = hn - } else { - n.flags.hash = nil + enc := h.encodeFullNode(n) + if len(enc) < 32 && !force { + buf := make([]byte, len(enc)) + copy(buf, enc) + return buf // Nodes smaller than 32 bytes are stored inside their parent } - return hashed - default: - // Value and hash nodes don't have children, so they're left as were + hash := h.hashData(enc) + n.flags.hash = hash + return hash + + case hashNode: + // hash nodes don't have children, so they're left as were return n + + default: + panic(fmt.Errorf("unexpected node type, %T", n)) } } -// hashShortNodeChildren returns a copy of the supplied shortNode, with its child -// being replaced by either the hash or an embedded node if the child is small. -func (h *hasher) hashShortNodeChildren(n *shortNode) *shortNode { - var collapsed shortNode - collapsed.Key = hexToCompact(n.Key) - switch n.Val.(type) { - case *fullNode, *shortNode: - collapsed.Val = h.hash(n.Val, false) - default: - collapsed.Val = n.Val +func (h *hasher) encodeShortNode(n *shortNode) []byte { + // Encode leaf node + if hasTerm(n.Key) { + var ln leafNodeEncoder + ln.Key = hexToCompact(n.Key) + ln.Val = n.Val.(valueNode) + ln.encode(h.encbuf) + return h.encodedBytes() } - return &collapsed + // Encode extension node + var en extNodeEncoder + en.Key = hexToCompact(n.Key) + en.Val = h.hash(n.Val, false) + en.encode(h.encbuf) + return h.encodedBytes() } -// hashFullNodeChildren returns a copy of the supplied fullNode, with its child +// encodeFullNode returns a copy of the supplied fullNode, with its child // being replaced by either the hash or an embedded node if the child is small. -func (h *hasher) hashFullNodeChildren(n *fullNode) *fullNode { - var children [17]node +func (h *hasher) encodeFullNode(n *fullNode) []byte { + var fn fullnodeEncoder if h.parallel { var wg sync.WaitGroup - wg.Add(16) for i := 0; i < 16; i++ { + if n.Children[i] == nil { + continue + } + wg.Add(1) go func(i int) { hasher := newHasher(false) if child := n.Children[i]; child != nil { - children[i] = hasher.hash(child, false) - } else { - children[i] = nilValueNode + fn.Children[i] = hasher.hash(child, false) } returnHasherToPool(hasher) wg.Done() @@ -122,41 +134,15 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) *fullNode { } else { for i := 0; i < 16; i++ { if child := n.Children[i]; child != nil { - children[i] = h.hash(child, false) - } else { - children[i] = nilValueNode + fn.Children[i] = h.hash(child, false) } } } if n.Children[16] != nil { - children[16] = n.Children[16] - } - return &fullNode{flags: nodeFlag{}, Children: children} -} - -// shortNodeToHash computes the hash of the given shortNode. The shortNode must -// first be collapsed, with its key converted to compact form. If the RLP-encoded -// node data is smaller than 32 bytes, the node itself is returned. -func (h *hasher) shortnodeToHash(n *shortNode, force bool) node { - n.encode(h.encbuf) - enc := h.encodedBytes() - - if len(enc) < 32 && !force { - return n // Nodes smaller than 32 bytes are stored inside their parent - } - return h.hashData(enc) -} - -// fullnodeToHash computes the hash of the given fullNode. If the RLP-encoded -// node data is smaller than 32 bytes, the node itself is returned. -func (h *hasher) fullnodeToHash(n *fullNode, force bool) node { - n.encode(h.encbuf) - enc := h.encodedBytes() - - if len(enc) < 32 && !force { - return n // Nodes smaller than 32 bytes are stored inside their parent + fn.Children[16] = n.Children[16].(valueNode) } - return h.hashData(enc) + fn.encode(h.encbuf) + return h.encodedBytes() } // encodedBytes returns the result of the last encoding operation on h.encbuf. @@ -176,8 +162,8 @@ func (h *hasher) encodedBytes() []byte { } // hashData hashes the provided data -func (h *hasher) hashData(data []byte) hashNode { - n := make(hashNode, 32) +func (h *hasher) hashData(data []byte) []byte { + n := make([]byte, 32) h.sha.Reset() h.sha.Write(data) h.sha.Read(n) @@ -196,16 +182,13 @@ func (h *hasher) hashDataTo(dst, data []byte) { // node (for later RLP encoding) as well as the hashed node -- unless the // node is smaller than 32 bytes, in which case it will be returned as is. // This method does not do anything on value- or hash-nodes. -func (h *hasher) proofHash(original node) (collapsed, hashed node) { +func (h *hasher) proofHash(original node) []byte { switch n := original.(type) { case *shortNode: - sn := h.hashShortNodeChildren(n) - return sn, h.shortnodeToHash(sn, false) + return h.encodeShortNode(n) case *fullNode: - fn := h.hashFullNodeChildren(n) - return fn, h.fullnodeToHash(fn, false) + return h.encodeFullNode(n) default: - // Value and hash nodes don't have children, so they're left as were - return n, n + panic(fmt.Errorf("unexpected node type, %T", original)) } } diff --git a/trie/iterator.go b/trie/iterator.go index fa016110636a..9dcd98e2d872 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -20,7 +20,6 @@ import ( "bytes" "container/heap" "errors" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" ) @@ -240,9 +239,9 @@ func (it *nodeIterator) LeafProof() [][]byte { for i, item := range it.stack[:len(it.stack)-1] { // Gather nodes that end up as hash nodes (or the root) - node, hashed := hasher.proofHash(item.node) - if _, ok := hashed.(hashNode); ok || i == 0 { - proofs = append(proofs, nodeToBytes(node)) + enc := hasher.proofHash(item.node) + if len(enc) >= 32 || i == 0 { + proofs = append(proofs, common.CopyBytes(enc)) } } return proofs diff --git a/trie/node.go b/trie/node.go index 96f077ebbb78..74fac4fd4ea6 100644 --- a/trie/node.go +++ b/trie/node.go @@ -68,10 +68,6 @@ type ( } ) -// nilValueNode is used when collapsing internal trie nodes for hashing, since -// unset children need to serialize correctly. -var nilValueNode = valueNode(nil) - // EncodeRLP encodes a full node into the consensus RLP format. func (n *fullNode) EncodeRLP(w io.Writer) error { eb := rlp.NewEncoderBuffer(w) diff --git a/trie/node_enc.go b/trie/node_enc.go index c95587eeabb7..cd863002c52b 100644 --- a/trie/node_enc.go +++ b/trie/node_enc.go @@ -42,13 +42,19 @@ func (n *fullNode) encode(w rlp.EncoderBuffer) { func (n *fullnodeEncoder) encode(w rlp.EncoderBuffer) { offset := w.List() - for _, c := range n.Children { + for i, c := range n.Children { if c == nil { w.Write(rlp.EmptyString) - } else if len(c) < 32 { - w.Write(c) // rawNode } else { - w.WriteBytes(c) // hashNode + if i == 16 { + w.WriteBytes(c) // valueNode + } else { + if len(c) < 32 { + w.Write(c) // rawNode + } else { + w.WriteBytes(c) // hashNode + } + } } } w.ListEnd(offset) diff --git a/trie/proof.go b/trie/proof.go index 751d6f620f3b..dd9a105beaf8 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" ) @@ -85,16 +86,12 @@ func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { defer returnHasherToPool(hasher) for i, n := range nodes { - var hn node - n, hn = hasher.proofHash(n) - if hash, ok := hn.(hashNode); ok || i == 0 { - // If the node's database encoding is a hash (or is the - // root node), it becomes a proof element. - enc := nodeToBytes(n) - if !ok { - hash = hasher.hashData(enc) - } - proofDb.Put(hash, enc) + enc := hasher.proofHash(n) + if len(enc) == 32 { + fmt.Println("DEBUG") + } + if len(enc) >= 32 || i == 0 { + proofDb.Put(crypto.Keccak256(enc), enc) } } return nil diff --git a/trie/trie.go b/trie/trie.go index fdb4da9be47a..222bf8b1f023 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -626,7 +626,7 @@ func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - return common.BytesToHash(t.hashRoot().(hashNode)) + return common.BytesToHash(t.hashRoot()) } // Commit collects all dirty nodes in the trie and replaces them with the @@ -677,9 +677,9 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { } // hashRoot calculates the root hash of the given trie -func (t *Trie) hashRoot() node { +func (t *Trie) hashRoot() []byte { if t.root == nil { - return hashNode(types.EmptyRootHash.Bytes()) + return types.EmptyRootHash.Bytes() } // If the number of changes is below 100, we let one thread handle it h := newHasher(t.unhashed >= 100) diff --git a/trie/trie_test.go b/trie/trie_test.go index 91fde6dbf260..756e7800ae08 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -863,7 +863,7 @@ func (s *spongeDb) Flush() { s.sponge.Write([]byte(key)) s.sponge.Write([]byte(s.values[key])) } - fmt.Println(len(s.keys)) + //fmt.Println(len(s.keys)) } // spongeBatch is a dummy batch which immediately writes to the underlying spongedb From 8e9ff7056c9c13c5fd0fb080d45d6cba93e51eec Mon Sep 17 00:00:00 2001 From: Gary Rong Date: Sun, 25 May 2025 21:04:19 +0800 Subject: [PATCH 2/4] trie: add sync pool --- trie/hasher.go | 13 ++++++++++++- trie/node_enc.go | 10 +++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/trie/hasher.go b/trie/hasher.go index cd65115777c8..9a0b55530a0f 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -110,10 +110,19 @@ func (h *hasher) encodeShortNode(n *shortNode) []byte { return h.encodedBytes() } +var fnEncoderPool = sync.Pool{ + New: func() interface{} { + var enc fullnodeEncoder + return &enc + }, +} + // encodeFullNode returns a copy of the supplied fullNode, with its child // being replaced by either the hash or an embedded node if the child is small. func (h *hasher) encodeFullNode(n *fullNode) []byte { - var fn fullnodeEncoder + fn := fnEncoderPool.Get().(*fullnodeEncoder) + fn.reset() + if h.parallel { var wg sync.WaitGroup for i := 0; i < 16; i++ { @@ -142,6 +151,8 @@ func (h *hasher) encodeFullNode(n *fullNode) []byte { fn.Children[16] = n.Children[16].(valueNode) } fn.encode(h.encbuf) + fnEncoderPool.Put(fn) + return h.encodedBytes() } diff --git a/trie/node_enc.go b/trie/node_enc.go index cd863002c52b..81b3677f06f6 100644 --- a/trie/node_enc.go +++ b/trie/node_enc.go @@ -43,7 +43,7 @@ func (n *fullNode) encode(w rlp.EncoderBuffer) { func (n *fullnodeEncoder) encode(w rlp.EncoderBuffer) { offset := w.List() for i, c := range n.Children { - if c == nil { + if len(c) == 0 { w.Write(rlp.EmptyString) } else { if i == 16 { @@ -60,6 +60,14 @@ func (n *fullnodeEncoder) encode(w rlp.EncoderBuffer) { w.ListEnd(offset) } +func (n *fullnodeEncoder) reset() { + for i, c := range n.Children { + if len(c) != 0 { + n.Children[i] = n.Children[i][:0] + } + } +} + func (n *shortNode) encode(w rlp.EncoderBuffer) { offset := w.List() w.WriteBytes(n.Key) From aae088da2ed6802feedaa7cd175626e8e09b14e3 Mon Sep 17 00:00:00 2001 From: Gary Rong Date: Mon, 26 May 2025 09:25:29 +0800 Subject: [PATCH 3/4] trie: polish --- trie/hasher.go | 15 +++++++++------ trie/iterator.go | 1 + trie/node_enc.go | 13 +++++-------- trie/proof.go | 3 --- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/trie/hasher.go b/trie/hasher.go index 9a0b55530a0f..060f15f13807 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -65,9 +65,11 @@ func (h *hasher) hash(n node, force bool) []byte { case *shortNode: enc := h.encodeShortNode(n) if len(enc) < 32 && !force { + // Nodes smaller than 32 bytes are embedded directly in their parent. + // In such cases, return the raw encoded blob instead of the node hash. buf := make([]byte, len(enc)) copy(buf, enc) - return buf // Nodes smaller than 32 bytes are stored inside their parent + return buf } hash := h.hashData(enc) n.flags.hash = hash @@ -76,9 +78,11 @@ func (h *hasher) hash(n node, force bool) []byte { case *fullNode: enc := h.encodeFullNode(n) if len(enc) < 32 && !force { + // Nodes smaller than 32 bytes are embedded directly in their parent. + // In such cases, return the raw encoded blob instead of the node hash. buf := make([]byte, len(enc)) copy(buf, enc) - return buf // Nodes smaller than 32 bytes are stored inside their parent + return buf } hash := h.hashData(enc) n.flags.hash = hash @@ -110,6 +114,8 @@ func (h *hasher) encodeShortNode(n *shortNode) []byte { return h.encodedBytes() } +// fnEncoderPool is the pool for storing shared fullNode encoder to mitigate +// the significant memory allocation overhead. var fnEncoderPool = sync.Pool{ New: func() interface{} { var enc fullnodeEncoder @@ -189,10 +195,7 @@ func (h *hasher) hashDataTo(dst, data []byte) { h.sha.Read(dst) } -// proofHash is used to construct trie proofs, and returns the 'collapsed' -// node (for later RLP encoding) as well as the hashed node -- unless the -// node is smaller than 32 bytes, in which case it will be returned as is. -// This method does not do anything on value- or hash-nodes. +// proofHash is used to construct trie proofs, returning the rlp-encoded node blobs. func (h *hasher) proofHash(original node) []byte { switch n := original.(type) { case *shortNode: diff --git a/trie/iterator.go b/trie/iterator.go index 9dcd98e2d872..d4f1e60a61a4 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -20,6 +20,7 @@ import ( "bytes" "container/heap" "errors" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" ) diff --git a/trie/node_enc.go b/trie/node_enc.go index 81b3677f06f6..02b93ee6f3ee 100644 --- a/trie/node_enc.go +++ b/trie/node_enc.go @@ -46,14 +46,11 @@ func (n *fullnodeEncoder) encode(w rlp.EncoderBuffer) { if len(c) == 0 { w.Write(rlp.EmptyString) } else { - if i == 16 { - w.WriteBytes(c) // valueNode + // valueNode or hashNode + if i == 16 || len(c) >= 32 { + w.WriteBytes(c) } else { - if len(c) < 32 { - w.Write(c) // rawNode - } else { - w.WriteBytes(c) // hashNode - } + w.Write(c) // rawNode } } } @@ -84,7 +81,7 @@ func (n *extNodeEncoder) encode(w rlp.EncoderBuffer) { w.WriteBytes(n.Key) if n.Val == nil { - w.Write(rlp.EmptyString) + w.Write(rlp.EmptyString) // theoretically impossible to happen } else if len(n.Val) < 32 { w.Write(n.Val) // rawNode } else { diff --git a/trie/proof.go b/trie/proof.go index dd9a105beaf8..53b7acc30c9d 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -87,9 +87,6 @@ func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { for i, n := range nodes { enc := hasher.proofHash(n) - if len(enc) == 32 { - fmt.Println("DEBUG") - } if len(enc) >= 32 || i == 0 { proofDb.Put(crypto.Keccak256(enc), enc) } From 6b5ccf960e69584827cb9c99584046b3d6ec3047 Mon Sep 17 00:00:00 2001 From: Gary Rong Date: Mon, 26 May 2025 09:57:34 +0800 Subject: [PATCH 4/4] trie: polish --- trie/hasher.go | 16 ++++++++++------ trie/trie_test.go | 1 - 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/trie/hasher.go b/trie/hasher.go index 060f15f13807..cd45313547e1 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -67,6 +67,8 @@ func (h *hasher) hash(n node, force bool) []byte { if len(enc) < 32 && !force { // Nodes smaller than 32 bytes are embedded directly in their parent. // In such cases, return the raw encoded blob instead of the node hash. + // It's essential to deep-copy the node blob, as the underlying buffer + // of enc will be reused later. buf := make([]byte, len(enc)) copy(buf, enc) return buf @@ -80,6 +82,8 @@ func (h *hasher) hash(n node, force bool) []byte { if len(enc) < 32 && !force { // Nodes smaller than 32 bytes are embedded directly in their parent. // In such cases, return the raw encoded blob instead of the node hash. + // It's essential to deep-copy the node blob, as the underlying buffer + // of enc will be reused later. buf := make([]byte, len(enc)) copy(buf, enc) return buf @@ -137,12 +141,11 @@ func (h *hasher) encodeFullNode(n *fullNode) []byte { } wg.Add(1) go func(i int) { - hasher := newHasher(false) - if child := n.Children[i]; child != nil { - fn.Children[i] = hasher.hash(child, false) - } - returnHasherToPool(hasher) - wg.Done() + defer wg.Done() + + h := newHasher(false) + defer returnHasherToPool(h) + fn.Children[i] = h.hash(n.Children[i], false) }(i) } wg.Wait() @@ -196,6 +199,7 @@ func (h *hasher) hashDataTo(dst, data []byte) { } // proofHash is used to construct trie proofs, returning the rlp-encoded node blobs. +// Note, only resolved node (shortNode or fullNode) is expected for proofing. func (h *hasher) proofHash(original node) []byte { switch n := original.(type) { case *shortNode: diff --git a/trie/trie_test.go b/trie/trie_test.go index 756e7800ae08..7b864685790e 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -863,7 +863,6 @@ func (s *spongeDb) Flush() { s.sponge.Write([]byte(key)) s.sponge.Write([]byte(s.values[key])) } - //fmt.Println(len(s.keys)) } // spongeBatch is a dummy batch which immediately writes to the underlying spongedb