Skip to content

Commit 71bc0f7

Browse files
Amxxarr00
andauthored
Add function to update a leaf in a MerkleTree structure (#5453)
Co-authored-by: Arr00 <13561405+arr00@users.noreply.github.com>
1 parent 7276774 commit 71bc0f7

File tree

4 files changed

+213
-28
lines changed

4 files changed

+213
-28
lines changed

.changeset/good-zebras-ring.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': minor
3+
---
4+
5+
`MerkleTree`: Add an update function that replaces a previously inserted leaf with a new value, updating the tree root along the way.

contracts/mocks/MerkleTreeMock.sol

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ contract MerkleTreeMock {
1414
bytes32 public root;
1515

1616
event LeafInserted(bytes32 leaf, uint256 index, bytes32 root);
17+
event LeafUpdated(bytes32 oldLeaf, bytes32 newLeaf, uint256 index, bytes32 root);
1718

1819
function setup(uint8 _depth, bytes32 _zero) public {
1920
root = _tree.setup(_depth, _zero);
@@ -25,6 +26,13 @@ contract MerkleTreeMock {
2526
root = currentRoot;
2627
}
2728

29+
function update(uint256 index, bytes32 oldValue, bytes32 newValue, bytes32[] memory proof) public {
30+
(bytes32 oldRoot, bytes32 newRoot) = _tree.update(index, oldValue, newValue, proof);
31+
if (oldRoot != root) revert MerkleTree.MerkleTreeUpdateInvalidProof();
32+
emit LeafUpdated(oldValue, newValue, index, newRoot);
33+
root = newRoot;
34+
}
35+
2836
function depth() public view returns (uint256) {
2937
return _tree.depth();
3038
}

contracts/utils/structs/MerkleTree.sol

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pragma solidity ^0.8.20;
66
import {Hashes} from "../cryptography/Hashes.sol";
77
import {Arrays} from "../Arrays.sol";
88
import {Panic} from "../Panic.sol";
9+
import {StorageSlot} from "../StorageSlot.sol";
910

1011
/**
1112
* @dev Library for managing https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] data structures.
@@ -27,6 +28,12 @@ import {Panic} from "../Panic.sol";
2728
* _Available since v5.1._
2829
*/
2930
library MerkleTree {
31+
/// @dev Error emitted when trying to update a leaf that was not previously pushed.
32+
error MerkleTreeUpdateInvalidIndex(uint256 index, uint256 length);
33+
34+
/// @dev Error emitted when the proof used during an update is invalid (could not reproduce the side).
35+
error MerkleTreeUpdateInvalidProof();
36+
3037
/**
3138
* @dev A complete `bytes32` Merkle tree.
3239
*
@@ -166,6 +173,91 @@ library MerkleTree {
166173
return (index, currentLevelHash);
167174
}
168175

176+
/**
177+
* @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
178+
* root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
179+
* root is the last known one.
180+
*
181+
* The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
182+
* vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
183+
* all "in flight" updates invalid.
184+
*
185+
* This variant uses {Hashes-commutativeKeccak256} to hash internal nodes. It should only be used on merkle trees
186+
* that were setup using the same (default) hashing function (i.e. by calling
187+
* {xref-MerkleTree-setup-struct-MerkleTree-Bytes32PushTree-uint8-bytes32-}[the default setup] function).
188+
*/
189+
function update(
190+
Bytes32PushTree storage self,
191+
uint256 index,
192+
bytes32 oldValue,
193+
bytes32 newValue,
194+
bytes32[] memory proof
195+
) internal returns (bytes32 oldRoot, bytes32 newRoot) {
196+
return update(self, index, oldValue, newValue, proof, Hashes.commutativeKeccak256);
197+
}
198+
199+
/**
200+
* @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
201+
* root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
202+
* root is the last known one.
203+
*
204+
* The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
205+
* vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
206+
* all "in flight" updates invalid.
207+
*
208+
* This variant uses a custom hashing function to hash internal nodes. It should only be called with the same
209+
* function as the one used during the initial setup of the merkle tree.
210+
*/
211+
function update(
212+
Bytes32PushTree storage self,
213+
uint256 index,
214+
bytes32 oldValue,
215+
bytes32 newValue,
216+
bytes32[] memory proof,
217+
function(bytes32, bytes32) view returns (bytes32) fnHash
218+
) internal returns (bytes32 oldRoot, bytes32 newRoot) {
219+
unchecked {
220+
// Check index range
221+
uint256 length = self._nextLeafIndex;
222+
if (index >= length) revert MerkleTreeUpdateInvalidIndex(index, length);
223+
224+
// Cache read
225+
uint256 treeDepth = depth(self);
226+
227+
// Workaround stack too deep
228+
bytes32[] storage sides = self._sides;
229+
230+
// This cannot overflow because: 0 <= index < length
231+
uint256 lastIndex = length - 1;
232+
uint256 currentIndex = index;
233+
bytes32 currentLevelHashOld = oldValue;
234+
bytes32 currentLevelHashNew = newValue;
235+
for (uint32 i = 0; i < treeDepth; i++) {
236+
bool isLeft = currentIndex % 2 == 0;
237+
238+
lastIndex >>= 1;
239+
currentIndex >>= 1;
240+
241+
if (isLeft && currentIndex == lastIndex) {
242+
StorageSlot.Bytes32Slot storage side = Arrays.unsafeAccess(sides, i);
243+
if (side.value != currentLevelHashOld) revert MerkleTreeUpdateInvalidProof();
244+
side.value = currentLevelHashNew;
245+
}
246+
247+
bytes32 sibling = proof[i];
248+
currentLevelHashOld = fnHash(
249+
isLeft ? currentLevelHashOld : sibling,
250+
isLeft ? sibling : currentLevelHashOld
251+
);
252+
currentLevelHashNew = fnHash(
253+
isLeft ? currentLevelHashNew : sibling,
254+
isLeft ? sibling : currentLevelHashNew
255+
);
256+
}
257+
return (currentLevelHashOld, currentLevelHashNew);
258+
}
259+
}
260+
169261
/**
170262
* @dev Tree's depth (set at initialization)
171263
*/

test/utils/structs/MerkleTree.test.js

Lines changed: 108 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,23 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
55
const { StandardMerkleTree } = require('@openzeppelin/merkle-tree');
66

77
const { generators } = require('../../helpers/random');
8+
const { range } = require('../../helpers/iterate');
89

9-
const makeTree = (leaves = [ethers.ZeroHash]) =>
10+
const DEPTH = 4; // 16 slots
11+
12+
const makeTree = (leaves = [], length = 2 ** DEPTH, zero = ethers.ZeroHash) =>
1013
StandardMerkleTree.of(
11-
leaves.map(leaf => [leaf]),
14+
[]
15+
.concat(
16+
leaves,
17+
Array.from({ length: length - leaves.length }, () => zero),
18+
)
19+
.map(leaf => [leaf]),
1220
['bytes32'],
1321
{ sortLeaves: false },
1422
);
1523

16-
const hashLeaf = leaf => makeTree().leafHash([leaf]);
17-
18-
const DEPTH = 4n; // 16 slots
19-
const ZERO = hashLeaf(ethers.ZeroHash);
24+
const ZERO = makeTree().leafHash([ethers.ZeroHash]);
2025

2126
async function fixture() {
2227
const mock = await ethers.deployContract('MerkleTreeMock');
@@ -30,69 +35,144 @@ describe('MerkleTree', function () {
3035
});
3136

3237
it('sets initial values at setup', async function () {
33-
const merkleTree = makeTree(Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash));
38+
const merkleTree = makeTree();
3439

35-
expect(await this.mock.root()).to.equal(merkleTree.root);
36-
expect(await this.mock.depth()).to.equal(DEPTH);
37-
expect(await this.mock.nextLeafIndex()).to.equal(0n);
40+
await expect(this.mock.root()).to.eventually.equal(merkleTree.root);
41+
await expect(this.mock.depth()).to.eventually.equal(DEPTH);
42+
await expect(this.mock.nextLeafIndex()).to.eventually.equal(0n);
3843
});
3944

4045
describe('push', function () {
41-
it('tree is correctly updated', async function () {
42-
const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
46+
it('pushing correctly updates the tree', async function () {
47+
const leaves = [];
4348

4449
// for each leaf slot
45-
for (const i in leaves) {
46-
// generate random leaf and hash it
47-
const hashedLeaf = hashLeaf((leaves[i] = generators.bytes32()));
50+
for (const i in range(2 ** DEPTH)) {
51+
// generate random leaf
52+
leaves.push(generators.bytes32());
4853

49-
// update leaf list and rebuild tree.
54+
// rebuild tree.
5055
const tree = makeTree(leaves);
56+
const hash = tree.leafHash(tree.at(i));
5157

5258
// push value to tree
53-
await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, i, tree.root);
59+
await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, i, tree.root);
5460

5561
// check tree
56-
expect(await this.mock.root()).to.equal(tree.root);
57-
expect(await this.mock.nextLeafIndex()).to.equal(BigInt(i) + 1n);
62+
await expect(this.mock.root()).to.eventually.equal(tree.root);
63+
await expect(this.mock.nextLeafIndex()).to.eventually.equal(BigInt(i) + 1n);
5864
}
5965
});
6066

61-
it('revert when tree is full', async function () {
67+
it('pushing to a full tree reverts', async function () {
6268
await Promise.all(Array.from({ length: 2 ** Number(DEPTH) }).map(() => this.mock.push(ethers.ZeroHash)));
6369

6470
await expect(this.mock.push(ethers.ZeroHash)).to.be.revertedWithPanic(PANIC_CODES.TOO_MUCH_MEMORY_ALLOCATED);
6571
});
6672
});
6773

74+
describe('update', function () {
75+
for (const { leafCount, leafIndex } of range(2 ** DEPTH + 1).flatMap(leafCount =>
76+
range(leafCount).map(leafIndex => ({ leafCount, leafIndex })),
77+
))
78+
it(`updating a leaf correctly updates the tree (leaf #${leafIndex + 1}/${leafCount})`, async function () {
79+
// initial tree
80+
const leaves = Array.from({ length: leafCount }, generators.bytes32);
81+
const oldTree = makeTree(leaves);
82+
83+
// fill tree and verify root
84+
for (const i in leaves) {
85+
await this.mock.push(oldTree.leafHash(oldTree.at(i)));
86+
}
87+
await expect(this.mock.root()).to.eventually.equal(oldTree.root);
88+
89+
// create updated tree
90+
leaves[leafIndex] = generators.bytes32();
91+
const newTree = makeTree(leaves);
92+
93+
const oldLeafHash = oldTree.leafHash(oldTree.at(leafIndex));
94+
const newLeafHash = newTree.leafHash(newTree.at(leafIndex));
95+
96+
// perform update
97+
await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, oldTree.getProof(leafIndex)))
98+
.to.emit(this.mock, 'LeafUpdated')
99+
.withArgs(oldLeafHash, newLeafHash, leafIndex, newTree.root);
100+
101+
// verify updated root
102+
await expect(this.mock.root()).to.eventually.equal(newTree.root);
103+
104+
// if there is still room in the tree, fill it
105+
for (const i of range(leafCount, 2 ** DEPTH)) {
106+
// push new value and rebuild tree
107+
leaves.push(generators.bytes32());
108+
const nextTree = makeTree(leaves);
109+
110+
// push and verify root
111+
await this.mock.push(nextTree.leafHash(nextTree.at(i)));
112+
await expect(this.mock.root()).to.eventually.equal(nextTree.root);
113+
}
114+
});
115+
116+
it('replacing a leaf that was not previously pushed reverts', async function () {
117+
// changing leaf 0 on an empty tree
118+
await expect(this.mock.update(1, ZERO, ZERO, []))
119+
.to.be.revertedWithCustomError(this.mock, 'MerkleTreeUpdateInvalidIndex')
120+
.withArgs(1, 0);
121+
});
122+
123+
it('replacing a leaf using an invalid proof reverts', async function () {
124+
const leafCount = 4;
125+
const leafIndex = 2;
126+
127+
const leaves = Array.from({ length: leafCount }, generators.bytes32);
128+
const tree = makeTree(leaves);
129+
130+
// fill tree and verify root
131+
for (const i in leaves) {
132+
await this.mock.push(tree.leafHash(tree.at(i)));
133+
}
134+
await expect(this.mock.root()).to.eventually.equal(tree.root);
135+
136+
const oldLeafHash = tree.leafHash(tree.at(leafIndex));
137+
const newLeafHash = generators.bytes32();
138+
const proof = tree.getProof(leafIndex);
139+
// invalid proof (tamper)
140+
proof[1] = generators.bytes32();
141+
142+
await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, proof)).to.be.revertedWithCustomError(
143+
this.mock,
144+
'MerkleTreeUpdateInvalidProof',
145+
);
146+
});
147+
});
148+
68149
it('reset', async function () {
69150
// empty tree
70-
const zeroLeaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
71-
const zeroTree = makeTree(zeroLeaves);
151+
const emptyTree = makeTree();
72152

73153
// tree with one element
74-
const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
75-
const hashedLeaf = hashLeaf((leaves[0] = generators.bytes32())); // fill first leaf and hash it
154+
const leaves = [generators.bytes32()];
76155
const tree = makeTree(leaves);
156+
const hash = tree.leafHash(tree.at(0));
77157

78158
// root should be that of a zero tree
79-
expect(await this.mock.root()).to.equal(zeroTree.root);
159+
expect(await this.mock.root()).to.equal(emptyTree.root);
80160
expect(await this.mock.nextLeafIndex()).to.equal(0n);
81161

82162
// push leaf and check root
83-
await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
163+
await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);
84164

85165
expect(await this.mock.root()).to.equal(tree.root);
86166
expect(await this.mock.nextLeafIndex()).to.equal(1n);
87167

88168
// reset tree
89169
await this.mock.setup(DEPTH, ZERO);
90170

91-
expect(await this.mock.root()).to.equal(zeroTree.root);
171+
expect(await this.mock.root()).to.equal(emptyTree.root);
92172
expect(await this.mock.nextLeafIndex()).to.equal(0n);
93173

94174
// re-push leaf and check root
95-
await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
175+
await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);
96176

97177
expect(await this.mock.root()).to.equal(tree.root);
98178
expect(await this.mock.nextLeafIndex()).to.equal(1n);

0 commit comments

Comments
 (0)