Skip to content

Commit a9b1f58

Browse files
authored
Add saturating (unsigned) math operations and optimize try operations (#5527)
1 parent 506e1f8 commit a9b1f58

File tree

3 files changed

+109
-16
lines changed

3 files changed

+109
-16
lines changed

.changeset/fair-pumpkins-compete.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': minor
3+
---
4+
5+
`Math`: Add saturating arithmetic operations `saturatingAdd`, `saturatingSub` and `saturatingMul`.

contracts/utils/math/Math.sol

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ library Math {
5151
function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
5252
unchecked {
5353
uint256 c = a + b;
54-
if (c < a) return (false, 0);
55-
return (true, c);
54+
success = c >= a;
55+
result = c * SafeCast.toUint(success);
5656
}
5757
}
5858

@@ -61,8 +61,9 @@ library Math {
6161
*/
6262
function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
6363
unchecked {
64-
if (b > a) return (false, 0);
65-
return (true, a - b);
64+
uint256 c = a - b;
65+
success = c <= a;
66+
result = c * SafeCast.toUint(success);
6667
}
6768
}
6869

@@ -71,13 +72,14 @@ library Math {
7172
*/
7273
function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
7374
unchecked {
74-
// Gas optimization: this is cheaper than requiring 'a' not being zero, but the
75-
// benefit is lost if 'b' is also tested.
76-
// See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522
77-
if (a == 0) return (true, 0);
7875
uint256 c = a * b;
79-
if (c / a != b) return (false, 0);
80-
return (true, c);
76+
assembly ("memory-safe") {
77+
// Only true when the multiplication doesn't overflow
78+
// (c / a == b) || (a == 0)
79+
success := or(eq(div(c, a), b), iszero(a))
80+
}
81+
// equivalent to: success ? c : 0
82+
result = c * SafeCast.toUint(success);
8183
}
8284
}
8385

@@ -86,8 +88,11 @@ library Math {
8688
*/
8789
function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
8890
unchecked {
89-
if (b == 0) return (false, 0);
90-
return (true, a / b);
91+
success = b > 0;
92+
assembly ("memory-safe") {
93+
// The `DIV` opcode returns zero when the denominator is 0.
94+
result := div(a, b)
95+
}
9196
}
9297
}
9398

@@ -96,11 +101,38 @@ library Math {
96101
*/
97102
function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
98103
unchecked {
99-
if (b == 0) return (false, 0);
100-
return (true, a % b);
104+
success = b > 0;
105+
assembly ("memory-safe") {
106+
// The `MOD` opcode returns zero when the denominator is 0.
107+
result := mod(a, b)
108+
}
101109
}
102110
}
103111

112+
/**
113+
* @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing.
114+
*/
115+
function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) {
116+
(bool success, uint256 result) = tryAdd(a, b);
117+
return ternary(success, result, type(uint256).max);
118+
}
119+
120+
/**
121+
* @dev Unsigned saturating subtraction, bounds to zero instead of overflowing.
122+
*/
123+
function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) {
124+
(, uint256 result) = trySub(a, b);
125+
return result;
126+
}
127+
128+
/**
129+
* @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing.
130+
*/
131+
function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) {
132+
(bool success, uint256 result) = tryMul(a, b);
133+
return ternary(success, result, type(uint256).max);
134+
}
135+
104136
/**
105137
* @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
106138
*
@@ -192,7 +224,7 @@ library Math {
192224

193225
// Make division exact by subtracting the remainder from [high low].
194226
uint256 remainder;
195-
assembly {
227+
assembly ("memory-safe") {
196228
// Compute remainder using mulmod.
197229
remainder := mulmod(x, y, denominator)
198230

@@ -205,7 +237,7 @@ library Math {
205237
// Always >= 1. See https://cs.stackexchange.com/q/138556/92363.
206238

207239
uint256 twos = denominator & (0 - denominator);
208-
assembly {
240+
assembly ("memory-safe") {
209241
// Divide denominator by twos.
210242
denominator := div(denominator, twos)
211243

test/utils/math/Math.test.js

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,62 @@ describe('Math', function () {
168168
});
169169
});
170170

171+
describe('saturatingAdd', function () {
172+
it('adds correctly', async function () {
173+
const a = 5678n;
174+
const b = 1234n;
175+
await testCommutative(this.mock.$saturatingAdd, a, b, a + b);
176+
await testCommutative(this.mock.$saturatingAdd, a, 0n, a);
177+
await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 0n, ethers.MaxUint256);
178+
});
179+
180+
it('bounds on addition overflow', async function () {
181+
await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 1n, ethers.MaxUint256);
182+
await expect(this.mock.$saturatingAdd(ethers.MaxUint256, ethers.MaxUint256)).to.eventually.equal(
183+
ethers.MaxUint256,
184+
);
185+
});
186+
});
187+
188+
describe('saturatingSub', function () {
189+
it('subtracts correctly', async function () {
190+
const a = 5678n;
191+
const b = 1234n;
192+
await expect(this.mock.$saturatingSub(a, b)).to.eventually.equal(a - b);
193+
await expect(this.mock.$saturatingSub(a, a)).to.eventually.equal(0n);
194+
await expect(this.mock.$saturatingSub(a, 0n)).to.eventually.equal(a);
195+
await expect(this.mock.$saturatingSub(0n, a)).to.eventually.equal(0n);
196+
await expect(this.mock.$saturatingSub(ethers.MaxUint256, 1n)).to.eventually.equal(ethers.MaxUint256 - 1n);
197+
});
198+
199+
it('bounds on subtraction overflow', async function () {
200+
await expect(this.mock.$saturatingSub(0n, 1n)).to.eventually.equal(0n);
201+
await expect(this.mock.$saturatingSub(1n, 2n)).to.eventually.equal(0n);
202+
await expect(this.mock.$saturatingSub(1n, ethers.MaxUint256)).to.eventually.equal(0n);
203+
await expect(this.mock.$saturatingSub(ethers.MaxUint256 - 1n, ethers.MaxUint256)).to.eventually.equal(0n);
204+
});
205+
});
206+
207+
describe('saturatingMul', function () {
208+
it('multiplies correctly', async function () {
209+
const a = 1234n;
210+
const b = 5678n;
211+
await testCommutative(this.mock.$saturatingMul, a, b, a * b);
212+
});
213+
214+
it('multiplies by zero correctly', async function () {
215+
const a = 0n;
216+
const b = 5678n;
217+
await testCommutative(this.mock.$saturatingMul, a, b, 0n);
218+
});
219+
220+
it('bounds on multiplication overflow', async function () {
221+
const a = ethers.MaxUint256;
222+
const b = 2n;
223+
await testCommutative(this.mock.$saturatingMul, a, b, ethers.MaxUint256);
224+
});
225+
});
226+
171227
describe('max', function () {
172228
it('is correctly detected in both position', async function () {
173229
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));

0 commit comments

Comments
 (0)