Skip to content

Commit 95969bf

Browse files
authored
SMT proof verification (#149)
* Add proof verification for SMT * refactor and add on-chain siblings trimming * refactor and bump the patch version
1 parent 7822238 commit 95969bf

File tree

5 files changed

+252
-9
lines changed

5 files changed

+252
-9
lines changed

contracts/libs/data-structures/SparseMerkleTree.sol

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ pragma solidity ^0.8.21;
5555
*
5656
* SparseMerkleTree.Proof memory proof = uintTree.getProof(100);
5757
*
58+
* uintTree.verifyProof(proof);
59+
*
5860
* uintTree.getNodeByKey(100);
5961
*
6062
* uintTree.remove(100);
@@ -177,6 +179,18 @@ library SparseMerkleTree {
177179
return _proof(tree._tree, bytes32(key_));
178180
}
179181

182+
/**
183+
* @notice The function to verify the proof for inclusion or exclusion of a node in the SMT.
184+
* Complexity is O(log(n)), where n is the max depth of the tree.
185+
*
186+
* @param tree self.
187+
* @param proof_ The SMT proof struct.
188+
* @return True if the proof is valid, false otherwise.
189+
*/
190+
function verifyProof(UintSMT storage tree, Proof memory proof_) internal view returns (bool) {
191+
return _verifyProof(tree._tree, proof_);
192+
}
193+
180194
/**
181195
* @notice The function to get the root of the Merkle tree.
182196
* Complexity is O(1).
@@ -347,6 +361,21 @@ library SparseMerkleTree {
347361
return _proof(tree._tree, key_);
348362
}
349363

364+
/**
365+
* @notice The function to verify the proof for inclusion or exclusion of a node in the SMT.
366+
* Complexity is O(log(n)), where n is the max depth of the tree.
367+
*
368+
* @param tree self.
369+
* @param proof_ The SMT proof struct.
370+
* @return True if the proof is valid, false otherwise.
371+
*/
372+
function verifyProof(
373+
Bytes32SMT storage tree,
374+
Proof memory proof_
375+
) internal view returns (bool) {
376+
return _verifyProof(tree._tree, proof_);
377+
}
378+
350379
/**
351380
* @notice The function to get the root of the Merkle tree.
352381
* Complexity is O(1).
@@ -523,6 +552,21 @@ library SparseMerkleTree {
523552
return _proof(tree._tree, key_);
524553
}
525554

555+
/**
556+
* @notice The function to verify the proof for inclusion or exclusion of a node in the SMT.
557+
* Complexity is O(log(n)), where n is the max depth of the tree.
558+
*
559+
* @param tree self.
560+
* @param proof_ The SMT proof struct.
561+
* @return True if the proof is valid, false otherwise.
562+
*/
563+
function verifyProof(
564+
AddressSMT storage tree,
565+
Proof memory proof_
566+
) internal view returns (bool) {
567+
return _verifyProof(tree._tree, proof_);
568+
}
569+
526570
/**
527571
* @notice The function to get the root of the Merkle tree.
528572
* Complexity is O(1).
@@ -988,12 +1032,10 @@ library SparseMerkleTree {
9881032
* non-empty nodes and is not intended for external use.
9891033
*/
9901034
function _getNodeHash(SMT storage tree, Node memory node_) private view returns (bytes32) {
991-
function(bytes32, bytes32) view returns (bytes32) hash2_ = tree.isCustomHasherSet
992-
? tree.hash2
993-
: _hash2;
994-
function(bytes32, bytes32, bytes32) view returns (bytes32) hash3_ = tree.isCustomHasherSet
995-
? tree.hash3
996-
: _hash3;
1035+
(
1036+
function(bytes32, bytes32) view returns (bytes32) hash2_,
1037+
function(bytes32, bytes32, bytes32) view returns (bytes32) hash3_
1038+
) = _getHashFunctions(tree);
9971039

9981040
if (node_.nodeType == NodeType.LEAF) {
9991041
return hash3_(node_.key, node_.value, bytes32(uint256(1)));
@@ -1054,6 +1096,66 @@ library SparseMerkleTree {
10541096
return proof_;
10551097
}
10561098

1099+
/**
1100+
* @dev Computes the root by hashing up the path from the leaf or aux leaf of the proof.
1101+
* The `tree` argument is used only to access its configured hash functions.
1102+
* If no custom hash functions are configured, default hashing implementations are used instead.
1103+
*/
1104+
function _verifyProof(SMT storage tree, Proof memory proof_) private view returns (bool) {
1105+
// invalid exclusion proof
1106+
if (!proof_.existence && proof_.auxExistence && proof_.key == proof_.auxKey) {
1107+
return false;
1108+
}
1109+
1110+
(
1111+
function(bytes32, bytes32) view returns (bytes32) hash2_,
1112+
function(bytes32, bytes32, bytes32) view returns (bytes32) hash3_
1113+
) = _getHashFunctions(tree);
1114+
1115+
bytes32 computedHash_;
1116+
1117+
if (proof_.existence) {
1118+
computedHash_ = hash3_(proof_.key, proof_.value, bytes32(uint256(1)));
1119+
} else if (proof_.auxExistence) {
1120+
computedHash_ = hash3_(proof_.auxKey, proof_.auxValue, bytes32(uint256(1)));
1121+
} else {
1122+
computedHash_ = bytes32(0);
1123+
}
1124+
1125+
uint256 pathIndex_ = uint256(proof_.key);
1126+
uint256 depth_ = proof_.siblings.length;
1127+
1128+
while (depth_ > 0 && proof_.siblings[depth_ - 1] == bytes32(0)) {
1129+
--depth_;
1130+
}
1131+
1132+
for (uint256 i = depth_; i > 0; --i) {
1133+
uint256 sIndex_ = i - 1;
1134+
1135+
if ((pathIndex_ >> sIndex_) & 1 == 1) {
1136+
computedHash_ = hash2_(proof_.siblings[sIndex_], computedHash_);
1137+
} else {
1138+
computedHash_ = hash2_(computedHash_, proof_.siblings[sIndex_]);
1139+
}
1140+
}
1141+
1142+
return computedHash_ == proof_.root;
1143+
}
1144+
1145+
function _getHashFunctions(
1146+
SMT storage tree
1147+
)
1148+
private
1149+
view
1150+
returns (
1151+
function(bytes32, bytes32) view returns (bytes32) hash2_,
1152+
function(bytes32, bytes32, bytes32) view returns (bytes32) hash3_
1153+
)
1154+
{
1155+
hash2_ = tree.isCustomHasherSet ? tree.hash2 : _hash2;
1156+
hash3_ = tree.isCustomHasherSet ? tree.hash3 : _hash3;
1157+
}
1158+
10571159
function _hash2(bytes32 a_, bytes32 b_) private pure returns (bytes32 result_) {
10581160
assembly {
10591161
mstore(0, a_)

contracts/mock/libs/data-structures/SparseMerkleTreeMock.sol

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,22 @@ contract SparseMerkleTreeMock {
103103
return _addressTree.getProof(key_);
104104
}
105105

106+
function verifyUintProof(SparseMerkleTree.Proof memory proof_) external view returns (bool) {
107+
return _uintTree.verifyProof(proof_);
108+
}
109+
110+
function verifyBytes32Proof(
111+
SparseMerkleTree.Proof memory proof_
112+
) external view returns (bool) {
113+
return _bytes32Tree.verifyProof(proof_);
114+
}
115+
116+
function verifyAddressProof(
117+
SparseMerkleTree.Proof memory proof_
118+
) external view returns (bool) {
119+
return _addressTree.verifyProof(proof_);
120+
}
121+
106122
function getUintRoot() external view returns (bytes32) {
107123
return _uintTree.getRoot();
108124
}

package-lock.json

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@solarity/solidity-lib",
3-
"version": "3.1.1",
3+
"version": "3.1.2",
44
"license": "MIT",
55
"author": "Distributed Lab",
66
"readme": "README.md",

test/libs/data-structures/SparseMerkleTree.test.ts

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,65 @@ describe("SparseMerkleTree", () => {
410410

411411
expect((await merkleTree.getUintNodeByKey(5n)).nodeType).to.be.equal(0);
412412
});
413+
414+
it("should handle proof verification correctly", async () => {
415+
const treeSize = 20;
416+
417+
const keys: string[] = new Array(treeSize);
418+
419+
for (let i = 1; i <= treeSize; i++) {
420+
const value = BigInt(toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32));
421+
const key = poseidonHash(toBeHex(`0x` + value.toString(16), 32));
422+
423+
keys[i - 1] = key;
424+
425+
await merkleTree.addUint(key, value);
426+
427+
await localMerkleTree.add(BigInt(key), BigInt(value));
428+
}
429+
430+
const randomNum = Math.floor(Math.random() * (treeSize - 1));
431+
const randomKey = keys[randomNum];
432+
433+
const inclusionProof = JSON.parse(JSON.stringify(await merkleTree.getUintProof(randomKey)));
434+
435+
expect(await merkleTree.verifyUintProof(inclusionProof)).to.be.true;
436+
437+
inclusionProof[0] = inclusionProof[3];
438+
expect(await merkleTree.verifyUintProof(inclusionProof)).to.be.false;
439+
440+
await merkleTree.removeUint(randomKey);
441+
442+
let exclusionProof = JSON.parse(JSON.stringify(await merkleTree.getUintProof(randomKey)));
443+
444+
expect(await merkleTree.verifyUintProof(exclusionProof)).to.be.true;
445+
446+
exclusionProof[0] = exclusionProof[3];
447+
expect(await merkleTree.verifyUintProof(exclusionProof)).to.be.false;
448+
449+
const [root, siblings, , , value, , , auxValue] = exclusionProof;
450+
451+
const invalidExclusionProof = {
452+
root,
453+
siblings,
454+
key: randomKey,
455+
value,
456+
existence: false,
457+
auxKey: randomKey,
458+
auxValue,
459+
auxExistence: true,
460+
};
461+
462+
expect(await merkleTree.verifyUintProof(invalidExclusionProof)).to.be.false;
463+
464+
do {
465+
const outOfRangeKey = toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32);
466+
467+
exclusionProof = JSON.parse(JSON.stringify(await merkleTree.getUintProof(outOfRangeKey)));
468+
} while (exclusionProof[2] || exclusionProof[5]);
469+
470+
expect(await merkleTree.verifyUintProof(exclusionProof)).to.be.true;
471+
});
413472
});
414473

415474
describe("Bytes32 SMT", () => {
@@ -525,6 +584,39 @@ describe("SparseMerkleTree", () => {
525584
expect(await verifyProof(await localMerkleTree.root(), onchainProof, BigInt(key), BigInt(value))).to.be.true;
526585
}
527586
});
587+
588+
it("should handle proof verification correctly", async () => {
589+
const treeSize = 20;
590+
591+
const keys: string[] = new Array(treeSize);
592+
593+
for (let i = 1; i <= treeSize; i++) {
594+
const value = toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32);
595+
const key = poseidonHash(value);
596+
597+
keys[i - 1] = key;
598+
599+
await merkleTree.addBytes32(key, value);
600+
}
601+
602+
const randomKey = keys[Math.floor(Math.random() * (treeSize - 1))];
603+
604+
const inclusionProof = JSON.parse(JSON.stringify(await merkleTree.getBytes32Proof(randomKey)));
605+
606+
expect(await merkleTree.verifyBytes32Proof(inclusionProof)).to.be.true;
607+
608+
inclusionProof[0] = inclusionProof[3];
609+
expect(await merkleTree.verifyBytes32Proof(inclusionProof)).to.be.false;
610+
611+
await merkleTree.removeBytes32(randomKey);
612+
613+
const exclusionProof = JSON.parse(JSON.stringify(await merkleTree.getBytes32Proof(randomKey)));
614+
615+
expect(await merkleTree.verifyBytes32Proof(exclusionProof)).to.be.true;
616+
617+
exclusionProof[0] = exclusionProof[3];
618+
expect(await merkleTree.verifyBytes32Proof(exclusionProof)).to.be.false;
619+
});
528620
});
529621

530622
describe("Address SMT", () => {
@@ -640,5 +732,38 @@ describe("SparseMerkleTree", () => {
640732
expect(await verifyProof(await localMerkleTree.root(), onchainProof, BigInt(key), BigInt(value))).to.be.true;
641733
}
642734
});
735+
736+
it("should handle proof verification correctly", async () => {
737+
const treeSize = 20;
738+
739+
const keys: string[] = new Array(treeSize);
740+
741+
for (let i = 1; i <= treeSize; i++) {
742+
const value = toBeHex(BigInt(await USER1.getAddress()) + BigInt(i));
743+
const key = poseidonHash(value);
744+
745+
keys[i - 1] = key;
746+
747+
await merkleTree.addAddress(key, value);
748+
}
749+
750+
const randomKey = keys[Math.floor(Math.random() * (treeSize - 1))];
751+
752+
const inclusionProof = JSON.parse(JSON.stringify(await merkleTree.getAddressProof(randomKey)));
753+
754+
expect(await merkleTree.verifyAddressProof(inclusionProof)).to.be.true;
755+
756+
inclusionProof[0] = inclusionProof[3];
757+
expect(await merkleTree.verifyAddressProof(inclusionProof)).to.be.false;
758+
759+
await merkleTree.removeAddress(randomKey);
760+
761+
const exclusionProof = JSON.parse(JSON.stringify(await merkleTree.getAddressProof(randomKey)));
762+
763+
expect(await merkleTree.verifyAddressProof(exclusionProof)).to.be.true;
764+
765+
exclusionProof[0] = exclusionProof[3];
766+
expect(await merkleTree.verifyAddressProof(exclusionProof)).to.be.false;
767+
});
643768
});
644769
});

0 commit comments

Comments
 (0)