From 13a1d7137e70442a148a2c1beae4a17e06932fec Mon Sep 17 00:00:00 2001
From: Nabil Al Masri <142125053+NabilMx99@users.noreply.github.com>
Date: Sat, 31 May 2025 23:06:19 +0000
Subject: [PATCH] Rewrite Chinese Remainder Theorem implementation and Add unit
tests
---
.../math/ChineseRemainderTheorem.java | 271 +++++++-----------
.../math/ChineseRemainderTheoremTest.java | 82 ++++++
2 files changed, 184 insertions(+), 169 deletions(-)
create mode 100644 src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java
diff --git a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java
index a00d29538..e5ade3015 100644
--- a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java
+++ b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java
@@ -1,191 +1,124 @@
-/**
- * Use the chinese remainder theorem to solve a set of congruence equations.
- *
- *
The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod
- * m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the
- * coefficient. A value of null is returned if the coefficient cannot be eliminated.
- *
- *
The second method (reduce) is used to reduce a set of equations so that the moduli become
- * pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and
- * output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod
- * m_n−1). Note that the number of equations may change during this process. A value of null is
- * returned if the set of equations cannot be reduced to co-prime moduli.
- *
- *
The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of
- * moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod
- * m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form
- * x≡a_new(mod m_new)x≡a_new(mod m_new).
- *
- * @author Micah Stairs
+/*
+ * This program solves a system of linear congruence equations using the Chinese Remainder Theorem (CRT).
+ *
+ * The user is prompted to enter the number of congruence equations, followed by the coefficients (a) and moduli (m)
+ * for each equation of the form: x ≡ a[i] (mod m[i]), where all moduli must be pairwise coprime.
+ *
+ * The program checks if the moduli are pairwise coprime, computes the unique solution modulo the product of all moduli,
+ * and prints the result.
+ *
+ * Time Complexity: O(k × log²(n)), where k is the number of equations and n is the product of all moduli.
+ *
*/
+
package com.williamfiset.algorithms.math;
-import java.util.*;
+import java.util.Scanner;
+import java.math.BigInteger;
public class ChineseRemainderTheorem {
+ public static void main(String[] args) {
+
+ Scanner scanner = new Scanner(System.in);
+
+ System.out.print("Enter the number of congruence equations: ");
- // eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new).
- public static long[] eliminateCoefficient(long c, long a, long m) {
-
- long d = egcd(c, m)[0];
-
- if (a % d != 0) return null;
-
- c /= d;
- a /= d;
- m /= d;
-
- long inv = egcd(c, m)[1];
- m = Math.abs(m);
- a = (((a * inv) % m) + m) % m;
-
- return new long[] {a, m};
- }
-
- // reduce() takes a set of equations and reduces them to an equivalent
- // set with pairwise co-prime moduli (or null if not solvable).
- public static long[][] reduce(long[] a, long[] m) {
-
- List aNew = new ArrayList();
- List mNew = new ArrayList();
-
- // Split up each equation into prime factors
- for (int i = 0; i < a.length; i++) {
- List factors = primeFactorization(m[i]);
- Collections.sort(factors);
- ListIterator iterator = factors.listIterator();
- while (iterator.hasNext()) {
- long val = iterator.next();
- long total = val;
- while (iterator.hasNext()) {
- long nextVal = iterator.next();
- if (nextVal == val) {
- total *= val;
- } else {
- iterator.previous();
- break;
- }
+ int k = scanner.nextInt();
+
+ // Ensure there are at least two equations
+ if (k < 2) {
+ System.out.println("\nThe number of equations must be at least 2.");
+ scanner.close();
+ return;
}
- aNew.add(a[i] % total);
- mNew.add(total);
- }
- }
- // Throw away repeated information and look for conflicts
- for (int i = 0; i < aNew.size(); i++) {
- for (int j = i + 1; j < aNew.size(); j++) {
- if (mNew.get(i) % mNew.get(j) == 0 || mNew.get(j) % mNew.get(i) == 0) {
- if (mNew.get(i) > mNew.get(j)) {
- if ((aNew.get(i) % mNew.get(j)) == aNew.get(j)) {
- aNew.remove(j);
- mNew.remove(j);
- j--;
- continue;
- } else return null;
- } else {
- if ((aNew.get(j) % mNew.get(i)) == aNew.get(i)) {
- aNew.remove(i);
- mNew.remove(i);
- i--;
- break;
- } else return null;
- }
+ BigInteger a[] = new BigInteger[k];
+ BigInteger m[] = new BigInteger[k];
+
+ System.out.println("Enter the coefficient values (a): ");
+
+ for (int i = 0; i < k; i++) {
+
+ System.out.print("a[" + (i + 1) + "] = ");
+
+ a[i] = scanner.nextBigInteger();
+
}
- }
- }
- // Put result into an array
- long[][] res = new long[2][aNew.size()];
- for (int i = 0; i < aNew.size(); i++) {
- res[0][i] = aNew.get(i);
- res[1][i] = mNew.get(i);
- }
+ System.out.println("Enter the moduli values (m): ");
- return res;
- }
+ for (int i = 0; i < k; i++) {
- public static long[] crt(long[] a, long[] m) {
+ System.out.print("m[" + (i + 1) + "] = ");
- long M = 1;
- for (int i = 0; i < m.length; i++) M *= m[i];
+ m[i] = scanner.nextBigInteger();
- long[] inv = new long[a.length];
- for (int i = 0; i < inv.length; i++) inv[i] = egcd(M / m[i], m[i])[1];
+ }
- long x = 0;
- for (int i = 0; i < m.length; i++) {
- x += (M / m[i]) * a[i] * inv[i]; // Overflow could occur here
- x = ((x % M) + M) % M;
- }
+ if (!arePairwiseCoprime(k, m)) {
+ System.out.println("\nModuli values are not pairwise coprime.");
+ scanner.close();
+ return;
+ }
+
+ BigInteger x = chineseRemainder(k, a, m);
+
+ System.out.println("\nx = " + x);
+
+ scanner.close();
- return new long[] {x, M};
- }
-
- private static ArrayList primeFactorization(long n) {
- ArrayList factors = new ArrayList();
- if (n <= 0) throw new IllegalArgumentException();
- else if (n == 1) return factors;
- PriorityQueue divisorQueue = new PriorityQueue();
- divisorQueue.add(n);
- while (!divisorQueue.isEmpty()) {
- long divisor = divisorQueue.remove();
- if (isPrime(divisor)) {
- factors.add(divisor);
- continue;
- }
- long next_divisor = pollardRho(divisor);
- if (next_divisor == divisor) {
- divisorQueue.add(divisor);
- } else {
- divisorQueue.add(next_divisor);
- divisorQueue.add(divisor / next_divisor);
- }
- }
- return factors;
- }
-
- private static long pollardRho(long n) {
- if (n % 2 == 0) return 2;
- // Get a number in the range [2, 10^6]
- long x = 2 + (long) (999999 * Math.random());
- long c = 2 + (long) (999999 * Math.random());
- long y = x;
- long d = 1;
- while (d == 1) {
- x = (x * x + c) % n;
- y = (y * y + c) % n;
- y = (y * y + c) % n;
- d = gcf(Math.abs(x - y), n);
- if (d == n) break;
- }
- return d;
- }
-
- // Extended euclidean algorithm
- private static long[] egcd(long a, long b) {
- if (b == 0) return new long[] {a, 1, 0};
- else {
- long[] ret = egcd(b, a % b);
- long tmp = ret[1] - ret[2] * (a / b);
- ret[1] = ret[2];
- ret[2] = tmp;
- return ret;
}
- }
- private static long gcf(long a, long b) {
- return b == 0 ? a : gcf(b, a % b);
- }
+ /*
+ * Computes the solution to the system of congruence equations using CRT.
+ */
+ public static BigInteger chineseRemainder(int k, BigInteger a[], BigInteger m[]) {
- private static boolean isPrime(long n) {
- if (n < 2) return false;
- if (n == 2 || n == 3) return true;
- if (n % 2 == 0 || n % 3 == 0) return false;
+ BigInteger moduliProduct = BigInteger.ONE;
+ BigInteger result = BigInteger.ZERO;
- int limit = (int) Math.sqrt(n);
+ BigInteger M[] = new BigInteger[k];
+ BigInteger y[] = new BigInteger[k];
- for (int i = 5; i <= limit; i += 6) if (n % i == 0 || n % (i + 2) == 0) return false;
+ // Compute the product of all moduli
+ for (int i = 0; i < k; i++) {
+ moduliProduct = moduliProduct.multiply(m[i]);
+ }
- return true;
- }
-}
+ // Compute M[i] = product of all moduli divided by m[i]
+ for (int i = 0; i < k; i++) {
+ M[i] = moduliProduct.divide(m[i]);
+ }
+
+ // Compute modular inverse of M[i] modulo m[i]
+ for (int i = 0; i < k; i++) {
+ y[i] = M[i].modInverse(m[i]);
+ }
+
+ // Calculate the result using the CRT formula
+ for (int i = 0; i < k; i++) {
+ result = result.add(a[i].multiply(M[i]).multiply(y[i]));
+ }
+
+ result = result.mod(moduliProduct);
+
+ return result;
+
+ }
+
+ /*
+ * Checks if all moduli are pairwise coprime.
+ */
+ public static boolean arePairwiseCoprime(int k, BigInteger m[]) {
+
+ for (int i = 0; i < k; i++) {
+ for (int j = i + 1; j < k; j++) {
+ if (!m[i].gcd(m[j]).equals(BigInteger.ONE)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java
new file mode 100644
index 000000000..c38062415
--- /dev/null
+++ b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java
@@ -0,0 +1,82 @@
+package com.williamfiset.algorithms.math;
+
+import org.junit.jupiter.api.Test;
+import java.math.BigInteger;
+import static org.junit.jupiter.api.Assertions.*;
+
+class ChineseRemainderTheoremTest {
+
+ @Test
+ void testCRT_TwoEquations() {
+ BigInteger[] a = {BigInteger.valueOf(0), BigInteger.valueOf(3)};
+ BigInteger[] m = {BigInteger.valueOf(4), BigInteger.valueOf(5)};
+ BigInteger result = ChineseRemainderTheorem.chineseRemainder(2, a, m);
+ assertEquals(BigInteger.valueOf(8), result);
+ }
+
+ @Test
+ void testCRT_ThreeEquations() {
+ BigInteger[] a = {BigInteger.valueOf(2), BigInteger.valueOf(3), BigInteger.valueOf(2)};
+ BigInteger[] m = {BigInteger.valueOf(3), BigInteger.valueOf(5), BigInteger.valueOf(7)};
+ BigInteger result = ChineseRemainderTheorem.chineseRemainder(3, a, m);
+ assertEquals(BigInteger.valueOf(23), result);
+ }
+
+ @Test
+ void testCRT_FourEquations() {
+ BigInteger[] a = {
+ BigInteger.valueOf(1),
+ BigInteger.valueOf(2),
+ BigInteger.valueOf(3),
+ BigInteger.valueOf(4)
+ };
+ BigInteger[] m = {
+ BigInteger.valueOf(5),
+ BigInteger.valueOf(7),
+ BigInteger.valueOf(9),
+ BigInteger.valueOf(11)
+ };
+ BigInteger result = ChineseRemainderTheorem.chineseRemainder(4, a, m);
+
+ assertEquals(BigInteger.valueOf(1731), result);
+ }
+
+ @Test
+ void testArePairwiseCoprime_True() {
+ BigInteger[] m = {
+ BigInteger.valueOf(3),
+ BigInteger.valueOf(5),
+ BigInteger.valueOf(7)
+ };
+ assertTrue(ChineseRemainderTheorem.arePairwiseCoprime(3, m));
+ }
+
+ @Test
+ void testArePairwiseCoprime_False() {
+ BigInteger[] m = {
+ BigInteger.valueOf(6),
+ BigInteger.valueOf(8),
+ BigInteger.valueOf(9)
+ };
+ assertFalse(ChineseRemainderTheorem.arePairwiseCoprime(3, m));
+ }
+
+ @Test
+ void testCRT_LargeNumbers() {
+ BigInteger[] a = {
+ BigInteger.valueOf(123456),
+ BigInteger.valueOf(789012),
+ BigInteger.valueOf(345678)
+ };
+ BigInteger[] m = {
+ BigInteger.valueOf(1000003),
+ BigInteger.valueOf(1000033),
+ BigInteger.valueOf(1000037)
+ };
+ BigInteger result = ChineseRemainderTheorem.chineseRemainder(3, a, m);
+
+ assertEquals(a[0], result.mod(m[0]));
+ assertEquals(a[1], result.mod(m[1]));
+ assertEquals(a[2], result.mod(m[2]));
+ }
+}
\ No newline at end of file