From 13a1d7137e70442a148a2c1beae4a17e06932fec Mon Sep 17 00:00:00 2001 From: Nabil Al Masri <142125053+NabilMx99@users.noreply.github.com> Date: Sat, 31 May 2025 23:06:19 +0000 Subject: [PATCH] Rewrite Chinese Remainder Theorem implementation and Add unit tests --- .../math/ChineseRemainderTheorem.java | 271 +++++++----------- .../math/ChineseRemainderTheoremTest.java | 82 ++++++ 2 files changed, 184 insertions(+), 169 deletions(-) create mode 100644 src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java diff --git a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java index a00d29538..e5ade3015 100644 --- a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java +++ b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java @@ -1,191 +1,124 @@ -/** - * Use the chinese remainder theorem to solve a set of congruence equations. - * - *

The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod - * m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the - * coefficient. A value of null is returned if the coefficient cannot be eliminated. - * - *

The second method (reduce) is used to reduce a set of equations so that the moduli become - * pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and - * output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod - * m_n−1). Note that the number of equations may change during this process. A value of null is - * returned if the set of equations cannot be reduced to co-prime moduli. - * - *

The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of - * moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod - * m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form - * x≡a_new(mod m_new)x≡a_new(mod m_new). - * - * @author Micah Stairs +/* + * This program solves a system of linear congruence equations using the Chinese Remainder Theorem (CRT). + * + * The user is prompted to enter the number of congruence equations, followed by the coefficients (a) and moduli (m) + * for each equation of the form: x ≡ a[i] (mod m[i]), where all moduli must be pairwise coprime. + * + * The program checks if the moduli are pairwise coprime, computes the unique solution modulo the product of all moduli, + * and prints the result. + * + * Time Complexity: O(k × log²(n)), where k is the number of equations and n is the product of all moduli. + * */ + package com.williamfiset.algorithms.math; -import java.util.*; +import java.util.Scanner; +import java.math.BigInteger; public class ChineseRemainderTheorem { + public static void main(String[] args) { + + Scanner scanner = new Scanner(System.in); + + System.out.print("Enter the number of congruence equations: "); - // eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new). - public static long[] eliminateCoefficient(long c, long a, long m) { - - long d = egcd(c, m)[0]; - - if (a % d != 0) return null; - - c /= d; - a /= d; - m /= d; - - long inv = egcd(c, m)[1]; - m = Math.abs(m); - a = (((a * inv) % m) + m) % m; - - return new long[] {a, m}; - } - - // reduce() takes a set of equations and reduces them to an equivalent - // set with pairwise co-prime moduli (or null if not solvable). - public static long[][] reduce(long[] a, long[] m) { - - List aNew = new ArrayList(); - List mNew = new ArrayList(); - - // Split up each equation into prime factors - for (int i = 0; i < a.length; i++) { - List factors = primeFactorization(m[i]); - Collections.sort(factors); - ListIterator iterator = factors.listIterator(); - while (iterator.hasNext()) { - long val = iterator.next(); - long total = val; - while (iterator.hasNext()) { - long nextVal = iterator.next(); - if (nextVal == val) { - total *= val; - } else { - iterator.previous(); - break; - } + int k = scanner.nextInt(); + + // Ensure there are at least two equations + if (k < 2) { + System.out.println("\nThe number of equations must be at least 2."); + scanner.close(); + return; } - aNew.add(a[i] % total); - mNew.add(total); - } - } - // Throw away repeated information and look for conflicts - for (int i = 0; i < aNew.size(); i++) { - for (int j = i + 1; j < aNew.size(); j++) { - if (mNew.get(i) % mNew.get(j) == 0 || mNew.get(j) % mNew.get(i) == 0) { - if (mNew.get(i) > mNew.get(j)) { - if ((aNew.get(i) % mNew.get(j)) == aNew.get(j)) { - aNew.remove(j); - mNew.remove(j); - j--; - continue; - } else return null; - } else { - if ((aNew.get(j) % mNew.get(i)) == aNew.get(i)) { - aNew.remove(i); - mNew.remove(i); - i--; - break; - } else return null; - } + BigInteger a[] = new BigInteger[k]; + BigInteger m[] = new BigInteger[k]; + + System.out.println("Enter the coefficient values (a): "); + + for (int i = 0; i < k; i++) { + + System.out.print("a[" + (i + 1) + "] = "); + + a[i] = scanner.nextBigInteger(); + } - } - } - // Put result into an array - long[][] res = new long[2][aNew.size()]; - for (int i = 0; i < aNew.size(); i++) { - res[0][i] = aNew.get(i); - res[1][i] = mNew.get(i); - } + System.out.println("Enter the moduli values (m): "); - return res; - } + for (int i = 0; i < k; i++) { - public static long[] crt(long[] a, long[] m) { + System.out.print("m[" + (i + 1) + "] = "); - long M = 1; - for (int i = 0; i < m.length; i++) M *= m[i]; + m[i] = scanner.nextBigInteger(); - long[] inv = new long[a.length]; - for (int i = 0; i < inv.length; i++) inv[i] = egcd(M / m[i], m[i])[1]; + } - long x = 0; - for (int i = 0; i < m.length; i++) { - x += (M / m[i]) * a[i] * inv[i]; // Overflow could occur here - x = ((x % M) + M) % M; - } + if (!arePairwiseCoprime(k, m)) { + System.out.println("\nModuli values are not pairwise coprime."); + scanner.close(); + return; + } + + BigInteger x = chineseRemainder(k, a, m); + + System.out.println("\nx = " + x); + + scanner.close(); - return new long[] {x, M}; - } - - private static ArrayList primeFactorization(long n) { - ArrayList factors = new ArrayList(); - if (n <= 0) throw new IllegalArgumentException(); - else if (n == 1) return factors; - PriorityQueue divisorQueue = new PriorityQueue(); - divisorQueue.add(n); - while (!divisorQueue.isEmpty()) { - long divisor = divisorQueue.remove(); - if (isPrime(divisor)) { - factors.add(divisor); - continue; - } - long next_divisor = pollardRho(divisor); - if (next_divisor == divisor) { - divisorQueue.add(divisor); - } else { - divisorQueue.add(next_divisor); - divisorQueue.add(divisor / next_divisor); - } - } - return factors; - } - - private static long pollardRho(long n) { - if (n % 2 == 0) return 2; - // Get a number in the range [2, 10^6] - long x = 2 + (long) (999999 * Math.random()); - long c = 2 + (long) (999999 * Math.random()); - long y = x; - long d = 1; - while (d == 1) { - x = (x * x + c) % n; - y = (y * y + c) % n; - y = (y * y + c) % n; - d = gcf(Math.abs(x - y), n); - if (d == n) break; - } - return d; - } - - // Extended euclidean algorithm - private static long[] egcd(long a, long b) { - if (b == 0) return new long[] {a, 1, 0}; - else { - long[] ret = egcd(b, a % b); - long tmp = ret[1] - ret[2] * (a / b); - ret[1] = ret[2]; - ret[2] = tmp; - return ret; } - } - private static long gcf(long a, long b) { - return b == 0 ? a : gcf(b, a % b); - } + /* + * Computes the solution to the system of congruence equations using CRT. + */ + public static BigInteger chineseRemainder(int k, BigInteger a[], BigInteger m[]) { - private static boolean isPrime(long n) { - if (n < 2) return false; - if (n == 2 || n == 3) return true; - if (n % 2 == 0 || n % 3 == 0) return false; + BigInteger moduliProduct = BigInteger.ONE; + BigInteger result = BigInteger.ZERO; - int limit = (int) Math.sqrt(n); + BigInteger M[] = new BigInteger[k]; + BigInteger y[] = new BigInteger[k]; - for (int i = 5; i <= limit; i += 6) if (n % i == 0 || n % (i + 2) == 0) return false; + // Compute the product of all moduli + for (int i = 0; i < k; i++) { + moduliProduct = moduliProduct.multiply(m[i]); + } - return true; - } -} + // Compute M[i] = product of all moduli divided by m[i] + for (int i = 0; i < k; i++) { + M[i] = moduliProduct.divide(m[i]); + } + + // Compute modular inverse of M[i] modulo m[i] + for (int i = 0; i < k; i++) { + y[i] = M[i].modInverse(m[i]); + } + + // Calculate the result using the CRT formula + for (int i = 0; i < k; i++) { + result = result.add(a[i].multiply(M[i]).multiply(y[i])); + } + + result = result.mod(moduliProduct); + + return result; + + } + + /* + * Checks if all moduli are pairwise coprime. + */ + public static boolean arePairwiseCoprime(int k, BigInteger m[]) { + + for (int i = 0; i < k; i++) { + for (int j = i + 1; j < k; j++) { + if (!m[i].gcd(m[j]).equals(BigInteger.ONE)) { + return false; + } + } + } + + return true; + } +} \ No newline at end of file diff --git a/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java new file mode 100644 index 000000000..c38062415 --- /dev/null +++ b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java @@ -0,0 +1,82 @@ +package com.williamfiset.algorithms.math; + +import org.junit.jupiter.api.Test; +import java.math.BigInteger; +import static org.junit.jupiter.api.Assertions.*; + +class ChineseRemainderTheoremTest { + + @Test + void testCRT_TwoEquations() { + BigInteger[] a = {BigInteger.valueOf(0), BigInteger.valueOf(3)}; + BigInteger[] m = {BigInteger.valueOf(4), BigInteger.valueOf(5)}; + BigInteger result = ChineseRemainderTheorem.chineseRemainder(2, a, m); + assertEquals(BigInteger.valueOf(8), result); + } + + @Test + void testCRT_ThreeEquations() { + BigInteger[] a = {BigInteger.valueOf(2), BigInteger.valueOf(3), BigInteger.valueOf(2)}; + BigInteger[] m = {BigInteger.valueOf(3), BigInteger.valueOf(5), BigInteger.valueOf(7)}; + BigInteger result = ChineseRemainderTheorem.chineseRemainder(3, a, m); + assertEquals(BigInteger.valueOf(23), result); + } + + @Test + void testCRT_FourEquations() { + BigInteger[] a = { + BigInteger.valueOf(1), + BigInteger.valueOf(2), + BigInteger.valueOf(3), + BigInteger.valueOf(4) + }; + BigInteger[] m = { + BigInteger.valueOf(5), + BigInteger.valueOf(7), + BigInteger.valueOf(9), + BigInteger.valueOf(11) + }; + BigInteger result = ChineseRemainderTheorem.chineseRemainder(4, a, m); + + assertEquals(BigInteger.valueOf(1731), result); + } + + @Test + void testArePairwiseCoprime_True() { + BigInteger[] m = { + BigInteger.valueOf(3), + BigInteger.valueOf(5), + BigInteger.valueOf(7) + }; + assertTrue(ChineseRemainderTheorem.arePairwiseCoprime(3, m)); + } + + @Test + void testArePairwiseCoprime_False() { + BigInteger[] m = { + BigInteger.valueOf(6), + BigInteger.valueOf(8), + BigInteger.valueOf(9) + }; + assertFalse(ChineseRemainderTheorem.arePairwiseCoprime(3, m)); + } + + @Test + void testCRT_LargeNumbers() { + BigInteger[] a = { + BigInteger.valueOf(123456), + BigInteger.valueOf(789012), + BigInteger.valueOf(345678) + }; + BigInteger[] m = { + BigInteger.valueOf(1000003), + BigInteger.valueOf(1000033), + BigInteger.valueOf(1000037) + }; + BigInteger result = ChineseRemainderTheorem.chineseRemainder(3, a, m); + + assertEquals(a[0], result.mod(m[0])); + assertEquals(a[1], result.mod(m[1])); + assertEquals(a[2], result.mod(m[2])); + } +} \ No newline at end of file