@@ -17,6 +17,34 @@ library Math {
17
17
Expand // Away from zero
18
18
}
19
19
20
+ /**
21
+ * @dev Return the 512-bit addition of two uint256.
22
+ *
23
+ * The result is stored in two 256 variables such that sum = high * 2²⁵⁶ + low.
24
+ */
25
+ function add512 (uint256 a , uint256 b ) internal pure returns (uint256 high , uint256 low ) {
26
+ assembly ("memory-safe" ) {
27
+ low := add (a, b)
28
+ high := lt (low, a)
29
+ }
30
+ }
31
+
32
+ /**
33
+ * @dev Return the 512-bit multiplication of two uint256.
34
+ *
35
+ * The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low.
36
+ */
37
+ function mul512 (uint256 a , uint256 b ) internal pure returns (uint256 high , uint256 low ) {
38
+ // 512-bit multiply [high low] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
39
+ // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
40
+ // variables such that product = high * 2²⁵⁶ + low.
41
+ assembly ("memory-safe" ) {
42
+ let mm := mulmod (a, b, not (0 ))
43
+ low := mul (a, b)
44
+ high := sub (sub (mm, low), lt (mm, low))
45
+ }
46
+ }
47
+
20
48
/**
21
49
* @dev Returns the addition of two unsigned integers, with an success flag (no overflow).
22
50
*/
@@ -143,42 +171,34 @@ library Math {
143
171
*/
144
172
function mulDiv (uint256 x , uint256 y , uint256 denominator ) internal pure returns (uint256 result ) {
145
173
unchecked {
146
- // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
147
- // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
148
- // variables such that product = prod1 * 2²⁵⁶ + prod0.
149
- uint256 prod0 = x * y; // Least significant 256 bits of the product
150
- uint256 prod1; // Most significant 256 bits of the product
151
- assembly {
152
- let mm := mulmod (x, y, not (0 ))
153
- prod1 := sub (sub (mm, prod0), lt (mm, prod0))
154
- }
174
+ (uint256 high , uint256 low ) = mul512 (x, y);
155
175
156
176
// Handle non-overflow cases, 256 by 256 division.
157
- if (prod1 == 0 ) {
177
+ if (high == 0 ) {
158
178
// Solidity will revert if denominator == 0, unlike the div opcode on its own.
159
179
// The surrounding unchecked block does not change this fact.
160
180
// See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic.
161
- return prod0 / denominator;
181
+ return low / denominator;
162
182
}
163
183
164
184
// Make sure the result is less than 2²⁵⁶. Also prevents denominator == 0.
165
- if (denominator <= prod1 ) {
185
+ if (denominator <= high ) {
166
186
Panic.panic (ternary (denominator == 0 , Panic.DIVISION_BY_ZERO, Panic.UNDER_OVERFLOW));
167
187
}
168
188
169
189
///////////////////////////////////////////////
170
190
// 512 by 256 division.
171
191
///////////////////////////////////////////////
172
192
173
- // Make division exact by subtracting the remainder from [prod1 prod0 ].
193
+ // Make division exact by subtracting the remainder from [high low ].
174
194
uint256 remainder;
175
195
assembly {
176
196
// Compute remainder using mulmod.
177
197
remainder := mulmod (x, y, denominator)
178
198
179
199
// Subtract 256 bit number from 512 bit number.
180
- prod1 := sub (prod1 , gt (remainder, prod0 ))
181
- prod0 := sub (prod0 , remainder)
200
+ high := sub (high , gt (remainder, low ))
201
+ low := sub (low , remainder)
182
202
}
183
203
184
204
// Factor powers of two out of denominator and compute largest power of two divisor of denominator.
@@ -189,15 +209,15 @@ library Math {
189
209
// Divide denominator by twos.
190
210
denominator := div (denominator, twos)
191
211
192
- // Divide [prod1 prod0 ] by twos.
193
- prod0 := div (prod0 , twos)
212
+ // Divide [high low ] by twos.
213
+ low := div (low , twos)
194
214
195
215
// Flip twos such that it is 2²⁵⁶ / twos. If twos is zero, then it becomes one.
196
216
twos := add (div (sub (0 , twos), twos), 1 )
197
217
}
198
218
199
- // Shift in bits from prod1 into prod0 .
200
- prod0 |= prod1 * twos;
219
+ // Shift in bits from high into low .
220
+ low |= high * twos;
201
221
202
222
// Invert denominator mod 2²⁵⁶. Now that denominator is an odd number, it has an inverse modulo 2²⁵⁶ such
203
223
// that denominator * inv ≡ 1 mod 2²⁵⁶. Compute the inverse by starting with a seed that is correct for
@@ -215,9 +235,9 @@ library Math {
215
235
216
236
// Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
217
237
// This will give us the correct result modulo 2²⁵⁶. Since the preconditions guarantee that the outcome is
218
- // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and prod1
238
+ // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and high
219
239
// is no longer required.
220
- result = prod0 * inverse;
240
+ result = low * inverse;
221
241
return result;
222
242
}
223
243
}
@@ -229,6 +249,26 @@ library Math {
229
249
return mulDiv (x, y, denominator) + SafeCast.toUint (unsignedRoundsUp (rounding) && mulmod (x, y, denominator) > 0 );
230
250
}
231
251
252
+ /**
253
+ * @dev Calculates floor(x * y >> n) with full precision. Throws if result overflows a uint256.
254
+ */
255
+ function mulShr (uint256 x , uint256 y , uint8 n ) internal pure returns (uint256 result ) {
256
+ unchecked {
257
+ (uint256 high , uint256 low ) = mul512 (x, y);
258
+ if (high >= 1 << n) {
259
+ Panic.panic (Panic.UNDER_OVERFLOW);
260
+ }
261
+ return (high << (256 - n)) | (low >> n);
262
+ }
263
+ }
264
+
265
+ /**
266
+ * @dev Calculates x * y >> n with full precision, following the selected rounding direction.
267
+ */
268
+ function mulShr (uint256 x , uint256 y , uint8 n , Rounding rounding ) internal pure returns (uint256 ) {
269
+ return mulShr (x, y, n) + SafeCast.toUint (unsignedRoundsUp (rounding) && mulmod (x, y, 1 << n) > 0 );
270
+ }
271
+
232
272
/**
233
273
* @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
234
274
*
0 commit comments