From 0185c636cc65b978f8458d431fc9618bed97fd00 Mon Sep 17 00:00:00 2001 From: Vinh Tran Date: Mon, 1 Jul 2024 19:57:24 +0700 Subject: [PATCH] fix:seperate bytes and uint --- contracts/src/Common.sol | 7 ++ contracts/src/SmartVault.sol | 156 +++++++++++++++++++++++++---- contracts/src/libraries/Parser.sol | 14 +++ contracts/test/Parser.t.sol | 4 +- contracts/test/SmartVault.t.sol | 8 +- 5 files changed, 166 insertions(+), 23 deletions(-) diff --git a/contracts/src/Common.sol b/contracts/src/Common.sol index 7a7b154..3b5511f 100644 --- a/contracts/src/Common.sol +++ b/contracts/src/Common.sol @@ -12,6 +12,13 @@ enum Operator { } enum Type { + UINT8, + UINT16, + UINT24, + UINT32, + UINT64, + UINT128, + UINT256, INT8, INT16, INT24, diff --git a/contracts/src/SmartVault.sol b/contracts/src/SmartVault.sol index fc99f9d..4c09fed 100644 --- a/contracts/src/SmartVault.sol +++ b/contracts/src/SmartVault.sol @@ -248,6 +248,24 @@ contract SmartVault is ISmartVault { } } + function _checkUint256(Operator operator, uint256 value, uint256 threshold) private pure { + if (operator == Operator.EQ) { + require(value == threshold, "SmartVault: invalid int256"); + } else if (operator == Operator.NEQ) { + require(value != threshold, "SmartVault: invalid int256"); + } else if (operator == Operator.GT) { + require(value > threshold, "SmartVault: invalid int256"); + } else if (operator == Operator.GTE) { + require(value >= threshold, "SmartVault: invalid int256"); + } else if (operator == Operator.LT) { + require(value < threshold, "SmartVault: invalid int256"); + } else if (operator == Operator.LTE) { + require(value <= threshold, "SmartVault: invalid int256"); + } else if (operator != Operator.NONE) { + revert("SmartVault: invalid operation"); + } + } + function _checkInt256(Operator operator, int256 value, int256 threshold) private pure { if (operator == Operator.EQ) { require(value == threshold, "SmartVault: invalid int256"); @@ -338,140 +356,244 @@ contract SmartVault is ISmartVault { for (uint256 i = 0; i < rules.types.length; i++) { if (rules.types[i] == Type.BYTES1) { bytes1 value; - bytes1 threshold = abi.decode(rules.thresholds[i], (bytes1)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + bytes1 threshold = abi.decode(rules.thresholds[i], (bytes1)); _checkBytes32(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.BYTES2) { bytes2 value; - bytes2 threshold = abi.decode(rules.thresholds[i], (bytes2)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + bytes2 threshold = abi.decode(rules.thresholds[i], (bytes2)); _checkBytes32(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.BYTES3) { bytes3 value; - bytes3 threshold = abi.decode(rules.thresholds[i], (bytes3)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + bytes3 threshold = abi.decode(rules.thresholds[i], (bytes3)); _checkBytes32(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.BYTES4) { bytes4 value; - bytes4 threshold = abi.decode(rules.thresholds[i], (bytes4)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + bytes4 threshold = abi.decode(rules.thresholds[i], (bytes4)); _checkBytes32(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.BYTES8) { bytes8 value; - bytes8 threshold = abi.decode(rules.thresholds[i], (bytes8)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + bytes8 threshold = abi.decode(rules.thresholds[i], (bytes8)); _checkBytes32(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.BYTES16) { bytes16 value; - bytes16 threshold = abi.decode(rules.thresholds[i], (bytes16)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + bytes16 threshold = abi.decode(rules.thresholds[i], (bytes16)); _checkBytes32(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.BYTES32) { bytes32 value; - bytes32 threshold = abi.decode(rules.thresholds[i], (bytes32)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + bytes32 threshold = abi.decode(rules.thresholds[i], (bytes32)); _checkBytes32(rules.operators[i], value, threshold); + } else if (rules.types[i] == Type.UINT8) { + uint8 value; + assembly { + value := mload(validationData) + validationData := add(validationData, 0x20) + } + if (rules.operators[i] == Operator.NONE) continue; + + uint8 threshold = abi.decode(rules.thresholds[i], (uint8)); + _checkUint256(rules.operators[i], value, threshold); + } else if (rules.types[i] == Type.UINT16) { + uint16 value; + assembly { + value := mload(validationData) + validationData := add(validationData, 0x20) + } + if (rules.operators[i] == Operator.NONE) continue; + + uint16 threshold = abi.decode(rules.thresholds[i], (uint16)); + _checkUint256(rules.operators[i], value, threshold); + } else if (rules.types[i] == Type.UINT24) { + uint24 value; + assembly { + value := mload(validationData) + validationData := add(validationData, 0x20) + } + if (rules.operators[i] == Operator.NONE) continue; + + uint24 threshold = abi.decode(rules.thresholds[i], (uint24)); + _checkUint256(rules.operators[i], value, threshold); + } else if (rules.types[i] == Type.UINT32) { + uint32 value; + assembly { + value := mload(validationData) + validationData := add(validationData, 0x20) + } + if (rules.operators[i] == Operator.NONE) continue; + + uint32 threshold = abi.decode(rules.thresholds[i], (uint32)); + _checkUint256(rules.operators[i], value, threshold); + } else if (rules.types[i] == Type.UINT64) { + uint64 value; + assembly { + value := mload(validationData) + validationData := add(validationData, 0x20) + } + if (rules.operators[i] == Operator.NONE) continue; + + uint64 threshold = abi.decode(rules.thresholds[i], (uint64)); + _checkUint256(rules.operators[i], value, threshold); + } else if (rules.types[i] == Type.UINT128) { + uint128 value; + assembly { + value := mload(validationData) + validationData := add(validationData, 0x20) + } + if (rules.operators[i] == Operator.NONE) continue; + + uint128 threshold = abi.decode(rules.thresholds[i], (uint128)); + _checkUint256(rules.operators[i], value, threshold); + } else if (rules.types[i] == Type.UINT256) { + uint256 value; + assembly { + value := mload(validationData) + validationData := add(validationData, 0x20) + } + if (rules.operators[i] == Operator.NONE) continue; + + uint256 threshold = abi.decode(rules.thresholds[i], (uint256)); + _checkUint256(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.INT8) { int8 value; - int8 threshold = abi.decode(rules.thresholds[i], (int8)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + int8 threshold = abi.decode(rules.thresholds[i], (int8)); _checkInt256(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.INT16) { int16 value; - int16 threshold = abi.decode(rules.thresholds[i], (int16)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + int16 threshold = abi.decode(rules.thresholds[i], (int16)); _checkInt256(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.INT24) { int24 value; - int24 threshold = abi.decode(rules.thresholds[i], (int24)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + int24 threshold = abi.decode(rules.thresholds[i], (int24)); _checkInt256(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.INT32) { int32 value; - int32 threshold = abi.decode(rules.thresholds[i], (int32)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + int32 threshold = abi.decode(rules.thresholds[i], (int32)); _checkInt256(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.INT64) { int64 value; - int64 threshold = abi.decode(rules.thresholds[i], (int64)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + int64 threshold = abi.decode(rules.thresholds[i], (int64)); _checkInt256(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.INT128) { int128 value; - int128 threshold = abi.decode(rules.thresholds[i], (int128)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + int128 threshold = abi.decode(rules.thresholds[i], (int128)); _checkInt256(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.INT256) { int256 value; - int256 threshold = abi.decode(rules.thresholds[i], (int256)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + int256 threshold = abi.decode(rules.thresholds[i], (int256)); _checkInt256(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.ADDRESS) { address value; - address threshold = abi.decode(rules.thresholds[i], (address)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + address threshold = abi.decode(rules.thresholds[i], (address)); _checkAddress(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.BOOL) { bool value; - bool threshold = abi.decode(rules.thresholds[i], (bool)); assembly { value := mload(validationData) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + bool threshold = abi.decode(rules.thresholds[i], (bool)); _checkBool(rules.operators[i], value, threshold); } else if (rules.types[i] == Type.BYTES || rules.types[i] == Type.STRING) { bytes memory value; - bytes memory threshold = abi.decode(rules.thresholds[i], (bytes)); assembly { let offset := mload(validationData) value := add(offset, pointer) validationData := add(validationData, 0x20) } + if (rules.operators[i] == Operator.NONE) continue; + + bytes memory threshold = abi.decode(rules.thresholds[i], (bytes)); _checkBytes(rules.operators[i], value, threshold); } else { require(rules.operators[i] == Operator.NONE, "SmartVault: can not compare unsupported type"); diff --git a/contracts/src/libraries/Parser.sol b/contracts/src/libraries/Parser.sol index 6d987f4..1ceedb6 100644 --- a/contracts/src/libraries/Parser.sol +++ b/contracts/src/libraries/Parser.sol @@ -71,6 +71,20 @@ library Parser { return Type.INT128; } else if (keccak256(b) == keccak256("int256") || keccak256(b) == keccak256("int")) { return Type.INT256; + } else if (keccak256(b) == keccak256("uint8")) { + return Type.UINT8; + } else if (keccak256(b) == keccak256("uint16")) { + return Type.UINT16; + } else if (keccak256(b) == keccak256("uint24")) { + return Type.UINT24; + } else if (keccak256(b) == keccak256("uint32")) { + return Type.UINT32; + } else if (keccak256(b) == keccak256("uint64")) { + return Type.UINT64; + } else if (keccak256(b) == keccak256("uint128")) { + return Type.UINT128; + } else if (keccak256(b) == keccak256("uint256") || keccak256(b) == keccak256("uint")) { + return Type.UINT256; } else if (keccak256(b) == keccak256("bytes1") || keccak256(b) == keccak256("uint8")) { return Type.BYTES1; } else if (keccak256(b) == keccak256("bytes2") || keccak256(b) == keccak256("uint16")) { diff --git a/contracts/test/Parser.t.sol b/contracts/test/Parser.t.sol index d02f58f..05a9aac 100644 --- a/contracts/test/Parser.t.sol +++ b/contracts/test/Parser.t.sol @@ -14,7 +14,7 @@ contract ParserTest is Test { string memory input = "uint256 age,address recipient,bool passed"; Type[] memory output = input.extractTypes(); assertEq(output.length, 3); - assertEq(uint256(output[0]), uint256(Type.BYTES32)); + assertEq(uint256(output[0]), uint256(Type.UINT256)); assertEq(uint256(output[1]), uint256(Type.ADDRESS)); assertEq(uint256(output[2]), uint256(Type.BOOL)); } @@ -24,7 +24,7 @@ contract ParserTest is Test { bytes memory inputBytes = bytes(input); Type[] memory output = inputBytes.extractTypes(); assertEq(output.length, 3); - assertEq(uint256(output[0]), uint256(Type.BYTES32)); + assertEq(uint256(output[0]), uint256(Type.UINT256)); assertEq(uint256(output[1]), uint256(Type.ADDRESS)); assertEq(uint256(output[2]), uint256(Type.BOOL)); } diff --git a/contracts/test/SmartVault.t.sol b/contracts/test/SmartVault.t.sol index a3ea90c..d443f0d 100644 --- a/contracts/test/SmartVault.t.sol +++ b/contracts/test/SmartVault.t.sol @@ -76,7 +76,7 @@ contract SmartVaultTest is Test { smartVault.getRules(vaultId); assertEq(types.length, 3); - assertTrue(types[0] == Type.BYTES32); + assertTrue(types[0] == Type.UINT256); assertTrue(types[1] == Type.STRING); assertTrue(types[2] == Type.BOOL); @@ -101,7 +101,7 @@ contract SmartVaultTest is Test { ops[2] = Operator.EQ; bytes[] memory thresholds = new bytes[](3); - thresholds[0] = abi.encode(18); + thresholds[0] = abi.encode(22); thresholds[1] = abi.encode("MIT"); thresholds[2] = abi.encode(true); @@ -152,7 +152,7 @@ contract SmartVaultTest is Test { ops[2] = Operator.EQ; bytes[] memory thresholds = new bytes[](3); - thresholds[0] = abi.encode(18); + thresholds[0] = abi.encode(22); thresholds[1] = abi.encode("MIT"); thresholds[2] = abi.encode(true); @@ -193,7 +193,7 @@ contract SmartVaultTest is Test { // create attestation address claimer = makeAddr("claimer"); AttestationRequestData memory data = - AttestationRequestData(claimer, NO_EXPIRATION_TIME, false, bytes32(0), abi.encode(20, "MIT", true), 0); + AttestationRequestData(claimer, NO_EXPIRATION_TIME, false, bytes32(0), abi.encode(23, "MIT", true), 0); AttestationRequest memory request = AttestationRequest(validationSchema, data); bytes32 attestationUID = eas.attest(request);