Skip to content

fix: seperate bytes and uint #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions contracts/src/Common.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ enum Operator {
}

enum Type {
UINT8,
UINT16,
UINT24,
UINT32,
UINT64,
UINT128,
UINT256,
INT8,
INT16,
INT24,
Expand Down
156 changes: 139 additions & 17 deletions contracts/src/SmartVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,24 @@ contract SmartVault is ISmartVault {
}
}

function _checkUint256(Operator operator, uint256 value, uint256 threshold) private pure {
if (operator == Operator.EQ) {
require(value == threshold, "SmartVault: invalid int256");
} else if (operator == Operator.NEQ) {
require(value != threshold, "SmartVault: invalid int256");
} else if (operator == Operator.GT) {
require(value > threshold, "SmartVault: invalid int256");
} else if (operator == Operator.GTE) {
require(value >= threshold, "SmartVault: invalid int256");
} else if (operator == Operator.LT) {
require(value < threshold, "SmartVault: invalid int256");
} else if (operator == Operator.LTE) {
require(value <= threshold, "SmartVault: invalid int256");
} else if (operator != Operator.NONE) {
revert("SmartVault: invalid operation");
}
}

function _checkInt256(Operator operator, int256 value, int256 threshold) private pure {
if (operator == Operator.EQ) {
require(value == threshold, "SmartVault: invalid int256");
Expand Down Expand Up @@ -338,140 +356,244 @@ contract SmartVault is ISmartVault {
for (uint256 i = 0; i < rules.types.length; i++) {
if (rules.types[i] == Type.BYTES1) {
bytes1 value;
bytes1 threshold = abi.decode(rules.thresholds[i], (bytes1));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

bytes1 threshold = abi.decode(rules.thresholds[i], (bytes1));
_checkBytes32(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.BYTES2) {
bytes2 value;
bytes2 threshold = abi.decode(rules.thresholds[i], (bytes2));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

bytes2 threshold = abi.decode(rules.thresholds[i], (bytes2));
_checkBytes32(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.BYTES3) {
bytes3 value;
bytes3 threshold = abi.decode(rules.thresholds[i], (bytes3));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

bytes3 threshold = abi.decode(rules.thresholds[i], (bytes3));
_checkBytes32(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.BYTES4) {
bytes4 value;
bytes4 threshold = abi.decode(rules.thresholds[i], (bytes4));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

bytes4 threshold = abi.decode(rules.thresholds[i], (bytes4));
_checkBytes32(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.BYTES8) {
bytes8 value;
bytes8 threshold = abi.decode(rules.thresholds[i], (bytes8));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

bytes8 threshold = abi.decode(rules.thresholds[i], (bytes8));
_checkBytes32(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.BYTES16) {
bytes16 value;
bytes16 threshold = abi.decode(rules.thresholds[i], (bytes16));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

bytes16 threshold = abi.decode(rules.thresholds[i], (bytes16));
_checkBytes32(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.BYTES32) {
bytes32 value;
bytes32 threshold = abi.decode(rules.thresholds[i], (bytes32));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

bytes32 threshold = abi.decode(rules.thresholds[i], (bytes32));
_checkBytes32(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.UINT8) {
uint8 value;
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

uint8 threshold = abi.decode(rules.thresholds[i], (uint8));
_checkUint256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.UINT16) {
uint16 value;
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

uint16 threshold = abi.decode(rules.thresholds[i], (uint16));
_checkUint256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.UINT24) {
uint24 value;
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

uint24 threshold = abi.decode(rules.thresholds[i], (uint24));
_checkUint256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.UINT32) {
uint32 value;
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

uint32 threshold = abi.decode(rules.thresholds[i], (uint32));
_checkUint256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.UINT64) {
uint64 value;
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

uint64 threshold = abi.decode(rules.thresholds[i], (uint64));
_checkUint256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.UINT128) {
uint128 value;
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

uint128 threshold = abi.decode(rules.thresholds[i], (uint128));
_checkUint256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.UINT256) {
uint256 value;
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

uint256 threshold = abi.decode(rules.thresholds[i], (uint256));
_checkUint256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.INT8) {
int8 value;
int8 threshold = abi.decode(rules.thresholds[i], (int8));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

int8 threshold = abi.decode(rules.thresholds[i], (int8));
_checkInt256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.INT16) {
int16 value;
int16 threshold = abi.decode(rules.thresholds[i], (int16));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

int16 threshold = abi.decode(rules.thresholds[i], (int16));
_checkInt256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.INT24) {
int24 value;
int24 threshold = abi.decode(rules.thresholds[i], (int24));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

int24 threshold = abi.decode(rules.thresholds[i], (int24));
_checkInt256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.INT32) {
int32 value;
int32 threshold = abi.decode(rules.thresholds[i], (int32));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

int32 threshold = abi.decode(rules.thresholds[i], (int32));
_checkInt256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.INT64) {
int64 value;
int64 threshold = abi.decode(rules.thresholds[i], (int64));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

int64 threshold = abi.decode(rules.thresholds[i], (int64));
_checkInt256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.INT128) {
int128 value;
int128 threshold = abi.decode(rules.thresholds[i], (int128));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

int128 threshold = abi.decode(rules.thresholds[i], (int128));
_checkInt256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.INT256) {
int256 value;
int256 threshold = abi.decode(rules.thresholds[i], (int256));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

int256 threshold = abi.decode(rules.thresholds[i], (int256));
_checkInt256(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.ADDRESS) {
address value;
address threshold = abi.decode(rules.thresholds[i], (address));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

address threshold = abi.decode(rules.thresholds[i], (address));
_checkAddress(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.BOOL) {
bool value;
bool threshold = abi.decode(rules.thresholds[i], (bool));
assembly {
value := mload(validationData)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

bool threshold = abi.decode(rules.thresholds[i], (bool));
_checkBool(rules.operators[i], value, threshold);
} else if (rules.types[i] == Type.BYTES || rules.types[i] == Type.STRING) {
bytes memory value;
bytes memory threshold = abi.decode(rules.thresholds[i], (bytes));
assembly {
let offset := mload(validationData)
value := add(offset, pointer)
validationData := add(validationData, 0x20)
}
if (rules.operators[i] == Operator.NONE) continue;

bytes memory threshold = abi.decode(rules.thresholds[i], (bytes));
_checkBytes(rules.operators[i], value, threshold);
} else {
require(rules.operators[i] == Operator.NONE, "SmartVault: can not compare unsupported type");
Expand Down
14 changes: 14 additions & 0 deletions contracts/src/libraries/Parser.sol
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ library Parser {
return Type.INT128;
} else if (keccak256(b) == keccak256("int256") || keccak256(b) == keccak256("int")) {
return Type.INT256;
} else if (keccak256(b) == keccak256("uint8")) {
return Type.UINT8;
} else if (keccak256(b) == keccak256("uint16")) {
return Type.UINT16;
} else if (keccak256(b) == keccak256("uint24")) {
return Type.UINT24;
} else if (keccak256(b) == keccak256("uint32")) {
return Type.UINT32;
} else if (keccak256(b) == keccak256("uint64")) {
return Type.UINT64;
} else if (keccak256(b) == keccak256("uint128")) {
return Type.UINT128;
} else if (keccak256(b) == keccak256("uint256") || keccak256(b) == keccak256("uint")) {
return Type.UINT256;
} else if (keccak256(b) == keccak256("bytes1") || keccak256(b) == keccak256("uint8")) {
return Type.BYTES1;
} else if (keccak256(b) == keccak256("bytes2") || keccak256(b) == keccak256("uint16")) {
Expand Down
4 changes: 2 additions & 2 deletions contracts/test/Parser.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ contract ParserTest is Test {
string memory input = "uint256 age,address recipient,bool passed";
Type[] memory output = input.extractTypes();
assertEq(output.length, 3);
assertEq(uint256(output[0]), uint256(Type.BYTES32));
assertEq(uint256(output[0]), uint256(Type.UINT256));
assertEq(uint256(output[1]), uint256(Type.ADDRESS));
assertEq(uint256(output[2]), uint256(Type.BOOL));
}
Expand All @@ -24,7 +24,7 @@ contract ParserTest is Test {
bytes memory inputBytes = bytes(input);
Type[] memory output = inputBytes.extractTypes();
assertEq(output.length, 3);
assertEq(uint256(output[0]), uint256(Type.BYTES32));
assertEq(uint256(output[0]), uint256(Type.UINT256));
assertEq(uint256(output[1]), uint256(Type.ADDRESS));
assertEq(uint256(output[2]), uint256(Type.BOOL));
}
Expand Down
8 changes: 4 additions & 4 deletions contracts/test/SmartVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ contract SmartVaultTest is Test {
smartVault.getRules(vaultId);

assertEq(types.length, 3);
assertTrue(types[0] == Type.BYTES32);
assertTrue(types[0] == Type.UINT256);
assertTrue(types[1] == Type.STRING);
assertTrue(types[2] == Type.BOOL);

Expand All @@ -101,7 +101,7 @@ contract SmartVaultTest is Test {
ops[2] = Operator.EQ;

bytes[] memory thresholds = new bytes[](3);
thresholds[0] = abi.encode(18);
thresholds[0] = abi.encode(22);
thresholds[1] = abi.encode("MIT");
thresholds[2] = abi.encode(true);

Expand Down Expand Up @@ -152,7 +152,7 @@ contract SmartVaultTest is Test {
ops[2] = Operator.EQ;

bytes[] memory thresholds = new bytes[](3);
thresholds[0] = abi.encode(18);
thresholds[0] = abi.encode(22);
thresholds[1] = abi.encode("MIT");
thresholds[2] = abi.encode(true);

Expand Down Expand Up @@ -193,7 +193,7 @@ contract SmartVaultTest is Test {
// create attestation
address claimer = makeAddr("claimer");
AttestationRequestData memory data =
AttestationRequestData(claimer, NO_EXPIRATION_TIME, false, bytes32(0), abi.encode(20, "MIT", true), 0);
AttestationRequestData(claimer, NO_EXPIRATION_TIME, false, bytes32(0), abi.encode(23, "MIT", true), 0);
AttestationRequest memory request = AttestationRequest(validationSchema, data);
bytes32 attestationUID = eas.attest(request);

Expand Down
Loading