Skip to content

Rewrite Chinese Remainder Theorem implementation #1236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,191 +1,124 @@
/**
* Use the chinese remainder theorem to solve a set of congruence equations.
*
* <p>The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod
* m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the
* coefficient. A value of null is returned if the coefficient cannot be eliminated.
*
* <p>The second method (reduce) is used to reduce a set of equations so that the moduli become
* pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and
* output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod
* m_n−1). Note that the number of equations may change during this process. A value of null is
* returned if the set of equations cannot be reduced to co-prime moduli.
*
* <p>The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of
* moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod
* m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form
* x≡a_new(mod m_new)x≡a_new(mod m_new).
*
* @author Micah Stairs
/*
* This program solves a system of linear congruence equations using the Chinese Remainder Theorem (CRT).
*
* The user is prompted to enter the number of congruence equations, followed by the coefficients (a) and moduli (m)
* for each equation of the form: x ≡ a[i] (mod m[i]), where all moduli must be pairwise coprime.
*
* The program checks if the moduli are pairwise coprime, computes the unique solution modulo the product of all moduli,
* and prints the result.
*
* Time Complexity: O(k × log²(n)), where k is the number of equations and n is the product of all moduli.
*
*/

package com.williamfiset.algorithms.math;

import java.util.*;
import java.util.Scanner;
import java.math.BigInteger;

public class ChineseRemainderTheorem {
public static void main(String[] args) {

Scanner scanner = new Scanner(System.in);

System.out.print("Enter the number of congruence equations: ");

// eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new).
public static long[] eliminateCoefficient(long c, long a, long m) {

long d = egcd(c, m)[0];

if (a % d != 0) return null;

c /= d;
a /= d;
m /= d;

long inv = egcd(c, m)[1];
m = Math.abs(m);
a = (((a * inv) % m) + m) % m;

return new long[] {a, m};
}

// reduce() takes a set of equations and reduces them to an equivalent
// set with pairwise co-prime moduli (or null if not solvable).
public static long[][] reduce(long[] a, long[] m) {

List<Long> aNew = new ArrayList<Long>();
List<Long> mNew = new ArrayList<Long>();

// Split up each equation into prime factors
for (int i = 0; i < a.length; i++) {
List<Long> factors = primeFactorization(m[i]);
Collections.sort(factors);
ListIterator<Long> iterator = factors.listIterator();
while (iterator.hasNext()) {
long val = iterator.next();
long total = val;
while (iterator.hasNext()) {
long nextVal = iterator.next();
if (nextVal == val) {
total *= val;
} else {
iterator.previous();
break;
}
int k = scanner.nextInt();

// Ensure there are at least two equations
if (k < 2) {
System.out.println("\nThe number of equations must be at least 2.");
scanner.close();
return;
}
aNew.add(a[i] % total);
mNew.add(total);
}
}

// Throw away repeated information and look for conflicts
for (int i = 0; i < aNew.size(); i++) {
for (int j = i + 1; j < aNew.size(); j++) {
if (mNew.get(i) % mNew.get(j) == 0 || mNew.get(j) % mNew.get(i) == 0) {
if (mNew.get(i) > mNew.get(j)) {
if ((aNew.get(i) % mNew.get(j)) == aNew.get(j)) {
aNew.remove(j);
mNew.remove(j);
j--;
continue;
} else return null;
} else {
if ((aNew.get(j) % mNew.get(i)) == aNew.get(i)) {
aNew.remove(i);
mNew.remove(i);
i--;
break;
} else return null;
}
BigInteger a[] = new BigInteger[k];
BigInteger m[] = new BigInteger[k];

System.out.println("Enter the coefficient values (a): ");

for (int i = 0; i < k; i++) {

System.out.print("a[" + (i + 1) + "] = ");

a[i] = scanner.nextBigInteger();

}
}
}

// Put result into an array
long[][] res = new long[2][aNew.size()];
for (int i = 0; i < aNew.size(); i++) {
res[0][i] = aNew.get(i);
res[1][i] = mNew.get(i);
}
System.out.println("Enter the moduli values (m): ");

return res;
}
for (int i = 0; i < k; i++) {

public static long[] crt(long[] a, long[] m) {
System.out.print("m[" + (i + 1) + "] = ");

long M = 1;
for (int i = 0; i < m.length; i++) M *= m[i];
m[i] = scanner.nextBigInteger();

long[] inv = new long[a.length];
for (int i = 0; i < inv.length; i++) inv[i] = egcd(M / m[i], m[i])[1];
}

long x = 0;
for (int i = 0; i < m.length; i++) {
x += (M / m[i]) * a[i] * inv[i]; // Overflow could occur here
x = ((x % M) + M) % M;
}
if (!arePairwiseCoprime(k, m)) {
System.out.println("\nModuli values are not pairwise coprime.");
scanner.close();
return;
}

BigInteger x = chineseRemainder(k, a, m);

System.out.println("\nx = " + x);

scanner.close();

return new long[] {x, M};
}

private static ArrayList<Long> primeFactorization(long n) {
ArrayList<Long> factors = new ArrayList<Long>();
if (n <= 0) throw new IllegalArgumentException();
else if (n == 1) return factors;
PriorityQueue<Long> divisorQueue = new PriorityQueue<Long>();
divisorQueue.add(n);
while (!divisorQueue.isEmpty()) {
long divisor = divisorQueue.remove();
if (isPrime(divisor)) {
factors.add(divisor);
continue;
}
long next_divisor = pollardRho(divisor);
if (next_divisor == divisor) {
divisorQueue.add(divisor);
} else {
divisorQueue.add(next_divisor);
divisorQueue.add(divisor / next_divisor);
}
}
return factors;
}

private static long pollardRho(long n) {
if (n % 2 == 0) return 2;
// Get a number in the range [2, 10^6]
long x = 2 + (long) (999999 * Math.random());
long c = 2 + (long) (999999 * Math.random());
long y = x;
long d = 1;
while (d == 1) {
x = (x * x + c) % n;
y = (y * y + c) % n;
y = (y * y + c) % n;
d = gcf(Math.abs(x - y), n);
if (d == n) break;
}
return d;
}

// Extended euclidean algorithm
private static long[] egcd(long a, long b) {
if (b == 0) return new long[] {a, 1, 0};
else {
long[] ret = egcd(b, a % b);
long tmp = ret[1] - ret[2] * (a / b);
ret[1] = ret[2];
ret[2] = tmp;
return ret;
}
}

private static long gcf(long a, long b) {
return b == 0 ? a : gcf(b, a % b);
}
/*
* Computes the solution to the system of congruence equations using CRT.
*/
public static BigInteger chineseRemainder(int k, BigInteger a[], BigInteger m[]) {

private static boolean isPrime(long n) {
if (n < 2) return false;
if (n == 2 || n == 3) return true;
if (n % 2 == 0 || n % 3 == 0) return false;
BigInteger moduliProduct = BigInteger.ONE;
BigInteger result = BigInteger.ZERO;

int limit = (int) Math.sqrt(n);
BigInteger M[] = new BigInteger[k];
BigInteger y[] = new BigInteger[k];

for (int i = 5; i <= limit; i += 6) if (n % i == 0 || n % (i + 2) == 0) return false;
// Compute the product of all moduli
for (int i = 0; i < k; i++) {
moduliProduct = moduliProduct.multiply(m[i]);
}

return true;
}
}
// Compute M[i] = product of all moduli divided by m[i]
for (int i = 0; i < k; i++) {
M[i] = moduliProduct.divide(m[i]);
}

// Compute modular inverse of M[i] modulo m[i]
for (int i = 0; i < k; i++) {
y[i] = M[i].modInverse(m[i]);
}

// Calculate the result using the CRT formula
for (int i = 0; i < k; i++) {
result = result.add(a[i].multiply(M[i]).multiply(y[i]));
}

result = result.mod(moduliProduct);

return result;

}

/*
* Checks if all moduli are pairwise coprime.
*/
public static boolean arePairwiseCoprime(int k, BigInteger m[]) {

for (int i = 0; i < k; i++) {
for (int j = i + 1; j < k; j++) {
if (!m[i].gcd(m[j]).equals(BigInteger.ONE)) {
return false;
}
}
}

return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.williamfiset.algorithms.math;

import org.junit.jupiter.api.Test;
import java.math.BigInteger;
import static org.junit.jupiter.api.Assertions.*;

class ChineseRemainderTheoremTest {

@Test
void testCRT_TwoEquations() {
BigInteger[] a = {BigInteger.valueOf(0), BigInteger.valueOf(3)};
BigInteger[] m = {BigInteger.valueOf(4), BigInteger.valueOf(5)};
BigInteger result = ChineseRemainderTheorem.chineseRemainder(2, a, m);
assertEquals(BigInteger.valueOf(8), result);
}

@Test
void testCRT_ThreeEquations() {
BigInteger[] a = {BigInteger.valueOf(2), BigInteger.valueOf(3), BigInteger.valueOf(2)};
BigInteger[] m = {BigInteger.valueOf(3), BigInteger.valueOf(5), BigInteger.valueOf(7)};
BigInteger result = ChineseRemainderTheorem.chineseRemainder(3, a, m);
assertEquals(BigInteger.valueOf(23), result);
}

@Test
void testCRT_FourEquations() {
BigInteger[] a = {
BigInteger.valueOf(1),
BigInteger.valueOf(2),
BigInteger.valueOf(3),
BigInteger.valueOf(4)
};
BigInteger[] m = {
BigInteger.valueOf(5),
BigInteger.valueOf(7),
BigInteger.valueOf(9),
BigInteger.valueOf(11)
};
BigInteger result = ChineseRemainderTheorem.chineseRemainder(4, a, m);

assertEquals(BigInteger.valueOf(1731), result);
}

@Test
void testArePairwiseCoprime_True() {
BigInteger[] m = {
BigInteger.valueOf(3),
BigInteger.valueOf(5),
BigInteger.valueOf(7)
};
assertTrue(ChineseRemainderTheorem.arePairwiseCoprime(3, m));
}

@Test
void testArePairwiseCoprime_False() {
BigInteger[] m = {
BigInteger.valueOf(6),
BigInteger.valueOf(8),
BigInteger.valueOf(9)
};
assertFalse(ChineseRemainderTheorem.arePairwiseCoprime(3, m));
}

@Test
void testCRT_LargeNumbers() {
BigInteger[] a = {
BigInteger.valueOf(123456),
BigInteger.valueOf(789012),
BigInteger.valueOf(345678)
};
BigInteger[] m = {
BigInteger.valueOf(1000003),
BigInteger.valueOf(1000033),
BigInteger.valueOf(1000037)
};
BigInteger result = ChineseRemainderTheorem.chineseRemainder(3, a, m);

assertEquals(a[0], result.mod(m[0]));
assertEquals(a[1], result.mod(m[1]));
assertEquals(a[2], result.mod(m[2]));
}
}