Skip to content

Commit f999ba4

Browse files
authored
Add 512bits add and mult operations (#5035)
1 parent 2ed8956 commit f999ba4

File tree

5 files changed

+372
-195
lines changed

5 files changed

+372
-195
lines changed

.changeset/blue-nails-give.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': minor
3+
---
4+
5+
`Math`: Add `add512`, `mul512` and `mulShr`.

contracts/utils/math/Math.sol

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,34 @@ library Math {
1717
Expand // Away from zero
1818
}
1919

20+
/**
21+
* @dev Return the 512-bit addition of two uint256.
22+
*
23+
* The result is stored in two 256 variables such that sum = high * 2²⁵⁶ + low.
24+
*/
25+
function add512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
26+
assembly ("memory-safe") {
27+
low := add(a, b)
28+
high := lt(low, a)
29+
}
30+
}
31+
32+
/**
33+
* @dev Return the 512-bit multiplication of two uint256.
34+
*
35+
* The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low.
36+
*/
37+
function mul512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
38+
// 512-bit multiply [high low] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
39+
// the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
40+
// variables such that product = high * 2²⁵⁶ + low.
41+
assembly ("memory-safe") {
42+
let mm := mulmod(a, b, not(0))
43+
low := mul(a, b)
44+
high := sub(sub(mm, low), lt(mm, low))
45+
}
46+
}
47+
2048
/**
2149
* @dev Returns the addition of two unsigned integers, with an success flag (no overflow).
2250
*/
@@ -143,42 +171,34 @@ library Math {
143171
*/
144172
function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
145173
unchecked {
146-
// 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
147-
// the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
148-
// variables such that product = prod1 * 2²⁵⁶ + prod0.
149-
uint256 prod0 = x * y; // Least significant 256 bits of the product
150-
uint256 prod1; // Most significant 256 bits of the product
151-
assembly {
152-
let mm := mulmod(x, y, not(0))
153-
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
154-
}
174+
(uint256 high, uint256 low) = mul512(x, y);
155175

156176
// Handle non-overflow cases, 256 by 256 division.
157-
if (prod1 == 0) {
177+
if (high == 0) {
158178
// Solidity will revert if denominator == 0, unlike the div opcode on its own.
159179
// The surrounding unchecked block does not change this fact.
160180
// See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic.
161-
return prod0 / denominator;
181+
return low / denominator;
162182
}
163183

164184
// Make sure the result is less than 2²⁵⁶. Also prevents denominator == 0.
165-
if (denominator <= prod1) {
185+
if (denominator <= high) {
166186
Panic.panic(ternary(denominator == 0, Panic.DIVISION_BY_ZERO, Panic.UNDER_OVERFLOW));
167187
}
168188

169189
///////////////////////////////////////////////
170190
// 512 by 256 division.
171191
///////////////////////////////////////////////
172192

173-
// Make division exact by subtracting the remainder from [prod1 prod0].
193+
// Make division exact by subtracting the remainder from [high low].
174194
uint256 remainder;
175195
assembly {
176196
// Compute remainder using mulmod.
177197
remainder := mulmod(x, y, denominator)
178198

179199
// Subtract 256 bit number from 512 bit number.
180-
prod1 := sub(prod1, gt(remainder, prod0))
181-
prod0 := sub(prod0, remainder)
200+
high := sub(high, gt(remainder, low))
201+
low := sub(low, remainder)
182202
}
183203

184204
// Factor powers of two out of denominator and compute largest power of two divisor of denominator.
@@ -189,15 +209,15 @@ library Math {
189209
// Divide denominator by twos.
190210
denominator := div(denominator, twos)
191211

192-
// Divide [prod1 prod0] by twos.
193-
prod0 := div(prod0, twos)
212+
// Divide [high low] by twos.
213+
low := div(low, twos)
194214

195215
// Flip twos such that it is 2²⁵⁶ / twos. If twos is zero, then it becomes one.
196216
twos := add(div(sub(0, twos), twos), 1)
197217
}
198218

199-
// Shift in bits from prod1 into prod0.
200-
prod0 |= prod1 * twos;
219+
// Shift in bits from high into low.
220+
low |= high * twos;
201221

202222
// Invert denominator mod 2²⁵⁶. Now that denominator is an odd number, it has an inverse modulo 2²⁵⁶ such
203223
// that denominator * inv ≡ 1 mod 2²⁵⁶. Compute the inverse by starting with a seed that is correct for
@@ -215,9 +235,9 @@ library Math {
215235

216236
// Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
217237
// This will give us the correct result modulo 2²⁵⁶. Since the preconditions guarantee that the outcome is
218-
// less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and prod1
238+
// less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and high
219239
// is no longer required.
220-
result = prod0 * inverse;
240+
result = low * inverse;
221241
return result;
222242
}
223243
}
@@ -229,6 +249,26 @@ library Math {
229249
return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
230250
}
231251

252+
/**
253+
* @dev Calculates floor(x * y >> n) with full precision. Throws if result overflows a uint256.
254+
*/
255+
function mulShr(uint256 x, uint256 y, uint8 n) internal pure returns (uint256 result) {
256+
unchecked {
257+
(uint256 high, uint256 low) = mul512(x, y);
258+
if (high >= 1 << n) {
259+
Panic.panic(Panic.UNDER_OVERFLOW);
260+
}
261+
return (high << (256 - n)) | (low >> n);
262+
}
263+
}
264+
265+
/**
266+
* @dev Calculates x * y >> n with full precision, following the selected rounding direction.
267+
*/
268+
function mulShr(uint256 x, uint256 y, uint8 n, Rounding rounding) internal pure returns (uint256) {
269+
return mulShr(x, y, n) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, 1 << n) > 0);
270+
}
271+
232272
/**
233273
* @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
234274
*

test/helpers/enums.js

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
function Enum(...options) {
2-
return Object.fromEntries(options.map((key, i) => [key, BigInt(i)]));
3-
}
1+
const { ethers } = require('ethers');
2+
3+
const Enum = (...options) => Object.fromEntries(options.map((key, i) => [key, BigInt(i)]));
4+
const EnumTyped = (...options) => Object.fromEntries(options.map((key, i) => [key, ethers.Typed.uint8(i)]));
45

56
module.exports = {
67
Enum,
8+
EnumTyped,
79
ProposalState: Enum('Pending', 'Active', 'Canceled', 'Defeated', 'Succeeded', 'Queued', 'Expired', 'Executed'),
810
VoteType: Object.assign(Enum('Against', 'For', 'Abstain'), { Parameters: 255n }),
9-
Rounding: Enum('Floor', 'Ceil', 'Trunc', 'Expand'),
11+
Rounding: EnumTyped('Floor', 'Ceil', 'Trunc', 'Expand'),
1012
OperationState: Enum('Unset', 'Waiting', 'Ready', 'Done'),
11-
RevertType: Enum('None', 'RevertWithoutMessage', 'RevertWithMessage', 'RevertWithCustomError', 'Panic'),
13+
RevertType: EnumTyped('None', 'RevertWithoutMessage', 'RevertWithMessage', 'RevertWithCustomError', 'Panic'),
1214
};

test/utils/math/Math.t.sol

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,48 @@ contract MathTest is Test {
1111
assertEq(Math.ternary(f, a, b), f ? a : b);
1212
}
1313

14+
// ADD512 & MUL512
15+
function testAdd512(uint256 a, uint256 b) public pure {
16+
(uint256 high, uint256 low) = Math.add512(a, b);
17+
18+
// test against tryAdd
19+
(bool success, uint256 result) = Math.tryAdd(a, b);
20+
if (success) {
21+
assertEq(high, 0);
22+
assertEq(low, result);
23+
} else {
24+
assertEq(high, 1);
25+
}
26+
27+
// test against unchecked
28+
unchecked {
29+
assertEq(low, a + b); // unchecked allow overflow
30+
}
31+
}
32+
33+
function testMul512(uint256 a, uint256 b) public pure {
34+
(uint256 high, uint256 low) = Math.mul512(a, b);
35+
36+
// test against tryMul
37+
(bool success, uint256 result) = Math.tryMul(a, b);
38+
if (success) {
39+
assertEq(high, 0);
40+
assertEq(low, result);
41+
} else {
42+
assertGt(high, 0);
43+
}
44+
45+
// test against unchecked
46+
unchecked {
47+
assertEq(low, a * b); // unchecked allow overflow
48+
}
49+
50+
// test against alternative method
51+
(uint256 _high, uint256 _low) = _mulKaratsuba(a, b);
52+
assertEq(high, _high);
53+
assertEq(low, _low);
54+
}
55+
1456
// MIN & MAX
1557
function testSymbolicMinMax(uint256 a, uint256 b) public pure {
1658
assertEq(Math.min(a, b), a < b ? a : b);
@@ -184,7 +226,7 @@ contract MathTest is Test {
184226
// MULDIV
185227
function testMulDiv(uint256 x, uint256 y, uint256 d) public pure {
186228
// Full precision for x * y
187-
(uint256 xyHi, uint256 xyLo) = _mulHighLow(x, y);
229+
(uint256 xyHi, uint256 xyLo) = Math.mul512(x, y);
188230

189231
// Assume result won't overflow (see {testMulDivDomain})
190232
// This also checks that `d` is positive
@@ -194,9 +236,9 @@ contract MathTest is Test {
194236
uint256 q = Math.mulDiv(x, y, d);
195237

196238
// Full precision for q * d
197-
(uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d);
239+
(uint256 qdHi, uint256 qdLo) = Math.mul512(q, d);
198240
// Add remainder of x * y / d (computed as rem = (x * y % d))
199-
(uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d));
241+
(uint256 c, uint256 qdRemLo) = Math.add512(qdLo, mulmod(x, y, d));
200242
uint256 qdRemHi = qdHi + c;
201243

202244
// Full precision check that x * y = q * d + rem
@@ -206,7 +248,7 @@ contract MathTest is Test {
206248

207249
/// forge-config: default.allow_internal_expect_revert = true
208250
function testMulDivDomain(uint256 x, uint256 y, uint256 d) public {
209-
(uint256 xyHi, ) = _mulHighLow(x, y);
251+
(uint256 xyHi, ) = Math.mul512(x, y);
210252

211253
// Violate {testMulDiv} assumption (covers d is 0 and result overflow)
212254
vm.assume(xyHi >= d);
@@ -266,26 +308,13 @@ contract MathTest is Test {
266308
}
267309
}
268310

269-
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
270-
if (m == 1) return 0;
271-
uint256 r = 1;
272-
while (e > 0) {
273-
if (e % 2 > 0) {
274-
r = mulmod(r, b, m);
275-
}
276-
b = mulmod(b, b, m);
277-
e >>= 1;
278-
}
279-
return r;
280-
}
281-
282311
// Helpers
283312
function _asRounding(uint8 r) private pure returns (Math.Rounding) {
284313
vm.assume(r < uint8(type(Math.Rounding).max));
285314
return Math.Rounding(r);
286315
}
287316

288-
function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
317+
function _mulKaratsuba(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
289318
(uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128);
290319
(uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128);
291320

@@ -305,10 +334,16 @@ contract MathTest is Test {
305334
}
306335
}
307336

308-
function _addCarry(uint256 x, uint256 y) private pure returns (uint256 res, uint256 carry) {
309-
unchecked {
310-
res = x + y;
337+
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
338+
if (m == 1) return 0;
339+
uint256 r = 1;
340+
while (e > 0) {
341+
if (e % 2 > 0) {
342+
r = mulmod(r, b, m);
343+
}
344+
b = mulmod(b, b, m);
345+
e >>= 1;
311346
}
312-
carry = res < x ? 1 : 0;
347+
return r;
313348
}
314349
}

0 commit comments

Comments
 (0)