Skip to content

Commit 24a641d

Browse files
authored
Get leaves from memory in processMultiProofCalldata (#5140)
1 parent aec36dd commit 24a641d

File tree

2 files changed

+42
-37
lines changed

2 files changed

+42
-37
lines changed

contracts/utils/cryptography/MerkleProof.sol

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ library MerkleProof {
105105
* This version handles proofs in calldata with the default hashing function.
106106
*/
107107
function verifyCalldata(bytes32[] calldata proof, bytes32 root, bytes32 leaf) internal pure returns (bool) {
108-
return processProof(proof, leaf) == root;
108+
return processProofCalldata(proof, leaf) == root;
109109
}
110110

111111
/**
@@ -138,7 +138,7 @@ library MerkleProof {
138138
bytes32 leaf,
139139
function(bytes32, bytes32) view returns (bytes32) hasher
140140
) internal view returns (bool) {
141-
return processProof(proof, leaf, hasher) == root;
141+
return processProofCalldata(proof, leaf, hasher) == root;
142142
}
143143

144144
/**
@@ -200,15 +200,16 @@ library MerkleProof {
200200
// `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
201201
// the Merkle tree.
202202
uint256 leavesLen = leaves.length;
203+
uint256 proofFlagsLen = proofFlags.length;
203204

204205
// Check proof validity.
205-
if (leavesLen + proof.length != proofFlags.length + 1) {
206+
if (leavesLen + proof.length != proofFlagsLen + 1) {
206207
revert MerkleProofInvalidMultiproof();
207208
}
208209

209210
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
210211
// `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
211-
bytes32[] memory hashes = new bytes32[](proofFlags.length);
212+
bytes32[] memory hashes = new bytes32[](proofFlagsLen);
212213
uint256 leafPos = 0;
213214
uint256 hashPos = 0;
214215
uint256 proofPos = 0;
@@ -217,20 +218,20 @@ library MerkleProof {
217218
// get the next hash.
218219
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
219220
// `proof` array.
220-
for (uint256 i = 0; i < proofFlags.length; i++) {
221+
for (uint256 i = 0; i < proofFlagsLen; i++) {
221222
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
222223
bytes32 b = proofFlags[i]
223224
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
224225
: proof[proofPos++];
225226
hashes[i] = Hashes.commutativeKeccak256(a, b);
226227
}
227228

228-
if (proofFlags.length > 0) {
229+
if (proofFlagsLen > 0) {
229230
if (proofPos != proof.length) {
230231
revert MerkleProofInvalidMultiproof();
231232
}
232233
unchecked {
233-
return hashes[proofFlags.length - 1];
234+
return hashes[proofFlagsLen - 1];
234235
}
235236
} else if (leavesLen > 0) {
236237
return leaves[0];
@@ -280,15 +281,16 @@ library MerkleProof {
280281
// `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
281282
// the Merkle tree.
282283
uint256 leavesLen = leaves.length;
284+
uint256 proofFlagsLen = proofFlags.length;
283285

284286
// Check proof validity.
285-
if (leavesLen + proof.length != proofFlags.length + 1) {
287+
if (leavesLen + proof.length != proofFlagsLen + 1) {
286288
revert MerkleProofInvalidMultiproof();
287289
}
288290

289291
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
290292
// `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
291-
bytes32[] memory hashes = new bytes32[](proofFlags.length);
293+
bytes32[] memory hashes = new bytes32[](proofFlagsLen);
292294
uint256 leafPos = 0;
293295
uint256 hashPos = 0;
294296
uint256 proofPos = 0;
@@ -297,20 +299,20 @@ library MerkleProof {
297299
// get the next hash.
298300
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
299301
// `proof` array.
300-
for (uint256 i = 0; i < proofFlags.length; i++) {
302+
for (uint256 i = 0; i < proofFlagsLen; i++) {
301303
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
302304
bytes32 b = proofFlags[i]
303305
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
304306
: proof[proofPos++];
305307
hashes[i] = hasher(a, b);
306308
}
307309

308-
if (proofFlags.length > 0) {
310+
if (proofFlagsLen > 0) {
309311
if (proofPos != proof.length) {
310312
revert MerkleProofInvalidMultiproof();
311313
}
312314
unchecked {
313-
return hashes[proofFlags.length - 1];
315+
return hashes[proofFlagsLen - 1];
314316
}
315317
} else if (leavesLen > 0) {
316318
return leaves[0];
@@ -331,9 +333,9 @@ library MerkleProof {
331333
bytes32[] calldata proof,
332334
bool[] calldata proofFlags,
333335
bytes32 root,
334-
bytes32[] calldata leaves
336+
bytes32[] memory leaves
335337
) internal pure returns (bool) {
336-
return processMultiProof(proof, proofFlags, leaves) == root;
338+
return processMultiProofCalldata(proof, proofFlags, leaves) == root;
337339
}
338340

339341
/**
@@ -351,22 +353,23 @@ library MerkleProof {
351353
function processMultiProofCalldata(
352354
bytes32[] calldata proof,
353355
bool[] calldata proofFlags,
354-
bytes32[] calldata leaves
356+
bytes32[] memory leaves
355357
) internal pure returns (bytes32 merkleRoot) {
356358
// This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by
357359
// consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the
358360
// `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
359361
// the Merkle tree.
360362
uint256 leavesLen = leaves.length;
363+
uint256 proofFlagsLen = proofFlags.length;
361364

362365
// Check proof validity.
363-
if (leavesLen + proof.length != proofFlags.length + 1) {
366+
if (leavesLen + proof.length != proofFlagsLen + 1) {
364367
revert MerkleProofInvalidMultiproof();
365368
}
366369

367370
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
368371
// `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
369-
bytes32[] memory hashes = new bytes32[](proofFlags.length);
372+
bytes32[] memory hashes = new bytes32[](proofFlagsLen);
370373
uint256 leafPos = 0;
371374
uint256 hashPos = 0;
372375
uint256 proofPos = 0;
@@ -375,20 +378,20 @@ library MerkleProof {
375378
// get the next hash.
376379
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
377380
// `proof` array.
378-
for (uint256 i = 0; i < proofFlags.length; i++) {
381+
for (uint256 i = 0; i < proofFlagsLen; i++) {
379382
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
380383
bytes32 b = proofFlags[i]
381384
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
382385
: proof[proofPos++];
383386
hashes[i] = Hashes.commutativeKeccak256(a, b);
384387
}
385388

386-
if (proofFlags.length > 0) {
389+
if (proofFlagsLen > 0) {
387390
if (proofPos != proof.length) {
388391
revert MerkleProofInvalidMultiproof();
389392
}
390393
unchecked {
391-
return hashes[proofFlags.length - 1];
394+
return hashes[proofFlagsLen - 1];
392395
}
393396
} else if (leavesLen > 0) {
394397
return leaves[0];
@@ -409,10 +412,10 @@ library MerkleProof {
409412
bytes32[] calldata proof,
410413
bool[] calldata proofFlags,
411414
bytes32 root,
412-
bytes32[] calldata leaves,
415+
bytes32[] memory leaves,
413416
function(bytes32, bytes32) view returns (bytes32) hasher
414417
) internal view returns (bool) {
415-
return processMultiProof(proof, proofFlags, leaves, hasher) == root;
418+
return processMultiProofCalldata(proof, proofFlags, leaves, hasher) == root;
416419
}
417420

418421
/**
@@ -430,23 +433,24 @@ library MerkleProof {
430433
function processMultiProofCalldata(
431434
bytes32[] calldata proof,
432435
bool[] calldata proofFlags,
433-
bytes32[] calldata leaves,
436+
bytes32[] memory leaves,
434437
function(bytes32, bytes32) view returns (bytes32) hasher
435438
) internal view returns (bytes32 merkleRoot) {
436439
// This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by
437440
// consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the
438441
// `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
439442
// the Merkle tree.
440443
uint256 leavesLen = leaves.length;
444+
uint256 proofFlagsLen = proofFlags.length;
441445

442446
// Check proof validity.
443-
if (leavesLen + proof.length != proofFlags.length + 1) {
447+
if (leavesLen + proof.length != proofFlagsLen + 1) {
444448
revert MerkleProofInvalidMultiproof();
445449
}
446450

447451
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
448452
// `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
449-
bytes32[] memory hashes = new bytes32[](proofFlags.length);
453+
bytes32[] memory hashes = new bytes32[](proofFlagsLen);
450454
uint256 leafPos = 0;
451455
uint256 hashPos = 0;
452456
uint256 proofPos = 0;
@@ -455,20 +459,20 @@ library MerkleProof {
455459
// get the next hash.
456460
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
457461
// `proof` array.
458-
for (uint256 i = 0; i < proofFlags.length; i++) {
462+
for (uint256 i = 0; i < proofFlagsLen; i++) {
459463
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
460464
bytes32 b = proofFlags[i]
461465
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
462466
: proof[proofPos++];
463467
hashes[i] = hasher(a, b);
464468
}
465469

466-
if (proofFlags.length > 0) {
470+
if (proofFlagsLen > 0) {
467471
if (proofPos != proof.length) {
468472
revert MerkleProofInvalidMultiproof();
469473
}
470474
unchecked {
471-
return hashes[proofFlags.length - 1];
475+
return hashes[proofFlagsLen - 1];
472476
}
473477
} else if (leavesLen > 0) {
474478
return leaves[0];

scripts/generate/templates/MerkleProof.js

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function verify${suffix}(${(hash ? formatArgsMultiline : formatArgsSingleLine)(
5656
'bytes32 leaf',
5757
hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`,
5858
)}) internal ${visibility} returns (bool) {
59-
return processProof(proof, leaf${hash ? `, ${hash}` : ''}) == root;
59+
return processProof${suffix}(proof, leaf${hash ? `, ${hash}` : ''}) == root;
6060
}
6161
6262
/**
@@ -93,10 +93,10 @@ function multiProofVerify${suffix}(${formatArgsMultiline(
9393
`bytes32[] ${location} proof`,
9494
`bool[] ${location} proofFlags`,
9595
'bytes32 root',
96-
`bytes32[] ${location} leaves`,
96+
`bytes32[] memory leaves`,
9797
hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`,
9898
)}) internal ${visibility} returns (bool) {
99-
return processMultiProof(proof, proofFlags, leaves${hash ? `, ${hash}` : ''}) == root;
99+
return processMultiProof${suffix}(proof, proofFlags, leaves${hash ? `, ${hash}` : ''}) == root;
100100
}
101101
102102
/**
@@ -114,23 +114,24 @@ function multiProofVerify${suffix}(${formatArgsMultiline(
114114
function processMultiProof${suffix}(${formatArgsMultiline(
115115
`bytes32[] ${location} proof`,
116116
`bool[] ${location} proofFlags`,
117-
`bytes32[] ${location} leaves`,
117+
`bytes32[] memory leaves`,
118118
hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`,
119119
)}) internal ${visibility} returns (bytes32 merkleRoot) {
120120
// This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by
121121
// consuming and producing values on a queue. The queue starts with the \`leaves\` array, then goes onto the
122122
// \`hashes\` array. At the end of the process, the last hash in the \`hashes\` array should contain the root of
123123
// the Merkle tree.
124124
uint256 leavesLen = leaves.length;
125+
uint256 proofFlagsLen = proofFlags.length;
125126
126127
// Check proof validity.
127-
if (leavesLen + proof.length != proofFlags.length + 1) {
128+
if (leavesLen + proof.length != proofFlagsLen + 1) {
128129
revert MerkleProofInvalidMultiproof();
129130
}
130131
131132
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
132133
// \`xxx[xxxPos++]\`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
133-
bytes32[] memory hashes = new bytes32[](proofFlags.length);
134+
bytes32[] memory hashes = new bytes32[](proofFlagsLen);
134135
uint256 leafPos = 0;
135136
uint256 hashPos = 0;
136137
uint256 proofPos = 0;
@@ -139,20 +140,20 @@ function processMultiProof${suffix}(${formatArgsMultiline(
139140
// get the next hash.
140141
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
141142
// \`proof\` array.
142-
for (uint256 i = 0; i < proofFlags.length; i++) {
143+
for (uint256 i = 0; i < proofFlagsLen; i++) {
143144
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
144145
bytes32 b = proofFlags[i]
145146
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
146147
: proof[proofPos++];
147148
hashes[i] = ${hash ?? DEFAULT_HASH}(a, b);
148149
}
149150
150-
if (proofFlags.length > 0) {
151+
if (proofFlagsLen > 0) {
151152
if (proofPos != proof.length) {
152153
revert MerkleProofInvalidMultiproof();
153154
}
154155
unchecked {
155-
return hashes[proofFlags.length - 1];
156+
return hashes[proofFlagsLen - 1];
156157
}
157158
} else if (leavesLen > 0) {
158159
return leaves[0];

0 commit comments

Comments
 (0)