Skip to content

8355719: Reduce memory consumption of BigInteger.pow() #24690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
98a5b53
Add nthRoot(int) methods and optimize pow(int)
fabioromano1 Apr 16, 2025
926970e
Merge branch 'openjdk:master' into BigInteger-nth-root
fabioromano1 Apr 16, 2025
a3f1489
Remove trailing whitespaces
fabioromano1 Apr 16, 2025
d280c37
Correct loop recurrence according to proof of convergence
fabioromano1 Apr 16, 2025
91c3d1a
Correct initial estimate of nth root for BigIntegers
fabioromano1 Apr 17, 2025
1cfb775
Removed trailing whitespace
fabioromano1 Apr 17, 2025
c25bd32
Avoid an overflow in computing nth root estimate
fabioromano1 Apr 17, 2025
9c32ce4
Merge branch 'BigInteger-nth-root' of https://github.com/fabioromano1…
fabioromano1 Apr 17, 2025
7ff919b
optimize division in loop iteration of nth root
fabioromano1 Apr 17, 2025
1f5a9b4
An optimization
fabioromano1 Apr 17, 2025
b6c3320
Format code
fabioromano1 Apr 17, 2025
5d971fa
Correct left shift if shift is zero
fabioromano1 Apr 18, 2025
e459c23
Memory usage optimization
fabioromano1 Apr 18, 2025
48650de
An optimization
fabioromano1 Apr 18, 2025
54ec8f8
BigIntegers nth root's initial estimate optimization
fabioromano1 Apr 18, 2025
3ea5190
Extend use cases of MutableBigInteger.valueOf(double)
fabioromano1 Apr 19, 2025
6c9b364
An optimization
fabioromano1 Apr 19, 2025
3ca7d29
An optimization
fabioromano1 Apr 19, 2025
4516d88
Format code
fabioromano1 Apr 19, 2025
b427091
Format code
fabioromano1 Apr 19, 2025
524f195
Merge branch 'openjdk:master' into BigInteger-nth-root
fabioromano1 Apr 19, 2025
21fbf27
Code simplification
fabioromano1 Apr 20, 2025
e11b32f
Added reference for proof of convergence in the comment
fabioromano1 Apr 21, 2025
b527fa2
Revert format code changes
fabioromano1 Apr 21, 2025
8de6b82
Merge branch 'BigInteger-nth-root' of https://github.com/fabioromano1…
fabioromano1 Apr 21, 2025
00365c9
Merge branch 'BigInteger-nth-root' of https://github.com/fabioromano1…
fabioromano1 Apr 21, 2025
100d0e1
Merge remote-tracking branch 'origin/BigInteger-nth-root' into BigInt…
fabioromano1 Apr 22, 2025
1942fd1
Suggested change
fabioromano1 Apr 22, 2025
0e1a99e
Delete useless folder
fabioromano1 Apr 22, 2025
9cd136c
Optimized computation of nth root's remainder
fabioromano1 Apr 22, 2025
c3bd1b2
Format code
fabioromano1 Apr 22, 2025
f0d0605
An optimization
fabioromano1 Apr 22, 2025
f20d19b
Optimized BigInteger.pow(int) for single-word values
fabioromano1 Apr 24, 2025
8676af7
Optimized repeated squaring trick using cache for powers
fabioromano1 Apr 24, 2025
3cf820b
Some optimizations
fabioromano1 Apr 24, 2025
23914e8
Systematization of special cases in BigInteger.pow(int)
fabioromano1 Apr 24, 2025
bf099e4
Optimized nth root iteration loop
fabioromano1 Apr 26, 2025
f9bfd22
Merge branch 'openjdk:master' into BigInteger-nth-root
fabioromano1 Apr 26, 2025
7ceae87
Moved nth-root implementation to a dependent PR
fabioromano1 Apr 26, 2025
cb61ddc
Removed method used by nth-root
fabioromano1 Apr 26, 2025
10e122e
Put power's computation in a stand-alone method
fabioromano1 Apr 26, 2025
b94ca7e
Optimized BigInteger.pow(int) to support unsigned long bases
fabioromano1 Apr 26, 2025
139735d
Use BigInteger(long, int) constructor
fabioromano1 Apr 26, 2025
fcd5d55
Throw away unsignedIntPow(int, int)
fabioromano1 Apr 26, 2025
b831d01
Pre-cache the powers of x up to x^3 to simplify the code
fabioromano1 Apr 28, 2025
51272bc
Added tests for memory consumption
fabioromano1 Apr 28, 2025
70da95c
Decrease exponents in tests
fabioromano1 Apr 29, 2025
d85a634
Update test parameters
fabioromano1 Apr 29, 2025
9a5d696
Removed needless condition
fabioromano1 Apr 29, 2025
6033d25
Don't exclude a priori valid results
fabioromano1 Apr 29, 2025
5deb21a
Take into account special case exponent == 1
fabioromano1 Apr 29, 2025
280859b
Use a more loose formula to do range check
fabioromano1 Apr 29, 2025
925806b
Adjust the type of operand
fabioromano1 Apr 29, 2025
ad49b56
Use a more accurate formula to detect certain overflows
fabioromano1 Apr 29, 2025
6895926
Simplified the formula for detecting overflows
fabioromano1 Apr 29, 2025
b8ca4fe
Simplify long power computing
fabioromano1 Apr 30, 2025
4103e49
Suggested changes
fabioromano1 May 7, 2025
5ebc16b
Removed needless brackets
fabioromano1 May 7, 2025
e0816d5
Code simplification
fabioromano1 May 7, 2025
009937b
Suggested changes
fabioromano1 May 7, 2025
2e08f77
Suggested changes
fabioromano1 May 8, 2025
261dd31
Code simplification
fabioromano1 May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 148 additions & 39 deletions src/java.base/share/classes/java/math/BigInteger.java
Original file line number Diff line number Diff line change
Expand Up @@ -2593,14 +2593,14 @@ public BigInteger pow(int exponent) {
return (exponent == 0 ? ONE : this);
}

BigInteger partToSquare = this.abs();
BigInteger base = this.abs();

// Factor out powers of two from the base, as the exponentiation of
// these can be done by left shifts only.
// The remaining part can then be exponentiated faster. The
// powers of two will be multiplied back at the end.
int powersOfTwo = partToSquare.getLowestSetBit();
long bitsToShiftLong = (long)powersOfTwo * exponent;
int powersOfTwo = base.getLowestSetBit();
long bitsToShiftLong = (long) powersOfTwo * exponent;
if (bitsToShiftLong > Integer.MAX_VALUE) {
reportOverflow();
}
Expand All @@ -2610,19 +2610,19 @@ public BigInteger pow(int exponent) {

// Factor the powers of two out quickly by shifting right, if needed.
if (powersOfTwo > 0) {
partToSquare = partToSquare.shiftRight(powersOfTwo);
remainingBits = partToSquare.bitLength();
base = base.shiftRight(powersOfTwo);
remainingBits = base.bitLength();
if (remainingBits == 1) { // Nothing left but +/- 1?
if (signum < 0 && (exponent&1) == 1) {
if (signum < 0 && (exponent & 1) == 1) {
return NEGATIVE_ONE.shiftLeft(bitsToShift);
} else {
return ONE.shiftLeft(bitsToShift);
}
}
} else {
remainingBits = partToSquare.bitLength();
remainingBits = base.bitLength();
if (remainingBits == 1) { // Nothing left but +/- 1?
if (signum < 0 && (exponent&1) == 1) {
if (signum < 0 && (exponent & 1) == 1) {
return NEGATIVE_ONE;
} else {
return ONE;
Expand All @@ -2633,74 +2633,109 @@ public BigInteger pow(int exponent) {
// This is a quick way to approximate the size of the result,
// similar to doing log2[n] * exponent. This will give an upper bound
// of how big the result can be, and which algorithm to use.
long scaleFactor = (long)remainingBits * exponent;
long scaleFactor = (long) remainingBits * exponent;

// Use slightly different algorithms for small and large operands.
// See if the result will safely fit into a long. (Largest 2^63-1)
if (partToSquare.mag.length == 1 && scaleFactor <= 62) {
if (base.mag.length == 1 && scaleFactor <= 62) {
// Small number algorithm. Everything fits into a long.
int newSign = (signum <0 && (exponent&1) == 1 ? -1 : 1);
long result = 1;
long baseToPow2 = partToSquare.mag[0] & LONG_MASK;

int workingExponent = exponent;

// Perform exponentiation using repeated squaring trick
while (workingExponent != 0) {
if ((workingExponent & 1) == 1) {
result = result * baseToPow2;
}

if ((workingExponent >>>= 1) != 0) {
baseToPow2 = baseToPow2 * baseToPow2;
}
}
int newSign = (signum < 0 && (exponent & 1) == 1 ? -1 : 1);
long result = unsignedLongPow(base.mag[0] & LONG_MASK, exponent);

// Multiply back the powers of two (quickly, by shifting left)
if (powersOfTwo > 0) {
if (bitsToShift + scaleFactor <= 62) { // Fits in long?
return valueOf((result << bitsToShift) * newSign);
} else {
return valueOf(result*newSign).shiftLeft(bitsToShift);
return valueOf(result * newSign).shiftLeft(bitsToShift);
}
} else {
return valueOf(result*newSign);
return valueOf(result * newSign);
}
} else {
if ((long)bitLength() * exponent / Integer.SIZE > MAX_MAG_LENGTH) {
if ((long) bitLength() * exponent / Integer.SIZE > MAX_MAG_LENGTH) {
reportOverflow();
}

// Large number algorithm. This is basically identical to
// the algorithm above, but calls multiply() and square()
// the algorithm above, but calls multiply()
// which may use more efficient algorithms for large numbers.
BigInteger answer = ONE;

int workingExponent = exponent;
final int expZeros = Integer.numberOfLeadingZeros(exponent);
int workingExp = exponent << expZeros;
// Perform exponentiation using repeated squaring trick
while (workingExponent != 0) {
if ((workingExponent & 1) == 1) {
answer = answer.multiply(partToSquare);
}
for (int expLen = Integer.SIZE - expZeros; expLen > 0; expLen--) {
answer = answer.multiply(answer);
if (workingExp < 0) // leading bit is set
answer = answer.multiply(base);

if ((workingExponent >>>= 1) != 0) {
partToSquare = partToSquare.square();
}
workingExp <<= 1;
}

// Multiply back the (exponentiated) powers of two (quickly,
// by shifting left)
if (powersOfTwo > 0) {
answer = answer.shiftLeft(bitsToShift);
}

if (signum < 0 && (exponent&1) == 1) {
if (signum < 0 && (exponent & 1) == 1) {
return answer.negate();
} else {
return answer;
}
}
}

/**
* Computes {@code x^n} using repeated squaring trick.
* Assumes {@code x != 0 && x^n < 2^Long.SIZE}.
*/
static long unsignedLongPow(long x, int n) {
// Double.PRECISION / bitLength(x) is the largest integer e
// such that x^e fits into a double. If e <= 3, we won't use fp arithmetic.
// This allows to use fp arithmetic where possible.
final int maxExp = Math.max(3, Double.PRECISION / bitLengthForLong(x));
final int maxExpLen = bitLengthForInt(maxExp);

final int nZeros = Integer.numberOfLeadingZeros(n);
n <<= nZeros;

long pow = 1L;
int blockLen;
for (int nLen = Integer.SIZE - nZeros; nLen > 0; nLen -= blockLen) {
blockLen = maxExpLen < nLen ? maxExpLen : nLen;
// compute pow^(2^blockLen)
if (pow != 1L) {
for (int i = 0; i < blockLen; i++)
pow *= pow;
}

// add exp to power's exponent
int exp = n >>> -blockLen;
if (exp > 0) {
// adjust exp to fit x^expAdj into a double
int expAdj = exp <= maxExp ? exp : exp >>> 1;

// don't use fp arithmetic if expAdj <= 3
long xToExp = expAdj == 1 ? x :
(expAdj == 2 ? x*x :
(expAdj == 3 ? x*x*x : (long) Math.pow(x, expAdj)));

// append exp's rightmost bit to expAdj
if (expAdj != exp) {
xToExp *= xToExp;
if ((exp & 1) == 1)
xToExp *= x;
}
pow *= xToExp;
}
n <<= blockLen; // shift to next block of bits
}

return pow;
}

/**
* Returns the integer square root of this BigInteger. The integer square
* root of the corresponding mathematical integer {@code n} is the largest
Expand Down Expand Up @@ -2750,6 +2785,80 @@ public BigInteger[] sqrtAndRemainder() {
return new BigInteger[] { sqrtRem[0].toBigInteger(), sqrtRem[1].toBigInteger() };
}

/**
* Returns the integer {@code n}th root of this BigInteger. The integer
* {@code n}th root of the corresponding mathematical integer {@code x} has the
* same sign of {@code x}, and its magnitude is the largest integer {@code r}
* such that {@code r**n <= abs(x)}. It is equal to the value of
* {@code (x.signum() * floor(abs(nthRoot(x, n))))}, where {@code nthRoot(x, n)}
* denotes the real {@code n}th root of {@code x} treated as a real. If {@code n}
* is even and this BigInteger is negative, an {@code ArithmeticException} will be
* thrown.
*
* <p>Note that the magnitude of the integer {@code n}th root will be less than
* the magnitude of the real {@code n}th root if the latter is not representable
* as an integral value.
*
* @param n the root degree
* @return the integer {@code n}th root of {@code this}
* @throws ArithmeticException if {@code n == 0} (Zeroth roots are not
* defined.)
* @throws ArithmeticException if {@code n} is negative. (This would cause the
* operation to yield a non-integer value.)
* @throws ArithmeticException if {@code n} is even and {@code this} is
* negative. (This would cause the operation to
* yield non-real roots.)
* @see #sqrt()
* @since 25
*/
public BigInteger nthRoot(int n) {
if (n == 1)
return this;

if (n == 2)
return sqrt();

if (n <= 0)
throw new ArithmeticException("Non-positive root degree");

if ((n & 1) == 0 && this.signum < 0)
throw new ArithmeticException("Negative radicand with even root degree");

return new MutableBigInteger(this.mag).nthRoot(n).toBigInteger(signum);
}

/**
* Returns an array of two BigIntegers containing the integer {@code n}th root
* {@code r} of {@code this} and its remainder {@code this - r^n},
* respectively.
*
* @param n the root degree
* @return an array of two BigIntegers with the integer {@code n}th root at
* offset 0 and the remainder at offset 1
* @throws ArithmeticException if {@code n == 0} (Zeroth roots are not
* defined.)
* @throws ArithmeticException if {@code n} is negative. (This would cause the
* operation to yield a non-integer value.)
* @throws ArithmeticException if {@code n} is even and {@code this} is
* negative. (This would cause the operation to
* yield non-real roots.)
* @see #sqrt()
* @see #sqrtAndRemainder()
* @see #nthRoot(int)
* @since 25
*/
public BigInteger[] nthRootAndRemainder(int n) {
if (n == 1)
return new BigInteger[] { this, ZERO };

if (n == 2)
return sqrtAndRemainder();

BigInteger root = nthRoot(n), rem = this.subtract(root.pow(n));
assert rem.signum == 0 || rem.signum == this.signum;
return new BigInteger[] { root, rem };
}

/**
* Returns a BigInteger whose value is the greatest common divisor of
* {@code abs(this)} and {@code abs(val)}. Returns 0 if
Expand Down
114 changes: 114 additions & 0 deletions src/java.base/share/classes/java/math/MutableBigInteger.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,27 @@ private void init(int val) {
value = Arrays.copyOfRange(val.value, val.offset, val.offset + intLen);
}

/**
* Returns a MutableBigInteger with a magnitude specified by
* the absolute value of the double val. Any fractional part is discarded.
*
* Assume val is in the finite double range.
*/
static MutableBigInteger valueOf(double val) {
val = Math.abs(val);
if (val < 0x1p63)
return new MutableBigInteger((long) val);
// Translate the double into exponent and significand, according
// to the formulae in JLS, Section 20.10.22.
long valBits = Double.doubleToRawLongBits(val);
int exponent = (int) ((valBits >> 52) & 0x7ffL) - 1075;
long significand = (valBits & ((1L << 52) - 1)) | (1L << 52);
// At this point, val == significand * 2^exponent, with exponent > 0
MutableBigInteger result = new MutableBigInteger(significand);
result.leftShift(exponent);
return result;
}

/**
* Makes this number an {@code n}-int number all of whose bits are ones.
* Used by Burnikel-Ziegler division.
Expand Down Expand Up @@ -1892,6 +1913,99 @@ private boolean unsignedLongCompare(long one, long two) {
return (one+Long.MIN_VALUE) > (two+Long.MIN_VALUE);
}

/**
* Calculate the integer {@code n}th root {@code floor(nthRoot(this, n))} where
* {@code nthRoot(., n)} denotes the mathematical {@code n}th root. The contents of
* {@code this} are <b>not</b> changed. The value of {@code this} is assumed
* to be non-negative and the root degree {@code n >= 3}.
*
* @implNote The implementation is based on the material in Henry S. Warren,
* Jr., <i>Hacker's Delight (2nd ed.)</i> (Addison Wesley, 2013), 279-282.
*
* @return the integer {@code n}th of {@code this}
*/
MutableBigInteger nthRoot(int n) {
// Special cases.
if (this.isZero() || this.isOne())
return this;

final int bitLength = (int) this.bitLength();
// if this < 2^n, result is unity
if (bitLength <= n)
return new MutableBigInteger(1);

MutableBigInteger r;
if (bitLength <= Long.SIZE) {
// Initial estimate is the root of the unsigned long value.
final long x = this.toLong();
// Use fp arithmetic to get an upper bound of the root
final double base = Math.nextUp(x >= 0 ? x : x + 0x1p64);
final double exp = Math.nextUp(1.0 / n);
long rLong = (long) Math.ceil(Math.nextUp(Math.pow(base, exp)));

if (BigInteger.bitLengthForLong(rLong) * n <= Long.SIZE) {
// Refine the estimate.
do {
long rToN1 = BigInteger.unsignedLongPow(rLong, n - 1);
long rToN = rToN1 * rLong;
if (Long.compareUnsigned(rToN, x) <= 0)
return new MutableBigInteger(rLong);

// compute rLong - ceil((rToN - x) / (n * rToN1))
long dividend = rToN - x, divisor = n * rToN1;
if (Long.remainderUnsigned(dividend, divisor) != 0)
rLong--;

rLong -= Long.divideUnsigned(dividend, divisor);
} while (true);
} else { // r^n could overflow long range, use MutableBigInteger loop instead
r = new MutableBigInteger(rLong);
}
} else {
// Set up the initial estimate of the iteration.
// Determine a right shift that is a multiple of n into finite double range.
long shift = Math.max(0, bitLength - Double.MAX_EXPONENT); // use long to avoid overflow later
int shiftExcess = (int) (shift % n);

// Shift the value into finite double range
double base = this.toBigInteger().shiftRight((int) shift).doubleValue();
// Complete the shift to a multiple of n,
// avoiding to lose more bits than necessary.
if (shiftExcess != 0) {
int shiftLack = n - shiftExcess;
shift += shiftLack; // shift is long, no overflow
base /= Double.valueOf("0x1p" + shiftLack);
}

// Use the root of the shifted value as an estimate.
base = Math.nextUp(base);
final double exp = Math.nextUp(1.0 / n);
r = valueOf(Math.ceil(Math.nextUp(Math.pow(base, exp))));

// Shift the approximate root back into the original range.
r.safeLeftShift((int) (shift / n));
}

// Refine the estimate.
do {
BigInteger rBig = r.toBigInteger();
BigInteger rToN1 = rBig.pow(n - 1);
MutableBigInteger rToN = new MutableBigInteger(rToN1.multiply(rBig).mag);
if (rToN.subtract(this) <= 0)
return r;

// compute r - ceil((r^n - this) / (n * r^(n - 1)))
MutableBigInteger q1 = new MutableBigInteger();
MutableBigInteger delta = new MutableBigInteger();
// Don't use conditional-or to ensure to do both divisions
if (rToN.divideOneWord(n, q1) != 0
| !q1.divide(new MutableBigInteger(rToN1.mag), delta).isZero())
r.subtract(ONE);

r.subtract(delta);
} while (true);
}

/**
* Calculate the integer square root {@code floor(sqrt(this))} and the remainder
* if needed, where {@code sqrt(.)} denotes the mathematical square root.
Expand Down