diff --git a/src/java.base/share/classes/java/math/BigInteger.java b/src/java.base/share/classes/java/math/BigInteger.java index fb6d1eca3f562..b0a7f784aab09 100644 --- a/src/java.base/share/classes/java/math/BigInteger.java +++ b/src/java.base/share/classes/java/math/BigInteger.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 1996, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1996, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -1246,6 +1246,16 @@ else if (val < 0 && val >= -MAX_CONSTANT) return new BigInteger(val); } + /** + * Constructs a BigInteger with magnitude specified by the long, + * which may not be zero, and the signum specified by the int. + */ + private BigInteger(long mag, int signum) { + assert mag != 0 && signum != 0; + this.signum = signum; + this.mag = toMagArray(mag); + } + /** * Constructs a BigInteger with the specified value, which may not be zero. */ @@ -1256,16 +1266,14 @@ private BigInteger(long val) { } else { signum = 1; } + mag = toMagArray(val); + } - int highWord = (int)(val >>> 32); - if (highWord == 0) { - mag = new int[1]; - mag[0] = (int)val; - } else { - mag = new int[2]; - mag[0] = highWord; - mag[1] = (int)val; - } + private static int[] toMagArray(long mag) { + int highWord = (int) (mag >>> 32); + return highWord == 0 + ? new int[] { (int) mag } + : new int[] { highWord, (int) mag }; } /** @@ -2589,116 +2597,101 @@ public BigInteger pow(int exponent) { if (exponent < 0) { throw new ArithmeticException("Negative exponent"); } - if (signum == 0) { - return (exponent == 0 ? ONE : this); - } + if (exponent == 0 || this.equals(ONE)) + return ONE; + + if (signum == 0 || exponent == 1) + return this; - BigInteger partToSquare = this.abs(); + BigInteger base = this.abs(); + final boolean negative = signum < 0 && (exponent & 1) == 1; // Factor out powers of two from the base, as the exponentiation of // these can be done by left shifts only. // The remaining part can then be exponentiated faster. The // powers of two will be multiplied back at the end. - int powersOfTwo = partToSquare.getLowestSetBit(); - long bitsToShiftLong = (long)powersOfTwo * exponent; - if (bitsToShiftLong > Integer.MAX_VALUE) { + final int powersOfTwo = base.getLowestSetBit(); + final long bitsToShiftLong = (long) powersOfTwo * exponent; + final int bitsToShift = (int) bitsToShiftLong; + if (bitsToShift != bitsToShiftLong) { reportOverflow(); } - int bitsToShift = (int)bitsToShiftLong; - int remainingBits; - - // Factor the powers of two out quickly by shifting right, if needed. - if (powersOfTwo > 0) { - partToSquare = partToSquare.shiftRight(powersOfTwo); - remainingBits = partToSquare.bitLength(); - if (remainingBits == 1) { // Nothing left but +/- 1? - if (signum < 0 && (exponent&1) == 1) { - return NEGATIVE_ONE.shiftLeft(bitsToShift); - } else { - return ONE.shiftLeft(bitsToShift); - } - } - } else { - remainingBits = partToSquare.bitLength(); - if (remainingBits == 1) { // Nothing left but +/- 1? - if (signum < 0 && (exponent&1) == 1) { - return NEGATIVE_ONE; - } else { - return ONE; - } - } - } + // Factor the powers of two out quickly by shifting right. + base = base.shiftRight(powersOfTwo); + final int remainingBits = base.bitLength(); + if (remainingBits == 1) // Nothing left but +/- 1? + return (negative ? NEGATIVE_ONE : ONE).shiftLeft(bitsToShift); // This is a quick way to approximate the size of the result, // similar to doing log2[n] * exponent. This will give an upper bound // of how big the result can be, and which algorithm to use. - long scaleFactor = (long)remainingBits * exponent; + final long scaleFactor = (long) remainingBits * exponent; // Use slightly different algorithms for small and large operands. - // See if the result will safely fit into a long. (Largest 2^63-1) - if (partToSquare.mag.length == 1 && scaleFactor <= 62) { - // Small number algorithm. Everything fits into a long. - int newSign = (signum <0 && (exponent&1) == 1 ? -1 : 1); - long result = 1; - long baseToPow2 = partToSquare.mag[0] & LONG_MASK; - - int workingExponent = exponent; - - // Perform exponentiation using repeated squaring trick - while (workingExponent != 0) { - if ((workingExponent & 1) == 1) { - result = result * baseToPow2; - } - - if ((workingExponent >>>= 1) != 0) { - baseToPow2 = baseToPow2 * baseToPow2; - } - } + // See if the result will safely fit into an unsigned long. (Largest 2^64-1) + if (scaleFactor <= Long.SIZE) { + // Small number algorithm. Everything fits into an unsigned long. + final int newSign = negative ? -1 : 1; + final long result = unsignedLongPow(base.mag[0] & LONG_MASK, exponent); // Multiply back the powers of two (quickly, by shifting left) - if (powersOfTwo > 0) { - if (bitsToShift + scaleFactor <= 62) { // Fits in long? - return valueOf((result << bitsToShift) * newSign); - } else { - return valueOf(result*newSign).shiftLeft(bitsToShift); - } - } else { - return valueOf(result*newSign); - } - } else { - if ((long)bitLength() * exponent / Integer.SIZE > MAX_MAG_LENGTH) { - reportOverflow(); - } + return bitsToShift + scaleFactor <= Long.SIZE // Fits in long? + ? new BigInteger(result << bitsToShift, newSign) + : new BigInteger(result, newSign).shiftLeft(bitsToShift); + } - // Large number algorithm. This is basically identical to - // the algorithm above, but calls multiply() and square() - // which may use more efficient algorithms for large numbers. - BigInteger answer = ONE; + if ((bitLength() - 1L) * exponent >= Integer.MAX_VALUE) { + reportOverflow(); + } - int workingExponent = exponent; - // Perform exponentiation using repeated squaring trick - while (workingExponent != 0) { - if ((workingExponent & 1) == 1) { - answer = answer.multiply(partToSquare); - } + // Large number algorithm. This is basically identical to + // the algorithm above, but calls multiply() + // which may use more efficient algorithms for large numbers. + BigInteger answer = ONE; - if ((workingExponent >>>= 1) != 0) { - partToSquare = partToSquare.square(); - } - } - // Multiply back the (exponentiated) powers of two (quickly, - // by shifting left) - if (powersOfTwo > 0) { - answer = answer.shiftLeft(bitsToShift); - } + final int expZeros = Integer.numberOfLeadingZeros(exponent); + int workingExp = exponent << expZeros; + // Perform exponentiation using repeated squaring trick + // The loop relies on this invariant: + // base^exponent == answer^(2^expLen) * base^(workingExp >>> (32-expLen)) + for (int expLen = Integer.SIZE - expZeros; expLen > 0; expLen--) { + answer = answer.multiply(answer); + if (workingExp < 0) // leading bit is set + answer = answer.multiply(base); - if (signum < 0 && (exponent&1) == 1) { - return answer.negate(); - } else { - return answer; - } + workingExp <<= 1; + } + + // Multiply back the (exponentiated) powers of two (quickly, + // by shifting left) + answer = answer.shiftLeft(bitsToShift); + return negative ? answer.negate() : answer; + } + + /** + * Computes {@code x^n} using repeated squaring trick. + * Assumes {@code x != 0 && x^n < 2^Long.SIZE}. + */ + static long unsignedLongPow(long x, int n) { + if (x == 1L || n == 0) + return 1L; + + if (x == 2L) + return 1L << n; + + /* + * The method assumption means that n <= 40 here. + * Thus, the loop body executes at most 5 times. + */ + long pow = 1L; + for (; n != 1; n >>>= 1) { + if ((n & 1) != 0) + pow *= x; + + x *= x; } + return pow * x; } /** diff --git a/test/micro/org/openjdk/bench/java/math/BigIntegerPow.java b/test/micro/org/openjdk/bench/java/math/BigIntegerPow.java new file mode 100644 index 0000000000000..007d9bb975a04 --- /dev/null +++ b/test/micro/org/openjdk/bench/java/math/BigIntegerPow.java @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.openjdk.bench.java.math; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.profile.GCProfiler; + +import java.math.BigInteger; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Thread) +@Warmup(iterations = 1, time = 1) +@Measurement(iterations = 1, time = 1) +@Fork(value = 3) +public class BigIntegerPow { + + private static final int TESTSIZE = 1; + + private int xsExp = (1 << 20) - 1; + /* Each array entry is atmost 64 bits in size */ + private BigInteger[] xsArray = new BigInteger[TESTSIZE]; + + private int sExp = (1 << 18) - 1; + /* Each array entry is atmost 256 bits in size */ + private BigInteger[] sArray = new BigInteger[TESTSIZE]; + + private int mExp = (1 << 16) - 1; + /* Each array entry is atmost 1024 bits in size */ + private BigInteger[] mArray = new BigInteger[TESTSIZE]; + + private int lExp = (1 << 14) - 1; + /* Each array entry is atmost 4096 bits in size */ + private BigInteger[] lArray = new BigInteger[TESTSIZE]; + + private int xlExp = (1 << 12) - 1; + /* Each array entry is atmost 16384 bits in size */ + private BigInteger[] xlArray = new BigInteger[TESTSIZE]; + + private int[] randomExps; + + /* + * You can run this test via the command line: + * $ make test TEST="micro:java.math.BigIntegerPow" MICRO="OPTIONS=-prof gc" + */ + + @Setup + public void setup() { + Random r = new Random(1123); + + randomExps = new int[TESTSIZE]; + for (int i = 0; i < TESTSIZE; i++) { + xsArray[i] = new BigInteger(64, r); + sArray[i] = new BigInteger(256, r); + mArray[i] = new BigInteger(1024, r); + lArray[i] = new BigInteger(4096, r); + xlArray[i] = new BigInteger(16384, r); + randomExps[i] = r.nextInt(1 << 12); + } + } + + /** Test BigInteger.pow() with numbers long at most 64 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowXS(Blackhole bh) { + for (BigInteger xs : xsArray) { + bh.consume(xs.pow(xsExp)); + } + } + + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowXSRandomExps(Blackhole bh) { + int i = 0; + for (BigInteger xs : xsArray) { + bh.consume(xs.pow(randomExps[i++])); + } + } + + /** Test BigInteger.pow() with numbers long at most 256 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowS(Blackhole bh) { + for (BigInteger s : sArray) { + bh.consume(s.pow(sExp)); + } + } + + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowSRandomExps(Blackhole bh) { + int i = 0; + for (BigInteger s : sArray) { + bh.consume(s.pow(randomExps[i++])); + } + } + + /** Test BigInteger.pow() with numbers long at most 1024 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowM(Blackhole bh) { + for (BigInteger m : mArray) { + bh.consume(m.pow(mExp)); + } + } + + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowMRandomExps(Blackhole bh) { + int i = 0; + for (BigInteger m : mArray) { + bh.consume(m.pow(randomExps[i++])); + } + } + + /** Test BigInteger.pow() with numbers long at most 4096 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowL(Blackhole bh) { + for (BigInteger l : lArray) { + bh.consume(l.pow(lExp)); + } + } + + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowLRandomExps(Blackhole bh) { + int i = 0; + for (BigInteger l : lArray) { + bh.consume(l.pow(randomExps[i++])); + } + } + + /** Test BigInteger.pow() with numbers long at most 16384 bits */ + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowXL(Blackhole bh) { + for (BigInteger xl : xlArray) { + bh.consume(xl.pow(xlExp)); + } + } + + @Benchmark + @OperationsPerInvocation(TESTSIZE) + public void testPowXLRandomExps(Blackhole bh) { + int i = 0; + for (BigInteger xl : xlArray) { + bh.consume(xl.pow(randomExps[i++])); + } + } +}