From e610daea60a668b8de2833c091f19f1f5e60c78f Mon Sep 17 00:00:00 2001 From: Zachary DeStefano Date: Thu, 14 Aug 2025 13:54:20 -0400 Subject: [PATCH 01/38] Modularize multiplication by small constants. Allow for turning large bigints into field elements. --- curves/bn254/src/fields/fq.rs | 2 +- ff/src/biginteger/mod.rs | 109 ++++ ff/src/fields/models/fp/mod.rs | 6 +- ff/src/fields/models/fp/montgomery_backend.rs | 561 +++++------------- ff/src/fields/prime.rs | 4 +- 5 files changed, 268 insertions(+), 414 deletions(-) diff --git a/curves/bn254/src/fields/fq.rs b/curves/bn254/src/fields/fq.rs index 26deabf79..79cf6ba89 100644 --- a/curves/bn254/src/fields/fq.rs +++ b/curves/bn254/src/fields/fq.rs @@ -4,4 +4,4 @@ use ark_ff::fields::{Fp256, MontBackend, MontConfig}; #[modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] #[generator = "3"] pub struct FqConfig; -pub type Fq = Fp256>; +pub type Fq = Fp256>; diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 62f7bc658..322a1c092 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -443,6 +443,103 @@ impl BigInteger for BigInt { } } + #[inline] + #[unroll_for_loops(8)] + fn mul_u64_in_place(&mut self, other: u64) { + // special cases for 0 and 1 + if other == 0 || self.is_zero() { + *self = Self::zero(); + return; + } else if other == 1 { + return; + } + // Calculate the full 128-bit product of the lowest limb + let mut prod: u128 = (self.0[0] as u128) * (other as u128); + self.0[0] = prod as u64; + let mut carry = (prod >> 64) as u64; + // iterate through the remaining limbs + for i in 1..N { + // Calculate the full 128-bit product of the current limb and the u64 multiplier + prod = (self.0[i] as u128) * (other as u128) + (carry as u128); + self.0[i] = prod as u64; + carry = (prod >> 64) as u64; + } + debug_assert!(carry == 0, "Overflow in BigInt::mul_u64_in_place"); + } + + #[inline] + #[unroll_for_loops(8)] + fn mul_u64_w_carry(&self, other: u64) -> BigInt { + // ensure NPLUS1 is the correct size + debug_assert!(NPLUS1 == N + 1); + // special cases for 0 and 1 + if other == 0 || self.is_zero() { + return BigInt::::zero(); + } else if other == 1 { + let mut res = BigInt::::zero(); + for i in 0..N { + res.0[i] = self.0[i]; + } + return res; + } + // initialize result + let mut res: [u64; NPLUS1] = [0u64; NPLUS1]; + // Calculate the full 128-bit product of the lowest limb + let mut prod: u128 = (self.0[0] as u128) * (other as u128); + res[0] = prod as u64; + let mut carry = (prod >> 64) as u64; + // iterate through the remaining limbs + for i in 1..N { + // Calculate the full 128-bit product of the current limb and the u64 multiplier + prod = (self.0[i] as u128) * (other as u128) + (carry as u128); + res[i] = prod as u64; + carry = (prod >> 64) as u64; + } + // add final carry + res[N] = carry; + // and return + BigInt::(res) + } + + #[inline] + #[unroll_for_loops(8)] + fn mul_u128_w_carry( + &self, + other: u128, + ) -> BigInt { + // NPLUS1 is N + 1, NPLUS2 is N + 2 + debug_assert!(NPLUS1 == N + 1); + debug_assert!(NPLUS2 == N + 2); + // special cases for 0 and 1 + if other == 0 || self.is_zero() { + return BigInt::::zero(); + } else if other == 1 { + let mut res = BigInt::::zero(); + for i in 0..N { + res.0[i] = self.0[i]; + } + return res; + } + // split other into two u64s + let other_lo = other as u64; + let other_hi = (other >> 64) as u64; + // two u64 multiplications with carry + let lo_part = self.mul_u64_w_carry::(other_lo); + let hi_part = self.mul_u64_w_carry::(other_hi); + // pad lo_part right by one limb (extra high zero limb) + // pad hi_part left by one limb (i.e. multiply by 2^64) + let mut lo_padded = BigInt::::zero(); + let mut hi_padded = BigInt::::zero(); + for i in 0..NPLUS1 { + lo_padded.0[i] = lo_part.0[i]; + hi_padded.0[i + 1] = hi_part.0[i]; + } + // add the two padded parts + let (res, carry) = lo_padded.const_add_with_carry(&hi_padded); + debug_assert!(carry == false, "Overflow in BigInt::mul_u128_w_carry"); + res + } + #[inline] fn mul(&self, other: &Self) -> (Self, Self) { if self.is_zero() || other.is_zero() { @@ -1110,6 +1207,18 @@ pub trait BigInteger: #[deprecated(since = "0.4.2", note = "please use the operator `<<` instead")] fn muln(&mut self, amt: u32); + /// NEW! Multiplies self by a u64 in place. Overflow is ignored. + fn mul_u64_in_place(&mut self, other: u64); + + /// NEW! Multiplies self by a u64, returning a bigint with one extra limb to hold overflow. + fn mul_u64_w_carry(&self, other: u64) -> BigInt; + + /// NEW! Multiplies self by a u128, returning a bigint with two extra limbs to hold overflow. + fn mul_u128_w_carry( + &self, + other: u128, + ) -> BigInt; + /// Multiplies this [`BigInteger`] by another `BigInteger`, storing the result in `self`. /// Overflow is ignored. /// diff --git a/ff/src/fields/models/fp/mod.rs b/ff/src/fields/models/fp/mod.rs index 342788f8d..3fb2110a3 100644 --- a/ff/src/fields/models/fp/mod.rs +++ b/ff/src/fields/models/fp/mod.rs @@ -105,7 +105,7 @@ pub trait FpConfig: Send + Sync + 'static + Sized { /// Creates a field element from a `u64`. /// Returns `None` if the `u64` is larger than or equal to the modulus. - fn from_u64(val: u64) -> Option>; + fn from_u64(val: u64) -> Option>; } /// Represents an element of the prime field F_p, where `p == P::MODULUS`. @@ -374,8 +374,8 @@ impl, const N: usize> PrimeField for Fp { } #[inline] - fn from_u64(r: u64) -> Option { - P::from_u64(r) + fn from_u64(r: u64) -> Option { + P::from_u64::(r) } } diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index eac675038..299a7b8fb 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -429,31 +429,36 @@ pub trait MontConfig: 'static + Sync + Send + Sized { } } - fn from_i128(r: i128) -> Option, N>> { + fn from_i128( + r: i128, + ) -> Option, N>> { // TODO: small table for signed values? - Some(Fp::new_unchecked(Self::R).mul_i128(r)) + Some(Fp::new_unchecked(Self::R).mul_i128::(r)) } - fn from_u128(r: u128) -> Option, N>> { + fn from_u128( + r: u128, + ) -> Option, N>> { if r < PRECOMP_TABLE_SIZE as u128 { Some(Self::SMALL_ELEMENT_MONTGOMERY_PRECOMP[r as usize]) } else { // Multiply R (one in Montgomery form) with the u128 - Some(Fp::new_unchecked(Self::R).mul_u128(r)) + Some(Fp::new_unchecked(Self::R).mul_u128::(r)) } } - fn from_i64(r: i64) -> Option, N>> { + fn from_i64(r: i64) -> Option, N>> { // TODO: small table for signed values? - Some(Fp::new_unchecked(Self::R).mul_i64(r)) + Some(Fp::new_unchecked(Self::R).mul_i64::(r)) } - fn from_u64(r: u64) -> Option, N>> { + fn from_u64(r: u64) -> Option, N>> { + debug_assert!(NPLUS1 == N + 1); if r < PRECOMP_TABLE_SIZE as u64 { Some(Self::SMALL_ELEMENT_MONTGOMERY_PRECOMP[r as usize]) } else { // Multiply R (one in Montgomery form) with the u64 - Some(Fp::new_unchecked(Self::R).mul_u64(r)) + Some(Fp::new_unchecked(Self::R).mul_u64::(r)) } } @@ -797,8 +802,8 @@ impl, const N: usize> FpConfig for MontBackend { T::into_bigint(a) } - fn from_u64(r: u64) -> Option> { - T::from_u64(r) + fn from_u64(r: u64) -> Option> { + T::from_u64::(r) } } @@ -835,6 +840,34 @@ impl, const N: usize> Fp, N> { Self(element, PhantomData) } + /// NEW! Construct a new field element from a BigInt + /// which is in montgomery form and just needs to be reduced + /// via a barrett reduction. + #[inline] + pub fn from_unchecked_nplus1(element: BigInt<{ NPLUS1 }>) -> Self { + debug_assert!(NPLUS1 == N + 1); + // Barrett reduction + let r = barrett_reduce_nplus1_to_n::(element); + Self::new_unchecked(r) + } + + /// NEW! Construct a new field element from a BigInt + /// which is in montgomery form and just needs to be reduced + /// via a barrett reduction. + #[inline] + pub fn from_unchecked_nplus2( + element: BigInt<{ NPLUS2 }>, + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!(NPLUS2 == N + 2); + let c1 = BigInt::(element.0[1..NPLUS2].try_into().unwrap()); // c1 has N+1 limbs + let r1 = barrett_reduce_nplus1_to_n::(c1); // r1 = c1 mod p ([u64; N]) + // Round 2: Reduce c2 = c_lo[0] + r1 * r. + let c2 = nplus1_pair_low_to_bigint::((element.0[0], r1.0)); // c2 has N+1 limbs + let r2 = barrett_reduce_nplus1_to_n::(c2); // r2 = c2 mod p = c mod p ([u64; N]) + Self::new_unchecked(r2) + } + const fn const_is_zero(&self) -> bool { self.0.const_is_zero() } @@ -913,29 +946,24 @@ impl, const N: usize> Fp, N> { } #[inline(always)] - pub fn mul_u64(self, other: u64) -> Self { - // Stage 1: Bignum Multiplication - // Compute c = self.0 * other. Result c has N+1 limbs. - let c = bigint_mul_by_u64(&self.0 .0, other); - - // Stage 2: Barrett Reduction - let r = barrett_reduce_nplus1_to_n::(c); - - // Use the final r_n_limbs which holds the correct N-limb result - Self::new_unchecked(BigInt::(r)) + pub fn mul_u64(self, other: u64) -> Self { + debug_assert!(NPLUS1 == N + 1); + let c: BigInt = BigInt::mul_u64_w_carry(&self.0, other); // multiply + Self::from_unchecked_nplus1(c) // reduce and return the result } /// Multiply by an i64. Invokes `mul_u64` if the input is positive, /// otherwise negates the result of `mul_u64` of the absolute value. #[inline(always)] - pub fn mul_i64(self, other: i64) -> Self { + pub fn mul_i64(self, other: i64) -> Self { + debug_assert!(NPLUS1 == N + 1); if other >= 0 { // Multiply by the positive value directly - self.mul_u64(other as u64) + self.mul_u64::(other as u64) } else { // Multiply by the absolute value and then negate the result // (-other) cannot overflow since other is not i64::MIN - -(self.mul_u64((-other) as u64)) + -(self.mul_u64::((-other) as u64)) } } @@ -943,15 +971,15 @@ impl, const N: usize> Fp, N> { /// Uses optimized mul_u64 if the absolute value of the input fits within u64, /// otherwise falls back to the two-step Barrett reduction (`mul_u128_aux`). #[inline(always)] - pub fn mul_i128(self, other: i128) -> Self { + pub fn mul_i128(self, other: i128) -> Self { if other >= 0 { let other_u128 = other as u128; if other_u128 <= u64::MAX as u128 { // Positive value fits in u64 - self.mul_u64(other_u128 as u64) + self.mul_u64::(other_u128 as u64) } else { // Positive value requires u128 path - self.mul_u128_aux(other_u128) + self.mul_u128_aux::(other_u128) } } else { // Negative value, compute absolute value as u128 @@ -959,10 +987,10 @@ impl, const N: usize> Fp, N> { let abs_other = (-other) as u128; if abs_other <= u64::MAX as u128 { // Absolute value fits in u64 - -(self.mul_u64(abs_other as u64)) + -(self.mul_u64::(abs_other as u64)) } else { // Absolute value requires u128 path - -(self.mul_u128_aux(abs_other)) + -(self.mul_u128_aux::(abs_other)) } } } @@ -971,37 +999,20 @@ impl, const N: usize> Fp, N> { /// Uses optimized mul_u64 if the input fits within u64, /// otherwise falls back to standard multiplication. #[inline(always)] - pub fn mul_u128(self, other: u128) -> Self { + pub fn mul_u128(self, other: u128) -> Self { if other >> 64 == 0 { - self.mul_u64(other as u64) + self.mul_u64::(other as u64) } else { - self.mul_u128_aux(other) + self.mul_u128_aux::(other) } } /// Fallback option for mul_u128: if the input does not fit within u64, /// we perform a more expensive procedure with 2 rounds of Barrett reduction. #[inline(always)] - pub fn mul_u128_aux(self, other: u128) -> Self { - // Stage 1: Bignum Multiplication - // Compute c = self.0 * other. Result c has N+2 limbs: (c_lo: [u64; 2], c_hi: [u64; N]) - let (c_lo, c_hi) = bigint_mul_by_u128(&self.0, other); - - // Stage 2: Two rounds of Barrett Reduction using the modular subroutine - - // Round 1: Reduce the top N+1 limbs c1 = floor(c / r). - // c1 has low limb c_lo[1] and high N limbs c_hi. - // Input to barrett_reduce is (u64, [u64; N]). - let c1 = (c_lo[1], c_hi); - let r1 = barrett_reduce_nplus1_to_n::(c1); // r1 = c1 mod p ([u64; N]) - - // Round 2: Reduce c2 = c_lo[0] + r1 * r. - // c2 has low limb c_lo[0] and high N limbs r1.0. - // Input to barrett_reduce is (u64, [u64; N]). - let c2 = (c_lo[0], r1); // Pass r1 directly as the high N limbs array - let r2 = barrett_reduce_nplus1_to_n::(c2); // r2 = c2 mod p = c mod p ([u64; N]) - - Self::new_unchecked(BigInt::(r2)) + pub fn mul_u128_aux(self, other: u128) -> Self { + let c = BigInt::mul_u128_w_carry::(&self.0, other); // mul + Self::from_unchecked_nplus2::(c) // Reduce and return the result } const fn const_is_valid(&self) -> bool { @@ -1036,330 +1047,52 @@ impl, const N: usize> Fp, N> { } } -/// Multiply a N-limb big integer with a u64, producing a N+1 limb result, -/// represented as a tuple of a u64 low limb and an array of N high limbs. -#[unroll_for_loops(8)] -#[inline(always)] -fn bigint_mul_by_u64(val: &[u64; N], other: u64) -> (u64, [u64; N]) { - let mut result_hi = [0u64; N]; - let mut carry: u64; // Start with carry = 0 - - // Calculate the full 128-bit product of the lowest limb - let prod_lo: u128 = (val[0] as u128) * (other as u128); - let result_lo = prod_lo as u64; // Lowest limb of the result - carry = (prod_lo >> 64) as u64; // Carry into the high part - - // Iterate through the remaining limbs of the input BigInt - for i in 1..N { - // Calculate the full 128-bit product of the current limb and the u64 multiplier - let prod_hi: u128 = (val[i] as u128) * (other as u128) + (carry as u128); - result_hi[i - 1] = prod_hi as u64; // Store in result_hi[0] to result_hi[N-2] - carry = (prod_hi >> 64) as u64; - } - - // After the loop, the final carry is the highest limb (N-th limb of the high part) - result_hi[N - 1] = carry; - - (result_lo, result_hi) -} - -/// Multiply a N+1 limb big integer (represented as low N limbs and high u64) with a u64, -/// producing a N+1 limb result in the same format. -/// Also returns a boolean indicating if there was a carry out (overflow). -#[unroll_for_loops(8)] -#[inline(always)] -fn bigint_plus_one_mul_by_u64( - val: ([u64; N], u64), - other: u64, -) -> (([u64; N], u64), bool) { - let (val_lo_n, val_hi) = val; - let mut result_lo_n = [0u64; N]; - let mut carry: u128 = 0; // Use u128 for intermediate carry - - // Stage 1: Multiply the low N limbs - for i in 0..N { - let prod: u128 = (val_lo_n[i] as u128) * (other as u128) + carry; - result_lo_n[i] = prod as u64; - carry = prod >> 64; - } - - // Stage 2: Multiply the high limb - let prod_hi: u128 = (val_hi as u128) * (other as u128) + carry; - let result_hi = prod_hi as u64; - let final_carry = prod_hi >> 64; - - // Final carry indicates overflow - let overflow = final_carry != 0; - ((result_lo_n, result_hi), overflow) -} - -/// Subtract two N+1 limb big integers (represented as low N limbs and high u64). -/// Returns the N+1 limb result and a boolean indicating if a borrow occurred. -#[unroll_for_loops(8)] -#[inline(always)] -fn sub_bigint_plus_one( - a: ([u64; N], u64), - b: ([u64; N], u64), -) -> (([u64; N], u64), bool) { - let (mut a_lo_n, mut a_hi) = a; - let (b_lo_n, b_hi) = b; - let mut borrow: u64 = 0; // sbb uses u64 for borrow - - // Subtract low N limbs - for i in 0..N { - // Updates a_lo_n[i] in place and returns the new borrow - borrow = fa::sbb(&mut a_lo_n[i], b_lo_n[i], borrow); - } - - // Subtract high u64 limb - borrow = fa::sbb(&mut a_hi, b_hi, borrow); - - // Final borrow indicates if the result is negative (b > a) - let final_borrow_occurred = borrow != 0; - - ((a_lo_n, a_hi), final_borrow_occurred) -} - -/// Subtract two N+1 limb big integers where `a` is (u64, [u64; N]) and `b` is ([u64; N], u64). -/// Returns the N+1 limb result as ([u64; N], u64) and a boolean indicating if a borrow occurred. -#[unroll_for_loops(8)] -#[inline(always)] -fn sub_bigint_plus_one_prime( - a: (u64, [u64; N]), // Format: (low_limb, high_n_limbs) - b: ([u64; N], u64), // Format: (low_n_limbs, high_limb) -) -> (([u64; N], u64), bool) { - let (a_lo, a_hi_n) = a; - let (b_lo_n, b_hi) = b; - let mut result_lo_n = [0u64; N]; - let mut borrow: u64 = 0; - - // Subtract low limb: result_lo_n[0] = a_lo - b_lo_n[0] - borrow (initial borrow = 0) - result_lo_n[0] = a_lo; // Initialize result limb with a_lo - borrow = fa::sbb(&mut result_lo_n[0], b_lo_n[0], borrow); // result_lo_n[0] -= b_lo_n[0] + borrow - - // Subtract middle limbs (if N > 1): result_lo_n[i] = a_hi_n[i-1] - b_lo_n[i] - borrow - // This loop covers indices i = 1 to N-1. - // It uses a_hi_n limbs from index 0 to N-2. - for i in 1..N { - result_lo_n[i] = a_hi_n[i - 1]; // Initialize result limb with corresponding a limb - borrow = fa::sbb(&mut result_lo_n[i], b_lo_n[i], borrow); // result_lo_n[i] -= b_lo_n[i] + borrow - } - - // Subtract high limb: result_hi = a_hi_n[N-1] - b_hi - borrow - let mut result_hi = a_hi_n[N - 1]; // Initialize result limb with last a limb - borrow = fa::sbb(&mut result_hi, b_hi, borrow); // result_hi -= b_hi + borrow - - let final_borrow_occurred = borrow != 0; - - ((result_lo_n, result_hi), final_borrow_occurred) -} - -/// Compare two N+1 limb big integers (represented as low N limbs and high u64). -#[unroll_for_loops(8)] -#[inline(always)] -fn compare_bigint_plus_one( - a: ([u64; N], u64), - b: ([u64; N], u64), -) -> core::cmp::Ordering { - // Compare high u64 limb first - if a.1 > b.1 { - return core::cmp::Ordering::Greater; - } else if a.1 < b.1 { - return core::cmp::Ordering::Less; - } - // High limbs are equal, compare the low N limbs from most significant (N-1) down to 0 - for i in (0..N).rev() { - if a.0[i] > b.0[i] { - return core::cmp::Ordering::Greater; - } else if a.0[i] < b.0[i] { - return core::cmp::Ordering::Less; - } - } - // All limbs are equal - return core::cmp::Ordering::Equal; -} - -/// Multiply a N-limb big integer with a u128, producing a N+2 limb result, -/// represented as a tuple of an array of 2 low limbs and an array of N high limbs. -#[unroll_for_loops(8)] #[inline(always)] -fn bigint_mul_by_u128(val: &BigInt, other: u128) -> ([u64; 2], [u64; N]) { - let other_lo = other as u64; - let other_hi = (other >> 64) as u64; - - // Compute partial products - // p1 = val * other_lo -> (N+1) limbs: (p1_lo: u64, p1_hi: [u64; N]) - let (p1_lo, p1_hi) = bigint_mul_by_u64(&val.0, other_lo); - // p2 = val * other_hi -> (N+1) limbs: (p2_lo: u64, p2_hi: [u64; N]) - let (p2_lo, p2_hi) = bigint_mul_by_u64(&val.0, other_hi); - - // Calculate the final result r = p1 + (p2 << 64) limb by limb. - // p1 : [p1_lo, p1_hi[0], ..., p1_hi[N-1]] - // p2 << 64 : [0, p2_lo, p2_hi[0], ..., p2_hi[N-1]] - // Sum (r) : [r_lo[0], r_lo[1], r_hi[0], ..., r_hi[N-1]] (N+2 limbs) - - let mut r_lo = [0u64; 2]; - let mut r_hi = [0u64; N]; - let mut carry: u64 = 0; - - // r_lo[0] = p1_lo + 0 + carry (carry is initially 0) - r_lo[0] = p1_lo; - // carry = 0; // Initial carry is 0 - - // Calculate r_lo[1] = p1_hi[0] + p2_lo + carry (limb 1) - r_lo[1] = p1_hi[0]; // Initialize with p1 limb - carry = fa::adc(&mut r_lo[1], p2_lo, carry); // Add p2 limb and carry - - // Calculate r_hi[0] to r_hi[N-1] (limbs 2 to N+1) - for i in 0..N { - let p1_limb = if i + 1 < N { p1_hi[i + 1] } else { 0 }; // Limb p1[i+2] - let p2_limb = p2_hi[i]; // Limb p2[i+1] - - // r_hi[i] = p1_limb + p2_limb + carry - r_hi[i] = p1_limb; // Initialize with p1 limb - carry = fa::adc(&mut r_hi[i], p2_limb, carry); // Add p2 limb and carry - } - - // The final carry MUST be zero for the result to fit in N+2 limbs. - debug_assert!(carry == 0, "Overflow in bigint_mul_by_u128"); - - (r_lo, r_hi) +fn nplus1_pair_high_to_bigint( + r_tmp: ([u64; N], u64), +) -> BigInt { + debug_assert!(NPLUS1 == N + 1); + let mut limbs = [0u64; NPLUS1]; + limbs[..N].copy_from_slice(&r_tmp.0); + limbs[N] = r_tmp.1; + BigInt::(limbs) } -/// Old conditional subtraction logic for Barrett reduction -/// Takes an N+1 limb intermediate result `r_tmp` and returns the N-limb final result. #[inline(always)] -fn _barrett_cond_subtract_old, const N: usize>( - r_tmp: ([u64; N], u64), -) -> [u64; N] { - let mut current_r = r_tmp; // Working variable in ([u64; N], u64) format - - if T::MODULUS_NUM_SPARE_BITS >= 1 { - // Case S >= 1 - if T::MODULUS_NUM_SPARE_BITS >= 2 { - // Optimization for S >= 2: r_tmp = c - m*2p < 4p - // High limb of current_r should initially be 0 - debug_assert!( - current_r.1 == 0, - "High limb of r_tmp should be zero when S >= 2" - ); - - // Conditional subtraction 1 (if r >= 2P) using N+1 compare/sub - if compare_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1) - != core::cmp::Ordering::Less - { - let (sub_res, sub_borrow) = - sub_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1); - debug_assert!(!sub_borrow, "Borrow should not occur subtracting 2P (S>=2)"); - current_r = sub_res; - } - // Conditional subtraction 2 (if r >= P) using N+1 compare/sub - if compare_bigint_plus_one(current_r, T::MODULUS_NPLUS1) != core::cmp::Ordering::Less { - let (sub_res, sub_borrow) = sub_bigint_plus_one(current_r, T::MODULUS_NPLUS1); - debug_assert!(!sub_borrow, "Borrow should not occur subtracting P (S>=2)"); - current_r = sub_res; - } - // Result must fit in N limbs now - debug_assert!( - current_r.1 == 0, - "High limb must be zero after final subtraction (S>=2)" - ); - current_r.0 // Return N low limbs - } else { - // Case S == 1: r_tmp = c - m*2p might temporarily exceed N limbs initially - - // Conditional subtraction 1: if r >= 2p - if compare_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1) - != core::cmp::Ordering::Less - { - // Subtract 2P using N+1 limb subtraction - let (sub_res, sub_borrow) = - sub_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1); - // After subtracting 2P, the result MUST fit in N limbs - debug_assert!( - sub_res.1 == 0, - "High limb must be 0 after 2P subtraction when S=1" - ); - debug_assert!( - !sub_borrow, - "Borrow should not occur when subtracting 2P for S=1" - ); - current_r = sub_res; // Update current_r (now guaranteed to have high limb 0) - } - // At this point, current_r < 2P and its high limb is 0. - debug_assert!( - current_r.1 == 0, - "High limb should be 0 before N-limb subtraction (S=1)" - ); - let mut r_n_limbs = BigInt::(current_r.0); // Extract N low limbs - - // Conditional subtraction 2 (if r >= P) using N limbs directly - if r_n_limbs >= T::MODULUS { - // Compare N limbs - r_n_limbs.sub_with_borrow(&T::MODULUS); // Subtract N limbs. Ignore borrow. - } - r_n_limbs.0 // Return N low limbs - } - } else { - // Case S == 0: Use (N+1)-limb helpers throughout - - // Conditional subtraction 1: if r >= 2p - if compare_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1) - != core::cmp::Ordering::Less - { - let (sub_res, sub_borrow) = sub_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1); - debug_assert!(!sub_borrow, "Borrow should not occur subtracting 2P (S=0)"); - current_r = sub_res; - } - // Now current_r = c mod 2p, represented as ([u64; N], u64) - - // Conditional subtraction 2: if r >= p - if compare_bigint_plus_one(current_r, T::MODULUS_NPLUS1) != core::cmp::Ordering::Less { - // if r >= p - let (sub_res, sub_borrow) = sub_bigint_plus_one(current_r, T::MODULUS_NPLUS1); - // Result MUST fit in N limbs now - debug_assert!( - sub_res.1 == 0, - "High limb must be zero after subtracting P (S=0)" - ); - debug_assert!( - !sub_borrow, - "Borrow should not occur when subtracting P for S=0" - ); - current_r = sub_res; - } - // At this point, current_r < P and its high limb must be 0. - debug_assert!( - current_r.1 == 0, - "High limb must be zero after final subtraction (S=0)" - ); - current_r.0 // Return N low limbs - } +fn nplus1_pair_low_to_bigint( + r_tmp: (u64, [u64; N]), +) -> BigInt { + debug_assert!(NPLUS1 == N + 1); + let mut limbs = [0u64; NPLUS1]; + limbs[0] = r_tmp.0; + limbs[1..NPLUS1].copy_from_slice(&r_tmp.1); + BigInt::(limbs) } /// Conditional subtraction logic for Barrett reduction, trading an extra comparison for a conditional subtraction. /// Includes optimizations based on MODULUS_NUM_SPARE_BITS. -/// Takes an N+1 limb intermediate result `r_tmp` (in `([u64; N], u64)` format) and returns the N-limb final result. +/// Takes an N+1 limb intermediate result `r_tmp` and returns the N-limb final result. #[unroll_for_loops(4)] #[inline(always)] -fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64)) -> [u64; N] { - let final_limbs: [u64; N]; - let r_n = BigInt::(r_tmp.0); // N low limbs as BigInt - let r_hi = r_tmp.1; // High limb - +fn barrett_cond_subtract, const N: usize, const NPLUS1: usize>( + r_tmp: BigInt, +) -> BigInt { + debug_assert!(NPLUS1 == N + 1); // Compare with 2p let compare_2p = if T::MODULUS_NUM_SPARE_BITS == 0 { // S = 0: Must use N+1 compare - compare_bigint_plus_one(r_tmp, T::MODULUS_TIMES_2_NPLUS1) + r_tmp.cmp(&nplus1_pair_high_to_bigint::( + T::MODULUS_TIMES_2_NPLUS1, + )) } else { // S >= 1: 2p fits N limbs (mostly). Compare N limbs. // We assume r_tmp's high limb is 0 here if S >= 1. debug_assert!( - r_hi == 0, + r_tmp.0[N] == 0, "High limb expected to be 0 if S >= 1 before 2p comparison" ); let p2_n = BigInt::(T::MODULUS_TIMES_2_NPLUS1.0); - r_n.cmp(&p2_n) + BigInt::(r_tmp.0[0..N].try_into().unwrap()).cmp(&p2_n) // Compare N limbs }; if compare_2p != core::cmp::Ordering::Less { @@ -1367,15 +1100,17 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 // Compare with 3p let compare_3p = if T::MODULUS_NUM_SPARE_BITS < 2 { // S < 2 (S=0 or S=1): Need N+1 compare - compare_bigint_plus_one(r_tmp, T::MODULUS_TIMES_3_NPLUS1) + r_tmp.cmp(&nplus1_pair_high_to_bigint::( + T::MODULUS_TIMES_3_NPLUS1, + )) } else { // S >= 2: 3p fits N limbs. Compare N limbs. debug_assert!( - r_hi == 0, + r_tmp.0[N] == 0, "High limb expected to be 0 if S >= 2 before 3p comparison" ); let p3_n = BigInt::(T::MODULUS_TIMES_3_NPLUS1.0); - r_n.cmp(&p3_n) + BigInt::(r_tmp.0[0..N].try_into().unwrap()).cmp(&p3_n) // Compare N limbs }; if compare_3p != core::cmp::Ordering::Less { @@ -1384,23 +1119,27 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 if T::MODULUS_NUM_SPARE_BITS >= 2 { // S >= 2: 3p fits N limbs. Use N-limb sub. debug_assert!( - r_hi == 0, + r_tmp.0[N] == 0, "High limb expected to be 0 if S >= 2 for 3p subtraction" ); let p3_n = BigInt::(T::MODULUS_TIMES_3_NPLUS1.0); + let r_n = BigInt::(r_tmp.0[0..N].try_into().unwrap()); + // Subtract 3p from r_n + // Use const_sub_with_borrow to avoid borrow checking issues + // This is safe because we know r_n >= 3p from the comparison above. let (res_n, borrow_n) = r_n.const_sub_with_borrow(&p3_n); debug_assert!(!borrow_n, "Borrow should not occur subtracting 3p (S>=2)"); - final_limbs = res_n.0; + return res_n; // Return the N-limb result directly } else { // S < 2: Use N+1 limb sub. - let ((res_n_limbs, res_hi_limb), borrow_n1) = - sub_bigint_plus_one(r_tmp, T::MODULUS_TIMES_3_NPLUS1); - debug_assert!(!borrow_n1, "Borrow should not occur subtracting 3p (S<2)"); + let p3_n1 = nplus1_pair_high_to_bigint::(T::MODULUS_TIMES_3_NPLUS1); + let (res_n1, borrow) = r_tmp.const_sub_with_borrow(&p3_n1); + debug_assert!(!borrow, "Borrow should not occur subtracting 3p (S<2)"); debug_assert!( - res_hi_limb == 0, + res_n1.0[N] == 0, "High limb must be zero after subtracting 3p" ); - final_limbs = res_n_limbs; + return BigInt::(res_n1.0[0..N].try_into().unwrap()); } } else { // 2p <= r_tmp < 3p @@ -1408,23 +1147,24 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 if T::MODULUS_NUM_SPARE_BITS >= 1 { // S >= 1: 2p fits N limbs (mostly). Use N-limb sub. debug_assert!( - r_hi == 0, + r_tmp.0[N] == 0, "High limb expected to be 0 if S >= 1 for 2p subtraction" ); let p2_n = BigInt::(T::MODULUS_TIMES_2_NPLUS1.0); + let r_n = BigInt::(r_tmp.0[0..N].try_into().unwrap()); let (res_n, borrow_n) = r_n.const_sub_with_borrow(&p2_n); debug_assert!(!borrow_n, "Borrow should not occur subtracting 2p (S>=1)"); - final_limbs = res_n.0; + return res_n; // Return the N-limb result directly } else { // S == 0: Use N+1 limb sub. - let ((res_n_limbs, res_hi_limb), borrow_n1) = - sub_bigint_plus_one(r_tmp, T::MODULUS_TIMES_2_NPLUS1); - debug_assert!(!borrow_n1, "Borrow should not occur subtracting 2p (S=0)"); + let p2_n1 = nplus1_pair_high_to_bigint::(T::MODULUS_TIMES_2_NPLUS1); + let (res_n1, borrow) = r_tmp.const_sub_with_borrow(&p2_n1); + debug_assert!(!borrow, "Borrow should not occur subtracting 2p (S=0)"); debug_assert!( - res_hi_limb == 0, + res_n1.0[N] == 0, "High limb must be zero after subtracting 2p" ); - final_limbs = res_n_limbs; + return BigInt::(res_n1.0[0..N].try_into().unwrap()); } } } else { @@ -1433,11 +1173,15 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 let compare_p = if T::MODULUS_NUM_SPARE_BITS >= 1 { // S >= 1: Use N-limb compare. // Assume r_tmp high limb is 0 because r_tmp < 2p and 2p fits N limbs (mostly) if S >= 1 - debug_assert!(r_hi == 0, "High limb expected to be 0 if S >= 1 and r < 2p"); - r_n.cmp(&T::MODULUS) // Compare N limbs + debug_assert!( + r_tmp.0[N] == 0, + "High limb expected to be 0 if S >= 1 before p comparison" + ); + let p_n = BigInt::(T::MODULUS.0); + BigInt::(r_tmp.0[0..N].try_into().unwrap()).cmp(&p_n) // Compare N limbs } else { // S == 0: Use N+1 limb compare. - compare_bigint_plus_one(r_tmp, T::MODULUS_NPLUS1) + r_tmp.cmp(&nplus1_pair_high_to_bigint::(T::MODULUS_NPLUS1)) }; if compare_p != core::cmp::Ordering::Less { @@ -1445,33 +1189,30 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 // Subtract p if T::MODULUS_NUM_SPARE_BITS >= 1 { // S >= 1: Use N-limb sub. - debug_assert!( - r_hi == 0, - "High limb expected to be 0 if S >= 1 for p subtraction" - ); - let (res_n, borrow_n) = r_n.const_sub_with_borrow(&T::MODULUS); + let p_n = BigInt::(T::MODULUS.0); + let r_n = BigInt::(r_tmp.0[0..N].try_into().unwrap()); + let (res_n, borrow_n) = r_n.const_sub_with_borrow(&p_n); debug_assert!(!borrow_n, "Borrow should not occur subtracting p (S>=1)"); - final_limbs = res_n.0; + return res_n; // Return the N-limb result directly } else { // S == 0: Use N+1 limb sub. - let ((res_n_limbs, res_hi_limb), borrow_n1) = - sub_bigint_plus_one(r_tmp, T::MODULUS_NPLUS1); - debug_assert!(!borrow_n1, "Borrow should not occur subtracting p (S=0)"); + let p_n1 = nplus1_pair_high_to_bigint::(T::MODULUS_NPLUS1); + let (res_n1, borrow) = r_tmp.const_sub_with_borrow(&p_n1); + debug_assert!(!borrow, "Borrow should not occur subtracting p (S=0)"); debug_assert!( - res_hi_limb == 0, + res_n1.0[N] == 0, "High limb must be zero after subtracting p" ); - final_limbs = res_n_limbs; + return BigInt::(res_n1.0[0..N].try_into().unwrap()); } } else { // r_tmp < p // Subtract 0 (No-op) // Result must already fit in N limbs. Assert high limb is 0. - debug_assert!(r_hi == 0, "High limb must be zero when r_tmp < p"); - final_limbs = r_n.0; // Use the low N limbs directly + debug_assert!(r_tmp.0[N] == 0, "High limb must be zero when r_tmp < p"); + return BigInt::(r_tmp.0[0..N].try_into().unwrap()); } } - final_limbs } /// Helper function to perform Barrett reduction from N+1 limbs to N limbs. @@ -1480,38 +1221,38 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 /// Output is the N-limb result `[u64; N]`. #[unroll_for_loops(4)] #[inline(always)] -fn barrett_reduce_nplus1_to_n, const N: usize>(c: (u64, [u64; N])) -> [u64; N] { - let (c_lo, c_hi) = c; // c_lo is the lowest limb, c_hi holds the top N limbs - +fn barrett_reduce_nplus1_to_n, const N: usize, const NPLUS1: usize>( + c: BigInt, +) -> BigInt { + debug_assert!(NPLUS1 == N + 1, "NPLUS1 must be N + 1 for this function"); // Compute tilde_c = floor(c / R') = floor(c / 2^MODULUS_BITS) // This involves the top two limbs of the N+1 limb number `c`. - // The highest limb is c_hi[N-1]. The second highest is c_hi[N-2]. // Assume that `N >= 1` let tilde_c: u64 = if T::MODULUS_HAS_SPARE_BIT { - let high_limb = c_hi[N - 1]; - let second_high_limb = if N > 1 { c_hi[N - 2] } else { c_lo }; // Use c_lo if N=1 + let high_limb = c.0[N]; + let second_high_limb = c.0[N - 1]; // N is at least 1, so this is safe (high_limb << T::MODULUS_NUM_SPARE_BITS) + (second_high_limb >> (64 - T::MODULUS_NUM_SPARE_BITS)) } else { - c_hi[N - 1] // If no spare bits, tilde_c is just the highest limb + c.0[N] // If no spare bits, tilde_c is just the highest limb }; // Estimate m = floor( (tilde_c * BARRETT_MU) / r ) // where r = 2^64 let m: u64 = ((tilde_c as u128 * T::BARRETT_MU as u128) >> 64) as u64; + // unroll T::MODULUS_TIMES_2_NPLUS1 from ([u64; N], u64) to BigInt + let mut m2p = nplus1_pair_high_to_bigint::(T::MODULUS_TIMES_2_NPLUS1); // Compute m * 2p (N+1 limbs) - let (m_times_2p, m2p_overflow) = bigint_plus_one_mul_by_u64::(T::MODULUS_TIMES_2_NPLUS1, m); - // If m * 2p overflows N+1 limbs, the logic might be flawed or input c was too large. - debug_assert!(!m2p_overflow, "Overflow calculating m * 2p"); + BigInt::mul_u64_in_place(&mut m2p, m); - // Compute r_tmp = c - m * 2p (result is ([u64; N], u64)) - let (r_tmp, r_tmp_borrow) = sub_bigint_plus_one_prime(c, m_times_2p); + // Compute r_tmp = c - m * 2p // A borrow here implies c was smaller than m*2p, which shouldn't happen with correct m. - debug_assert!(!r_tmp_borrow, "Borrow occurred calculating c - m*2p"); + let (r_tmp, borrow) = c.const_sub_with_borrow(&m2p); + debug_assert!(!borrow, "Borrow should not occur in Barrett reduction"); - // Use the optimized conditional subtraction which expects ([u64; N], u64) - barrett_cond_subtract::(r_tmp) + // Use the optimized conditional subtraction to go from N+1 limbs to N limbs. + barrett_cond_subtract::(r_tmp) } #[cfg(test)] @@ -1521,6 +1262,10 @@ mod test { use ark_test_curves::bn254::Fr; use num_bigint::{BigInt, BigUint, Sign}; use rand::Rng; + // constants for the number of limbs in bn254 + const N: usize = 4; + const NPLUS1: usize = N + 1; + const NPLUS2: usize = N + 2; #[test] fn test_mul_u64_random() { @@ -1534,7 +1279,7 @@ mod test { let expected_c = a * Fr::from(b_bigint); // Actual result using the function under test - let result_c = a.mul_u64(b_val); + let result_c = a.mul_u64::(b_val); assert_eq!( result_c, @@ -1562,7 +1307,7 @@ mod test { }; // Actual result using the function under test - let result_c = a.mul_i64(b_val); + let result_c = a.mul_i64::(b_val); assert_eq!( result_c, expected_c, @@ -1584,7 +1329,7 @@ mod test { let expected_c = a * Fr::from(b_bigint); // Actual result using the function under test - let result_c = a.mul_u128(b_val); + let result_c = a.mul_u128::(b_val); assert_eq!( result_c, expected_c, @@ -1612,7 +1357,7 @@ mod test { }; // Actual result using the function under test - let result_c = a.mul_i128(b_val); + let result_c = a.mul_i128::(b_val); assert_eq!( result_c, expected_c, diff --git a/ff/src/fields/prime.rs b/ff/src/fields/prime.rs index 28b896e59..67a7f0db0 100644 --- a/ff/src/fields/prime.rs +++ b/ff/src/fields/prime.rs @@ -57,9 +57,9 @@ pub trait PrimeField: /// Converts an element of the prime field into an integer in the range 0..(p - 1). fn into_bigint(self) -> Self::BigInt; - /// Creates a field element from a `u64`. + /// Creates a field element from a `u64`. /// Returns `None` if the `u64` is larger than or equal to the modulus. - fn from_u64(val: u64) -> Option; + fn from_u64(val: u64) -> Option; /// Reads bytes in big-endian, and converts them to a field element. /// If the integer represented by `bytes` is larger than the modulus `p`, this method From d91c33ff1e3149d6656fbe056a60e8c06fbd396d Mon Sep 17 00:00:00 2001 From: Zachary DeStefano Date: Thu, 14 Aug 2025 17:25:11 -0400 Subject: [PATCH 02/38] revert extra bn254 change --- curves/bn254/src/fields/fq.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/curves/bn254/src/fields/fq.rs b/curves/bn254/src/fields/fq.rs index 79cf6ba89..26deabf79 100644 --- a/curves/bn254/src/fields/fq.rs +++ b/curves/bn254/src/fields/fq.rs @@ -4,4 +4,4 @@ use ark_ff::fields::{Fp256, MontBackend, MontConfig}; #[modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] #[generator = "3"] pub struct FqConfig; -pub type Fq = Fp256>; +pub type Fq = Fp256>; From 8669d83c69dca5e3ceec3565ef97c9166924e2f6 Mon Sep 17 00:00:00 2001 From: Zachary DeStefano Date: Thu, 14 Aug 2025 17:25:45 -0400 Subject: [PATCH 03/38] add fma for multiplication by u64 --- ff/src/biginteger/mod.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 322a1c092..98ff7b841 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -501,6 +501,32 @@ impl BigInteger for BigInt { BigInt::(res) } + #[inline] + #[unroll_for_loops(8)] + fn fmu64a(&self, other: u64, acc: &mut BigInt) { + // ensure NPLUS1 is the correct size + debug_assert!(NPLUS1 == N + 1); + // special cases for 0 and 1 + if other == 0 || self.is_zero() { + // idempotent + return; + } else if other == 1 { + // just addition + let mut carry = 0; + for i in 0..N { + carry = arithmetic::adc_for_add_with_carry(&mut acc.0[i], self.0[i], carry); + } + acc.0[N] += carry as u64; + return; + } + // otherwise fma + let mut carry = 0; + for i in 0..N { + acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], other, &mut carry); + } + acc.0[N] += carry as u64; + } + #[inline] #[unroll_for_loops(8)] fn mul_u128_w_carry( @@ -1213,6 +1239,10 @@ pub trait BigInteger: /// NEW! Multiplies self by a u64, returning a bigint with one extra limb to hold overflow. fn mul_u64_w_carry(&self, other: u64) -> BigInt; + /// NEW! Multiplies self by a u64, accumulating the result in `acc`, which must have one extra limb. + /// overflow causes a wraparound in the highest limb of the accumulator. + fn fmu64a(&self, other: u64, acc: &mut BigInt); + /// NEW! Multiplies self by a u128, returning a bigint with two extra limbs to hold overflow. fn mul_u128_w_carry( &self, From f72f76c0b8eb464478bfbee2d005e112fcb5b3cb Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Thu, 14 Aug 2025 21:28:11 -0600 Subject: [PATCH 04/38] fix bench --- test-curves/benches/small_mul.rs | 14 ++++++++------ test-templates/src/msm.rs | 10 +++++----- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/test-curves/benches/small_mul.rs b/test-curves/benches/small_mul.rs index d79b9c530..abdcbb254 100644 --- a/test-curves/benches/small_mul.rs +++ b/test-curves/benches/small_mul.rs @@ -46,7 +46,8 @@ fn mul_small_bench(c: &mut Criterion) { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_u64(b_u64_s[i])) + // bn254 Fr has N=4 limbs => N+1 = 5 + criterion::black_box(a_s[i].mul_u64::<5>(b_u64_s[i])) }) }); @@ -54,7 +55,7 @@ fn mul_small_bench(c: &mut Criterion) { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_i64(b_i64_s[i])) + criterion::black_box(a_s[i].mul_i64::<5>(b_i64_s[i])) }) }); @@ -64,7 +65,8 @@ fn mul_small_bench(c: &mut Criterion) { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_u128(b_u128_s[i])) + // bn254 Fr has N=4 limbs => N+1 = 5, N+2 = 6 + criterion::black_box(a_s[i].mul_u128::<5, 6>(b_u128_s[i])) }) }); @@ -72,7 +74,7 @@ fn mul_small_bench(c: &mut Criterion) { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_i128(b_i128_s[i])) + criterion::black_box(a_s[i].mul_i128::<5, 6>(b_i128_s[i])) }) }); @@ -90,7 +92,7 @@ fn mul_small_bench(c: &mut Criterion) { bench.iter(|| { i = (i + 1) % SAMPLES; // Call mul_u128 but provide a u64 input cast to u128 - criterion::black_box(a_s[i].mul_u128(b_u64_as_u128_s[i])) + criterion::black_box(a_s[i].mul_u128::<5, 6>(b_u64_as_u128_s[i])) }) }); @@ -102,7 +104,7 @@ fn mul_small_bench(c: &mut Criterion) { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_u128_aux(b_u128_s[i])) + criterion::black_box(a_s[i].mul_u128_aux::<5, 6>(b_u128_s[i])) }) }); diff --git a/test-templates/src/msm.rs b/test-templates/src/msm.rs index 5db14ced6..2b091ed38 100644 --- a/test-templates/src/msm.rs +++ b/test-templates/src/msm.rs @@ -81,31 +81,31 @@ pub fn test_var_base_msm_specialized() { let v = (0..SAMPLES).map(|_| bool::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u1(g.as_slice(), v.as_slice()); + let fast = G::msm_u1(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); let v = (0..SAMPLES).map(|_| u8::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u8(g.as_slice(), v.as_slice()); + let fast = G::msm_u8(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); let v = (0..SAMPLES).map(|_| u16::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u16(g.as_slice(), v.as_slice()); + let fast = G::msm_u16(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); let v = (0..SAMPLES).map(|_| u32::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u32(g.as_slice(), v.as_slice()); + let fast = G::msm_u32(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); let v = (0..SAMPLES).map(|_| u64::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u64(g.as_slice(), v.as_slice()); + let fast = G::msm_u64(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); } From 4786b216e08ae85f61f9fdd7ce0429358dcc12be Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Thu, 14 Aug 2025 21:42:52 -0600 Subject: [PATCH 05/38] use mac_with_carry rather than hand-written code --- ff/src/biginteger/mod.rs | 165 +++++++++++++++++++++++---------------- 1 file changed, 97 insertions(+), 68 deletions(-) diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 98ff7b841..1ae1eb754 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -285,17 +285,8 @@ impl BigInt { /// leading zeros in the most significant limb. #[doc(hidden)] pub const fn num_spare_bits(self) -> u32 { - // Count the leading zeros in the most significant limb - let msb = self.0[N - 1]; - let mut count = 0; - let mut mask = 1u64 << 63; // Start with the highest bit - - while count < 64 && (msb & mask) == 0 { - count += 1; - mask >>= 1; - } - - count + // Fast path: directly use the intrinsic on the most significant limb + self.0[N - 1].leading_zeros() } #[inline] @@ -447,23 +438,19 @@ impl BigInteger for BigInt { #[unroll_for_loops(8)] fn mul_u64_in_place(&mut self, other: u64) { // special cases for 0 and 1 - if other == 0 || self.is_zero() { - *self = Self::zero(); - return; - } else if other == 1 { - return; - } - // Calculate the full 128-bit product of the lowest limb - let mut prod: u128 = (self.0[0] as u128) * (other as u128); - self.0[0] = prod as u64; - let mut carry = (prod >> 64) as u64; - // iterate through the remaining limbs - for i in 1..N { - // Calculate the full 128-bit product of the current limb and the u64 multiplier - prod = (self.0[i] as u128) * (other as u128) + (carry as u128); - self.0[i] = prod as u64; - carry = (prod >> 64) as u64; + // if other == 0 || self.is_zero() { + // *self = Self::zero(); + // return; + // } else if other == 1 { + // return; + // } + // Use the same low-level multiply-accumulate primitive that already + // benefits from x86 optimizations in this crate. + let mut carry = 0u64; + for i in 0..N { + self.0[i] = mac_with_carry!(0u64, self.0[i], other, &mut carry); } + // Overflow is ignored by contract; assert in debug to catch misuse. debug_assert!(carry == 0, "Overflow in BigInt::mul_u64_in_place"); } @@ -473,32 +460,23 @@ impl BigInteger for BigInt { // ensure NPLUS1 is the correct size debug_assert!(NPLUS1 == N + 1); // special cases for 0 and 1 - if other == 0 || self.is_zero() { - return BigInt::::zero(); - } else if other == 1 { - let mut res = BigInt::::zero(); - for i in 0..N { - res.0[i] = self.0[i]; - } - return res; - } - // initialize result - let mut res: [u64; NPLUS1] = [0u64; NPLUS1]; - // Calculate the full 128-bit product of the lowest limb - let mut prod: u128 = (self.0[0] as u128) * (other as u128); - res[0] = prod as u64; - let mut carry = (prod >> 64) as u64; - // iterate through the remaining limbs - for i in 1..N { - // Calculate the full 128-bit product of the current limb and the u64 multiplier - prod = (self.0[i] as u128) * (other as u128) + (carry as u128); - res[i] = prod as u64; - carry = (prod >> 64) as u64; + // if other == 0 || self.is_zero() { + // return BigInt::::zero(); + // } else if other == 1 { + // let mut res = BigInt::::zero(); + // for i in 0..N { + // res.0[i] = self.0[i]; + // } + // return res; + // } + // Use the same multiply-accumulate primitive and capture the final carry + let mut res = BigInt::::zero(); + let mut carry = 0u64; + for i in 0..N { + res.0[i] = mac_with_carry!(0u64, self.0[i], other, &mut carry); } - // add final carry - res[N] = carry; - // and return - BigInt::(res) + res.0[N] = carry; + res } #[inline] @@ -516,7 +494,7 @@ impl BigInteger for BigInt { for i in 0..N { carry = arithmetic::adc_for_add_with_carry(&mut acc.0[i], self.0[i], carry); } - acc.0[N] += carry as u64; + acc.0[N] = acc.0[N].wrapping_add(carry as u64); return; } // otherwise fma @@ -524,7 +502,50 @@ impl BigInteger for BigInt { for i in 0..N { acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], other, &mut carry); } - acc.0[N] += carry as u64; + acc.0[N] = acc.0[N].wrapping_add(carry as u64); + } + + #[inline] + #[unroll_for_loops(8)] + fn fm128a(&self, other: u128, acc: &mut BigInt) { + // ensure NPLUS2 is the correct size (N + 2 limbs) + debug_assert!(NPLUS2 == N + 2); + // special cases for 0 and 1 + // if other == 0 || self.is_zero() { + // // idempotent + // return; + // } else if other == 1 { + // // just addition into lower N limbs; propagate final carry into acc[N] + // let mut carry = 0; + // for i in 0..N { + // carry = arithmetic::adc_for_add_with_carry(&mut acc.0[i], self.0[i], carry); + // } + // // carry is at most 1; fold into limb N (wrapping into highest limb if needed later) + // acc.0[N] = acc.0[N].wrapping_add(carry as u64); + // return; + // } + + let other_lo = other as u64; + let other_hi = (other >> 64) as u64; + + // Accumulate self * other_lo into acc[0..=N] + let mut carry = 0u64; + for i in 0..N { + acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], other_lo, &mut carry); + } + // Add final carry into limb N, propagating into highest limb if it overflows + let (new_n, of1) = acc.0[N].overflowing_add(carry); + acc.0[N] = new_n; + if of1 { + acc.0[N + 1] = acc.0[N + 1].wrapping_add(1); + } + + // Accumulate self * other_hi into acc[1..=N+1] + let mut carry2 = 0u64; + for i in 0..N { + acc.0[i + 1] = mac_with_carry!(acc.0[i + 1], self.0[i], other_hi, &mut carry2); + } + acc.0[N + 1] = acc.0[N + 1].wrapping_add(carry2); } #[inline] @@ -546,23 +567,26 @@ impl BigInteger for BigInt { } return res; } - // split other into two u64s + // Split other into two u64s and accumulate directly into the result buffer. let other_lo = other as u64; let other_hi = (other >> 64) as u64; - // two u64 multiplications with carry - let lo_part = self.mul_u64_w_carry::(other_lo); - let hi_part = self.mul_u64_w_carry::(other_hi); - // pad lo_part right by one limb (extra high zero limb) - // pad hi_part left by one limb (i.e. multiply by 2^64) - let mut lo_padded = BigInt::::zero(); - let mut hi_padded = BigInt::::zero(); - for i in 0..NPLUS1 { - lo_padded.0[i] = lo_part.0[i]; - hi_padded.0[i + 1] = hi_part.0[i]; + + let mut res = BigInt::::zero(); + + // First pass: res[i] += self[i] * other_lo + let mut carry = 0u64; + for i in 0..N { + res.0[i] = mac_with_carry!(res.0[i], self.0[i], other_lo, &mut carry); } - // add the two padded parts - let (res, carry) = lo_padded.const_add_with_carry(&hi_padded); - debug_assert!(carry == false, "Overflow in BigInt::mul_u128_w_carry"); + res.0[N] = carry; + + // Second pass: res[i+1] += self[i] * other_hi + let mut carry2 = 0u64; + for i in 0..N { + res.0[i + 1] = mac_with_carry!(res.0[i + 1], self.0[i], other_hi, &mut carry2); + } + res.0[N + 1] = carry2; + res } @@ -1249,6 +1273,11 @@ pub trait BigInteger: other: u128, ) -> BigInt; + /// NEW! Fused multiply-accumulate with a u128 multiplier. + /// Accumulate self * other into `acc`, which must have two extra limbs. + /// Overflow causes wraparound in the highest limb of the accumulator. + fn fm128a(&self, other: u128, acc: &mut BigInt); + /// Multiplies this [`BigInteger`] by another `BigInteger`, storing the result in `self`. /// Overflow is ignored. /// From 3a31318c2a58e48f12d1179158f862814439b323 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Thu, 14 Aug 2025 21:43:03 -0600 Subject: [PATCH 06/38] added test for new bigint mul --- ff/src/biginteger/tests.rs | 208 +++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 4b3fa54b3..4e0dd0a0e 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -279,3 +279,211 @@ fn test_biginteger832() { use crate::biginteger::BigInteger832 as B; test_biginteger(B::new([u64::MAX; 13]), B::new([0u64; 13])); } + +// Tests for NEW functions +use crate::biginteger::BigInteger256; + +#[test] +fn test_mul_u64_in_place() { + let mut a = BigInteger256::from(0x123456789ABCDEFu64); + let b = 0x987654321u64; + + // Test against reference implementation + let expected = BigUint::from(0x123456789ABCDEFu64) * BigUint::from(b); + a.mul_u64_in_place(b); + assert_eq!(BigUint::from(a), expected); + + // Test zero multiplication + let mut zero = BigInteger256::zero(); + zero.mul_u64_in_place(12345); + assert!(zero.is_zero()); + + // Test multiplication by zero + let mut a = BigInteger256::from(12345u64); + a.mul_u64_in_place(0); + assert!(a.is_zero()); + + // Test multiplication by one + let orig = BigInteger256::from(0xDEADBEEFu64); + let mut a = orig; + a.mul_u64_in_place(1); + assert_eq!(a, orig); +} + +#[test] +fn test_mul_u64_w_carry() { + let a = BigInteger256::from(u64::MAX); + let b = u64::MAX; + + // Test against reference implementation + let expected = BigUint::from(u64::MAX) * BigUint::from(u64::MAX); + let result = a.mul_u64_w_carry::<5>(b); + assert_eq!(BigUint::from(result), expected); + + // Test with small numbers + let a = BigInteger256::from(12345u64); + let b = 67890u64; + let expected = BigUint::from(12345u64) * BigUint::from(67890u64); + let result = a.mul_u64_w_carry::<5>(b); + assert_eq!(BigUint::from(result), expected); + + // Test zero cases + let zero = BigInteger256::zero(); + let result = zero.mul_u64_w_carry::<5>(12345); + assert!(result.is_zero()); + + let a = BigInteger256::from(12345u64); + let result = a.mul_u64_w_carry::<5>(0); + assert!(result.is_zero()); + + // Test multiplication by one + let a = BigInteger256::from(0xDEADBEEFu64); + let result = a.mul_u64_w_carry::<5>(1); + let expected_bytes = a.to_bytes_le(); + let result_bytes = result.to_bytes_le(); + assert_eq!(&result_bytes[..expected_bytes.len()], &expected_bytes[..]); +} + +#[test] +fn test_fmu64a() { + let a = BigInteger256::from(12345u64); + let b = 67890u64; + let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); + + // Perform fused multiply-accumulate + a.fmu64a(b, &mut acc); + + // Compare against separate multiply and add + let expected_mul = BigUint::from(12345u64) * BigUint::from(67890u64); + let expected_total = expected_mul + BigUint::from(11111u64); + assert_eq!(BigUint::from(acc), expected_total); + + // Test zero cases + let zero = BigInteger256::zero(); + let mut acc = BigInteger256::from(12345u64).mul_u64_w_carry::<5>(1); + let acc_copy = acc; + zero.fmu64a(67890, &mut acc); + assert_eq!(acc, acc_copy); // Should be unchanged + + // Test multiplication by zero + let a = BigInteger256::from(12345u64); + let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); + let acc_copy = acc; + a.fmu64a(0, &mut acc); + assert_eq!(acc, acc_copy); // Should be unchanged + + // Test multiplication by one (should be just addition) + let a = BigInteger256::from(12345u64); + let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); + a.fmu64a(1, &mut acc); + let expected = BigUint::from(12345u64) + BigUint::from(11111u64); + assert_eq!(BigUint::from(acc), expected); +} + +#[test] +fn test_mul_u128_w_carry() { + let a = BigInteger256::from(0x123456789ABCDEFu64); + let b = 0x987654321DEADBEEFu128; + + // Test against reference implementation + let expected = BigUint::from(0x123456789ABCDEFu64) * BigUint::from(0x987654321DEADBEEFu128); + let result = a.mul_u128_w_carry::<5, 6>(b); + assert_eq!(BigUint::from(result), expected); + + // Test with u64 value (should be same as mul_u64_w_carry) + let b_u64 = 0x987654321u64; + let result_u128 = a.mul_u128_w_carry::<5, 6>(b_u64 as u128); + let result_u64 = a.mul_u64_w_carry::<5>(b_u64); + + // Compare first 5 limbs (u64 result size) + for i in 0..5 { + assert_eq!(result_u128.0[i], result_u64.0[i]); + } + assert_eq!(result_u128.0[5], 0); // Extra limb should be zero + + // Test zero cases + let zero = BigInteger256::zero(); + let result = zero.mul_u128_w_carry::<5, 6>(12345); + assert!(result.is_zero()); + + let a = BigInteger256::from(12345u64); + let result = a.mul_u128_w_carry::<5, 6>(0); + assert!(result.is_zero()); + + // Test multiplication by one + let a = BigInteger256::from(0xDEADBEEFu64); + let result = a.mul_u128_w_carry::<5, 6>(1); + let expected_bytes = a.to_bytes_le(); + let result_bytes = result.to_bytes_le(); + assert_eq!(&result_bytes[..expected_bytes.len()], &expected_bytes[..]); +} + +#[test] +fn test_fm128a_basic_and_edges() { + use crate::biginteger::BigInteger256 as B; + // Basic reference check against BigUint + let a = B::from(0x123456789ABCDEFu64); + let b = 0x987654321DEADBEEFu128; + let mut acc = B::zero().mul_u128_w_carry::<5, 6>(1); // zero-extended accumulator (6 limbs) + a.fm128a::<6>(b, &mut acc); + let expected = num_bigint::BigUint::from(0x123456789ABCDEFu64) + * num_bigint::BigUint::from(0x987654321DEADBEEFu128); + assert_eq!(num_bigint::BigUint::from(acc), expected); + + // Zero multiplier: no change + let a = B::from(12345u64); + let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); + let acc_copy = acc; + a.fm128a::<6>(0, &mut acc); + assert_eq!(acc, acc_copy); + + // One multiplier: reduces to addition + let a = B::from(12345u64); + let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); + a.fm128a::<6>(1, &mut acc); + let expected = num_bigint::BigUint::from(12345u64) + num_bigint::BigUint::from(11111u64); + assert_eq!(num_bigint::BigUint::from(acc), expected); + + // Overflow propagation from limb N into highest limb + let a = B::new([u64::MAX; 4]); + let mut acc = B::zero().mul_u128_w_carry::<5, 6>(1); + // Pre-fill limb N to force overflow when adding the final carry from low pass + acc.0[4] = u64::MAX; // limb N + acc.0[5] = 0; // highest limb + // cause carry=1 from low pass (a * 2) + a.fm128a::<6>(2, &mut acc); + // Expect highest limb incremented by 1 due to overflow from limb N + assert_eq!(acc.0[5], 1); +} + +#[test] +fn test_overflow_behavior_fmu64a() { + // Test that overflow in the highest limb wraps around as documented + let a = BigInteger256::new([u64::MAX; 4]); + let mut acc = BigInteger256::new([0, 0, 0, 0]).mul_u64_w_carry::<5>(1); + acc.0[4] = u64::MAX; // Set highest limb to max + + // This should cause overflow in the highest limb + a.fmu64a(2, &mut acc); + + // The overflow should wrap around + // u64::MAX * 2 = 2^65 - 2, which when added to u64::MAX = 2^65 + u64::MAX - 2 + // This wraps to u64::MAX - 2 with a carry of 1 that itself wraps + assert_eq!(acc.0[4], u64::MAX.wrapping_add(1)); // Wrapped result +} + +#[test] +fn test_edge_cases_large_numbers() { + // Test with maximum values + let max_bi = BigInteger256::new([u64::MAX; 4]); + + // mul_u64_w_carry with max values + let result = max_bi.mul_u64_w_carry::<5>(u64::MAX); + let expected = BigUint::from(max_bi) * BigUint::from(u64::MAX); + assert_eq!(BigUint::from(result), expected); + + // mul_u128_w_carry with max values + let result = max_bi.mul_u128_w_carry::<5, 6>(u128::MAX); + let expected = BigUint::from(max_bi) * BigUint::from(u128::MAX); + assert_eq!(BigUint::from(result), expected); +} From 64af2cb520968b3773c4c2d3b8e7e4a0492a82c5 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Thu, 14 Aug 2025 21:43:12 -0600 Subject: [PATCH 07/38] added uncheck nplus3 conversion --- ff/src/fields/models/fp/montgomery_backend.rs | 67 ++++++++++++++++++- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index 299a7b8fb..b552b4c2d 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -868,6 +868,29 @@ impl, const N: usize> Fp, N> { Self::new_unchecked(r2) } + /// Construct a new field element from a BigInt which is in + /// Montgomery form and should be reduced via two Barrett rounds then a final combine. + #[inline] + pub fn from_unchecked_nplus3( + element: BigInt<{ NPLUS3 }>, + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!(NPLUS2 == N + 2); + debug_assert!(NPLUS3 == N + 3); + + // Reduce the upper N+2 limbs of `element` to N limbs + let c_hi = BigInt::(element.0[1..NPLUS3].try_into().unwrap()); + let c_hi_hi = BigInt::(c_hi.0[1..NPLUS2].try_into().unwrap()); + let r1 = barrett_reduce_nplus1_to_n::(c_hi_hi); + let c_hi_merged = nplus1_pair_low_to_bigint::((c_hi.0[0], r1.0)); + let r_hi = barrett_reduce_nplus1_to_n::(c_hi_merged); + + // Combine the original lowest limb with r_hi and perform final Barrett reduction + let c_final = nplus1_pair_low_to_bigint::((element.0[0], r_hi.0)); + let r_final = barrett_reduce_nplus1_to_n::(c_final); + Self::new_unchecked(r_final) + } + const fn const_is_zero(&self) -> bool { self.0.const_is_zero() } @@ -945,6 +968,45 @@ impl, const N: usize> Fp, N> { } } + /// Montgomery reduction for 2N-limb inputs (standard Montgomery reduction) + /// Takes a 2N-limb BigInt that represents a product in "unreduced" form + /// and reduces it to N limbs in Montgomery form. + #[inline(always)] + pub fn montgomery_reduce_2n(input: BigInt) -> Self { + debug_assert!(TWON == 2 * N); + // Work in-place over the owned 2N-limb buffer + let mut limbs = input.0; + let (lo, hi) = limbs.split_at_mut(N); + + // Montgomery reduction - mirrors mul_without_cond_subtract + let mut carry2 = 0u64; + for i in 0..N { + let tmp = lo[i].wrapping_mul(T::INV); + let mut carry = 0u64; + fa::mac_discard(lo[i], tmp, T::MODULUS.0[0], &mut carry); + for j in 1..N { + let k = i + j; + if k >= N { + hi[k - N] = fa::mac_with_carry(hi[k - N], tmp, T::MODULUS.0[j], &mut carry); + } else { + lo[k] = fa::mac_with_carry(lo[k], tmp, T::MODULUS.0[j], &mut carry); + } + } + carry2 = fa::adc(&mut hi[i], carry, carry2); + } + + // Move the high half into the output BigInt + let mut hi_out = [0u64; N]; + hi_out.copy_from_slice(hi); + let mut result = Self::new_unchecked(BigInt::(hi_out)); + if T::MODULUS_HAS_SPARE_BIT { + result.subtract_modulus(); + } else { + result.subtract_modulus_with_carry(carry2 != 0); + } + result + } + #[inline(always)] pub fn mul_u64(self, other: u64) -> Self { debug_assert!(NPLUS1 == N + 1); @@ -1369,11 +1431,12 @@ mod test { #[test] fn test_mont_macro_correctness() { - // This test succeeds **only** on the secp256k1 curve. + // This test succeeds only on the secp256k1 curve. let (is_positive, limbs) = str_to_limbs_u64( "111192936301596926984056301862066282284536849596023571352007112326586892541694", ); - let t = Fr::from_sign_and_limbs(is_positive, &limbs); + // Use secp256k1::Fr here (do not use the bn254 alias `Fr` above). + let t = ark_test_curves::secp256k1::Fr::from_sign_and_limbs(is_positive, &limbs); let result: BigUint = t.into(); let expected = BigUint::from_str( From 317d5094b1428e96988e3c30c8d6822e05dd78fb Mon Sep 17 00:00:00 2001 From: Zachary DeStefano Date: Fri, 15 Aug 2025 02:19:34 -0400 Subject: [PATCH 08/38] odd speedup in barrett reduce --- ff/src/fields/models/fp/montgomery_backend.rs | 56 +++++++++++++++++-- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index b552b4c2d..4bb737018 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -1,3 +1,4 @@ + use super::{Fp, FpConfig}; use crate::{ biginteger::arithmetic as fa, BigInt, BigInteger, PrimeField, SqrtPrecomputation, Zero, @@ -1277,6 +1278,40 @@ fn barrett_cond_subtract, const N: usize, const NPLUS1: usize>( } } +/// Subtract two N+1 limb big integers where `a` is (u64, [u64; N]) and `b` is ([u64; N], u64). +/// Returns the N+1 limb result as ([u64; N], u64) and a boolean indicating if a borrow occurred. +#[unroll_for_loops(8)] +#[inline(always)] +fn sub_bigint_plus_one_prime( + a: (u64, [u64; N]), // Format: (low_limb, high_n_limbs) + b: ([u64; N], u64), // Format: (low_n_limbs, high_limb) +) -> (([u64; N], u64), bool) { + let (a_lo, a_hi_n) = a; + let (b_lo_n, b_hi) = b; + let mut result_lo_n = [0u64; N]; + let mut borrow: u64 = 0; + + // Subtract low limb: result_lo_n[0] = a_lo - b_lo_n[0] - borrow (initial borrow = 0) + result_lo_n[0] = a_lo; // Initialize result limb with a_lo + borrow = fa::sbb(&mut result_lo_n[0], b_lo_n[0], borrow); // result_lo_n[0] -= b_lo_n[0] + borrow + + // Subtract middle limbs (if N > 1): result_lo_n[i] = a_hi_n[i-1] - b_lo_n[i] - borrow + // This loop covers indices i = 1 to N-1. + // It uses a_hi_n limbs from index 0 to N-2. + for i in 1..N { + result_lo_n[i] = a_hi_n[i - 1]; // Initialize result limb with corresponding a limb + borrow = fa::sbb(&mut result_lo_n[i], b_lo_n[i], borrow); // result_lo_n[i] -= b_lo_n[i] + borrow + } + + // Subtract high limb: result_hi = a_hi_n[N-1] - b_hi - borrow + let mut result_hi = a_hi_n[N - 1]; // Initialize result limb with last a limb + borrow = fa::sbb(&mut result_hi, b_hi, borrow); // result_hi -= b_hi + borrow + + let final_borrow_occurred = borrow != 0; + + ((result_lo_n, result_hi), final_borrow_occurred) +} + /// Helper function to perform Barrett reduction from N+1 limbs to N limbs. /// Input `c` is represented as `(u64, [u64; N])` (to be compatible with outside invocations). /// Internally, it converts to `([u64; N], u64)` and operates in that format. @@ -1308,13 +1343,24 @@ fn barrett_reduce_nplus1_to_n, const N: usize, const NPLUS1: us // Compute m * 2p (N+1 limbs) BigInt::mul_u64_in_place(&mut m2p, m); - // Compute r_tmp = c - m * 2p + // I really have no idea why the following sequence of operations + // is significantly faster than a simple BigInt sub operation. + // Compute r_tmp = c - m * 2p (result is ([u64; N], u64)) + let m_times_2p = ( + m2p.0[0..N].try_into().unwrap(), // Convert to ([u64; N], u64) + m2p.0[N] // High limb remains as u64 + ); + let (r_tmp, r_tmp_borrow) = sub_bigint_plus_one_prime((c.0[0], c.0[1..N+1].try_into().unwrap()), m_times_2p); // A borrow here implies c was smaller than m*2p, which shouldn't happen with correct m. - let (r_tmp, borrow) = c.const_sub_with_borrow(&m2p); - debug_assert!(!borrow, "Borrow should not occur in Barrett reduction"); - + debug_assert!(!r_tmp_borrow, "Borrow occurred calculating c - m*2p"); + // Change formats again! + let r_tmp_bigint = nplus1_pair_high_to_bigint::(r_tmp); + // Alternative simple BigInt subtraction (much slower for some reason): + /*let (r_tmp_bigint, r_borrow) = c.const_sub_with_borrow(&m2p); + debug_assert!(!r_borrow, "Borrow occurred calculating c - m*2p");*/ + // Use the optimized conditional subtraction to go from N+1 limbs to N limbs. - barrett_cond_subtract::(r_tmp) + barrett_cond_subtract::(r_tmp_bigint) } #[cfg(test)] From 79f2614ae68072a0f6a665f32c11e10befec8687 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Sat, 16 Aug 2025 16:03:59 -0600 Subject: [PATCH 09/38] added fma into nplus4 --- ff/src/biginteger/mod.rs | 234 +++++++++++++++++++++++++++++++++++++ ff/src/biginteger/tests.rs | 110 +++++++++++++++++ 2 files changed, 344 insertions(+) diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 1ae1eb754..cea78a0b3 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -505,6 +505,41 @@ impl BigInteger for BigInt { acc.0[N] = acc.0[N].wrapping_add(carry as u64); } + #[inline] + #[unroll_for_loops(8)] + fn fmu64a_carry_propagating( + &self, + other: u64, + acc: &mut BigInt, + ) { + // ensure NPLUS2 is the correct size (N + 2 limbs) + debug_assert!(NPLUS2 == N + 2); + if other == 0 || self.is_zero() { + return; + } + if other == 1 { + let mut carry: u8 = 0; + for i in 0..N { + carry = arithmetic::adc_for_add_with_carry(&mut acc.0[i], self.0[i], carry); + } + let (new_n, of1) = acc.0[N].overflowing_add(carry as u64); + acc.0[N] = new_n; + if of1 { + acc.0[N + 1] = acc.0[N + 1].wrapping_add(1); + } + return; + } + let mut carry = 0u64; + for i in 0..N { + acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], other, &mut carry); + } + let (new_n, of1) = acc.0[N].overflowing_add(carry); + acc.0[N] = new_n; + if of1 { + acc.0[N + 1] = acc.0[N + 1].wrapping_add(1); + } + } + #[inline] #[unroll_for_loops(8)] fn fm128a(&self, other: u128, acc: &mut BigInt) { @@ -548,6 +583,180 @@ impl BigInteger for BigInt { acc.0[N + 1] = acc.0[N + 1].wrapping_add(carry2); } + #[inline] + #[unroll_for_loops(8)] + fn fmu64a_into_nplus4(&self, other: u64, acc: &mut BigInt) { + debug_assert!(NPLUS4 == N + 4); + if other == 0 || self.is_zero() { + return; + } + if other == 1 { + let mut carry: u8 = 0; + for i in 0..N { + carry = arithmetic::adc_for_add_with_carry(&mut acc.0[i], self.0[i], carry); + } + if carry != 0 { + let (n0, of0) = acc.0[N].overflowing_add(1); + acc.0[N] = n0; + if of0 { + let (n1, of1) = acc.0[N + 1].overflowing_add(1); + acc.0[N + 1] = n1; + if of1 { + let (n2, of2) = acc.0[N + 2].overflowing_add(1); + acc.0[N + 2] = n2; + if of2 { + let (n3, _of3) = acc.0[N + 3].overflowing_add(1); + acc.0[N + 3] = n3; + } + } + } + } + return; + } + let mut carry0 = 0u64; + for i in 0..N { + acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], other, &mut carry0); + } + if carry0 != 0 { + let (n0, of0) = acc.0[N].overflowing_add(carry0); + acc.0[N] = n0; + if of0 { + let (n1, of1) = acc.0[N + 1].overflowing_add(1); + acc.0[N + 1] = n1; + if of1 { + let (n2, of2) = acc.0[N + 2].overflowing_add(1); + acc.0[N + 2] = n2; + if of2 { + let (n3, _of3) = acc.0[N + 3].overflowing_add(1); + acc.0[N + 3] = n3; + } + } + } + } + } + + #[inline] + #[unroll_for_loops(8)] + fn fm2x64a_into_nplus4(&self, other: [u64; 2], acc: &mut BigInt) { + debug_assert!(NPLUS4 == N + 4); + let lo = other[0]; + let hi = other[1]; + if (lo | hi) == 0 || self.is_zero() { + return; + } + + if lo != 0 { + let mut carry0 = 0u64; + for i in 0..N { + acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], lo, &mut carry0); + } + if carry0 != 0 { + let (n0, of0) = acc.0[N].overflowing_add(carry0); + acc.0[N] = n0; + if of0 { + let (n1, of1) = acc.0[N + 1].overflowing_add(1); + acc.0[N + 1] = n1; + if of1 { + let (n2, of2) = acc.0[N + 2].overflowing_add(1); + acc.0[N + 2] = n2; + if of2 { + let (n3, _of3) = acc.0[N + 3].overflowing_add(1); + acc.0[N + 3] = n3; + } + } + } + } + } + + if hi != 0 { + let mut carry1 = 0u64; + for i in 0..N { + acc.0[i + 1] = mac_with_carry!(acc.0[i + 1], self.0[i], hi, &mut carry1); + } + if carry1 != 0 { + let (n1, of1) = acc.0[N + 1].overflowing_add(carry1); + acc.0[N + 1] = n1; + if of1 { + let (n2, of2) = acc.0[N + 2].overflowing_add(1); + acc.0[N + 2] = n2; + if of2 { + let (n3, _of3) = acc.0[N + 3].overflowing_add(1); + acc.0[N + 3] = n3; + } + } + } + } + } + + #[inline] + #[unroll_for_loops(8)] + fn fm3x64a_into_nplus4(&self, other: [u64; 3], acc: &mut BigInt) { + debug_assert!(NPLUS4 == N + 4); + let o0 = other[0]; + let o1 = other[1]; + let o2 = other[2]; + if (o0 | o1 | o2) == 0 || self.is_zero() { + return; + } + + if o0 != 0 { + let mut carry0 = 0u64; + for i in 0..N { + acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], o0, &mut carry0); + } + if carry0 != 0 { + let (n0, of0) = acc.0[N].overflowing_add(carry0); + acc.0[N] = n0; + if of0 { + let (n1, of1) = acc.0[N + 1].overflowing_add(1); + acc.0[N + 1] = n1; + if of1 { + let (n2, of2) = acc.0[N + 2].overflowing_add(1); + acc.0[N + 2] = n2; + if of2 { + let (n3, _of3) = acc.0[N + 3].overflowing_add(1); + acc.0[N + 3] = n3; + } + } + } + } + } + + if o1 != 0 { + let mut carry1 = 0u64; + for i in 0..N { + acc.0[i + 1] = mac_with_carry!(acc.0[i + 1], self.0[i], o1, &mut carry1); + } + if carry1 != 0 { + let (n1, of1) = acc.0[N + 1].overflowing_add(carry1); + acc.0[N + 1] = n1; + if of1 { + let (n2, of2) = acc.0[N + 2].overflowing_add(1); + acc.0[N + 2] = n2; + if of2 { + let (n3, _of3) = acc.0[N + 3].overflowing_add(1); + acc.0[N + 3] = n3; + } + } + } + } + + if o2 != 0 { + let mut carry2 = 0u64; + for i in 0..N { + acc.0[i + 2] = mac_with_carry!(acc.0[i + 2], self.0[i], o2, &mut carry2); + } + if carry2 != 0 { + let (n2, of2) = acc.0[N + 2].overflowing_add(carry2); + acc.0[N + 2] = n2; + if of2 { + let (n3, _of3) = acc.0[N + 3].overflowing_add(1); + acc.0[N + 3] = n3; + } + } + } + } + #[inline] #[unroll_for_loops(8)] fn mul_u128_w_carry( @@ -1267,6 +1476,15 @@ pub trait BigInteger: /// overflow causes a wraparound in the highest limb of the accumulator. fn fmu64a(&self, other: u64, acc: &mut BigInt); + /// NEW! Fused multiply-accumulate with a u64 multiplier and explicit overflow propagation. + /// Accumulates `self * other` into `acc`, which must have two extra limbs (N + 2). + /// Any overflow from limb N is carried into limb N+1 instead of wrapping. + fn fmu64a_carry_propagating( + &self, + other: u64, + acc: &mut BigInt, + ); + /// NEW! Multiplies self by a u128, returning a bigint with two extra limbs to hold overflow. fn mul_u128_w_carry( &self, @@ -1278,6 +1496,22 @@ pub trait BigInteger: /// Overflow causes wraparound in the highest limb of the accumulator. fn fm128a(&self, other: u128, acc: &mut BigInt); + /// NEW! Fused multiply-accumulate of `self` by a single `u64` limb, accumulating into + /// an accumulator with four extra limbs (N + 4), with carry propagation within the width. + /// This will accumulate `self * other` into `acc` and propagate any overflow from limb N + /// into limbs N+1..=N+3. Overflow beyond limb N+3 is dropped by contract. + fn fmu64a_into_nplus4(&self, other: u64, acc: &mut BigInt); + + /// NEW! Fused multiply-accumulate of `self` by a two-limb `[u64; 2]` multiplier, accumulating + /// into an accumulator with four extra limbs (N + 4). Carries are propagated within the width. + /// This is equivalent to doing two u64 passes offset by one limb and cascading carries. + fn fm2x64a_into_nplus4(&self, other: [u64; 2], acc: &mut BigInt); + + /// NEW! Fused multiply-accumulate of `self` by a three-limb `[u64; 3]` multiplier, accumulating + /// into an accumulator with four extra limbs (N + 4). Carries are propagated within the width. + /// This is equivalent to doing three u64 passes offset by 0, 1, and 2 limbs, respectively. + fn fm3x64a_into_nplus4(&self, other: [u64; 3], acc: &mut BigInt); + /// Multiplies this [`BigInteger`] by another `BigInteger`, storing the result in `self`. /// Overflow is ignored. /// diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 4e0dd0a0e..c60c13439 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -487,3 +487,113 @@ fn test_edge_cases_large_numbers() { let expected = BigUint::from(max_bi) * BigUint::from(u128::MAX); assert_eq!(BigUint::from(result), expected); } + +#[test] +fn test_fmu64a_into_nplus4_correctness_and_edges() { + use crate::biginteger::{BigInt, BigInteger256 as B}; + let a = B::from(0xDEADBEEFCAFEBABEu64); + let other = 0xFEDCBA9876543210u64; + let mut acc = BigInt::<8>::zero(); // N+4 accumulator for N=4 + + // Reference: (a * other + acc_before) mod 2^(64*(N+4)) + let before = BigUint::from(acc.clone()); + a.fmu64a_into_nplus4::<8>(other, &mut acc); + let mut expected = BigUint::from(a); + expected *= BigUint::from(other); + expected += before; + let modulus = BigUint::from(1u8) << (64 * 8); + expected %= &modulus; + assert_eq!(BigUint::from(acc.clone()), expected); + + // Zero multiplier is no-op + let mut acc2 = acc.clone(); + a.fmu64a_into_nplus4::<8>(0, &mut acc2); + assert_eq!(acc2, acc); + + // One multiplier reduces to addition + let mut acc3 = BigInt::<8>::zero(); + acc3.0[0] = 11111; + let before3 = BigUint::from(acc3.clone()); + a.fmu64a_into_nplus4::<8>(1, &mut acc3); + let mut expected3 = BigUint::from(a); + expected3 += before3; + expected3 %= &modulus; + assert_eq!(BigUint::from(acc3), expected3); + + // Force cascading carry across N..=N+3 + let a = B::new([u64::MAX; 4]); + let mut acc4 = BigInt::<8>::zero(); + acc4.0[4] = u64::MAX; // limb N + acc4.0[5] = u64::MAX; // limb N+1 + acc4.0[6] = u64::MAX; // limb N+2 + acc4.0[7] = 0; // limb N+3 (top) + // Use multiplier 2 so the low pass produces a carry=1 + a.fmu64a_into_nplus4::<8>(2, &mut acc4); + assert_eq!(acc4.0[7], 1); +} + +#[test] +fn test_fm2x64a_into_nplus4_correctness() { + use crate::biginteger::{BigInt, BigInteger256 as B}; + let a = B::from(0x1234567890ABCDEFu64); + let other = [0x0FEDCBA987654321u64, 0x0011223344556677u64]; + let mut acc = BigInt::<8>::zero(); + + let before = BigUint::from(acc.clone()); + a.fm2x64a_into_nplus4::<8>(other, &mut acc); + + // Expected: a * (lo + (hi << 64)) + acc_before mod 2^(64*8) + let hi = BigUint::from(other[1]); + let lo = BigUint::from(other[0]); + let factor = (hi << 64) + lo; + let mut expected = BigUint::from(a); + expected *= factor; + expected += before; + let modulus = BigUint::from(1u8) << (64 * 8); + expected %= &modulus; + assert_eq!(BigUint::from(acc.clone()), expected); + + // Zero limbs are no-op + let mut acc2 = acc.clone(); + a.fm2x64a_into_nplus4::<8>([0, 0], &mut acc2); + assert_eq!(acc2, acc); +} + +#[test] +fn test_fm3x64a_into_nplus4_correctness() { + use crate::biginteger::{BigInt, BigInteger256 as B}; + let a = B::from(0x0F0E0D0C0B0A0908u64); + let other = [0x89ABCDEF01234567u64, 0x76543210FEDCBA98u64, 0x1122334455667788u64]; + let mut acc = BigInt::<8>::zero(); + + let before = BigUint::from(acc.clone()); + a.fm3x64a_into_nplus4::<8>(other, &mut acc); + + // Expected: a * (o0 + (o1<<64) + (o2<<128)) + acc_before mod 2^(64*8) + let term0 = BigUint::from(other[0]); + let term1 = BigUint::from(other[1]) << 64; + let term2 = BigUint::from(other[2]) << 128; + let factor = term0 + term1 + term2; + let mut expected = BigUint::from(a); + expected *= factor; + expected += before; + let modulus = BigUint::from(1u8) << (64 * 8); + expected %= &modulus; + assert_eq!(BigUint::from(acc.clone()), expected); + + // Edge: ensure offset accumulation lands in correct limbs + // Fill acc with a pattern, then accumulate using only the highest limb to ensure writes start at index 2 + let a = B::from(3u64); + let mut acc2 = BigInt::<8>::zero(); + acc2.0[0] = 5; + acc2.0[1] = 7; + let other2 = [0, 0, 2]; // Only offset by 2 limbs + let before2 = BigUint::from(acc2.clone()); + a.fm3x64a_into_nplus4::<8>(other2, &mut acc2); + let mut expected2 = BigUint::from(a); + expected2 *= BigUint::from(2u64) << 128; + expected2 += before2; + let modulus = BigUint::from(1u8) << (64 * 8); + expected2 %= &modulus; + assert_eq!(BigUint::from(acc2), expected2); +} From 1ae55ed3ecb75509647d8bf11047222bfb6dc2dd Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Sat, 16 Aug 2025 17:35:41 -0600 Subject: [PATCH 10/38] refactor mul by high limbs --- ff/src/fields/models/fp/montgomery_backend.rs | 161 ++++++++++++++++++ test-curves/benches/small_mul.rs | 61 ++++++- 2 files changed, 218 insertions(+), 4 deletions(-) diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index 4bb737018..d194ee743 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -969,6 +969,165 @@ impl, const N: usize> Fp, N> { } } + /// Multiply-assign by a RHS that is zero in its low N-2 limbs in Montgomery limbs, + /// and whose highest two limbs are provided by `hi` (low 64 bits map to limb N-2, + /// high 64 bits map to limb N-1). This is equivalent to K=2 non-zero high limbs. + #[inline] + pub const fn mul_assign_hi_u128(&mut self, hi: u128) { + // Construct a synthetic RHS by using the const CIOS with K=2, passing limbs directly. + // Leverage existing const CIOS specialized by K via a tiny adapter. + *self = self.const_cios_mul_rhs_hi2(hi as u64, (hi >> 64) as u64); + } + + /// Returns self * rhs_high_limbs, where RHS is zero in low N-2 limbs and has its top two + /// limbs provided by `hi` (low 64 -> limb N-2, high 64 -> limb N-1). Equivalent to K=2. + #[inline] + pub const fn mul_hi_u128(self, hi: u128) -> Self { + self.const_cios_mul_rhs_hi2(hi as u64, (hi >> 64) as u64) + } + + /// Const-capable CIOS fastpath specialized for exactly two high limbs (K=2), passed + /// directly as u64s instead of via an Fp operand. Assumes all lower limbs are zero. + #[inline] + #[allow(unused_assignments)] + const fn const_cios_mul_rhs_hi2(self, limb_n2: u64, limb_n1: u64) -> Self { + let mut r = [0u64; N]; + // i = N-2 + if N >= 2 { + let mut carry1 = 0u64; + r[0] = mac!(r[0], (self.0).0[0], limb_n2, &mut carry1); + let k = r[0].wrapping_mul(T::INV); + let mut carry2 = 0u64; + let _discard = mac!(r[0], k, T::MODULUS.0[0], &mut carry2); + crate::const_for!((j in 1..N) { + let new_rj = mac_with_carry!(r[j], (self.0).0[j], limb_n2, &mut carry1); + let new_rj_minus_1 = mac_with_carry!(new_rj, k, T::MODULUS.0[j], &mut carry2); + r[j] = new_rj; + r[j - 1] = new_rj_minus_1; + }); + r[N - 1] = carry1.wrapping_add(carry2); + } + // i = N-1 + { + let mut carry1 = 0u64; + r[0] = mac!(r[0], (self.0).0[0], limb_n1, &mut carry1); + let k = r[0].wrapping_mul(T::INV); + let mut carry2 = 0u64; + let _discard = mac!(r[0], k, T::MODULUS.0[0], &mut carry2); + crate::const_for!((j in 1..N) { + let new_rj = mac_with_carry!(r[j], (self.0).0[j], limb_n1, &mut carry1); + let new_rj_minus_1 = mac_with_carry!(new_rj, k, T::MODULUS.0[j], &mut carry2); + r[j] = new_rj; + r[j - 1] = new_rj_minus_1; + }); + r[N - 1] = carry1.wrapping_add(carry2); + } + let mut out = Self::new_unchecked(crate::BigInt::(r)); + out = out.const_subtract_modulus(); + out + } + + /// Multiply-assign by a BigInt that populates the highest K Montgomery limbs of the RHS, + /// with all lower limbs zero. Lower 64 bits of `rhs_hi.0[0]` map to limb N-K, etc. + #[inline] + pub const fn mul_assign_hi_bigint(&mut self, rhs_hi: &crate::BigInt) { + if T::CAN_USE_NO_CARRY_MUL_OPT { + *self = self.const_cios_mul_rhs_hi::(rhs_hi); + } else { + let (carry, res) = self.mul_without_cond_subtract_rhs_hi::(rhs_hi); + *self = res; + if T::MODULUS_HAS_SPARE_BIT { + self.const_subtract_modulus(); + } else { + self.const_subtract_modulus_with_carry(carry); + } + } + } + + /// Returns self * BigInt that populates the highest K Montgomery limbs of the RHS, + /// with all lower limbs zero. + #[inline] + pub const fn mul_hi_bigint(self, rhs_hi: &crate::BigInt) -> Self { + if T::CAN_USE_NO_CARRY_MUL_OPT { + self.const_cios_mul_rhs_hi::(rhs_hi) + } else { + let (carry, res) = self.mul_without_cond_subtract_rhs_hi::(rhs_hi); + if T::MODULUS_HAS_SPARE_BIT { + res.const_subtract_modulus() + } else { + res.const_subtract_modulus_with_carry(carry) + } + } + } + + /// Const-capable CIOS kernel for a RHS with exactly K non-zero HIGH limbs provided via BigInt. + #[inline] + #[allow(unused_assignments)] + const fn const_cios_mul_rhs_hi(self, rhs_hi: &crate::BigInt) -> Self { + let mut r = [0u64; N]; + // Iterate high columns: t indexes 0..K-1 mapping to global i = N-K+t + crate::const_for!((t in 0..K) { + let b_i = rhs_hi.0[t]; + let mut carry1 = 0u64; + r[0] = mac!(r[0], (self.0).0[0], b_i, &mut carry1); + let k = r[0].wrapping_mul(T::INV); + let mut carry2 = 0u64; + let _discard = mac!(r[0], k, T::MODULUS.0[0], &mut carry2); + crate::const_for!((j in 1..N) { + let new_rj = mac_with_carry!(r[j], (self.0).0[j], b_i, &mut carry1); + let new_rj_minus_1 = mac_with_carry!(new_rj, k, T::MODULUS.0[j], &mut carry2); + r[j] = new_rj; + r[j - 1] = new_rj_minus_1; + }); + r[N - 1] = carry1.wrapping_add(carry2); + }); + let mut out = Self::new_unchecked(crate::BigInt::(r)); + out = out.const_subtract_modulus(); + out + } + + /// Two-phase (schoolbook+REDC) multiply with a RHS whose highest K limbs are provided + /// in `rhs_hi` and lower limbs are zero. + #[inline] + const fn mul_without_cond_subtract_rhs_hi(mut self, rhs_hi: &crate::BigInt) -> (bool, Self) { + let (mut lo, mut hi) = ([0u64; N], [0u64; N]); + // Schoolbook: only columns j in [N-K, N) + crate::const_for!((i in 0..N) { + let mut carry = 0u64; + crate::const_for!((t in 0..K) { + let j = N - K + t; + let b = rhs_hi.0[t]; + let k = i + j; + if k >= N { + hi[k - N] = mac_with_carry!(hi[k - N], (self.0).0[i], b, &mut carry); + } else { + lo[k] = mac_with_carry!(lo[k], (self.0).0[i], b, &mut carry); + } + }); + hi[i] = carry; + }); + // REDC: only i in [N-K, N) + let mut carry2 = 0u64; + crate::const_for!((i in 0..N) { + if i < N - K { /* skip */ } else { + let tmp = lo[i].wrapping_mul(T::INV); + let mut carry; + mac!(lo[i], tmp, T::MODULUS.0[0], &mut carry); + crate::const_for!((j in 1..N) { + let k = i + j; + if k >= N { + hi[k - N] = mac_with_carry!(hi[k - N], tmp, T::MODULUS.0[j], &mut carry); + } else { + lo[k] = mac_with_carry!(lo[k], tmp, T::MODULUS.0[j], &mut carry); + } + }); + hi[i] = adc!(hi[i], carry, &mut carry2); + } + }); + crate::const_for!((i in 0..N) { (self.0).0[i] = hi[i]; }); + (carry2 != 0, self) + } + /// Montgomery reduction for 2N-limb inputs (standard Montgomery reduction) /// Takes a 2N-limb BigInt that represents a product in "unreduced" form /// and reduces it to N limbs in Montgomery form. @@ -1475,6 +1634,8 @@ mod test { } } + // Removed trailing-zero API tests due to API consolidation + #[test] fn test_mont_macro_correctness() { // This test succeeds only on the secp256k1 curve. diff --git a/test-curves/benches/small_mul.rs b/test-curves/benches/small_mul.rs index abdcbb254..5e3e9cdaa 100644 --- a/test-curves/benches/small_mul.rs +++ b/test-curves/benches/small_mul.rs @@ -1,10 +1,16 @@ +// This bench prefers bn254; if not enabled, provide a no-op main +#[cfg(feature = "bn254")] use ark_ff::UniformRand; +#[cfg(feature = "bn254")] use ark_std::rand::{rngs::StdRng, Rng, SeedableRng}; -use ark_test_curves::bn254::{Fr, FrConfig}; +#[cfg(feature = "bn254")] +use ark_test_curves::bn254::Fr; +#[cfg(feature = "bn254")] use criterion::{criterion_group, criterion_main, Criterion}; // Hack: copy over the helper functions from the Montgomery backend to be benched +#[cfg(feature = "bn254")] fn mul_small_bench(c: &mut Criterion) { const SAMPLES: usize = 1000; // Use a fixed seed for reproducibility @@ -13,7 +19,7 @@ fn mul_small_bench(c: &mut Criterion) { let a_s = (0..SAMPLES) .map(|_| Fr::rand(&mut rng)) .collect::>(); - let a_limbs_s = a_s.iter().map(|a| a.0.0).collect::>(); + // let a_limbs_s = a_s.iter().map(|a| a.0.0).collect::>(); let b_u64_s = (0..SAMPLES) .map(|_| rng.gen::()) @@ -86,7 +92,49 @@ fn mul_small_bench(c: &mut Criterion) { }) }); - // Benchmark mul_u128 specifically with inputs known to fit in u64 + // Bench specialized trailing-zero RHS fastpaths (K = 1, 2) + // Construct b' with K trailing zeros in limbs for K=1 and K=2 + let mut b_k1 = b_fr_s.clone(); + for b in &mut b_k1 { (b.0).0[0] = 0; } + let mut b_k2 = b_fr_s.clone(); + for b in &mut b_k2 { (b.0).0[0] = 0; (b.0).0[1] = 0; } + + group.bench_function("mul_assign_rhs_trailing_zeros::<1>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut x = a_s[i]; + x.mul_assign_rhs_trailing_zeros::<1>(&b_k1[i]); + criterion::black_box(x) + }) + }); + + group.bench_function("mul_assign_rhs_trailing_zeros::<2>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut x = a_s[i]; + x.mul_assign_rhs_trailing_zeros::<2>(&b_k2[i]); + criterion::black_box(x) + }) + }); + + group.bench_function("mul_rhs_trailing_zeros::<1>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(a_s[i].mul_rhs_trailing_zeros::<1>(&b_k1[i])) + }) + }); + + group.bench_function("mul_rhs_trailing_zeros::<2>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(a_s[i].mul_rhs_trailing_zeros::<2>(&b_k2[i])) + }) + }); + group.bench_function("mul_u128 (u64 inputs)", |bench| { let mut i = 0; bench.iter(|| { @@ -119,5 +167,10 @@ fn mul_small_bench(c: &mut Criterion) { group.finish(); } +#[cfg(feature = "bn254")] criterion_group!(benches, mul_small_bench); -criterion_main!(benches); \ No newline at end of file +#[cfg(feature = "bn254")] +criterion_main!(benches); + +#[cfg(not(feature = "bn254"))] +fn main() {} \ No newline at end of file From 1756d8b5f32e03f889ea26a87f2e49c88ac09b34 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Fri, 22 Aug 2025 10:57:02 -0600 Subject: [PATCH 11/38] added signed bigint --- ff/src/biginteger/mod.rs | 22 ++ ff/src/biginteger/signed.rs | 513 ++++++++++++++++++++++++++++++++++++ ff/src/biginteger/tests.rs | 142 +++++++++- ff/src/lib.rs | 2 +- 4 files changed, 677 insertions(+), 2 deletions(-) create mode 100644 ff/src/biginteger/signed.rs diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index cea78a0b3..d446a8daf 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -29,6 +29,9 @@ use zeroize::Zeroize; #[macro_use] pub mod arithmetic; +pub mod signed; +pub use signed::SignedBigInt; + #[derive(Copy, Clone, PartialEq, Eq, Hash, Zeroize)] pub struct BigInt(pub [u64; N]); @@ -289,6 +292,25 @@ impl BigInt { self.0[N - 1].leading_zeros() } + /// Truncated-width multiplication: compute self * other and fit into P limbs; overflow is ignored. + #[inline] + pub fn mul_trunc(&self, other: &BigInt) -> BigInt

{ + let mut res = BigInt::

::zero(); + let i_limit = core::cmp::min(N, P); + for i in 0..i_limit { + let mut carry = 0u64; + let j_limit = core::cmp::min(M, P - i); + for j in 0..j_limit { + res.0[i + j] = mac_with_carry!(res.0[i + j], self.0[i], other.0[j], &mut carry); + } + if i + j_limit < P { + let (new_val, _of) = res.0[i + j_limit].overflowing_add(carry); + res.0[i + j_limit] = new_val; + } + } + res + } + #[inline] pub(crate) const fn const_sub_with_borrow(mut self, other: &Self) -> (Self, bool) { let mut borrow = 0; diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs new file mode 100644 index 000000000..50622485d --- /dev/null +++ b/ff/src/biginteger/signed.rs @@ -0,0 +1,513 @@ +use crate::biginteger::{BigInt, BigInteger}; +use core::cmp::Ordering; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +/// A signed big integer using arkworks BigInt for magnitude and a sign bit +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SignedBigInt { + pub magnitude: BigInt, + pub is_positive: bool, +} + +impl SignedBigInt { + /// Construct from limbs and sign; limbs are little-endian. + #[inline] + pub fn new(limbs: [u64; N], is_positive: bool) -> Self { + Self { + magnitude: BigInt::new(limbs), + is_positive, + } + } + + /// Construct from an existing BigInt magnitude and sign. + #[inline] + pub fn from_bigint(magnitude: BigInt, is_positive: bool) -> Self { + Self { magnitude, is_positive } + } + + /// Zero value with a positive sign (negative zero allowed elsewhere). + #[inline] + pub fn zero() -> Self { + Self { magnitude: BigInt::from(0u64), is_positive: true } + } + + /// One with a positive sign. + #[inline] + pub fn one() -> Self { + Self { magnitude: BigInt::from(1u64), is_positive: true } + } + + /// Return true if magnitude is zero (sign is not considered). + #[inline] + pub fn is_zero(&self) -> bool { + self.magnitude.is_zero() + } + + /// Borrow the magnitude (absolute value). + #[inline] + pub fn as_magnitude(&self) -> &BigInt { &self.magnitude } + + /// Return the magnitude limbs by value (copy). + #[inline] + pub fn magnitude_limbs(&self) -> [u64; N] { self.magnitude.0 } + + /// Return true iff the value is non-negative. + #[inline] + pub fn sign(&self) -> bool { + self.is_positive + } + + /// Compute self + other modulo 2^(64*N); carry beyond N limbs is dropped. + #[inline] + pub fn add(mut self, other: Self) -> Self { self += other; self } + + /// Compute self - other modulo 2^(64*N); borrow beyond N limbs is dropped. + #[inline] + pub fn sub(mut self, other: Self) -> Self { self -= other; self } + + /// Compute self * other and keep only the low N limbs; high limbs are discarded. + #[inline] + pub fn mul(mut self, other: Self) -> Self { self *= other; self } + + /// Flip the sign; zero is not canonicalized (negative zero may occur). + #[inline] + pub fn neg(self) -> Self { + Self::from_bigint(self.magnitude, !self.is_positive) + } + + // ===== in-place helpers ===== + /// In-place addition with sign handling; drops overflow beyond N limbs. + #[inline] + fn add_assign_in_place(&mut self, rhs: &Self) { + if self.is_positive == rhs.is_positive { + let _carry = self.magnitude.add_with_carry(&rhs.magnitude); + // overflow ignored by design + } else { + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let _borrow = self.magnitude.sub_with_borrow(&rhs.magnitude); + } + Ordering::Less => { + let mut tmp = rhs.magnitude; + let _borrow = tmp.sub_with_borrow(&self.magnitude); + self.magnitude = tmp; + self.is_positive = rhs.is_positive; + } + } + } + } + + /// In-place subtraction with sign handling; drops borrow beyond N limbs. + #[inline] + fn sub_assign_in_place(&mut self, rhs: &Self) { + // self - rhs == self + (-rhs) + let rhs_neg = Self { magnitude: rhs.magnitude, is_positive: !rhs.is_positive }; + self.add_assign_in_place(&rhs_neg); + } + + /// In-place multiply using low-limb product only; updates sign, discards high limbs. + #[inline] + fn mul_assign_in_place(&mut self, rhs: &Self) { + let low = self.magnitude.mul_low(&rhs.magnitude); + self.magnitude = low; + self.is_positive = self.is_positive == rhs.is_positive; + } +} + +impl SignedBigInt { + // ===== truncated-width operations ===== + + /// Truncated add: compute (self + rhs) and fit into M limbs; overflow is ignored. + #[inline] + pub fn add_trunc(&self, rhs: &SignedBigInt) -> SignedBigInt { + if self.is_positive == rhs.is_positive { + // Same sign -> truncate limbwise sum + let mut res = BigInt::::zero(); + let mut carry: u8 = 0; + let lim = core::cmp::min(N, M); + for i in 0..lim { + let (s1, c1) = self.magnitude.0[i].overflowing_add(rhs.magnitude.0[i]); + let (s2, c2) = s1.overflowing_add(carry as u64); + res.0[i] = s2; + carry = (c1 as u8) | (c2 as u8); + } + // propagate carry into next limb if within M, else drop + if lim < M { + let (s, _c) = 0u64.overflowing_add(carry as u64); + res.0[lim] = s; + } + SignedBigInt:: { magnitude: res, is_positive: self.is_positive } + } else { + // Different signs -> subtract smaller magnitude from larger + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let mut res = BigInt::::zero(); + let lim = core::cmp::min(N, M); + let mut borrow: bool = false; + for i in 0..lim { + let (d1, b1) = self.magnitude.0[i].overflowing_sub(rhs.magnitude.0[i]); + if borrow { + let (d2, b2) = d1.overflowing_sub(1); + res.0[i] = d2; + borrow = b1 || b2; + } else { + res.0[i] = d1; + borrow = b1; + } + } + SignedBigInt:: { magnitude: res, is_positive: self.is_positive } + } + Ordering::Less => { + let mut res = BigInt::::zero(); + let lim = core::cmp::min(N, M); + let mut borrow: bool = false; + for i in 0..lim { + let (d1, b1) = rhs.magnitude.0[i].overflowing_sub(self.magnitude.0[i]); + if borrow { + let (d2, b2) = d1.overflowing_sub(1); + res.0[i] = d2; + borrow = b1 || b2; + } else { + res.0[i] = d1; + borrow = b1; + } + } + SignedBigInt:: { magnitude: res, is_positive: rhs.is_positive } + } + } + } + } + + /// Truncated sub: compute (self - rhs) and fit into M limbs; overflow is ignored. + #[inline] + pub fn sub_trunc(&self, rhs: &SignedBigInt) -> SignedBigInt { + if self.is_positive != rhs.is_positive { + // same as addition path + let mut res = BigInt::::zero(); + let mut carry: u8 = 0; + let lim = core::cmp::min(N, M); + for i in 0..lim { + let (s1, c1) = self.magnitude.0[i].overflowing_add(rhs.magnitude.0[i]); + let (s2, c2) = s1.overflowing_add(carry as u64); + res.0[i] = s2; + carry = (c1 as u8) | (c2 as u8); + } + if lim < M { + let (s, _c) = 0u64.overflowing_add(carry as u64); + res.0[lim] = s; + } + SignedBigInt:: { magnitude: res, is_positive: self.is_positive } + } else { + // different signs wrt subtraction => subtract magnitudes + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let mut res = BigInt::::zero(); + let lim = core::cmp::min(N, M); + let mut borrow: bool = false; + for i in 0..lim { + let (d1, b1) = self.magnitude.0[i].overflowing_sub(rhs.magnitude.0[i]); + if borrow { + let (d2, b2) = d1.overflowing_sub(1); + res.0[i] = d2; + borrow = b1 || b2; + } else { + res.0[i] = d1; + borrow = b1; + } + } + SignedBigInt:: { magnitude: res, is_positive: self.is_positive } + } + Ordering::Less => { + let mut res = BigInt::::zero(); + let lim = core::cmp::min(N, M); + let mut borrow: bool = false; + for i in 0..lim { + let (d1, b1) = rhs.magnitude.0[i].overflowing_sub(self.magnitude.0[i]); + if borrow { + let (d2, b2) = d1.overflowing_sub(1); + res.0[i] = d2; + borrow = b1 || b2; + } else { + res.0[i] = d1; + borrow = b1; + } + } + SignedBigInt:: { magnitude: res, is_positive: !self.is_positive } + } + } + } + } + + /// Truncated mul: compute self * rhs and fit into P limbs; no assumption on P; overflow ignored. + #[inline] + pub fn mul_trunc(&self, rhs: &SignedBigInt) -> SignedBigInt

{ + let mag = self.magnitude.mul_trunc::(&rhs.magnitude); + let sign = self.is_positive == rhs.is_positive; + SignedBigInt::

{ magnitude: mag, is_positive: sign } + } + + /// Fused multiply-add: acc += self * rhs, fitted into P limbs; overflow is ignored. + #[inline] + pub fn fmadd_trunc(&self, rhs: &SignedBigInt, acc: &mut SignedBigInt

) { + let prod_mag = self.magnitude.mul_trunc::(&rhs.magnitude); + let prod_sign = self.is_positive == rhs.is_positive; + let prod = SignedBigInt::

{ magnitude: prod_mag, is_positive: prod_sign }; + acc.add_assign_in_place(&prod); + } +} + +impl SignedBigInt { + // ===== generic conversions ===== + + /// Construct from u64 with positive sign. + #[inline] + pub fn from_u64(value: u64) -> Self { + Self::from_bigint(BigInt::from(value), true) + } + + /// Construct from (u64, sign); sign=true is non-negative. + #[inline] + pub fn from_u64_with_sign(value: u64, is_positive: bool) -> Self { + Self::from_bigint(BigInt::from(value), is_positive) + } + + /// Construct from i64; magnitude is |value|, sign reflects value>=0. + #[inline] + pub fn from_i64(value: i64) -> Self { + if value >= 0 { + Self::from_bigint(BigInt::from(value as u64), true) + } else { + // wrapping_neg handles i64::MIN + Self::from_bigint(BigInt::from(value.wrapping_neg() as u64), false) + } + } + + /// Construct from u128 with positive sign (N must be >= 2 in debug builds). + #[inline] + pub fn from_u128(value: u128) -> Self { + debug_assert!(N >= 2, "from_u128 requires at least 2 limbs"); + Self::from_bigint(BigInt::from(value), true) + } + + /// Construct from i128; magnitude is |value|, sign reflects value>=0 (N must be >= 2 in debug builds). + #[inline] + pub fn from_i128(value: i128) -> Self { + debug_assert!(N >= 2, "from_i128 requires at least 2 limbs"); + if value >= 0 { + Self::from_bigint(BigInt::from(value as u128), true) + } else { + let mag = (value as i128).unsigned_abs(); + Self::from_bigint(BigInt::from(mag), false) + } + } +} + +impl From for SignedBigInt { + /// From: positive sign; higher limbs are zeroed. + #[inline] + fn from(value: u64) -> Self { + Self::from_u64(value) + } +} + +impl From for SignedBigInt { + /// From: sign from value; magnitude is |value|; higher limbs are zeroed. + #[inline] + fn from(value: i64) -> Self { + Self::from_i64(value) + } +} + +impl From<(u64, bool)> for SignedBigInt { + /// From<(u64,bool)>: (magnitude, is_positive); higher limbs are zeroed. + #[inline] + fn from(value_and_sign: (u64, bool)) -> Self { + Self::from_u64_with_sign(value_and_sign.0, value_and_sign.1) + } +} + +impl From for SignedBigInt { + /// From: positive sign; debug-assert N >= 2; higher limbs are zeroed. + #[inline] + fn from(value: u128) -> Self { + debug_assert!(N >= 2, "From requires at least 2 limbs"); + Self::from_u128(value) + } +} + +impl From for SignedBigInt { + /// From: sign from value; debug-assert N >= 2; magnitude is |value|. + #[inline] + fn from(value: i128) -> Self { + debug_assert!(N >= 2, "From requires at least 2 limbs"); + Self::from_i128(value) + } +} + +// Specializations for common sizes +impl SignedBigInt<1> { + /// Convert to i128; any u64 magnitude fits for both signs. + #[inline] + pub fn to_i128(&self) -> i128 { + let magnitude = self.magnitude.0[0]; + if self.is_positive { magnitude as i128 } else { -(magnitude as i128) } + } +} + +impl SignedBigInt<2> { + /// Convert to i128 using 2^127 bounds: positive requires mag <= i128::MAX; negative allows mag == 2^127. + #[inline] + pub fn to_i128(&self) -> Option { + let hi = self.magnitude.0[1]; + let lo = self.magnitude.0[0]; + let hi_top_bit = hi >> 63; // bit 127 + if self.is_positive { + if hi_top_bit != 0 { return None; } + let mag = ((hi as u128) << 64) | (lo as u128); + Some(mag as i128) + } else { + if hi_top_bit == 0 { + let mag = ((hi as u128) << 64) | (lo as u128); + Some(-(mag as i128)) + } else if hi == (1u64 << 63) && lo == 0 { + Some(i128::MIN) + } else { + None + } + } + } + + /// Return the magnitude as u128 + #[inline] + pub fn magnitude_as_u128(&self) -> u128 { + (self.magnitude.0[1] as u128) << 64 | (self.magnitude.0[0] as u128) + } +} + +/// Helper function for single u64 signed arithmetic +/// Adds two signed u64 values (given as magnitude+sign) modulo 2^64; returns (magnitude, sign). +#[inline] +pub fn add_with_sign_u64(a_mag: u64, a_pos: bool, b_mag: u64, b_pos: bool) -> (u64, bool) { + let a = SignedBigInt::<1>::from_u64_with_sign(a_mag, a_pos); + let b = SignedBigInt::<1>::from_u64_with_sign(b_mag, b_pos); + let result = a + b; + (result.magnitude.0[0], result.is_positive) +} + +// =============================================== +// Standard operator trait implementations +// =============================================== + +impl Add for SignedBigInt { + type Output = Self; + + #[inline] + fn add(mut self, rhs: Self) -> Self::Output { + self.add_assign_in_place(&rhs); + self + } +} + +impl Sub for SignedBigInt { + type Output = Self; + + #[inline] + fn sub(mut self, rhs: Self) -> Self::Output { + self.sub_assign_in_place(&rhs); + self + } +} + +impl Mul for SignedBigInt { + type Output = Self; + + #[inline] + fn mul(mut self, rhs: Self) -> Self::Output { + self.mul_assign_in_place(&rhs); + self + } +} + +impl Neg for SignedBigInt { + type Output = Self; + + #[inline] + fn neg(self) -> Self::Output { + SignedBigInt::neg(self) + } +} + +impl AddAssign for SignedBigInt { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.add_assign_in_place(&rhs); + } +} + +impl SubAssign for SignedBigInt { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.sub_assign_in_place(&rhs); + } +} + +impl MulAssign for SignedBigInt { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + self.mul_assign_in_place(&rhs); + } +} + +// Reference variants for efficiency +impl Add<&SignedBigInt> for SignedBigInt { + type Output = SignedBigInt; + + #[inline] + fn add(mut self, rhs: &SignedBigInt) -> Self::Output { + self.add_assign_in_place(rhs); + self + } +} + +impl Sub<&SignedBigInt> for SignedBigInt { + type Output = SignedBigInt; + + #[inline] + fn sub(mut self, rhs: &SignedBigInt) -> Self::Output { + self.sub_assign_in_place(rhs); + self + } +} + +impl Mul<&SignedBigInt> for SignedBigInt { + type Output = SignedBigInt; + + #[inline] + fn mul(mut self, rhs: &SignedBigInt) -> Self::Output { + self.mul_assign_in_place(rhs); + self + } +} + +impl AddAssign<&SignedBigInt> for SignedBigInt { + #[inline] + fn add_assign(&mut self, rhs: &SignedBigInt) { + self.add_assign_in_place(rhs); + } +} + +impl SubAssign<&SignedBigInt> for SignedBigInt { + #[inline] + fn sub_assign(&mut self, rhs: &SignedBigInt) { + self.sub_assign_in_place(rhs); + } +} + +impl MulAssign<&SignedBigInt> for SignedBigInt { + #[inline] + fn mul_assign(&mut self, rhs: &SignedBigInt) { + self.mul_assign_in_place(rhs); + } +} + + diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index c60c13439..8656f4362 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -1,4 +1,7 @@ -use crate::{biginteger::BigInteger, UniformRand}; +#[cfg(test)] +pub mod tests { + +use crate::{biginteger::{BigInteger, SignedBigInt}, UniformRand}; use num_bigint::BigUint; // Test elementary math operations for BigInteger. @@ -597,3 +600,140 @@ fn test_fm3x64a_into_nplus4_correctness() { expected2 %= &modulus; assert_eq!(BigUint::from(acc2), expected2); } + +// ============================== +// SignedBigInt tests +// ============================== + +#[test] +fn test_signed_construction() { + // zero and one + let z = SignedBigInt::<1>::zero(); + assert!(z.is_zero()); + assert!(z.is_positive); + let o = SignedBigInt::<1>::one(); + assert!(!o.is_zero()); + assert!(o.is_positive); + + // from_u64 + let p = SignedBigInt::<1>::from_u64(42); + assert_eq!(p.magnitude.0[0], 42); + assert!(p.is_positive); + let n = SignedBigInt::<1>::from((42u64, false)); + assert_eq!(n.magnitude.0[0], 42); + assert!(!n.is_positive); +} + +#[test] +fn test_signed_add_sub_mul_neg() { + let a = SignedBigInt::<1>::from_u64(10); + let b = SignedBigInt::<1>::from_u64(5); + assert_eq!((a + b).magnitude.0[0], 15); + assert_eq!((a - b).magnitude.0[0], 5); + assert_eq!((a * b).magnitude.0[0], 50); + let neg = -a; + assert_eq!(neg.magnitude.0[0], 10); + assert!(!neg.is_positive); + + // opposite signs + let x = SignedBigInt::<1>::from_u64(30); + let y = SignedBigInt::<1>::from((20u64, false)); + let r = x + y; // 30 - 20 + assert!(r.is_positive); + assert_eq!(r.magnitude.0[0], 10); + + let x2 = SignedBigInt::<1>::from((20u64, false)); + let y2 = SignedBigInt::<1>::from_u64(30); + let r2 = x2 + y2; // -20 + 30 + assert!(r2.is_positive); + assert_eq!(r2.magnitude.0[0], 10); +} + +#[test] +fn test_signed_to_i128_and_mag_helpers() { + let p = SignedBigInt::<1>::from_u64(100); + assert_eq!(p.to_i128(), 100); + let n = SignedBigInt::<1>::from((100u64, false)); + assert_eq!(n.to_i128(), -100); + + let d = SignedBigInt::<2>::from_u128(0x1234_5678_9abc_def0_1111_2222_3333_4444u128); + assert_eq!(d.magnitude.0[0], 0x1111_2222_3333_4444); + assert_eq!(d.magnitude.0[1], 0x1234_5678_9abc_def0); + // Positive below 2^127 should convert + let expected_i128 = 0x1234_5678_9abc_def0_1111_2222_3333_4444u128 as i128; + assert_eq!(d.to_i128(), Some(expected_i128)); + + // Positive at 2^127 should fail + let too_big_pos = SignedBigInt::<2>::from_u128(1u128 << 127); + assert_eq!(too_big_pos.to_i128(), None); + + let small = SignedBigInt::<2>::new([100, 0], true); + assert_eq!(small.to_i128(), Some(100)); + assert_eq!(small.magnitude_as_u128(), 100u128); +} + +#[test] +fn test_add_with_sign_u64_helper() { + let (mag, sign) = crate::biginteger::signed::add_with_sign_u64(10, true, 5, true); + assert_eq!(mag, 15); + assert!(sign); + let (mag2, sign2) = crate::biginteger::signed::add_with_sign_u64(10, true, 5, false); + assert_eq!(mag2, 5); + assert!(sign2); + let (mag3, sign3) = crate::biginteger::signed::add_with_sign_u64(5, true, 10, false); + assert_eq!(mag3, 5); + assert!(!sign3); +} + +#[test] +fn test_signed_truncated_add_sub() { + use crate::biginteger::SignedBigInt as S; + let a = S::<2>::from_u128(0x0000_0000_0000_0001_ffff_ffff_ffff_ffff); + let b = S::<2>::from_u128(0x0000_0000_0000_0001_0000_0000_0000_0001); + // Add and truncate to 1 limb + let r1 = a.add_trunc::<1>(&b); + // expected low limb wrap of the low words, ignoring carry to limb1 + let expected_low = (0xffff_ffff_ffff_ffffu64).wrapping_add(0x0000_0000_0000_0001u64); + assert_eq!(r1.magnitude.0[0], expected_low); + assert!(r1.is_positive); + + // Different signs: subtraction path + let a = S::<2>::from_u128(0x2); + let b = S::<2>::from(-3i128); // -3 + let r2 = a.add_trunc::<1>(&b); // 2 + (-3) = -1, truncated to 64-bit + assert_eq!(r2.magnitude.0[0], 1); + assert!(!r2.is_positive); + + // sub_trunc uses add_trunc internally + let x = S::<1>::from_u64(10); + let y = S::<1>::from_u64(7); + let r3 = x.sub_trunc::<1>(&y); + assert_eq!(r3.magnitude.0[0], 3); + assert!(r3.is_positive); +} + +#[test] +fn test_signed_truncated_mul_and_fmadd() { + use crate::biginteger::SignedBigInt as S; + // 128-bit x 64-bit -> truncated to 2 limbs (128-bit) + let a = S::<2>::from_u128(0x0000_0000_0000_0001_FFFF_FFFF_FFFF_FFFFu128); + let b = S::<1>::from_u64(0x2); + let p = a.mul_trunc::<1, 2>(&b); + // Expected low 128 bits of the product + let expected = num_bigint::BigUint::from(0x0000_0000_0000_0001_FFFF_FFFF_FFFF_FFFFu128) + * num_bigint::BigUint::from(2u64); + let got = num_bigint::BigUint::from(p.magnitude); + assert_eq!(got, expected & ((num_bigint::BigUint::from(1u8) << 128) - 1u8)); + assert!(p.is_positive); + + // fmadd into 1-limb accumulator (truncate to 64 bits) + let a = S::<1>::from_u64(0xFFFF_FFFF_FFFF_FFFF); + let b = S::<1>::from_u64(0x2); + let mut acc = S::<1>::from_u64(1); + a.fmadd_trunc::<1, 1>(&b, &mut acc); // acc = 1 + (a*b) mod 2^64 with sign + + // a*b = (2^64 - 1)*2 = 2^65 - 2 => low 64 = (2^64 - 2) + let expected_low = (u64::MAX).wrapping_sub(1); + assert_eq!(acc.magnitude.0[0], expected_low.wrapping_add(1)); +} + +} diff --git a/ff/src/lib.rs b/ff/src/lib.rs index 6b10158f4..7464724ea 100644 --- a/ff/src/lib.rs +++ b/ff/src/lib.rs @@ -20,7 +20,7 @@ extern crate educe; pub mod biginteger; pub use biginteger::{ signed_mod_reduction, BigInt, BigInteger, BigInteger128, BigInteger256, BigInteger320, - BigInteger384, BigInteger448, BigInteger64, BigInteger768, BigInteger832, + BigInteger384, BigInteger448, BigInteger64, BigInteger768, BigInteger832, SignedBigInt, }; #[macro_use] From 433529b5ddeb93f0fda36d8d2094d3ef51e8aa2e Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Fri, 22 Aug 2025 15:44:03 -0600 Subject: [PATCH 12/38] more additions to (signed) bigint --- ff/src/biginteger/mod.rs | 18 +++++ ff/src/biginteger/signed.rs | 156 ++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+) diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index d446a8daf..642302e06 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -311,6 +311,24 @@ impl BigInt { res } + /// Fused multiply-add with truncation: acc += self * other, fitting into P limbs; overflow is ignored. + /// This is a generic version for arbitrary limb widths of `self` and `other`. + #[inline] + pub fn fmadd_trunc(&self, other: &BigInt, acc: &mut BigInt

) { + let i_limit = core::cmp::min(N, P); + for i in 0..i_limit { + let mut carry = 0u64; + let j_limit = core::cmp::min(M, P - i); + for j in 0..j_limit { + acc.0[i + j] = mac_with_carry!(acc.0[i + j], self.0[i], other.0[j], &mut carry); + } + if i + j_limit < P { + let (new_val, _of) = acc.0[i + j_limit].overflowing_add(carry); + acc.0[i + j_limit] = new_val; + } + } + } + #[inline] pub(crate) const fn const_sub_with_borrow(mut self, other: &Self) -> (Self, bool) { let mut borrow = 0; diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs index 50622485d..4a71c583f 100644 --- a/ff/src/biginteger/signed.rs +++ b/ff/src/biginteger/signed.rs @@ -238,6 +238,83 @@ impl SignedBigInt { } } + /// Truncated mixed-width addition: compute (self + rhs) where rhs can have a + /// different limb count, and fit into P limbs; overflow is ignored. + #[inline] + pub fn add_trunc_mixed(&self, rhs: &SignedBigInt) -> SignedBigInt

{ + // Case 1: same signs => add magnitudes, sign = self.is_positive + if self.is_positive == rhs.is_positive { + let mut res = BigInt::

::zero(); + let mut carry: u8 = 0; + for i in 0..P { + let a = if i < N { self.magnitude.0[i] } else { 0u64 }; + let b = if i < M { rhs.magnitude.0[i] } else { 0u64 }; + let (s1, c1) = a.overflowing_add(b); + let (s2, c2) = s1.overflowing_add(carry as u64); + res.0[i] = s2; + carry = (c1 as u8) | (c2 as u8); + } + return SignedBigInt::

{ magnitude: res, is_positive: self.is_positive }; + } + + // Case 2: different signs => subtract smaller magnitude from larger + let ord = { + let max_limbs = if N > M { N } else { M }; + let mut i = max_limbs; + let mut ordering = Ordering::Equal; + while i > 0 { + let idx = i - 1; + let a = if idx < N { self.magnitude.0[idx] } else { 0u64 }; + let b = if idx < M { rhs.magnitude.0[idx] } else { 0u64 }; + if a > b { ordering = Ordering::Greater; break; } + if a < b { ordering = Ordering::Less; break; } + i -= 1; + } + ordering + }; + + match ord { + Ordering::Greater | Ordering::Equal => { + // res_mag = self.mag - rhs.mag; sign = self.is_positive + let mut res = BigInt::

::zero(); + let mut borrow = false; + for i in 0..P { + let a = if i < N { self.magnitude.0[i] } else { 0u64 }; + let b = if i < M { rhs.magnitude.0[i] } else { 0u64 }; + let (d1, b1) = a.overflowing_sub(b); + if borrow { + let (d2, b2) = d1.overflowing_sub(1); + res.0[i] = d2; + borrow = b1 || b2; + } else { + res.0[i] = d1; + borrow = b1; + } + } + SignedBigInt::

{ magnitude: res, is_positive: self.is_positive } + } + Ordering::Less => { + // res_mag = rhs.mag - self.mag; sign = rhs.is_positive + let mut res = BigInt::

::zero(); + let mut borrow = false; + for i in 0..P { + let a = if i < M { rhs.magnitude.0[i] } else { 0u64 }; + let b = if i < N { self.magnitude.0[i] } else { 0u64 }; + let (d1, b1) = a.overflowing_sub(b); + if borrow { + let (d2, b2) = d1.overflowing_sub(1); + res.0[i] = d2; + borrow = b1 || b2; + } else { + res.0[i] = d1; + borrow = b1; + } + } + SignedBigInt::

{ magnitude: res, is_positive: rhs.is_positive } + } + } + } + /// Truncated mul: compute self * rhs and fit into P limbs; no assumption on P; overflow ignored. #[inline] pub fn mul_trunc(&self, rhs: &SignedBigInt) -> SignedBigInt

{ @@ -300,6 +377,85 @@ impl SignedBigInt { Self::from_bigint(BigInt::from(mag), false) } } + + /// Truncated mixed-width subtraction: compute (self - rhs) where rhs can have a + /// different limb count, and fit into P limbs; overflow is ignored. + #[inline] + pub fn sub_trunc_mixed(&self, rhs: &SignedBigInt) -> SignedBigInt

{ + // Case 1: different signs => addition of magnitudes, sign = self.is_positive + if self.is_positive != rhs.is_positive { + let mut res = BigInt::

::zero(); + let mut carry: u8 = 0; + for i in 0..P { + let a = if i < N { self.magnitude.0[i] } else { 0u64 }; + let b = if i < M { rhs.magnitude.0[i] } else { 0u64 }; + let (s1, c1) = a.overflowing_add(b); + let (s2, c2) = s1.overflowing_add(carry as u64); + res.0[i] = s2; + carry = (c1 as u8) | (c2 as u8); + } + return SignedBigInt::

{ magnitude: res, is_positive: self.is_positive }; + } + + // Case 2: same signs => subtract smaller magnitude from larger; sign accordingly + // Mixed-width magnitude comparison (zero-extended to max(N, M)) + let ord = { + // Compare from most significant limb down to 0 + let max_limbs = if N > M { N } else { M }; + let mut i = max_limbs; + let mut ordering = Ordering::Equal; + while i > 0 { + let idx = i - 1; + let a = if idx < N { self.magnitude.0[idx] } else { 0u64 }; + let b = if idx < M { rhs.magnitude.0[idx] } else { 0u64 }; + if a > b { ordering = Ordering::Greater; break; } + if a < b { ordering = Ordering::Less; break; } + i -= 1; + } + ordering + }; + + match ord { + Ordering::Greater | Ordering::Equal => { + // res_mag = self.mag - rhs.mag; sign = self.is_positive + let mut res = BigInt::

::zero(); + let mut borrow = false; + for i in 0..P { + let a = if i < N { self.magnitude.0[i] } else { 0u64 }; + let b = if i < M { rhs.magnitude.0[i] } else { 0u64 }; + let (d1, b1) = a.overflowing_sub(b); + if borrow { + let (d2, b2) = d1.overflowing_sub(1); + res.0[i] = d2; + borrow = b1 || b2; + } else { + res.0[i] = d1; + borrow = b1; + } + } + SignedBigInt::

{ magnitude: res, is_positive: self.is_positive } + } + Ordering::Less => { + // res_mag = rhs.mag - self.mag; sign = !self.is_positive + let mut res = BigInt::

::zero(); + let mut borrow = false; + for i in 0..P { + let a = if i < M { rhs.magnitude.0[i] } else { 0u64 }; + let b = if i < N { self.magnitude.0[i] } else { 0u64 }; + let (d1, b1) = a.overflowing_sub(b); + if borrow { + let (d2, b2) = d1.overflowing_sub(1); + res.0[i] = d2; + borrow = b1 || b2; + } else { + res.0[i] = d1; + borrow = b1; + } + } + SignedBigInt::

{ magnitude: res, is_positive: !self.is_positive } + } + } + } } impl From for SignedBigInt { From ebb26fdac76ab49d06c37fd24f6c0f724f1300dc Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Mon, 25 Aug 2025 15:28:37 -0600 Subject: [PATCH 13/38] new materials for svo --- ff/src/biginteger/arithmetic.rs | 29 +++ ff/src/biginteger/signed.rs | 186 +++++++++++--- ff/src/biginteger/tests.rs | 174 +++++++++++++ ff/src/fields/models/fp/montgomery_backend.rs | 145 ++++++++--- test-curves/benches/small_mul.rs | 243 +++++++++++++++--- 5 files changed, 666 insertions(+), 111 deletions(-) diff --git a/ff/src/biginteger/arithmetic.rs b/ff/src/biginteger/arithmetic.rs index ac15a26ae..493758ae7 100644 --- a/ff/src/biginteger/arithmetic.rs +++ b/ff/src/biginteger/arithmetic.rs @@ -123,6 +123,35 @@ pub fn mac_discard(a: u64, b: u64, c: u64, carry: &mut u64) { *carry = (tmp >> 64) as u64; } +/// Accumulate `limbs` into an N-limb accumulator starting at `lane_offset` (64-bit lanes), +/// returning the final carry. This is a helper for building wide accumulators. +#[inline(always)] +pub fn add_limbs_shifted_inplace( + acc: &mut [u64; N], + limbs: &[u64], + lane_offset: usize, +) -> u64 { + let mut carry = 0u64; + let mut i = 0usize; + while i < limbs.len() { + let idx = lane_offset + i; + if idx >= N { break; } + let tmp = (acc[idx] as u128) + (limbs[i] as u128) + (carry as u128); + acc[idx] = tmp as u64; + carry = (tmp >> 64) as u64; + i += 1; + } + // propagate carry across remaining lanes if any + let mut idx = lane_offset + i; + while carry != 0 && idx < N { + let tmp = (acc[idx] as u128) + (carry as u128); + acc[idx] = tmp as u64; + carry = (tmp >> 64) as u64; + idx += 1; + } + carry +} + macro_rules! mac_with_carry { ($a:expr, $b:expr, $c:expr, &mut $carry:expr$(,)?) => {{ let tmp = diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs index 4a71c583f..13c08854c 100644 --- a/ff/src/biginteger/signed.rs +++ b/ff/src/biginteger/signed.rs @@ -10,6 +10,20 @@ pub struct SignedBigInt { } impl SignedBigInt { + #[inline] + fn cmp_magnitude_mixed(&self, rhs: &SignedBigInt) -> Ordering { + let max_limbs = if N > M { N } else { M }; + let mut i = max_limbs; + while i > 0 { + let idx = i - 1; + let a = if idx < N { self.magnitude.0[idx] } else { 0u64 }; + let b = if idx < M { rhs.magnitude.0[idx] } else { 0u64 }; + if a > b { return Ordering::Greater; } + if a < b { return Ordering::Less; } + i -= 1; + } + Ordering::Equal + } /// Construct from limbs and sign; limbs are little-endian. #[inline] pub fn new(limbs: [u64; N], is_positive: bool) -> Self { @@ -51,6 +65,10 @@ impl SignedBigInt { #[inline] pub fn magnitude_limbs(&self) -> [u64; N] { self.magnitude.0 } + /// Borrow the magnitude limbs as a slice (avoids copying the array). + #[inline] + pub fn magnitude_slice(&self) -> &[u64] { self.magnitude.as_ref() } + /// Return true iff the value is non-negative. #[inline] pub fn sign(&self) -> bool { @@ -77,7 +95,7 @@ impl SignedBigInt { // ===== in-place helpers ===== /// In-place addition with sign handling; drops overflow beyond N limbs. - #[inline] + #[inline(always)] fn add_assign_in_place(&mut self, rhs: &Self) { if self.is_positive == rhs.is_positive { let _carry = self.magnitude.add_with_carry(&rhs.magnitude); @@ -88,9 +106,9 @@ impl SignedBigInt { let _borrow = self.magnitude.sub_with_borrow(&rhs.magnitude); } Ordering::Less => { - let mut tmp = rhs.magnitude; - let _borrow = tmp.sub_with_borrow(&self.magnitude); - self.magnitude = tmp; + // Minimize copies: move rhs magnitude into place and subtract old self + let old = core::mem::replace(&mut self.magnitude, rhs.magnitude); + let _borrow = self.magnitude.sub_with_borrow(&old); self.is_positive = rhs.is_positive; } } @@ -98,15 +116,30 @@ impl SignedBigInt { } /// In-place subtraction with sign handling; drops borrow beyond N limbs. - #[inline] + #[inline(always)] fn sub_assign_in_place(&mut self, rhs: &Self) { - // self - rhs == self + (-rhs) - let rhs_neg = Self { magnitude: rhs.magnitude, is_positive: !rhs.is_positive }; - self.add_assign_in_place(&rhs_neg); + // Implement directly to avoid temporary construction + if self.is_positive != rhs.is_positive { + // Signs differ -> add magnitudes; sign remains self.is_positive + let _carry = self.magnitude.add_with_carry(&rhs.magnitude); + } else { + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let _borrow = self.magnitude.sub_with_borrow(&rhs.magnitude); + // sign stays the same + } + Ordering::Less => { + // Result takes rhs magnitude minus self magnitude, sign flips + let old = core::mem::replace(&mut self.magnitude, rhs.magnitude); + let _borrow = self.magnitude.sub_with_borrow(&old); + self.is_positive = !self.is_positive; + } + } + } } /// In-place multiply using low-limb product only; updates sign, discards high limbs. - #[inline] + #[inline(always)] fn mul_assign_in_place(&mut self, rhs: &Self) { let low = self.magnitude.mul_low(&rhs.magnitude); self.magnitude = low; @@ -133,8 +166,7 @@ impl SignedBigInt { } // propagate carry into next limb if within M, else drop if lim < M { - let (s, _c) = 0u64.overflowing_add(carry as u64); - res.0[lim] = s; + res.0[lim] = carry as u64; } SignedBigInt:: { magnitude: res, is_positive: self.is_positive } } else { @@ -193,8 +225,7 @@ impl SignedBigInt { carry = (c1 as u8) | (c2 as u8); } if lim < M { - let (s, _c) = 0u64.overflowing_add(carry as u64); - res.0[lim] = s; + res.0[lim] = carry as u64; } SignedBigInt:: { magnitude: res, is_positive: self.is_positive } } else { @@ -246,42 +277,46 @@ impl SignedBigInt { if self.is_positive == rhs.is_positive { let mut res = BigInt::

::zero(); let mut carry: u8 = 0; - for i in 0..P { - let a = if i < N { self.magnitude.0[i] } else { 0u64 }; - let b = if i < M { rhs.magnitude.0[i] } else { 0u64 }; - let (s1, c1) = a.overflowing_add(b); + let overlap = core::cmp::min(core::cmp::min(N, M), P); + for i in 0..overlap { + let (s1, c1) = self.magnitude.0[i].overflowing_add(rhs.magnitude.0[i]); let (s2, c2) = s1.overflowing_add(carry as u64); res.0[i] = s2; carry = (c1 as u8) | (c2 as u8); } + let mut k = overlap; + if N > M { + let end = core::cmp::min(N, P); + while k < end { + let (s1, c1) = self.magnitude.0[k].overflowing_add(carry as u64); + res.0[k] = s1; + carry = c1 as u8; + k += 1; + } + } else if M > N { + let end = core::cmp::min(M, P); + while k < end { + let (s1, c1) = rhs.magnitude.0[k].overflowing_add(carry as u64); + res.0[k] = s1; + carry = c1 as u8; + k += 1; + } + } + if k < P { res.0[k] = carry as u64; } return SignedBigInt::

{ magnitude: res, is_positive: self.is_positive }; } // Case 2: different signs => subtract smaller magnitude from larger - let ord = { - let max_limbs = if N > M { N } else { M }; - let mut i = max_limbs; - let mut ordering = Ordering::Equal; - while i > 0 { - let idx = i - 1; - let a = if idx < N { self.magnitude.0[idx] } else { 0u64 }; - let b = if idx < M { rhs.magnitude.0[idx] } else { 0u64 }; - if a > b { ordering = Ordering::Greater; break; } - if a < b { ordering = Ordering::Less; break; } - i -= 1; - } - ordering - }; + let ord = self.cmp_magnitude_mixed(rhs); match ord { Ordering::Greater | Ordering::Equal => { // res_mag = self.mag - rhs.mag; sign = self.is_positive let mut res = BigInt::

::zero(); let mut borrow = false; - for i in 0..P { - let a = if i < N { self.magnitude.0[i] } else { 0u64 }; - let b = if i < M { rhs.magnitude.0[i] } else { 0u64 }; - let (d1, b1) = a.overflowing_sub(b); + let overlap = core::cmp::min(core::cmp::min(N, M), P); + for i in 0..overlap { + let (d1, b1) = self.magnitude.0[i].overflowing_sub(rhs.magnitude.0[i]); if borrow { let (d2, b2) = d1.overflowing_sub(1); res.0[i] = d2; @@ -291,16 +326,29 @@ impl SignedBigInt { borrow = b1; } } + let mut k = overlap; + if N > M { + let end = core::cmp::min(N, P); + while k < end { + if borrow { + let (d2, b2) = self.magnitude.0[k].overflowing_sub(1); + res.0[k] = d2; + borrow = b2; + } else { + res.0[k] = self.magnitude.0[k]; + } + k += 1; + } + } SignedBigInt::

{ magnitude: res, is_positive: self.is_positive } } Ordering::Less => { // res_mag = rhs.mag - self.mag; sign = rhs.is_positive let mut res = BigInt::

::zero(); let mut borrow = false; - for i in 0..P { - let a = if i < M { rhs.magnitude.0[i] } else { 0u64 }; - let b = if i < N { self.magnitude.0[i] } else { 0u64 }; - let (d1, b1) = a.overflowing_sub(b); + let overlap = core::cmp::min(core::cmp::min(N, M), P); + for i in 0..overlap { + let (d1, b1) = rhs.magnitude.0[i].overflowing_sub(self.magnitude.0[i]); if borrow { let (d2, b2) = d1.overflowing_sub(1); res.0[i] = d2; @@ -310,6 +358,20 @@ impl SignedBigInt { borrow = b1; } } + let mut k = overlap; + if M > N { + let end = core::cmp::min(M, P); + while k < end { + if borrow { + let (d2, b2) = rhs.magnitude.0[k].overflowing_sub(1); + res.0[k] = d2; + borrow = b2; + } else { + res.0[k] = rhs.magnitude.0[k]; + } + k += 1; + } + } SignedBigInt::

{ magnitude: res, is_positive: rhs.is_positive } } } @@ -328,8 +390,20 @@ impl SignedBigInt { pub fn fmadd_trunc(&self, rhs: &SignedBigInt, acc: &mut SignedBigInt

) { let prod_mag = self.magnitude.mul_trunc::(&rhs.magnitude); let prod_sign = self.is_positive == rhs.is_positive; - let prod = SignedBigInt::

{ magnitude: prod_mag, is_positive: prod_sign }; - acc.add_assign_in_place(&prod); + if acc.is_positive == prod_sign { + let _ = acc.magnitude.add_with_carry(&prod_mag); + } else { + match acc.magnitude.cmp(&prod_mag) { + Ordering::Greater | Ordering::Equal => { + let _ = acc.magnitude.sub_with_borrow(&prod_mag); + } + Ordering::Less => { + let old = core::mem::replace(&mut acc.magnitude, prod_mag); + let _ = acc.magnitude.sub_with_borrow(&old); + acc.is_positive = prod_sign; + } + } + } } } @@ -666,4 +740,34 @@ impl MulAssign<&SignedBigInt> for SignedBigInt { } } +// By-ref binary operator variants to avoid copying both operands +impl core::ops::Add for &SignedBigInt { + type Output = SignedBigInt; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.add_assign_in_place(rhs); + out + } +} + +impl core::ops::Sub for &SignedBigInt { + type Output = SignedBigInt; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.sub_assign_in_place(rhs); + out + } +} + +impl core::ops::Mul for &SignedBigInt { + type Output = SignedBigInt; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.mul_assign_in_place(rhs); + out + } +} diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 8656f4362..09794edc7 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -736,4 +736,178 @@ fn test_signed_truncated_mul_and_fmadd() { assert_eq!(acc.magnitude.0[0], expected_low.wrapping_add(1)); } +#[test] +fn test_signed_truncated_add_sub_mixed() { + use crate::biginteger::SignedBigInt as S; + // Same sign, different widths, ensure carry handling and sign preservation + let a = S::<2>::from_u128(0x0000_0000_0000_0002_FFFF_FFFF_FFFF_FFFF); + let b = S::<1>::from_u64(0x0000_0000_0000_0002); + let r = a.add_trunc_mixed::<1, 2>(&b); // 128-bit result + let expected = num_bigint::BigUint::from(0x0000_0000_0000_0002_FFFF_FFFF_FFFF_FFFFu128) + + num_bigint::BigUint::from(2u64); + assert_eq!(num_bigint::BigUint::from(r.magnitude), expected); + assert!(r.is_positive); + + // Different signs, |a| > |b|: result sign should be sign(a) + let a2 = S::<2>::from_u128(5000); + let b2 = S::<1>::from((3000u64, false)); // -3000 + let r2 = a2.add_trunc_mixed::<1, 2>(&b2); + assert!(r2.is_positive); + assert_eq!(r2.magnitude.0[0], 2000); + + // Different signs, |b| > |a|: result sign should be sign(b) + let a3 = S::<2>::from_u128(1000); + let b3 = S::<1>::from((3000u64, false)); // -3000 + let r3 = a3.add_trunc_mixed::<1, 2>(&b3); + assert!(!r3.is_positive); + assert_eq!(r3.magnitude.0[0], 2000); + + // sub_trunc_mixed basic checks + let a4 = S::<2>::from_u128(10000); + let b4 = S::<1>::from_u64(9999); + let r4 = a4.sub_trunc_mixed::<1, 2>(&b4); + assert!(r4.is_positive); + assert_eq!(r4.magnitude.0[0], 1); + + let a5 = S::<2>::from_u128(1000); + let b5 = S::<1>::from_u64(5000); + let r5 = a5.sub_trunc_mixed::<1, 2>(&b5); + assert!(!r5.is_positive); + assert_eq!(r5.magnitude.0[0], 4000); +} + +#[test] +fn test_signed_fmadd_trunc_mixed_width_and_signs() { + use crate::biginteger::SignedBigInt as S; + // Case 1: same sign => pure addition of magnitudes + let a = S::<2>::from_u128(30000); + let b = S::<1>::from_u64(7); + let mut acc = S::<2>::from_u128(1000000); + a.fmadd_trunc::<1, 2>(&b, &mut acc); // acc += 210000 + assert!(acc.is_positive); + assert_eq!(acc.magnitude.0[0] as u128 + ((acc.magnitude.0[1] as u128) << 64), 1210000u128); + + // Case 2: different sign, |prod| < |acc| => sign preserved + let a2 = S::<2>::from_u128(30000); + let b2 = S::<1>::from((7u64, false)); // -7 + let mut acc2 = S::<2>::from_u128(1000000); + a2.fmadd_trunc::<1, 2>(&b2, &mut acc2); // acc2 -= 210000 => 790000 + assert!(acc2.is_positive); + assert_eq!(acc2.magnitude.0[0] as u128 + ((acc2.magnitude.0[1] as u128) << 64), 790000u128); + + // Case 3: different sign, |prod| > |acc| => sign flips to prod_sign + let a3 = S::<2>::from_u128(300); + let b3 = S::<1>::from((7u64, false)); // -7 => prod = -2100 + let mut acc3 = S::<2>::from_u128(1000); + a3.fmadd_trunc::<1, 2>(&b3, &mut acc3); // 1000 - 2100 = -1100 + assert!(!acc3.is_positive); + assert_eq!(acc3.magnitude.0[0], 1100); +} + +#[test] +fn test_prop_add_sub_trunc_mixed_random() { + use crate::biginteger::SignedBigInt as S; + use ark_std::rand::Rng; + let mut rng = ark_std::test_rng(); + + // Helper to validate a single pair for given consts + macro_rules! run_case { + ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ + for _ in 0..$iters { + let a_mag: crate::biginteger::BigInt<$n> = UniformRand::rand(&mut rng); + let b_mag: crate::biginteger::BigInt<$m> = UniformRand::rand(&mut rng); + let a_pos = (rng.gen::() & 1) == 1; + let b_pos = (rng.gen::() & 1) == 1; + let a = S::<$n>::from_bigint(a_mag, a_pos); + let b = S::<$m>::from_bigint(b_mag, b_pos); + + // add_trunc_mixed + let r_add = a.add_trunc_mixed::<$m, $p>(&b); + let a_bu = num_bigint::BigUint::from(a.magnitude); + let b_bu = num_bigint::BigUint::from(b.magnitude); + let (exp_add_mag, exp_add_pos) = if a_pos == b_pos { + (&a_bu + &b_bu, a_pos) + } else if a_bu >= b_bu { + (&a_bu - &b_bu, a_pos) + } else { + (&b_bu - &a_bu, b_pos) + }; + let modulus = num_bigint::BigUint::from(1u8) << (64 * $p); + let exp_add_mag_mod = exp_add_mag % &modulus; + assert_eq!(num_bigint::BigUint::from(r_add.magnitude), exp_add_mag_mod); + if exp_add_mag_mod != num_bigint::BigUint::from(0u8) { + assert_eq!(r_add.is_positive, exp_add_pos); + } + + // sub_trunc_mixed: a - b + let r_sub = a.sub_trunc_mixed::<$m, $p>(&b); + let (exp_sub_mag, exp_sub_pos) = if a_pos != b_pos { + (&a_bu + &b_bu, a_pos) + } else if a_bu >= b_bu { + (&a_bu - &b_bu, a_pos) + } else { + (&b_bu - &a_bu, !a_pos) + }; + let exp_sub_mag_mod = exp_sub_mag % &modulus; + assert_eq!(num_bigint::BigUint::from(r_sub.magnitude), exp_sub_mag_mod); + if exp_sub_mag_mod != num_bigint::BigUint::from(0u8) { + assert_eq!(r_sub.is_positive, exp_sub_pos); + } + } + }}; + } + + run_case!(2, 3, 2, 200); + run_case!(3, 1, 2, 200); + run_case!(1, 2, 1, 200); +} + +#[test] +fn test_prop_fmadd_trunc_random() { + use crate::biginteger::SignedBigInt as S; + use ark_std::rand::Rng; + let mut rng = ark_std::test_rng(); + + macro_rules! run_case { + ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ + for _ in 0..$iters { + let a_mag: crate::biginteger::BigInt<$n> = UniformRand::rand(&mut rng); + let b_mag: crate::biginteger::BigInt<$m> = UniformRand::rand(&mut rng); + let acc_mag: crate::biginteger::BigInt<$p> = UniformRand::rand(&mut rng); + let a_pos = (rng.gen::() & 1) == 1; + let b_pos = (rng.gen::() & 1) == 1; + let acc_pos = (rng.gen::() & 1) == 1; + let a = S::<$n>::from_bigint(a_mag, a_pos); + let b = S::<$m>::from_bigint(b_mag, b_pos); + let mut acc = S::<$p>::from_bigint(acc_mag, acc_pos); + + // expected via BigUint with truncation of the product BEFORE combining signs + let a_bu = num_bigint::BigUint::from(a.magnitude); + let b_bu = num_bigint::BigUint::from(b.magnitude); + let acc_bu = num_bigint::BigUint::from(acc.magnitude); + let modulus = num_bigint::BigUint::from(1u8) << (64 * $p); + let prod_mod = (&a_bu * &b_bu) % &modulus; + let prod_pos = a_pos == b_pos; + let (exp_mag_mod, exp_pos) = if acc_pos == prod_pos { + ((acc_bu + &prod_mod) % &modulus, acc_pos) + } else if acc_bu >= prod_mod { + (acc_bu - &prod_mod, acc_pos) + } else { + (prod_mod - &acc_bu, prod_pos) + }; + + a.fmadd_trunc::<$m, $p>(&b, &mut acc); + + assert_eq!(num_bigint::BigUint::from(acc.magnitude), exp_mag_mod); + if exp_mag_mod != num_bigint::BigUint::from(0u8) { + assert_eq!(acc.is_positive, exp_pos); + } + } + }}; + } + + run_case!(2, 1, 2, 200); + run_case!(3, 2, 2, 200); +} + } diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index d194ee743..405a11797 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -844,10 +844,9 @@ impl, const N: usize> Fp, N> { /// NEW! Construct a new field element from a BigInt /// which is in montgomery form and just needs to be reduced /// via a barrett reduction. - #[inline] + #[inline(always)] pub fn from_unchecked_nplus1(element: BigInt<{ NPLUS1 }>) -> Self { debug_assert!(NPLUS1 == N + 1); - // Barrett reduction let r = barrett_reduce_nplus1_to_n::(element); Self::new_unchecked(r) } @@ -869,29 +868,6 @@ impl, const N: usize> Fp, N> { Self::new_unchecked(r2) } - /// Construct a new field element from a BigInt which is in - /// Montgomery form and should be reduced via two Barrett rounds then a final combine. - #[inline] - pub fn from_unchecked_nplus3( - element: BigInt<{ NPLUS3 }>, - ) -> Self { - debug_assert!(NPLUS1 == N + 1); - debug_assert!(NPLUS2 == N + 2); - debug_assert!(NPLUS3 == N + 3); - - // Reduce the upper N+2 limbs of `element` to N limbs - let c_hi = BigInt::(element.0[1..NPLUS3].try_into().unwrap()); - let c_hi_hi = BigInt::(c_hi.0[1..NPLUS2].try_into().unwrap()); - let r1 = barrett_reduce_nplus1_to_n::(c_hi_hi); - let c_hi_merged = nplus1_pair_low_to_bigint::((c_hi.0[0], r1.0)); - let r_hi = barrett_reduce_nplus1_to_n::(c_hi_merged); - - // Combine the original lowest limb with r_hi and perform final Barrett reduction - let c_final = nplus1_pair_low_to_bigint::((element.0[0], r_hi.0)); - let r_final = barrett_reduce_nplus1_to_n::(c_final); - Self::new_unchecked(r_final) - } - const fn const_is_zero(&self) -> bool { self.0.const_is_zero() } @@ -906,9 +882,8 @@ impl, const N: usize> Fp, N> { } /// Interpret a set of limbs (along with a sign) as a field element. - /// For *internal* use only; please use the `ark_ff::MontFp` macro instead - /// of this method - #[doc(hidden)] + /// The input limbs are interpreted little-endian. For public use; prefer + /// the `ark_ff::MontFp` macro for constant contexts. pub const fn from_sign_and_limbs(is_positive: bool, limbs: &[u64]) -> Self { let mut repr = BigInt::([0; N]); assert!(limbs.len() <= N); @@ -1267,6 +1242,118 @@ impl, const N: usize> Fp, N> { const fn sub_with_borrow(a: &BigInt, b: &BigInt) -> BigInt { a.const_sub_with_borrow(b).0 } + + /// Helper function: multiply a BigInt by u64 and accumulate into BigInt + /// This avoids creating temporary BigInt objects. + #[inline(always)] + #[unroll_for_loops(8)] + fn mul_u64_accumulate( + acc: &mut BigInt, + a: &BigInt, + b: u64 + ) { + debug_assert!(NPLUS1 == N + 1); + use crate::biginteger::arithmetic as fa; + + let mut carry = 0u64; + for i in 0..N { + acc.0[i] = fa::mac_with_carry(acc.0[i], a.0[i], b, &mut carry); + } + + // Add final carry to the high limb + let final_carry = fa::adc(&mut acc.0[N], carry, 0); + debug_assert!(final_carry == 0, "overflow in mul_u64_accumulate"); + } + + /// Compute a linear combination of field elements with u64 coefficients. + /// Performs unreduced accumulation in BigInt, then one final reduction. + /// This is more efficient than individual multiplications and additions. + #[inline(always)] + pub fn linear_combination_u64( + pairs: &[(Self, u64)] + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!(!pairs.is_empty(), "linear_combination_u64 requires at least one pair"); + + // Start with first term + let mut acc = pairs[0].0.0.mul_u64_w_carry::(pairs[0].1); + + // Accumulate remaining terms using multiply-accumulate to avoid temporaries + for (a, b) in &pairs[1..] { + Self::mul_u64_accumulate::(&mut acc, &a.0, *b); + } + + Self::from_unchecked_nplus1::(acc) + } + + /// Compute a linear combination with separate positive and negative terms. + /// Each term is multiplied by a u64 coefficient, then positive and negative + /// sums are computed separately and subtracted. One final reduction is performed. + #[inline(always)] + pub fn linear_combination_i64( + pos: &[(Self, u64)], + neg: &[(Self, u64)] + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!(!pos.is_empty(), "linear_combination_i64 requires at least one positive term"); + debug_assert!(!neg.is_empty(), "linear_combination_i64 requires at least one negative term"); + + // Compute unreduced positive sum + let mut pos_lc = pos[0].0.0.mul_u64_w_carry::(pos[0].1); + for (a, b) in &pos[1..] { + Self::mul_u64_accumulate::(&mut pos_lc, &a.0, *b); + } + + // Compute unreduced negative sum + let mut neg_lc = neg[0].0.0.mul_u64_w_carry::(neg[0].1); + for (a, b) in &neg[1..] { + Self::mul_u64_accumulate::(&mut neg_lc, &a.0, *b); + } + + // Subtract and reduce once + match pos_lc.cmp(&neg_lc) { + core::cmp::Ordering::Greater => { + let borrow = pos_lc.sub_with_borrow(&neg_lc); + debug_assert!(!borrow, "borrow in linear_combination_i64"); + Self::from_unchecked_nplus1::(pos_lc) + } + core::cmp::Ordering::Less => { + let borrow = neg_lc.sub_with_borrow(&pos_lc); + debug_assert!(!borrow, "borrow in linear_combination_i64"); + -Self::from_unchecked_nplus1::(neg_lc) + } + core::cmp::Ordering::Equal => Self::zero(), + } + } + + /// Optimized version for exactly 2 terms: a₁×b₁ + a₂×b₂ + /// Avoids slice overhead and loop setup costs. + #[inline(always)] + pub fn linear_combination_u64_2( + a1: &Self, b1: u64, + a2: &Self, b2: u64 + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + + let mut acc = a1.0.mul_u64_w_carry::(b1); + Self::mul_u64_accumulate::(&mut acc, &a2.0, b2); + Self::from_unchecked_nplus1::(acc) + } + + /// Optimized version for exactly 3 terms: a₁×b₁ + a₂×b₂ + a₃×b₃ + #[inline(always)] + pub fn linear_combination_u64_3( + a1: &Self, b1: u64, + a2: &Self, b2: u64, + a3: &Self, b3: u64 + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + + let mut acc = a1.0.mul_u64_w_carry::(b1); + Self::mul_u64_accumulate::(&mut acc, &a2.0, b2); + Self::mul_u64_accumulate::(&mut acc, &a3.0, b3); + Self::from_unchecked_nplus1::(acc) + } } #[inline(always)] @@ -1500,7 +1587,7 @@ fn barrett_reduce_nplus1_to_n, const N: usize, const NPLUS1: us // unroll T::MODULUS_TIMES_2_NPLUS1 from ([u64; N], u64) to BigInt let mut m2p = nplus1_pair_high_to_bigint::(T::MODULUS_TIMES_2_NPLUS1); // Compute m * 2p (N+1 limbs) - BigInt::mul_u64_in_place(&mut m2p, m); + m2p.mul_u64_in_place(m); // I really have no idea why the following sequence of operations // is significantly faster than a simple BigInt sub operation. diff --git a/test-curves/benches/small_mul.rs b/test-curves/benches/small_mul.rs index 5e3e9cdaa..eff9e544c 100644 --- a/test-curves/benches/small_mul.rs +++ b/test-curves/benches/small_mul.rs @@ -1,6 +1,6 @@ // This bench prefers bn254; if not enabled, provide a no-op main #[cfg(feature = "bn254")] -use ark_ff::UniformRand; +use ark_ff::{UniformRand, BigInteger}; #[cfg(feature = "bn254")] use ark_std::rand::{rngs::StdRng, Rng, SeedableRng}; #[cfg(feature = "bn254")] @@ -46,8 +46,43 @@ fn mul_small_bench(c: &mut Criterion) { .map(|_| Fr::rand(&mut rng)) .collect::>(); + // Generate test data for reduction benchmarks + use ark_ff::BigInt; + // Extract BigInt<4> from Fr elements for mul_u64_w_carry benchmark + let a_bigints = a_s.iter().map(|a| a.0).collect::>(); + + // For Montgomery reduction: 2N-limb inputs (N=4 for bn254, so 2N=8) + let bigint_2n_s = (0..SAMPLES) + .map(|_| BigInt::<8>([ + rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), + rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), + ])) + .collect::>(); + + // For Barrett reductions: N+1, N+2, N+3 limb inputs + let bigint_nplus1_s = (0..SAMPLES) + .map(|_| BigInt::<5>([ + rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), + ])) + .collect::>(); + + let bigint_nplus2_s = (0..SAMPLES) + .map(|_| BigInt::<6>([ + rng.gen::(), rng.gen::(), rng.gen::(), + rng.gen::(), rng.gen::(), rng.gen::(), + ])) + .collect::>(); + + let bigint_nplus3_s = (0..SAMPLES) + .map(|_| BigInt::<7>([ + rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), + rng.gen::(), rng.gen::(), rng.gen::(), + ])) + .collect::>(); + let mut group = c.benchmark_group("Fr Arithmetic Comparison"); + // Uncommented to compare with mul_u64_w_carry + Barrett reduction group.bench_function("mul_u64", |bench| { let mut i = 0; bench.iter(|| { @@ -57,110 +92,236 @@ fn mul_small_bench(c: &mut Criterion) { }) }); - group.bench_function("mul_i64", |bench| { + // Benchmark just the multiplication phase (without Barrett reduction) + group.bench_function("mul_u64_w_carry (multiplication only)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_i64::<5>(b_i64_s[i])) + // This is just the multiplication step, returns BigInt<5> + criterion::black_box(a_bigints[i].mul_u64_w_carry::<5>(b_u64_s[i])) }) }); - // Note: results might be worse than in real applications due to branch prediction being wrong - // 50% of the time - group.bench_function("mul_u128", |bench| { + // group.bench_function("mul_i64", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_i64::<5>(b_i64_s[i])) + // }) + // }); + + // // Note: results might be worse than in real applications due to branch prediction being wrong + // // 50% of the time + // group.bench_function("mul_u128", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // // bn254 Fr has N=4 limbs => N+1 = 5, N+2 = 6 + // criterion::black_box(a_s[i].mul_u128::<5, 6>(b_u128_s[i])) + // }) + // }); + + // group.bench_function("mul_i128", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_i128::<5, 6>(b_i128_s[i])) + // }) + // }); + + group.bench_function("standard mul (Fr * Fr)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - // bn254 Fr has N=4 limbs => N+1 = 5, N+2 = 6 - criterion::black_box(a_s[i].mul_u128::<5, 6>(b_u128_s[i])) + criterion::black_box(a_s[i] * b_fr_s[i]) }) }); - group.bench_function("mul_i128", |bench| { + // Bench specialized high-limb RHS fastpaths (K = 1, 2) + // Construct BigInt with random high limbs for K=1 and K=2 + let b_k1_bigint = (0..SAMPLES) + .map(|_| BigInt::<1>([rng.gen::()])) + .collect::>(); + let b_k2_bigint = (0..SAMPLES) + .map(|_| BigInt::<2>([rng.gen::(), rng.gen::()])) + .collect::>(); + + // group.bench_function("mul_assign_hi_bigint::<1>", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // let mut x = a_s[i]; + // x.mul_assign_hi_bigint::<1>(&b_k1_bigint[i]); + // criterion::black_box(x) + // }) + // }); + + // group.bench_function("mul_assign_hi_bigint::<2>", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // let mut x = a_s[i]; + // x.mul_assign_hi_bigint::<2>(&b_k2_bigint[i]); + // criterion::black_box(x) + // }) + // }); + + // group.bench_function("mul_hi_bigint::<1>", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_hi_bigint::<1>(&b_k1_bigint[i])) + // }) + // }); + + // group.bench_function("mul_hi_bigint::<2>", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_hi_bigint::<2>(&b_k2_bigint[i])) + // }) + // }); + + // group.bench_function("mul_u128 (u64 inputs)", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // // Call mul_u128 but provide a u64 input cast to u128 + // criterion::black_box(a_s[i].mul_u128::<5, 6>(b_u64_as_u128_s[i])) + // }) + // }); + + // Benchmark the auxiliary function directly (assuming it's made public) + // Note: Requires mul_u128_aux to be pub in montgomery_backend.rs + // Need to import it if not already done via wildcard/specific import + // Let's assume it's accessible via a_s[i].mul_u128_aux(...) for now + // group.bench_function("mul_u128_aux (u128 inputs)", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_u128_aux::<5, 6>(b_u128_s[i])) + // }) + // }); + + group.bench_function("Addition (Fr + Fr)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_i128::<5, 6>(b_i128_s[i])) + criterion::black_box(a_s[i] + c_s[i]) }) }); - group.bench_function("standard mul (Fr * Fr)", |bench| { + // Reduction benchmarks + group.bench_function("montgomery_reduce_2n", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i] * b_fr_s[i]) + criterion::black_box(Fr::montgomery_reduce_2n::<8>(bigint_2n_s[i])) }) }); - // Bench specialized trailing-zero RHS fastpaths (K = 1, 2) - // Construct b' with K trailing zeros in limbs for K=1 and K=2 - let mut b_k1 = b_fr_s.clone(); - for b in &mut b_k1 { (b.0).0[0] = 0; } - let mut b_k2 = b_fr_s.clone(); - for b in &mut b_k2 { (b.0).0[0] = 0; (b.0).0[1] = 0; } + // group.bench_function("from_unchecked_nplus1 (Barrett N+1)", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(Fr::from_unchecked_nplus1::<5>(bigint_nplus1_s[i])) + // }) + // }); - group.bench_function("mul_assign_rhs_trailing_zeros::<1>", |bench| { + // group.bench_function("from_unchecked_nplus2 (Barrett N+2)", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(Fr::from_unchecked_nplus2::<5, 6>(bigint_nplus2_s[i])) + // }) + // }); + + // group.bench_function("from_unchecked_nplus3 (Barrett N+3)", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(Fr::from_unchecked_nplus3::<5, 6, 7>(bigint_nplus3_s[i])) + // }) + // }); + + // Linear combination benchmarks + group.bench_function("linear_combination_u64 (2 terms)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - let mut x = a_s[i]; - x.mul_assign_rhs_trailing_zeros::<1>(&b_k1[i]); - criterion::black_box(x) + let pairs = [(a_s[i], b_u64_s[i]), (c_s[i], b_u64_s[(i + 1) % SAMPLES])]; + criterion::black_box(Fr::linear_combination_u64::<5>(&pairs)) }) }); - group.bench_function("mul_assign_rhs_trailing_zeros::<2>", |bench| { + group.bench_function("linear_combination_u64_2 (optimized)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - let mut x = a_s[i]; - x.mul_assign_rhs_trailing_zeros::<2>(&b_k2[i]); - criterion::black_box(x) + criterion::black_box(Fr::linear_combination_u64_2::<5>( + &a_s[i], b_u64_s[i], + &c_s[i], b_u64_s[(i + 1) % SAMPLES] + )) }) }); - group.bench_function("mul_rhs_trailing_zeros::<1>", |bench| { + group.bench_function("linear_combination_u64 (4 terms)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_rhs_trailing_zeros::<1>(&b_k1[i])) + let pairs = [ + (a_s[i], b_u64_s[i]), + (c_s[i], b_u64_s[(i + 1) % SAMPLES]), + (a_s[(i + 2) % SAMPLES], b_u64_s[(i + 2) % SAMPLES]), + (c_s[(i + 3) % SAMPLES], b_u64_s[(i + 3) % SAMPLES]), + ]; + criterion::black_box(Fr::linear_combination_u64::<5>(&pairs)) }) }); - group.bench_function("mul_rhs_trailing_zeros::<2>", |bench| { + group.bench_function("linear_combination_u64_3 (optimized)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_rhs_trailing_zeros::<2>(&b_k2[i])) + criterion::black_box(Fr::linear_combination_u64_3::<5>( + &a_s[i], b_u64_s[i], + &c_s[i], b_u64_s[(i + 1) % SAMPLES], + &a_s[(i + 2) % SAMPLES], b_u64_s[(i + 2) % SAMPLES] + )) }) }); - group.bench_function("mul_u128 (u64 inputs)", |bench| { + group.bench_function("linear_combination_i64 (2+2 terms)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - // Call mul_u128 but provide a u64 input cast to u128 - criterion::black_box(a_s[i].mul_u128::<5, 6>(b_u64_as_u128_s[i])) + let pos = [(a_s[i], b_u64_s[i]), (c_s[i], b_u64_s[(i + 1) % SAMPLES])]; + let neg = [(a_s[(i + 2) % SAMPLES], b_u64_s[(i + 2) % SAMPLES]), + (c_s[(i + 3) % SAMPLES], b_u64_s[(i + 3) % SAMPLES])]; + criterion::black_box(Fr::linear_combination_i64::<5>(&pos, &neg)) }) }); - // Benchmark the auxiliary function directly (assuming it's made public) - // Note: Requires mul_u128_aux to be pub in montgomery_backend.rs - // Need to import it if not already done via wildcard/specific import - // Let's assume it's accessible via a_s[i].mul_u128_aux(...) for now - group.bench_function("mul_u128_aux (u128 inputs)", |bench| { + // Comparison: naive approach vs linear combination (using mul_u64 for fair comparison) + group.bench_function("naive 2-term combination", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_u128_aux::<5, 6>(b_u128_s[i])) + let term1 = a_s[i].mul_u64::<5>(b_u64_s[i]); + let term2 = c_s[i].mul_u64::<5>(b_u64_s[(i + 1) % SAMPLES]); + criterion::black_box(term1 + term2) }) }); - group.bench_function("Addition (Fr + Fr)", |bench| { + group.bench_function("naive 4-term combination", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i] + c_s[i]) + let term1 = a_s[i].mul_u64::<5>(b_u64_s[i]); + let term2 = c_s[i].mul_u64::<5>(b_u64_s[(i + 1) % SAMPLES]); + let term3 = a_s[(i + 2) % SAMPLES].mul_u64::<5>(b_u64_s[(i + 2) % SAMPLES]); + let term4 = c_s[(i + 3) % SAMPLES].mul_u64::<5>(b_u64_s[(i + 3) % SAMPLES]); + criterion::black_box(term1 + term2 + term3 + term4) }) }); From 6c2c89320c7a910620160e4f3152f30034cd8729 Mon Sep 17 00:00:00 2001 From: markosg04 Date: Wed, 27 Aug 2025 19:58:31 -0400 Subject: [PATCH 14/38] feat: initial sz check scheme --- jolt-optimizations/Cargo.toml | 8 + .../benches/expression_bench.rs | 61 ++++++++ jolt-optimizations/benches/sz_check_bench.rs | 42 +++++ jolt-optimizations/src/expression.rs | 78 ++++++++++ jolt-optimizations/src/fq12_poly.rs | 144 ++++++++++++++++++ jolt-optimizations/src/lib.rs | 3 + jolt-optimizations/src/sz_check.rs | 91 +++++++++++ jolt-optimizations/tests/sz_check_tests.rs | 63 ++++++++ 8 files changed, 490 insertions(+) create mode 100644 jolt-optimizations/benches/expression_bench.rs create mode 100644 jolt-optimizations/benches/sz_check_bench.rs create mode 100644 jolt-optimizations/src/expression.rs create mode 100644 jolt-optimizations/src/fq12_poly.rs create mode 100644 jolt-optimizations/src/sz_check.rs create mode 100644 jolt-optimizations/tests/sz_check_tests.rs diff --git a/jolt-optimizations/Cargo.toml b/jolt-optimizations/Cargo.toml index 49aba6276..db228eb6f 100644 --- a/jolt-optimizations/Cargo.toml +++ b/jolt-optimizations/Cargo.toml @@ -48,6 +48,14 @@ harness = false name = "g1_scalar_multiplication" harness = false +[[bench]] +name = "sz_check_bench" +harness = false + +[[bench]] +name = "expression_bench" +harness = false + [[bench]] name = "vector_scalar_mul_add_gamma_g2" harness = false diff --git a/jolt-optimizations/benches/expression_bench.rs b/jolt-optimizations/benches/expression_bench.rs new file mode 100644 index 000000000..2ea696fe9 --- /dev/null +++ b/jolt-optimizations/benches/expression_bench.rs @@ -0,0 +1,61 @@ +use ark_bn254::{Fq, Fq12}; +use ark_ff::{Field, PrimeField, UniformRand}; +use ark_std::test_rng; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use jolt_optimizations::expression::{Expression, Term}; +use jolt_optimizations::sz_check::batch_verify; + +fn benchmark_expression_verification(c: &mut Criterion) { + let mut rng = test_rng(); + + let configs = vec![(15, 6)]; + + for (n, m) in configs { + // Generate n expressions, each with m terms + let mut all_expressions = Vec::new(); + let mut all_expected_results = Vec::new(); + + for _ in 0..n { + let mut terms = Vec::new(); + let mut expected = Fq12::from(1u64); + + for _ in 0..m { + let base = Fq12::rand(&mut rng); + let exponent = Fq::rand(&mut rng); + terms.push(Term { base, exponent }); + expected *= base.pow(exponent.into_bigint()); + } + + all_expressions.push(Expression::new(terms)); + all_expected_results.push(expected); + } + + let mut all_products = Vec::new(); + for expr in &all_expressions { + all_products.extend(expr.to_products()); + } + + let r = Fq::rand(&mut rng); + + // naive computation + c.bench_function(&format!("naive_expr_{}x{}", n, m), |bench| { + bench.iter(|| { + for i in 0..n { + let mut result = Fq12::from(1u64); + for term in &all_expressions[i].terms { + result *= black_box(term.base.pow(term.exponent.into_bigint())); + } + black_box(result); + } + }); + }); + + // SZ check verification + c.bench_function(&format!("sz_check_expr_{}x{}", n, m), |bench| { + bench.iter(|| black_box(batch_verify(&all_products, &r))); + }); + } +} + +criterion_group!(benches, benchmark_expression_verification,); +criterion_main!(benches); diff --git a/jolt-optimizations/benches/sz_check_bench.rs b/jolt-optimizations/benches/sz_check_bench.rs new file mode 100644 index 000000000..92edf1599 --- /dev/null +++ b/jolt-optimizations/benches/sz_check_bench.rs @@ -0,0 +1,42 @@ +use ark_bn254::{Fq, Fq12}; +use ark_ff::UniformRand; +use ark_std::test_rng; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use jolt_optimizations::sz_check::{batch_verify, Product}; + +fn benchmark_sz_check(c: &mut Criterion) { + let mut rng = test_rng(); + let sizes = vec![100000]; + + for k in sizes { + let mut products = Vec::new(); + let mut a_values = Vec::new(); + let mut b_values = Vec::new(); + + for _ in 0..k { + let a = Fq12::rand(&mut rng); + let b = Fq12::rand(&mut rng); + let c = a * b; + a_values.push(a); + b_values.push(b); + products.push(Product::new(a, b, c)); + } + + let r = Fq::rand(&mut rng); + + c.bench_function(&format!("naive_verify_{}", k), |bench| { + bench.iter(|| { + for i in 0..k { + let _ = black_box(a_values[i] * b_values[i]); + } + }); + }); + + c.bench_function(&format!("sz_check_{}", k), |bench| { + bench.iter(|| black_box(batch_verify(&products, &r))); + }); + } +} + +criterion_group!(benches, benchmark_sz_check); +criterion_main!(benches); diff --git a/jolt-optimizations/src/expression.rs b/jolt-optimizations/src/expression.rs new file mode 100644 index 000000000..6890c855b --- /dev/null +++ b/jolt-optimizations/src/expression.rs @@ -0,0 +1,78 @@ +use crate::sz_check::Product; +use ark_bn254::{Fq, Fq12}; +use ark_ff::{BigInteger, Field, One, PrimeField}; + +pub struct Term { + pub base: Fq12, + pub exponent: Fq, +} + +pub struct Expression { + pub terms: Vec, +} + +impl Expression { + pub fn new(terms: Vec) -> Self { + Self { terms } + } + + pub fn to_products(&self) -> Vec { + let mut products = Vec::new(); + let mut current_result = Fq12::one(); + + for term in &self.terms { + let term_value = term.base.pow(term.exponent.into_bigint()); + let term_products = exponentiate_to_products(term.base, term.exponent); + + products.extend(term_products); + + if current_result != Fq12::one() { + // Multiply this term's result with the accumulated result + let new_result = current_result * term_value; + products.push(Product::new(current_result, term_value, new_result)); + current_result = new_result; + } else { + current_result = term_value; + } + } + + products + } +} + +fn exponentiate_to_products(base: Fq12, exponent: Fq) -> Vec { + let mut products = Vec::new(); + + let bigint = exponent.into_bigint(); + let exp_bits = bigint.to_bits_le(); + + let last_one = exp_bits.iter().rposition(|&b| b); + + if last_one.is_none() { + return vec![]; + } + + let last_one = last_one.unwrap(); + + if last_one == 0 { + return vec![]; + } + + let mut current_power = base; + let mut result = if exp_bits[0] { base } else { Fq12::one() }; + + // square and multiply + for i in 1..=last_one { + let squared = current_power * current_power; + products.push(Product::new(current_power, current_power, squared)); + current_power = squared; + + if exp_bits[i] { + let new_result = result * current_power; + products.push(Product::new(result, current_power, new_result)); + result = new_result; + } + } + + products +} diff --git a/jolt-optimizations/src/fq12_poly.rs b/jolt-optimizations/src/fq12_poly.rs new file mode 100644 index 000000000..57a4f6561 --- /dev/null +++ b/jolt-optimizations/src/fq12_poly.rs @@ -0,0 +1,144 @@ +//! Fq12 polynomial operations and conversions for BN254 +//! +//! This module provides: +//! - Conversion between Fq12 field elements and polynomial representations +//! - Polynomial arithmetic operations over Fq[X] +//! - Evaluation and manipulation of the minimal polynomial g(X) = X^12 - 18X^6 + 82 + +use ark_bn254::{Fq, Fq12}; +use ark_ff::{Field, One, Zero}; + +/// Flatten Fq12 to 12 base-field coefficients for a(X)=Σ c_i X^i, X=w, +/// with the relation g(X) = X^12 - 18 X^6 + 82. +/// +/// The BN254 Fq12 field is constructed as a tower extension: +/// - Fq2 = Fq[u]/(u^2 + 1) +/// - Fq6 = Fq2[v]/(v^3 - (9 + u)) +/// - Fq12 = Fq6[w]/(w^2 - v) +/// +/// This function maps an Fq12 element to its polynomial representation +/// in Fq[X] where X = w, using the mapping: +/// (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6}, for k∈{0..5}. +/// @TODO(markosg04) provide proof? +pub fn fq12_to_poly12_coeffs(a: &Fq12) -> [Fq; 12] { + let nine = Fq::from(9u64); + let mut c = [Fq::zero(); 12]; + + // (term, k) pairs mapping Fq12 basis elements to powers of w: + // 1, v, v^2, w, v·w, v^2·w ↔ w^0, w^2, w^4, w^1, w^3, w^5 + let terms = [ + (&a.c0.c0, 0usize), // 1 → w^0 + (&a.c0.c1, 2usize), // v → w^2 + (&a.c0.c2, 4usize), // v^2 → w^4 + (&a.c1.c0, 1usize), // w → w^1 + (&a.c1.c1, 3usize), // v·w → w^3 + (&a.c1.c2, 5usize), // v^2·w → w^5 + ]; + + for (fp2, k) in terms { + let x = fp2.c0; // coefficient of 1 in Fp2 + let y = fp2.c1; // coefficient of u in Fp2 (with u^2 = -1) + // Apply the mapping: (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6} + c[k] += x - nine * y; + c[k + 6] += y; + } + c +} + +/// Evaluate g(X) = X^12 - 18 X^6 + 82 at a given point r. +pub fn g_eval(r: &Fq) -> Fq { + let r2 = r.square(); // r^2 + let r3 = r2 * r; // r^3 + let r6 = r3.square(); // r^6 + let r12 = r6.square(); // r^12 + r12 - (Fq::from(18u64) * r6) + Fq::from(82u64) +} + +/// Horner evaluation for arbitrary-degree polynomial. +pub fn eval_poly_vec(coeffs: &[Fq], r: &Fq) -> Fq { + let mut acc = Fq::zero(); + for &c in coeffs.iter().rev() { + acc *= r; + acc += c; + } + acc +} + +/// Add polynomial b to polynomial a in place. +pub fn poly_add_in_place(a: &mut Vec, b: &[Fq]) { + if b.len() > a.len() { + a.resize(b.len(), Fq::zero()); + } + for i in 0..b.len() { + a[i] += b[i]; + } +} + +/// Subtract polynomial b from polynomial a in place. +pub fn poly_sub_in_place(a: &mut Vec, b: &[Fq]) { + if b.len() > a.len() { + a.resize(b.len(), Fq::zero()); + } + for i in 0..b.len() { + a[i] -= b[i]; + } +} + +/// Multiply two polynomials using convolution. +pub fn poly_mul(a: &[Fq], b: &[Fq]) -> Vec { + if a.is_empty() || b.is_empty() { + return vec![]; + } + let mut out = vec![Fq::zero(); a.len() + b.len() - 1]; + for i in 0..a.len() { + for j in 0..b.len() { + out[i + j] += a[i] * b[j]; + } + } + out +} + +/// Polynomial long division by a monic divisor. +pub fn poly_div_rem_monic(mut dividend: Vec, g: &[Fq]) -> (Vec, Vec) { + assert!(!g.is_empty(), "divisor g must be non-empty"); + assert!( + g.last().unwrap().is_one(), + "divisor g must be monic (leading coefficient = 1)" + ); + + if dividend.is_empty() || dividend.len() < g.len() { + return (vec![], dividend); + } + + let n = dividend.len() - 1; + let m = g.len() - 1; // deg g + let mut q = vec![Fq::zero(); n - m + 1]; + + for k in (m..=n).rev() { + let lead = dividend[k]; // since g is monic, this is the quotient coefficient + q[k - m] = lead; + if lead.is_zero() { + continue; + } + // subtract lead * x^{k-m} * g from dividend + for j in 0..=m { + dividend[k - m + j] -= lead * g[j]; + } + } + + // trim trailing zeros from remainder + while let Some(true) = dividend.last().map(|c| c.is_zero()) { + dividend.pop(); + } + + (q, dividend) +} + +/// Build the coefficients for g(X) = X^12 - 18 X^6 + 82. +pub fn g_coeffs() -> Vec { + let mut g = vec![Fq::zero(); 13]; + g[0] = Fq::from(82u64); + g[6] = -Fq::from(18u64); + g[12] = Fq::one(); + g +} diff --git a/jolt-optimizations/src/lib.rs b/jolt-optimizations/src/lib.rs index 5ba72de8f..70898cba7 100644 --- a/jolt-optimizations/src/lib.rs +++ b/jolt-optimizations/src/lib.rs @@ -15,8 +15,11 @@ pub mod decomp_4d; pub mod dory_g1; pub mod dory_g2; pub mod dory_utils; +pub mod expression; +pub mod fq12_poly; pub mod frobenius; pub mod glv_two; +pub mod sz_check; mod glv_four; pub use glv_four::{ diff --git a/jolt-optimizations/src/sz_check.rs b/jolt-optimizations/src/sz_check.rs new file mode 100644 index 000000000..f6e88ee0d --- /dev/null +++ b/jolt-optimizations/src/sz_check.rs @@ -0,0 +1,91 @@ +use std::panic; + +use crate::fq12_poly::{fq12_to_poly12_coeffs, g_coeffs, poly_div_rem_monic, poly_mul}; +use ark_bn254::{Fq, Fq12}; +use ark_ff::{Field, Zero}; + +pub struct Product { + pub a: Fq12, + pub b: Fq12, + pub c: Fq12, + pub quotient: Vec, +} + +impl Product { + pub fn new(a: Fq12, b: Fq12, c: Fq12) -> Self { + let a_poly = fq12_to_poly12_coeffs(&a); + let b_poly = fq12_to_poly12_coeffs(&b); + let c_poly = fq12_to_poly12_coeffs(&c); + + let mut ab = poly_mul(&a_poly, &b_poly); + for i in 0..c_poly.len().min(ab.len()) { + ab[i] -= c_poly[i]; + } + + let (quotient, remainder) = poly_div_rem_monic(ab, &g_coeffs()); + + if !remainder.is_empty() && remainder.iter().any(|r| !r.is_zero()) { + panic!("invalid product: remainder is non-zero") + } + + Self { a, b, c, quotient } + } +} + +fn compute_r_powers(r: &Fq) -> [Fq; 12] { + let mut powers = [Fq::zero(); 12]; + powers[0] = Fq::from(1u64); + for i in 1..12 { + powers[i] = powers[i - 1] * r; + } + powers +} + +fn eval_with_powers(coeffs: &[Fq; 12], r_powers: &[Fq; 12]) -> Fq { + let mut result = Fq::zero(); + for i in 0..12 { + result += coeffs[i] * r_powers[i]; + } + result +} + +pub fn g_eval_optimized(r: &Fq) -> Fq { + let r2 = r.square(); + let r3 = r2 * r; + let r6 = r3.square(); + let r12 = r6.square(); + r12 - Fq::from(18u64) * r6 + Fq::from(82u64) +} + +pub fn batch_verify(products: &[Product], r: &Fq) -> bool { + let r_powers = compute_r_powers(r); + let g_r = g_eval_optimized(r); + + for product in products { + let a_coeffs = fq12_to_poly12_coeffs(&product.a); + let b_coeffs = fq12_to_poly12_coeffs(&product.b); + let c_coeffs = fq12_to_poly12_coeffs(&product.c); + + let a_r = eval_with_powers(&a_coeffs, &r_powers); + let b_r = eval_with_powers(&b_coeffs, &r_powers); + let c_r = eval_with_powers(&c_coeffs, &r_powers); + + let lhs = a_r * b_r - c_r; + + let mut q_r = Fq::zero(); + for (i, coeff) in product.quotient.iter().enumerate() { + if i < 12 { + q_r += *coeff * r_powers[i]; + } else { + panic!("this can't happen") + } + } + let rhs = q_r * g_r; + + if lhs != rhs { + return false; + } + } + + true +} diff --git a/jolt-optimizations/tests/sz_check_tests.rs b/jolt-optimizations/tests/sz_check_tests.rs new file mode 100644 index 000000000..95bb640a1 --- /dev/null +++ b/jolt-optimizations/tests/sz_check_tests.rs @@ -0,0 +1,63 @@ +use ark_bn254::{Fq, Fq12}; +use ark_ff::{Field, PrimeField, UniformRand}; +use ark_std::test_rng; +use jolt_optimizations::expression::{Expression, Term}; +use jolt_optimizations::sz_check::{batch_verify, Product}; + +#[test] +fn test_large_batch() { + let mut rng = test_rng(); + let k = 100000; + + let mut products = Vec::new(); + for _ in 0..k { + let a = Fq12::rand(&mut rng); + let b = Fq12::rand(&mut rng); + let c = a * b; + products.push(Product::new(a, b, c)); + } + + let r = Fq::rand(&mut rng); + + assert!(batch_verify(&products, &r)); +} + +#[test] +fn test_expression_to_sz_check() { + let mut rng = test_rng(); + let a1 = Fq12::rand(&mut rng); + let c1 = Fq::rand(&mut rng); + + let a2 = Fq12::rand(&mut rng); + let c2 = Fq::rand(&mut rng); + + let a3 = Fq12::rand(&mut rng); + let c3 = Fq::rand(&mut rng); + + let expected = a1.pow(c1.into_bigint()) * a2.pow(c2.into_bigint()) * a3.pow(c3.into_bigint()); + + let expr = Expression::new(vec![ + Term { + base: a1, + exponent: c1, + }, + Term { + base: a2, + exponent: c2, + }, + Term { + base: a3, + exponent: c3, + }, + ]); + + let products = expr.to_products(); + + let r = Fq::rand(&mut rng); + assert!(batch_verify(&products, &r)); + + if !products.is_empty() { + let final_result = products.last().unwrap().c; + assert_eq!(final_result, expected); + } +} From 30bd65d4de69f3fc43c42536fdd64d7bc895744b Mon Sep 17 00:00:00 2001 From: markosg04 Date: Tue, 2 Sep 2025 12:04:19 -0400 Subject: [PATCH 15/38] feat: multilinear fp12 --- jolt-optimizations/src/fq12_poly.rs | 16 ++++++++++++++++ jolt-optimizations/tests/sz_check_tests.rs | 3 ++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/jolt-optimizations/src/fq12_poly.rs b/jolt-optimizations/src/fq12_poly.rs index 57a4f6561..d7e00f4f4 100644 --- a/jolt-optimizations/src/fq12_poly.rs +++ b/jolt-optimizations/src/fq12_poly.rs @@ -142,3 +142,19 @@ pub fn g_coeffs() -> Vec { g[12] = Fq::one(); g } + +/// Convert Fq12 polynomial coefficients to multilinear evaluations by padding to 16 elements. +/// The 12 coefficients are padded with 4 zeros to make a power-of-2 size suitable for +/// multilinear polynomial commitment schemes. +pub fn to_multilinear_evals(coeffs: &[Fq; 12]) -> Vec { + let mut evals = coeffs.to_vec(); + evals.resize(16, Fq::zero()); + evals +} + +/// Convert an Fq12 element to multilinear evaluations. +/// First converts to polynomial coefficients, then pads to 16 elements. +pub fn fq12_to_multilinear_evals(a: &Fq12) -> Vec { + let coeffs = fq12_to_poly12_coeffs(a); + to_multilinear_evals(&coeffs) +} diff --git a/jolt-optimizations/tests/sz_check_tests.rs b/jolt-optimizations/tests/sz_check_tests.rs index 95bb640a1..212639979 100644 --- a/jolt-optimizations/tests/sz_check_tests.rs +++ b/jolt-optimizations/tests/sz_check_tests.rs @@ -1,7 +1,8 @@ use ark_bn254::{Fq, Fq12}; -use ark_ff::{Field, PrimeField, UniformRand}; +use ark_ff::{Field, PrimeField, UniformRand, Zero}; use ark_std::test_rng; use jolt_optimizations::expression::{Expression, Term}; +use jolt_optimizations::fq12_poly::{fq12_to_multilinear_evals, fq12_to_poly12_coeffs}; use jolt_optimizations::sz_check::{batch_verify, Product}; #[test] From 737c6034138c75f8b860064b57d26c9b393c777e Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Sat, 13 Sep 2025 18:45:09 -0400 Subject: [PATCH 16/38] WIP, add bigint `add_trunc` --- ff/src/biginteger/mod.rs | 68 ++++++++++++++ ff/src/biginteger/tests.rs | 112 +++++++++++++++++++++++ test-curves/Cargo.toml | 5 ++ test-curves/benches/bigint.rs | 163 ++++++++++++++++++++++++++++++++++ 4 files changed, 348 insertions(+) create mode 100644 test-curves/benches/bigint.rs diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 642302e06..63de685a5 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -311,6 +311,74 @@ impl BigInt { res } + /// Truncated-width addition: compute self + other and fit into P limbs; overflow is ignored. + #[inline] + pub fn add_trunc(&self, other: &BigInt) -> BigInt

{ + let mut res = BigInt::

::zero(); + let mut carry = 0u64; + + // Add all limbs up to the result size P, using 0 for missing limbs + let min_size = core::cmp::min(N, M); + let max_size = core::cmp::max(N, M); + + // Add corresponding limbs from both BigInts + for i in 0..core::cmp::min(min_size, P) { + res.0[i] = adc!(self.0[i], other.0[i], &mut carry); + } + + // Handle remaining limbs from the larger BigInt + for i in min_size..core::cmp::min(max_size, P) { + let a = if i < N { self.0[i] } else { 0 }; + let b = if i < M { other.0[i] } else { 0 }; + res.0[i] = adc!(a, b, &mut carry); + } + + // Propagate any remaining carry to unused limbs within P + let mut i = max_size; + while carry != 0 && i < P { + res.0[i] = adc!(res.0[i], 0, &mut carry); + i += 1; + } + + res + } + + /// Truncated-width addition that mutates self: self += other and fit result into P limbs; overflow is ignored. + #[inline] + pub fn add_assign_trunc(&mut self, other: &BigInt) { + let mut carry = 0u64; + let limit = core::cmp::min(P, N); + + let overlap = core::cmp::min(limit, core::cmp::min(N, M)); + for i in 0..overlap { + self.0[i] = adc!(self.0[i], other.0[i], &mut carry); + } + + // If self has remaining limbs within the limit, add carry through them + if N > M { + for i in overlap..limit { + self.0[i] = adc!(self.0[i], 0, &mut carry); + } + } else if M > N { + // If other has remaining limbs within the limit, add them into self (self's lanes may be zero) + for i in overlap..core::cmp::min(M, limit) { + self.0[i] = adc!(0, other.0[i], &mut carry); + } + } + + // Propagate any remaining carry within the limit + let mut i = core::cmp::min(core::cmp::max(N, M), limit); + while carry != 0 && i < limit { + self.0[i] = adc!(self.0[i], 0, &mut carry); + i += 1; + } + + // Zero out the remaining limbs beyond the limit (truncate to P limbs) + for i in limit..N { + self.0[i] = 0; + } + } + /// Fused multiply-add with truncation: acc += self * other, fitting into P limbs; overflow is ignored. /// This is a generic version for arbitrary limb widths of `self` and `other`. #[inline] diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 09794edc7..550f6de0c 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -910,4 +910,116 @@ fn test_prop_fmadd_trunc_random() { run_case!(3, 2, 2, 200); } +// ============================== +// Tests for add_trunc and add_assign_trunc (unsigned BigInt) +// ============================== + +#[test] +fn test_add_trunc_correctness_random() { + use crate::biginteger::BigInt; + let mut rng = ark_std::test_rng(); + + macro_rules! run_case { + ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ + for _ in 0..$iters { + let a: BigInt<$n> = UniformRand::rand(&mut rng); + let b: BigInt<$m> = UniformRand::rand(&mut rng); + + let res = a.add_trunc::<$m, $p>(&b); + + let a_bu = BigUint::from(a); + let b_bu = BigUint::from(b); + let modulus = BigUint::from(1u8) << (64 * $p); + let expected = (a_bu + b_bu) % &modulus; + assert_eq!(BigUint::from(res), expected); + } + }}; + } + + // Same-width, truncated equal width + run_case!(4, 4, 4, 200); + // Same-width, truncate to fewer limbs + run_case!(4, 4, 3, 200); + // Mixed widths, truncate to min and to max + run_case!(4, 2, 3, 200); + run_case!(2, 4, 2, 200); +} + +#[test] +fn test_add_assign_trunc_correctness_and_zeroing() { + use crate::biginteger::BigInt; + let mut rng = ark_std::test_rng(); + + // Case 1: N = 4, M = 4, P = 4 (no truncation); compare against add_trunc and add_with_carry + for _ in 0..200 { + let a: BigInt<4> = UniformRand::rand(&mut rng); + let b: BigInt<4> = UniformRand::rand(&mut rng); + let r_trunc = a.add_trunc::<4, 4>(&b); + let mut a2 = a; + a2.add_assign_trunc::<4, 4>(&b); + assert_eq!(a2, r_trunc); + + // Regular add_with_carry should match lower 4 limbs modulo 2^(256) + let mut a3 = a; + a3.add_with_carry(&b); + assert_eq!(a3, r_trunc); + } + + // Case 2: N = 4, M = 4, P = 3 (truncation) -> self's limb 3 must be zeroed + for _ in 0..200 { + let a: BigInt<4> = UniformRand::rand(&mut rng); + let b: BigInt<4> = UniformRand::rand(&mut rng); + let r_trunc = a.add_trunc::<4, 3>(&b); + let mut a2 = a; + a2.add_assign_trunc::<4, 3>(&b); + // Low 3 limbs match result + for i in 0..3 { assert_eq!(a2.0[i], r_trunc.0[i]); } + // Higher limbs of self must be zero + for i in 3..4 { assert_eq!(a2.0[i], 0); } + } + + // Case 3: Mixed widths N = 4, M = 2, P = 3 + for _ in 0..200 { + let a: BigInt<4> = UniformRand::rand(&mut rng); + let b: BigInt<2> = UniformRand::rand(&mut rng); + let r_trunc = a.add_trunc::<2, 3>(&b); + let mut a2 = a; + a2.add_assign_trunc::<2, 3>(&b); + for i in 0..3 { assert_eq!(a2.0[i], r_trunc.0[i]); } + // Truncated limb 3.. must be zero + for i in 3..4 { assert_eq!(a2.0[i], 0); } + } + + // Case 4: Mixed widths N = 2, M = 4, P = 2 (limit is N so no zeroing beyond N) + for _ in 0..200 { + let a: BigInt<2> = UniformRand::rand(&mut rng); + let b: BigInt<4> = UniformRand::rand(&mut rng); + let r_trunc = a.add_trunc::<4, 2>(&b); + let mut a2 = a; + a2.add_assign_trunc::<4, 2>(&b); + assert_eq!(a2, r_trunc); + } +} + +#[test] +fn test_add_trunc_and_add_assign_trunc_overflow_edges() { + use crate::biginteger::BigInt; + + // All-ones + all-ones with truncation + let a = BigInt::<4>::new([u64::MAX; 4]); + let b = BigInt::<4>::new([u64::MAX; 4]); + // P = 4: result should be wrapping add modulo 2^256 + let r = a.add_trunc::<4, 4>(&b); + let mut a2 = a; + a2.add_assign_trunc::<4, 4>(&b); + assert_eq!(a2, r); + + // P = 3: ensure high limb is zeroed in mutating version + let r3 = a.add_trunc::<4, 3>(&b); + let mut a3 = a; + a3.add_assign_trunc::<4, 3>(&b); + for i in 0..3 { assert_eq!(a3.0[i], r3.0[i]); } + assert_eq!(a3.0[3], 0); +} + } diff --git a/test-curves/Cargo.toml b/test-curves/Cargo.toml index 08d5380e7..5fef59054 100644 --- a/test-curves/Cargo.toml +++ b/test-curves/Cargo.toml @@ -88,3 +88,8 @@ harness = false name = "bn254" path = "benches/bn254.rs" harness = false + +[[bench]] +name = "bigint" +path = "benches/bigint.rs" +harness = false diff --git a/test-curves/benches/bigint.rs b/test-curves/benches/bigint.rs new file mode 100644 index 000000000..66319b37a --- /dev/null +++ b/test-curves/benches/bigint.rs @@ -0,0 +1,163 @@ +// Benchmark for BigInt operations +#[cfg(feature = "bn254")] +use ark_ff::{BigInteger, BigInt}; +#[cfg(feature = "bn254")] +use ark_std::rand::{rngs::StdRng, Rng, SeedableRng}; +#[cfg(feature = "bn254")] +use criterion::{criterion_group, criterion_main, Criterion}; + +#[cfg(feature = "bn254")] +fn bigint_add_bench(c: &mut Criterion) { + const SAMPLES: usize = 1000; + // Use a fixed seed for reproducibility + let mut rng = StdRng::seed_from_u64(0u64); + + // Generate random BigInt<4> instances for benchmarking + let a_bigints = (0..SAMPLES) + .map(|_| BigInt::<4>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ])) + .collect::>(); + + let b_bigints = (0..SAMPLES) + .map(|_| BigInt::<4>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ])) + .collect::>(); + + let mut group = c.benchmark_group("BigInt<4> Addition Comparison"); + + // Benchmark add_trunc with same limb count (4 -> 4) + group.bench_function("add_trunc<4, 4>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(a_bigints[i].add_trunc::<4, 4>(&b_bigints[i])) + }) + }); + + // Benchmark add_trunc with truncation (4 -> 3 limbs) + group.bench_function("add_trunc<4, 3>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(a_bigints[i].add_trunc::<4, 3>(&b_bigints[i])) + }) + }); + + // Benchmark add_trunc with expansion (4 -> 5 limbs) + group.bench_function("add_trunc<4, 5>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(a_bigints[i].add_trunc::<4, 5>(&b_bigints[i])) + }) + }); + + // Benchmark regular addition using add_with_carry + group.bench_function("add_with_carry (regular add)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + let carry = result.add_with_carry(&b_bigints[i]); + criterion::black_box((result, carry)) + }) + }); + + // Benchmark regular addition that ignores carry (for fair comparison) + group.bench_function("add_with_carry (ignore carry)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + result.add_with_carry(&b_bigints[i]); + criterion::black_box(result) + }) + }); + + // Benchmark add_assign_trunc with same limb count (4 -> 4) + group.bench_function("add_assign_trunc<4, 4>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + result.add_assign_trunc::<4, 4>(&b_bigints[i]); + criterion::black_box(result) + }) + }); + + // Benchmark add_assign_trunc with truncation (4 -> 3 limbs) + group.bench_function("add_assign_trunc<4, 3>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + result.add_assign_trunc::<4, 3>(&b_bigints[i]); + criterion::black_box(result) + }) + }); + + // Benchmark add_assign_trunc with expansion (4 -> 5 limbs) + group.bench_function("add_assign_trunc<4, 5>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + result.add_assign_trunc::<4, 5>(&b_bigints[i]); + criterion::black_box(result) + }) + }); + + // Test case: addition that would overflow to compare truncation behavior + let max_bigints = (0..SAMPLES) + .map(|_| BigInt::<4>([ + u64::MAX, u64::MAX, u64::MAX, u64::MAX, + ])) + .collect::>(); + + group.bench_function("add_trunc overflow case", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + // This will overflow and be truncated + criterion::black_box(max_bigints[i].add_trunc::<4, 4>(&max_bigints[i])) + }) + }); + + group.bench_function("add_with_carry overflow case", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = max_bigints[i]; + let carry = result.add_with_carry(&max_bigints[i]); + criterion::black_box((result, carry)) + }) + }); + + group.bench_function("add_assign_trunc overflow case", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = max_bigints[i]; + result.add_assign_trunc::<4, 4>(&max_bigints[i]); + criterion::black_box(result) + }) + }); + + group.finish(); +} + +#[cfg(feature = "bn254")] +criterion_group!(benches, bigint_add_bench); +#[cfg(feature = "bn254")] +criterion_main!(benches); + +#[cfg(not(feature = "bn254"))] +fn main() {} From 27f79379922919ff550cf7a7ea98026446848ceb Mon Sep 17 00:00:00 2001 From: markosg04 Date: Mon, 15 Sep 2025 14:58:18 -0400 Subject: [PATCH 17/38] debug: broken recursion example --- jolt-optimizations/src/expression.rs | 75 +++++--- jolt-optimizations/src/fq12_poly.rs | 12 +- jolt-optimizations/src/lib.rs | 9 + jolt-optimizations/src/steps.rs | 153 ++++++++++++++++ jolt-optimizations/src/sz_check.rs | 1 + jolt-optimizations/tests/steps_debug_test.rs | 175 +++++++++++++++++++ jolt-optimizations/tests/steps_test.rs | 135 ++++++++++++++ 7 files changed, 523 insertions(+), 37 deletions(-) create mode 100644 jolt-optimizations/src/steps.rs create mode 100644 jolt-optimizations/tests/steps_debug_test.rs create mode 100644 jolt-optimizations/tests/steps_test.rs diff --git a/jolt-optimizations/src/expression.rs b/jolt-optimizations/src/expression.rs index 6890c855b..e1871175e 100644 --- a/jolt-optimizations/src/expression.rs +++ b/jolt-optimizations/src/expression.rs @@ -1,7 +1,9 @@ +use crate::steps::{pow_with_steps_le, ExponentiationSteps}; use crate::sz_check::Product; use ark_bn254::{Fq, Fq12}; -use ark_ff::{BigInteger, Field, One, PrimeField}; +use ark_ff::{Field, One, PrimeField}; +#[derive(Clone)] pub struct Term { pub base: Fq12, pub exponent: Fq, @@ -11,6 +13,11 @@ pub struct Expression { pub terms: Vec, } +pub struct ExpressionSteps { + pub term_steps: Vec, + pub multiplication_products: Vec, +} + impl Expression { pub fn new(terms: Vec) -> Self { Self { terms } @@ -38,41 +45,57 @@ impl Expression { products } -} -fn exponentiate_to_products(base: Fq12, exponent: Fq) -> Vec { - let mut products = Vec::new(); - - let bigint = exponent.into_bigint(); - let exp_bits = bigint.to_bits_le(); + /// Evaluate the expression and return both the result and all computation steps + pub fn evaluate_with_steps(&self) -> (Fq12, ExpressionSteps) { + let mut term_steps = Vec::new(); + let mut multiplication_products = Vec::new(); + let mut current_result = Fq12::one(); - let last_one = exp_bits.iter().rposition(|&b| b); + for term in &self.terms { + // Compute this term with steps + let steps = pow_with_steps_le(term.base, term.exponent); + let term_value = steps.result; + term_steps.push(steps); - if last_one.is_none() { - return vec![]; - } + if current_result != Fq12::one() { + // Multiply this term's result with the accumulated result + let new_result = current_result * term_value; + multiplication_products.push(Product::new(current_result, term_value, new_result)); + current_result = new_result; + } else { + current_result = term_value; + } + } - let last_one = last_one.unwrap(); + let expression_steps = ExpressionSteps { + term_steps, + multiplication_products, + }; - if last_one == 0 { - return vec![]; + (current_result, expression_steps) } - let mut current_power = base; - let mut result = if exp_bits[0] { base } else { Fq12::one() }; + /// Convert expression steps to a flat list of products for verification + pub fn steps_to_products(steps: &ExpressionSteps) -> Vec { + let mut products = Vec::new(); - // square and multiply - for i in 1..=last_one { - let squared = current_power * current_power; - products.push(Product::new(current_power, current_power, squared)); - current_power = squared; + // Add all products from individual term exponentiations + for term_step in &steps.term_steps { + products.extend(term_step.to_products()); + } - if exp_bits[i] { - let new_result = result * current_power; - products.push(Product::new(result, current_power, new_result)); - result = new_result; + // Add products from multiplying terms together + for product in &steps.multiplication_products { + products.push(product.clone()); } + + products } +} - products +fn exponentiate_to_products(base: Fq12, exponent: Fq) -> Vec { + // Use the new stepped implementation to get products + let steps = pow_with_steps_le(base, exponent); + steps.to_products() } diff --git a/jolt-optimizations/src/fq12_poly.rs b/jolt-optimizations/src/fq12_poly.rs index d7e00f4f4..751844af6 100644 --- a/jolt-optimizations/src/fq12_poly.rs +++ b/jolt-optimizations/src/fq12_poly.rs @@ -1,10 +1,4 @@ //! Fq12 polynomial operations and conversions for BN254 -//! -//! This module provides: -//! - Conversion between Fq12 field elements and polynomial representations -//! - Polynomial arithmetic operations over Fq[X] -//! - Evaluation and manipulation of the minimal polynomial g(X) = X^12 - 18X^6 + 82 - use ark_bn254::{Fq, Fq12}; use ark_ff::{Field, One, Zero}; @@ -143,17 +137,13 @@ pub fn g_coeffs() -> Vec { g } -/// Convert Fq12 polynomial coefficients to multilinear evaluations by padding to 16 elements. -/// The 12 coefficients are padded with 4 zeros to make a power-of-2 size suitable for -/// multilinear polynomial commitment schemes. +/// Convert Fq12 polynomial coefficients to multilinear evaluations by padding to 16 elements.= pub fn to_multilinear_evals(coeffs: &[Fq; 12]) -> Vec { let mut evals = coeffs.to_vec(); evals.resize(16, Fq::zero()); evals } -/// Convert an Fq12 element to multilinear evaluations. -/// First converts to polynomial coefficients, then pads to 16 elements. pub fn fq12_to_multilinear_evals(a: &Fq12) -> Vec { let coeffs = fq12_to_poly12_coeffs(a); to_multilinear_evals(&coeffs) diff --git a/jolt-optimizations/src/lib.rs b/jolt-optimizations/src/lib.rs index 70898cba7..a58cc6745 100644 --- a/jolt-optimizations/src/lib.rs +++ b/jolt-optimizations/src/lib.rs @@ -19,6 +19,7 @@ pub mod expression; pub mod fq12_poly; pub mod frobenius; pub mod glv_two; +pub mod steps; pub mod sz_check; mod glv_four; @@ -57,3 +58,11 @@ pub use dory_g2::{ }; pub use batch_addition::{batch_g1_additions, batch_g1_additions_multi}; + +pub use fq12_poly::{ + fq12_to_multilinear_evals, fq12_to_poly12_coeffs, g_coeffs, g_eval, to_multilinear_evals, +}; + +pub use steps::{pow_with_steps_le, ExponentiationStep, ExponentiationSteps}; + +pub use expression::{Expression, ExpressionSteps, Term}; diff --git a/jolt-optimizations/src/steps.rs b/jolt-optimizations/src/steps.rs new file mode 100644 index 000000000..329514b85 --- /dev/null +++ b/jolt-optimizations/src/steps.rs @@ -0,0 +1,153 @@ +use crate::sz_check::Product; +use ark_bn254::{Fq, Fq12}; +use ark_ff::{BigInteger, Field, One, PrimeField}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; + +/// Represents a single step in the square-and-multiply exponentiation algorithm. +#[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] +pub struct ExponentiationStep { + pub step_index: usize, + pub bit_value: bool, + pub a_prev: Fq12, + pub a_curr: Fq12, + pub rho_before: Fq12, + pub rho_after: Fq12, +} + +#[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] +pub struct ExponentiationSteps { + /// The base being exponentiated + pub base: Fq12, + /// The exponent + pub exponent: Fq, + /// All steps in the computation + pub steps: Vec, + /// The final result (should equal base^exponent) + pub result: Fq12, +} + +impl ExponentiationSteps { + /// Convert the steps into Products for verification with sz_check + pub fn to_products(&self) -> Vec { + let mut products = Vec::new(); + + for step in &self.steps { + // Each squaring operation creates a product: a_i = a_{i-1} * a_{i-1} + products.push(Product::new(step.a_prev, step.a_prev, step.a_curr)); + + // If the bit is 1, we multiply rho by the current power + if step.bit_value && step.rho_before != step.rho_after { + products.push(Product::new(step.rho_before, step.a_curr, step.rho_after)); + } + } + + products + } + + pub fn sanity_verify(&self) -> bool { + let expected = self.base.pow(self.exponent.into_bigint()); + if self.result != expected { + return false; + } + + for (i, step) in self.steps.iter().enumerate() { + if step.a_curr != step.a_prev * step.a_prev { + return false; + } + + let expected_rho_after = if step.bit_value { + step.rho_before * step.a_curr + } else { + step.rho_before + }; + + if step.rho_after != expected_rho_after { + return false; + } + + if i + 1 < self.steps.len() { + if step.a_curr != self.steps[i + 1].a_prev { + return false; + } + if step.rho_after != self.steps[i + 1].rho_before { + return false; + } + } + } + + if let Some(last_step) = self.steps.last() { + if last_step.rho_after != self.result { + return false; + } + } + + true + } +} + +pub fn pow_with_steps_le(base: Fq12, exponent: Fq) -> ExponentiationSteps { + let mut steps = Vec::new(); + + let bigint = exponent.into_bigint(); + let exp_bits = bigint.to_bits_le(); + + // Find the position of the last 1 bit + let last_one = exp_bits.iter().rposition(|&b| b); + + if last_one.is_none() { + // Exponent is 0, return 1 + return ExponentiationSteps { + base, + exponent, + steps: vec![], + result: Fq12::one(), + }; + } + + let last_one = last_one.unwrap(); + + if last_one == 0 { + // Exponent is 1, return base + return ExponentiationSteps { + base, + exponent, + steps: vec![], + result: base, + }; + } + + let mut a_curr = base; // Current power of base + let mut rho = if exp_bits[0] { base } else { Fq12::one() }; + + for (step_idx, bit_idx) in (1..=last_one).enumerate() { + let bit_value = exp_bits[bit_idx]; + let a_prev = a_curr; + let rho_before = rho; + + a_curr = a_prev * a_prev; + + let rho_after = if bit_value { + rho_before * a_curr + } else { + rho_before + }; + + steps.push(ExponentiationStep { + step_index: step_idx, + bit_value, + a_prev, + a_curr, + rho_before, + rho_after, + }); + + rho = rho_after; + } + + ExponentiationSteps { + base, + exponent, + steps, + result: rho, + } +} diff --git a/jolt-optimizations/src/sz_check.rs b/jolt-optimizations/src/sz_check.rs index f6e88ee0d..81ad797d4 100644 --- a/jolt-optimizations/src/sz_check.rs +++ b/jolt-optimizations/src/sz_check.rs @@ -4,6 +4,7 @@ use crate::fq12_poly::{fq12_to_poly12_coeffs, g_coeffs, poly_div_rem_monic, poly use ark_bn254::{Fq, Fq12}; use ark_ff::{Field, Zero}; +#[derive(Clone)] pub struct Product { pub a: Fq12, pub b: Fq12, diff --git a/jolt-optimizations/tests/steps_debug_test.rs b/jolt-optimizations/tests/steps_debug_test.rs new file mode 100644 index 000000000..2a170938a --- /dev/null +++ b/jolt-optimizations/tests/steps_debug_test.rs @@ -0,0 +1,175 @@ +use ark_bn254::{Fq, Fq12}; +use ark_ff::BigInteger; +use ark_ff::{Field, One, PrimeField, UniformRand}; +use ark_std::test_rng; +use jolt_optimizations::steps::pow_with_steps_le; + +#[test] +#[ignore] // Run with: cargo test --test steps_debug_test test_debug_trace -- --nocapture --ignored +fn test_debug_trace() { + let mut rng = test_rng(); + + // Use a small exponent for readable output + let base = Fq12::rand(&mut rng); + let exponent = Fq::from(13u64); // Binary: 1101 + + println!("=== Square-and-Multiply Debug Trace ==="); + println!("Base: {:?}", base); + println!("Exponent: {} (binary: 1101)", 13u64); + println!(); + + let steps = pow_with_steps_le(base, exponent); + + // Print bit representation + let bigint = exponent.into_bigint(); + let exp_bits = bigint.to_bits_le(); + println!("Bit representation (LSB first):"); + for (i, bit) in exp_bits.iter().take(8).enumerate() { + println!(" Bit {}: {}", i, if *bit { "1" } else { "0" }); + } + println!(); + + // Print initial state + println!("Initial state:"); + println!(" a_0 = base"); + println!( + " rho_0 = {} (since bit 0 = {})", + if exp_bits[0] { "base" } else { "1" }, + if exp_bits[0] { "1" } else { "0" } + ); + println!(); + + // Print each step + println!("Steps:"); + for (i, step) in steps.steps.iter().enumerate() { + println!( + "Step {} (processing bit {} = {}):", + i + 1, + i + 1, + if step.bit_value { "1" } else { "0" } + ); + + println!(" Squaring: a_{} = a_{}^2", i + 1, i); + println!(" a_{} = {:?}", i, step.a_prev); + println!(" a_{} = {:?}", i + 1, step.a_curr); + + // Verify squaring + let expected_square = step.a_prev * step.a_prev; + println!( + " Verification: a_curr == a_prev^2? {}", + if step.a_curr == expected_square { + "✓" + } else { + "✗" + } + ); + + println!(" Accumulator update:"); + println!(" rho_before = {:?}", step.rho_before); + + if step.bit_value { + println!(" Bit is 1, so: rho_after = rho_before * a_curr"); + } else { + println!(" Bit is 0, so: rho_after = rho_before (unchanged)"); + } + + println!(" rho_after = {:?}", step.rho_after); + + // Verify accumulator update + let expected_rho = if step.bit_value { + step.rho_before * step.a_curr + } else { + step.rho_before + }; + println!( + " Verification: rho_after correct? {}", + if step.rho_after == expected_rho { + "✓" + } else { + "✗" + } + ); + + println!(); + } + + // Print final result + println!("Final result: {:?}", steps.result); + + // Verify against standard pow + let expected = base.pow(exponent.into_bigint()); + println!("Expected (base^13): {:?}", expected); + println!( + "Results match: {}", + if steps.result == expected { + "✓" + } else { + "✗" + } + ); + + // Print summary of operations + println!(); + println!("=== Summary ==="); + let num_squarings = steps.steps.len(); + let num_multiplications = steps.steps.iter().filter(|s| s.bit_value).count(); + println!("Total squarings: {}", num_squarings); + println!("Total multiplications by base: {}", num_multiplications); + println!("Total operations: {}", num_squarings + num_multiplications); + + // Verify the steps + assert!(steps.sanity_verify(), "Steps verification failed"); + assert_eq!(steps.result, expected, "Result doesn't match expected"); +} + +#[test] +#[ignore] // Run with: cargo test --test steps_debug_test test_trace_products -- --nocapture --ignored +fn test_trace_products() { + let mut rng = test_rng(); + + let base = Fq12::rand(&mut rng); + let exponent = Fq::from(5u64); // Binary: 101 + + println!("=== Products Generated from Steps ==="); + println!("Exponent: 5 (binary: 101)"); + println!(); + + let steps = pow_with_steps_le(base, exponent); + let products = steps.to_products(); + + println!("Products generated:"); + for (i, product) in products.iter().enumerate() { + println!("Product {}:", i); + println!(" a * b = c"); + println!(" a: {:?}", product.a); + println!(" b: {:?}", product.b); + println!(" c: {:?}", product.c); + + // Verify the product + let expected_c = product.a * product.b; + println!( + " Verification: c == a * b? {}", + if product.c == expected_c { + "✓" + } else { + "✗" + } + ); + println!(); + } + + println!("Total products: {}", products.len()); + + // Test batch verification + use jolt_optimizations::sz_check::batch_verify; + let r = Fq::rand(&mut rng); + let batch_result = batch_verify(&products, &r); + println!( + "Batch verification with random r: {}", + if batch_result { + "✓ PASSED" + } else { + "✗ FAILED" + } + ); +} diff --git a/jolt-optimizations/tests/steps_test.rs b/jolt-optimizations/tests/steps_test.rs new file mode 100644 index 000000000..0146b5a47 --- /dev/null +++ b/jolt-optimizations/tests/steps_test.rs @@ -0,0 +1,135 @@ +use ark_bn254::{Fq, Fq12}; +use ark_ff::{Field, One, PrimeField, UniformRand}; +use ark_std::test_rng; +use jolt_optimizations::expression::{Expression, Term}; +use jolt_optimizations::steps::pow_with_steps_le; +use jolt_optimizations::sz_check::batch_verify; + +#[test] +fn test_pow_with_steps_correctness() { + let mut rng = test_rng(); + + // Test with random base and exponent + let base = Fq12::rand(&mut rng); + let exponent = Fq::rand(&mut rng); + + // Compute with steps + let steps = pow_with_steps_le(base, exponent); + + // Verify the result matches standard pow + let expected = base.pow(exponent.into_bigint()); + assert_eq!(steps.result, expected, "Result mismatch"); + + // Verify the steps are internally consistent + assert!(steps.sanity_verify(), "Steps verification failed"); + + // Verify that products can be verified using batch_verify + let products = steps.to_products(); + let r = Fq::rand(&mut rng); + assert!(batch_verify(&products, &r), "Batch verification failed"); +} + +#[test] +fn test_pow_with_steps_edge_cases() { + let mut rng = test_rng(); + let base = Fq12::rand(&mut rng); + + // Test exponent = 0 + let steps = pow_with_steps_le(base, Fq::from(0u64)); + assert_eq!(steps.result, Fq12::one()); + assert_eq!(steps.steps.len(), 0); + + // Test exponent = 1 + let steps = pow_with_steps_le(base, Fq::from(1u64)); + assert_eq!(steps.result, base); + assert_eq!(steps.steps.len(), 0); + + // Test exponent = 2 + let steps = pow_with_steps_le(base, Fq::from(2u64)); + assert_eq!(steps.result, base * base); + assert_eq!(steps.steps.len(), 1); + assert!(steps.sanity_verify()); +} + +#[test] +fn test_expression_with_steps() { + let mut rng = test_rng(); + + // Create an expression with multiple terms + let terms = vec![ + Term { + base: Fq12::rand(&mut rng), + exponent: Fq::from(5u64), + }, + Term { + base: Fq12::rand(&mut rng), + exponent: Fq::from(3u64), + }, + ]; + + let expr = Expression::new(terms); + + // Evaluate with steps + let (result, steps) = expr.evaluate_with_steps(); + + // Verify result matches expected + let expected = expr.terms[0].base.pow(expr.terms[0].exponent.into_bigint()) + * expr.terms[1].base.pow(expr.terms[1].exponent.into_bigint()); + assert_eq!(result, expected); + + // Verify all steps + for term_step in &steps.term_steps { + assert!(term_step.sanity_verify()); + } + + // Convert to products and verify + let products = Expression::steps_to_products(&steps); + let r = Fq::rand(&mut rng); + assert!( + batch_verify(&products, &r), + "Batch verification of expression steps failed" + ); +} + +#[test] +fn test_step_continuity() { + let mut rng = test_rng(); + let base = Fq12::rand(&mut rng); + let exponent = Fq::from(255u64); // Use a reasonable sized exponent + + let steps = pow_with_steps_le(base, exponent); + + // Check continuity between steps + for i in 0..steps.steps.len() - 1 { + assert_eq!( + steps.steps[i].rho_after, + steps.steps[i + 1].rho_before, + "Step continuity broken at step {}", + i + ); + } + + // Check final step leads to result + if let Some(last_step) = steps.steps.last() { + assert_eq!(last_step.rho_after, steps.result); + } +} + +#[test] +fn test_squaring_correctness() { + let mut rng = test_rng(); + let base = Fq12::rand(&mut rng); + let exponent = Fq::from(100u64); + + let steps = pow_with_steps_le(base, exponent); + + // Verify each squaring operation: a_i = a_{i-1}^2 + for step in &steps.steps { + let expected_square = step.a_prev * step.a_prev; + assert_eq!( + step.a_curr, expected_square, + "Squaring incorrect at step {}", + step.step_index + ); + } +} From 791be0d3b93a5fbf7555358fd5fb600de125617f Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Mon, 15 Sep 2025 21:39:20 -0400 Subject: [PATCH 18/38] added new types for small value --- ff/src/biginteger/i8_or_i96.rs | 448 +++++++++++++++++++++++++++++ ff/src/biginteger/mod.rs | 8 +- ff/src/biginteger/signed.rs | 9 +- ff/src/biginteger/signed_hi_32.rs | 462 ++++++++++++++++++++++++++++++ 4 files changed, 925 insertions(+), 2 deletions(-) create mode 100644 ff/src/biginteger/i8_or_i96.rs create mode 100644 ff/src/biginteger/signed_hi_32.rs diff --git a/ff/src/biginteger/i8_or_i96.rs b/ff/src/biginteger/i8_or_i96.rs new file mode 100644 index 000000000..d1d94ecc8 --- /dev/null +++ b/ff/src/biginteger/i8_or_i96.rs @@ -0,0 +1,448 @@ +use core::ops::{Add, Sub, Mul, AddAssign, SubAssign, MulAssign}; + +/// Compact signed integer optimized for the common `i8` case, widening to a 96-bit +/// split representation when needed (low 64 bits in `large_lo`, next 32 bits in `large_hi`). +/// +/// ## Design goals: +/// - Set fields so that this fits in 16 bytes +/// - Encode the vast majority of values as `i8` for space/time locality. +/// - Keep all operations `const fn` so macros and static tables can fold at compile-time. +/// - After every operation, results are canonicalized to the smallest fitting form: +/// if a result fits in `i8`, it is stored in `small_i8`; otherwise it is stored +/// as a 96-bit split in `large_lo`/`large_hi`. +/// +/// ## Layout and Semantics +/// +/// The 96-bit value is stored in two's complement format, split across two fields: +/// - `large_hi: i32`: The upper 32 bits, which includes the sign bit of the 96-bit integer. +/// - `large_lo: u64`: The lower 64 bits, treated as an unsigned block of bits. +/// +/// The full value can be reconstructed using the formula: +/// `value = (large_hi as i128) << 64 | (large_lo as i128)` +/// This is equivalent to sign-extending `large_hi` and zero-extending `large_lo`. +/// +/// ## Notes: +/// - Arithmetic uses exact `i128` semantics (no modular reduction, no saturation). +/// - The `neg` implementation avoids `i8` overflow by widening `i8::MIN` to the wide form. +/// - Conversions are total: `to_i128()` always returns the exact value. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct I8OrI96 { + /// The lower 64 bits of the constant value. + large_lo: u64, + /// The upper 32 (signed) bits above `large_lo` (bits 64..95) + large_hi: i32, + /// Small constants that fit in i8 (-128 to 127) + pub small_i8: i8, + /// Whether the constant value is small (i8) + pub is_small: bool, +} + +impl I8OrI96 { + /// Returns zero encoded as `I8(0)`. + pub const fn zero() -> Self { + I8OrI96 { + large_lo: 0, + large_hi: 0, + is_small: true, + small_i8: 0, + } + } + + /// Returns one encoded as `I8(1)`. + pub const fn one() -> Self { + I8OrI96 { + large_lo: 0, + large_hi: 0, + is_small: true, + small_i8: 1, + } + } + + /// Construct from `i8` without widening. + pub const fn from_i8(value: i8) -> Self { + I8OrI96 { + large_lo: 0, + large_hi: 0, + is_small: true, + small_i8: value, + } + } + + /// Construct from `i128`, canonicalizing to `I8` if it fits. + /// Assumes the value fits in 96 bits (i64 + i32) + pub const fn from_i128(value: i128) -> Self { + if value >= i8::MIN as i128 && value <= i8::MAX as i128 { + I8OrI96 { + large_lo: 0, + large_hi: 0, + is_small: true, + small_i8: value as i8, + } + } else { + // Store as 96-bit signed split: low 64 bits and next 32 bits + I8OrI96 { + large_lo: value as u64, + large_hi: (value >> 64) as i32, + is_small: false, + small_i8: 0, + } + } + } + + /// Mutate in-place from `i8` without widening. Only updates `small_i8` and `is_small`. + #[inline] + pub const fn set_from_i8(&mut self, value: i8) { + self.small_i8 = value; + self.is_small = true; + } + + /// Mutate in-place from `i128`, canonicalizing to `i8` when it fits. + /// Assumes the value fits in 96 bits (i64 + i32). Minimizes writes. + #[inline] + pub const fn set_from_i128(&mut self, value: i128) { + if value >= i8::MIN as i128 && value <= i8::MAX as i128 { + self.small_i8 = value as i8; + self.is_small = true; + } else { + self.large_lo = value as u64; + self.large_hi = (value >> 64) as i32; + self.is_small = false; + } + } + + /// Exact conversion to `i128`. + #[inline] + pub const fn to_i128(&self) -> i128 { + if self.is_small { + self.small_i8 as i128 + } else { + // The `large_lo` (u64) is zero-extended to i128, and `large_hi` (i32) is sign-extended. + // This correctly reconstructs the 96-bit signed value. + (self.large_lo as i128) | ((self.large_hi as i128) << 64) + } + } + + /// Absolute value as unsigned magnitude. + pub const fn unsigned_abs(&self) -> u128 { + let v = self.to_i128(); + v.unsigned_abs() + } + + /// Returns true if the value equals zero. + #[inline] + pub const fn is_zero(&self) -> bool { + if self.is_small { + self.small_i8 == 0 + } else { + self.large_lo == 0 && self.large_hi == 0 + } + } + + /// Returns true if the value is encoded as `I128`. + #[inline] + pub const fn is_large(&self) -> bool { + !self.is_small + } + + /// Add two constants, returning a canonicalized result. + /// + /// Fast-path: if both operands are `I8`, perform `i8` addition directly. + /// If the `i8` addition overflows, it falls back to the `i128` slow path. + #[inline] + pub const fn add(self, other: I8OrI96) -> I8OrI96 { + let mut out = self; + out.add_assign(&other); + out + } + + /// In-place addition assignment: `self = self + other`. + /// Preserves fast path and falls back to `i128` on `i8` overflow. + #[inline] + pub const fn add_assign(&mut self, other: &I8OrI96) { + if self.is_small && other.is_small { + let (sum, overflow) = self.small_i8.overflowing_add(other.small_i8); + if !overflow { + self.set_from_i8(sum); + return; + } + } + let sum = self.to_i128() + other.to_i128(); + self.set_from_i128(sum); + } + + /// Multiply two constants, returning a canonicalized result. + /// + /// Fast-path: if both operands are `I8`, perform `i8` multiplication directly. + /// If `i8` multiplication overflows, it falls back to the `i128` slow path. + #[inline] + pub const fn mul(self, other: I8OrI96) -> I8OrI96 { + let mut out = self; + out.mul_assign(&other); + out + } + + /// In-place multiplication assignment: `self = self * other`. + /// Preserves fast path and falls back to `i128` on `i8` overflow. + #[inline] + pub const fn mul_assign(&mut self, other: &I8OrI96) { + if self.is_small && other.is_small { + let (prod, overflow) = self.small_i8.overflowing_mul(other.small_i8); + if !overflow { + self.set_from_i8(prod); + return; + } + } + let prod = self.to_i128() * other.to_i128(); + self.set_from_i128(prod); + } + + /// Arithmetic negation with canonicalization. + /// + /// Special-cases `I8(i8::MIN)` to avoid overflow by widening to `I128`. + /// In-place arithmetic negation. Preserves `i8::MIN` widening behavior. + #[inline] + pub const fn neg(&mut self) { + if self.is_small { + let v = self.small_i8; + if v == i8::MIN { + self.set_from_i128(-(v as i128)); + } else { + self.set_from_i8(-v); + } + } else { + self.set_from_i128(-self.to_i128()); + } + } + + /// Subtraction returning a new value. Delegates to `sub_assign`. + #[inline] + pub const fn sub(self, other: I8OrI96) -> I8OrI96 { + let mut out = self; + out.sub_assign(&other); + out + } + + /// In-place subtraction assignment: `self = self - other`. + /// Fast-path: if both operands are `I8`, perform `i8` subtraction directly. + /// If `i8` subtraction overflows, it falls back to `i128` slow path. + #[inline] + pub const fn sub_assign(&mut self, other: &I8OrI96) { + if self.is_small && other.is_small { + let (diff, overflow) = self.small_i8.overflowing_sub(other.small_i8); + if !overflow { + self.set_from_i8(diff); + return; + } + } + let diff = self.to_i128() - other.to_i128(); + self.set_from_i128(diff); + } +} + +impl Add for I8OrI96 { + type Output = I8OrI96; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + I8OrI96::add(self, rhs) + } +} + +impl AddAssign for I8OrI96 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.add_assign(&rhs) + } +} + +impl Mul for I8OrI96 { + type Output = I8OrI96; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + I8OrI96::mul(self, rhs) + } +} + +impl MulAssign for I8OrI96 { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + self.mul_assign(&rhs) + } +} + +impl Sub for I8OrI96 { + type Output = I8OrI96; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + I8OrI96::sub(self, rhs) + } +} + +impl SubAssign for I8OrI96 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.sub_assign(&rhs) + } +} + +impl Ord for I8OrI96 { + #[inline] + fn cmp(&self, other: &Self) -> ark_std::cmp::Ordering { + use ark_std::cmp::Ordering; + + // Fast path for when both values are small. + if self.is_small && other.is_small { + return self.small_i8.cmp(&other.small_i8); + } + + // Deconstruct into (hi, lo) parts to perform a 96-bit two's complement comparison. + // If a value is small, we convert it to its large representation on the fly. + let (self_hi, self_lo) = if self.is_small { + let val = self.small_i8 as i128; + ((val >> 64) as i32, val as u64) + } else { + (self.large_hi, self.large_lo) + }; + + let (other_hi, other_lo) = if other.is_small { + let val = other.small_i8 as i128; + ((val >> 64) as i32, val as u64) + } else { + (other.large_hi, other.large_lo) + }; + + // Compare the high parts first. If they differ, that determines the order. + match self_hi.cmp(&other_hi) { + Ordering::Equal => { + // If high parts are the same, the order is determined by the low parts. + self_lo.cmp(&other_lo) + } + order => order, + } + } +} + +impl PartialOrd for I8OrI96 { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: bool) -> Self { + if value { + I8OrI96::one() + } else { + I8OrI96::zero() + } + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i8) -> Self { + I8OrI96::from_i8(value) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i16) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i32) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i64) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i128) -> Self { + I8OrI96::from_i128(value) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: isize) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: u8) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: u16) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: u32) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: u64) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: usize) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TryFromU128Error; + +impl core::fmt::Display for TryFromU128Error { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "u128 does not fit in signed 96-bit range") + } +} + +impl core::convert::TryFrom for I8OrI96 { + type Error = TryFromU128Error; + + #[inline] + fn try_from(value: u128) -> Result { + // Signed 96-bit maximum is 2^95 - 1 + const I96_POS_MAX: u128 = (1u128 << 95) - 1; + if value <= i8::MAX as u128 { + Ok(I8OrI96::from_i8(value as i8)) + } else if value <= I96_POS_MAX { + Ok(I8OrI96 { + large_lo: value as u64, + large_hi: (value >> 64) as i32, + is_small: false, + small_i8: 0, + }) + } else { + Err(TryFromU128Error) + } + } +} \ No newline at end of file diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 63de685a5..151b47232 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -30,7 +30,13 @@ use zeroize::Zeroize; pub mod arithmetic; pub mod signed; -pub use signed::SignedBigInt; +pub use signed::{SignedBigInt, S64, S128, S196, S256}; + +pub mod signed_hi_32; +pub use signed_hi_32::{SignedBigIntHi32, S96, S160, S224}; + +pub mod i8_or_i96; +pub use i8_or_i96::I8OrI96; #[derive(Copy, Clone, PartialEq, Eq, Hash, Zeroize)] pub struct BigInt(pub [u64; N]); diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs index 13c08854c..fd3a0d675 100644 --- a/ff/src/biginteger/signed.rs +++ b/ff/src/biginteger/signed.rs @@ -1,14 +1,22 @@ use crate::biginteger::{BigInt, BigInteger}; use core::cmp::Ordering; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +#[cfg(feature = "allocative")] +use allocative::Allocative; /// A signed big integer using arkworks BigInt for magnitude and a sign bit +#[cfg_attr(feature = "allocative", derive(Allocative))] #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct SignedBigInt { pub magnitude: BigInt, pub is_positive: bool, } +pub type S64 = SignedBigInt<1>; +pub type S128 = SignedBigInt<2>; +pub type S196 = SignedBigInt<3>; +pub type S256 = SignedBigInt<4>; + impl SignedBigInt { #[inline] fn cmp_magnitude_mixed(&self, rhs: &SignedBigInt) -> Ordering { @@ -770,4 +778,3 @@ impl core::ops::Mul for &SignedBigInt { out } } - diff --git a/ff/src/biginteger/signed_hi_32.rs b/ff/src/biginteger/signed_hi_32.rs new file mode 100644 index 000000000..7a777a7cc --- /dev/null +++ b/ff/src/biginteger/signed_hi_32.rs @@ -0,0 +1,462 @@ +use core::ops::{Add, Sub, Mul, Neg, AddAssign, SubAssign, MulAssign}; +use ark_std::cmp::Ordering; +use ark_std::vec::Vec; + +/// Compact signed big-integer parameterized by limb count `N`, with top limb being u32 +/// +/// Representation: +/// - `magnitude_lo: [u64; N]` stores low limbs in little-endian order (index 0 is least significant). +/// - `magnitude_hi: i32` is the high 32-bit tail +/// - `is_positive: bool` is the sign +/// +/// Notes: +/// - For most applications, `N` is typically ≤ 3, but the API supports larger `N`. +#[cfg_attr(feature = "allocative", derive(Allocative))] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SignedBigIntHi32 { + /// Little-endian low limbs: limb 0 = low 64 bits, limb 1 = next 64 bits, and so on + magnitude_lo: [u64; N], + /// Top 32 bits + magnitude_hi: u32, + /// Whether the value is non-negative + is_positive: bool, +} + +pub type S96 = SignedBigIntHi32<1>; +pub type S160 = SignedBigIntHi32<2>; +pub type S224 = SignedBigIntHi32<3>; + +// ------------------------------------------------------------------------------------------------ +// Implementation +// ------------------------------------------------------------------------------------------------ + +impl SignedBigIntHi32 { + /// Creates a new `SignedBigIntHi32`. + /// + /// The sign is not normalized: a zero magnitude can be positive or negative. + pub const fn new(magnitude_lo: [u64; N], magnitude_hi: u32, is_positive: bool) -> Self { + Self { + magnitude_lo, + magnitude_hi, + is_positive, + } + } + + /// Returns the value `0`. + pub const fn zero() -> Self { + Self { + magnitude_lo: [0; N], + magnitude_hi: 0, + is_positive: true, + } + } + + /// Returns the value `1`. + pub fn one() -> Self { + let mut magnitude_lo = [0; N]; + let magnitude_hi; + + if N == 0 { + magnitude_hi = 1; + } else { + magnitude_lo[0] = 1; + magnitude_hi = 0; + } + + Self { + magnitude_lo, + magnitude_hi, + is_positive: true, + } + } + + // ------------------------------------------------------------------------------------------------ + // Accessors + // ------------------------------------------------------------------------------------------------ + + /// Returns the low limbs of the magnitude. + pub const fn magnitude_lo(&self) -> &[u64; N] { + &self.magnitude_lo + } + + /// Returns the high 32 bits of the magnitude. + pub const fn magnitude_hi(&self) -> u32 { + self.magnitude_hi + } + + /// Returns `true` if the number is non-negative. + pub const fn is_positive(&self) -> bool { + self.is_positive + } + + /// Returns `true` if the number is zero. + pub const fn is_zero(&self) -> bool { + let mut lo_is_zero = true; + let mut i = 0; + while i < N { + if self.magnitude_lo[i] != 0 { + lo_is_zero = false; + break; + } + i += 1; + } + self.magnitude_hi == 0 && lo_is_zero + } + + // ------------------------------------------------------------------------------------------------ + // Private arithmetic helpers + // ------------------------------------------------------------------------------------------------ + + fn compare_magnitudes(&self, other: &Self) -> Ordering { + if self.magnitude_hi != other.magnitude_hi { + return self.magnitude_hi.cmp(&other.magnitude_hi); + } + for i in (0..N).rev() { + if self.magnitude_lo[i] != other.magnitude_lo[i] { + return self.magnitude_lo[i].cmp(&other.magnitude_lo[i]); + } + } + Ordering::Equal + } + + fn add_assign_in_place(&mut self, rhs: &Self) { + if self.is_positive == rhs.is_positive { + let (lo, hi, _carry) = self.add_magnitudes_with_carry(rhs); + self.magnitude_lo = lo; + self.magnitude_hi = hi; + } else { + match self.compare_magnitudes(rhs) { + Ordering::Greater | Ordering::Equal => { + let (lo, hi, _borrow) = self.sub_magnitudes_with_borrow(rhs); + self.magnitude_lo = lo; + self.magnitude_hi = hi; + } + Ordering::Less => { + let (lo, hi, _borrow) = rhs.sub_magnitudes_with_borrow(self); + self.magnitude_lo = lo; + self.magnitude_hi = hi; + self.is_positive = rhs.is_positive; + } + } + } + } + + fn sub_assign_in_place(&mut self, rhs: &Self) { + let neg_rhs = -*rhs; + self.add_assign_in_place(&neg_rhs); + } + + fn mul_magnitudes(&self, other: &Self) -> ([u64; N], u32) { + // Fast paths for small N to avoid heap allocation and loops + if N == 0 { + let a2 = self.magnitude_hi as u64; + let b2 = other.magnitude_hi as u64; + let prod = a2.wrapping_mul(b2); + let hi = (prod & 0xFFFF_FFFF) as u32; + let lo: [u64; N] = [0u64; N]; + return (lo, hi); + } + + if N == 1 { + let a0 = self.magnitude_lo[0]; + let a1 = self.magnitude_hi as u64; // 32-bit value widened + let b0 = other.magnitude_lo[0]; + let b1 = other.magnitude_hi as u64; // 32-bit value widened + + let t0 = (a0 as u128) * (b0 as u128); + let lo0 = t0 as u64; + + let cross = (t0 >> 64) + + (a0 as u128) * (b1 as u128) + + (a1 as u128) * (b0 as u128); + + let hi = (cross as u64 & 0xFFFF_FFFF) as u32; + let mut lo = [0u64; N]; + lo[0] = lo0; + return (lo, hi); + } + + if N == 2 { + let a0 = self.magnitude_lo[0]; + let a1 = self.magnitude_lo[1]; + let a2 = self.magnitude_hi as u64; // 32-bit value widened + let b0 = other.magnitude_lo[0]; + let b1 = other.magnitude_lo[1]; + let b2 = other.magnitude_hi as u64; // 32-bit value widened + + // word 0 + let t0 = (a0 as u128) * (b0 as u128); + let r0 = t0 as u64; + let carry0 = t0 >> 64; + + // word 1 + let sum1 = carry0 + (a0 as u128) * (b1 as u128) + (a1 as u128) * (b0 as u128); + let r1 = sum1 as u64; + let carry1 = sum1 >> 64; + + // word 2 (only need low 32 bits) + let sum2 = carry1 + + (a0 as u128) * (b2 as u128) + + (a1 as u128) * (b1 as u128) + + (a2 as u128) * (b0 as u128); + let r2 = sum2 as u64; + let hi = (r2 & 0xFFFF_FFFF) as u32; + let mut lo = [0u64; N]; + lo[0] = r0; + lo[1] = r1; + return (lo, hi); + } + + // General path + // Product of (N*64 + 32)-bit numbers fits in (2*N*64 + 64) bits. + // Allocate 2*N + 2 u64 limbs to safely propagate carries; we'll truncate to N u64 + 32 bits. + let mut prod = vec![0u64; 2 * N + 2]; + + let self_limbs: Vec = self + .magnitude_lo + .iter() + .cloned() + .chain(core::iter::once(self.magnitude_hi as u64)) + .collect(); + + let other_limbs: Vec = other + .magnitude_lo + .iter() + .cloned() + .chain(core::iter::once(other.magnitude_hi as u64)) + .collect(); + + for i in 0..self_limbs.len() { + let mut carry: u128 = 0; + for j in 0..other_limbs.len() { + let idx = i + j; + let p = (self_limbs[i] as u128) + * (other_limbs[j] as u128) + + (prod[idx] as u128) + + carry; + prod[idx] = p as u64; + carry = p >> 64; + } + if carry > 0 { + let spill = i + other_limbs.len(); + if spill < prod.len() { + prod[spill] = prod[spill].wrapping_add(carry as u64); + } + // else: spill is beyond the truncated width; ignore (mod 2^(64*N+32)). + } + } + + // Truncate and split into lo and hi (keep only the low N u64 limbs and the low 32 bits of limb N) + let mut magnitude_lo = [0u64; N]; + if N > 0 { + magnitude_lo.copy_from_slice(&prod[0..N]); + } + let magnitude_hi = (prod[N] & 0xFFFF_FFFF) as u32; + + (magnitude_lo, magnitude_hi) + } + + // Returns final carry bit. + fn add_magnitudes_with_carry(&self, other: &Self) -> ([u64; N], u32, bool) { + let mut magnitude_lo = [0; N]; + let mut carry: u128 = 0; + + for i in 0..N { + let sum = + (self.magnitude_lo[i] as u128) + (other.magnitude_lo[i] as u128) + carry; + magnitude_lo[i] = sum as u64; + carry = sum >> 64; + } + + let sum_hi = (self.magnitude_hi as u128) + (other.magnitude_hi as u128) + carry; + let magnitude_hi = sum_hi as u32; + + let final_carry = (sum_hi >> 32) != 0; + (magnitude_lo, magnitude_hi, final_carry) + } + + // Returns final borrow bit. + fn sub_magnitudes_with_borrow(&self, other: &Self) -> ([u64; N], u32, bool) { + let mut magnitude_lo = [0u64; N]; + let mut borrow = false; + + for i in 0..N { + let (d1, b1) = self.magnitude_lo[i].overflowing_sub(other.magnitude_lo[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + magnitude_lo[i] = d2; + borrow = b1 || b2; + } + + let (hi1, b1) = self.magnitude_hi.overflowing_sub(other.magnitude_hi); + let (hi2, b2) = hi1.overflowing_sub(borrow as u32); + let final_borrow = b1 || b2; + + (magnitude_lo, hi2, final_borrow) + } +} + +// ------------------------------------------------------------------------------------------------ +// Operator traits +// ------------------------------------------------------------------------------------------------ + +impl Neg for SignedBigIntHi32 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self::new(self.magnitude_lo, self.magnitude_hi, !self.is_positive) + } +} + +impl Add for SignedBigIntHi32 { + type Output = Self; + + fn add(mut self, rhs: Self) -> Self::Output { + self.add_assign_in_place(&rhs); + self + } +} + +impl AddAssign for SignedBigIntHi32 { + fn add_assign(&mut self, rhs: Self) { + self.add_assign_in_place(&rhs); + } +} + +impl Sub for SignedBigIntHi32 { + type Output = Self; + + fn sub(mut self, rhs: Self) -> Self::Output { + self.sub_assign_in_place(&rhs); + self + } +} + +impl SubAssign for SignedBigIntHi32 { + fn sub_assign(&mut self, rhs: Self) { + self.sub_assign_in_place(&rhs); + } +} + +impl MulAssign for SignedBigIntHi32 { + fn mul_assign(&mut self, rhs: Self) { + *self = self.mul(&rhs); + } +} + +// Reference variants for efficiency +impl Add<&SignedBigIntHi32> for SignedBigIntHi32 { + type Output = SignedBigIntHi32; + + #[inline] + fn add(mut self, rhs: &SignedBigIntHi32) -> Self::Output { + self.add_assign_in_place(rhs); + self + } +} + +impl Sub<&SignedBigIntHi32> for SignedBigIntHi32 { + type Output = SignedBigIntHi32; + + #[inline] + fn sub(mut self, rhs: &SignedBigIntHi32) -> Self::Output { + self.sub_assign_in_place(rhs); + self + } +} + +impl Mul<&SignedBigIntHi32> for SignedBigIntHi32 { + type Output = SignedBigIntHi32; + + #[inline] + fn mul(self, rhs: &SignedBigIntHi32) -> Self::Output { + let (lo, hi) = self.mul_magnitudes(rhs); + let is_positive = !(self.is_positive ^ rhs.is_positive); + Self::new(lo, hi, is_positive) + } +} + +impl AddAssign<&SignedBigIntHi32> for SignedBigIntHi32 { + #[inline] + fn add_assign(&mut self, rhs: &SignedBigIntHi32) { + self.add_assign_in_place(rhs); + } +} + +impl SubAssign<&SignedBigIntHi32> for SignedBigIntHi32 { + #[inline] + fn sub_assign(&mut self, rhs: &SignedBigIntHi32) { + self.sub_assign_in_place(rhs); + } +} + +impl MulAssign<&SignedBigIntHi32> for SignedBigIntHi32 { + #[inline] + fn mul_assign(&mut self, rhs: &SignedBigIntHi32) { + *self = self.mul(rhs); + } +} + +// By-ref binary operator variants to avoid copying both operands +impl<'a, const N: usize> Add for &'a SignedBigIntHi32 { + type Output = SignedBigIntHi32; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.add_assign_in_place(rhs); + out + } +} + +impl<'a, const N: usize> Sub for &'a SignedBigIntHi32 { + type Output = SignedBigIntHi32; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.sub_assign_in_place(rhs); + out + } +} + +impl<'a, const N: usize> Mul for &'a SignedBigIntHi32 { + type Output = SignedBigIntHi32; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + let (lo, hi) = self.mul_magnitudes(rhs); + let is_positive = !(self.is_positive ^ rhs.is_positive); + SignedBigIntHi32::new(lo, hi, is_positive) + } +} + +// ------------------------------------------------------------------------------------------------ +// From traits +// ------------------------------------------------------------------------------------------------ + +impl From for S96 { + fn from(val: i64) -> Self { + Self::new([val.unsigned_abs()], 0, val.is_positive()) + } +} + +impl From for S96 { + fn from(val: u64) -> Self { + Self::new([val], 0, true) + } +} + +impl From for S160 { + fn from(val: i128) -> Self { + let is_positive = val.is_positive(); + let mag = val.unsigned_abs(); + let lo = mag as u64; + let hi = (mag >> 64) as u64; + Self::new([lo, hi], 0, is_positive) + } +} + +impl From for S160 { + fn from(val: u128) -> Self { + let lo = val as u64; + let hi = (val >> 64) as u64; + Self::new([lo, hi], 0, true) + } +} From ab80fd932766671700242429ec212fe4195f9fd0 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 16 Sep 2025 11:10:41 -0400 Subject: [PATCH 19/38] make allocative default, add more custom type ops --- bench-templates/src/macros/field.rs | 6 +- ec/src/lib.rs | 6 +- ec/src/pairing.rs | 10 +- ec/src/scalar_mul/fixed_base.rs | 2 +- ec/src/scalar_mul/mod.rs | 2 +- ff/Cargo.toml | 3 +- ff/src/biginteger/arithmetic.rs | 4 +- ff/src/biginteger/i8_or_i96.rs | 162 +- ff/src/biginteger/mod.rs | 25 +- ff/src/biginteger/signed.rs | 313 ++- ff/src/biginteger/signed_hi_32.rs | 86 +- ff/src/biginteger/tests.rs | 1943 +++++++++-------- ff/src/fields/models/fp/mod.rs | 2 - ff/src/fields/models/fp/montgomery_backend.rs | 87 +- test-curves/benches/bigint.rs | 34 +- test-curves/benches/small_mul.rs | 116 +- test-curves/src/bn254/fq.rs | 2 +- test-curves/src/bn254/fr.rs | 2 +- test-curves/src/bn254/g1.rs | 6 +- test-curves/src/bn254/test.rs | 4 +- 20 files changed, 1640 insertions(+), 1175 deletions(-) diff --git a/bench-templates/src/macros/field.rs b/bench-templates/src/macros/field.rs index 40fe597b1..db11baabf 100644 --- a/bench-templates/src/macros/field.rs +++ b/bench-templates/src/macros/field.rs @@ -405,16 +405,14 @@ macro_rules! prime_field { f[i].into_bigint() }) }); - let u64s = (0..SAMPLES) - .map(|_| rng.next_u64()) - .collect::>(); + let u64s = (0..SAMPLES).map(|_| rng.next_u64()).collect::>(); conversions.bench_function("From u64", |b| { let mut i = 0; b.iter(|| { i = (i + 1) % SAMPLES; <$F>::from_u64(u64s[i]) }) - }); + }); conversions.finish() } }; diff --git a/ec/src/lib.rs b/ec/src/lib.rs index ba99d4c87..47b8437e5 100644 --- a/ec/src/lib.rs +++ b/ec/src/lib.rs @@ -28,11 +28,7 @@ use ark_std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, vec::*, }; -pub use scalar_mul::{ - fixed_base::FixedBase, - variable_base::VariableBaseMSM, - ScalarMul, -}; +pub use scalar_mul::{fixed_base::FixedBase, variable_base::VariableBaseMSM, ScalarMul}; use zeroize::Zeroize; pub use ark_ff::AdditiveGroup; diff --git a/ec/src/pairing.rs b/ec/src/pairing.rs index a3aa83e26..f62d1be72 100644 --- a/ec/src/pairing.rs +++ b/ec/src/pairing.rs @@ -102,8 +102,14 @@ pub trait Pairing: Sized + 'static + Copy + Debug + Sync + Send + Eq { a: impl IntoIterator>, b: impl IntoIterator>, ) -> MillerLoopOutput { - let a_cloned = a.into_iter().map(|x| x.as_ref().clone()).collect::>(); - let b_cloned = b.into_iter().map(|x| x.as_ref().clone()).collect::>(); + let a_cloned = a + .into_iter() + .map(|x| x.as_ref().clone()) + .collect::>(); + let b_cloned = b + .into_iter() + .map(|x| x.as_ref().clone()) + .collect::>(); Self::multi_miller_loop(a_cloned, b_cloned) } diff --git a/ec/src/scalar_mul/fixed_base.rs b/ec/src/scalar_mul/fixed_base.rs index c9e5270d0..ce8001ccd 100644 --- a/ec/src/scalar_mul/fixed_base.rs +++ b/ec/src/scalar_mul/fixed_base.rs @@ -95,4 +95,4 @@ impl FixedBase { .map(|e| Self::windowed_mul::(outerc, window, table, e)) .collect::>() } -} \ No newline at end of file +} diff --git a/ec/src/scalar_mul/mod.rs b/ec/src/scalar_mul/mod.rs index 81a4c6595..cb38e432d 100644 --- a/ec/src/scalar_mul/mod.rs +++ b/ec/src/scalar_mul/mod.rs @@ -1,8 +1,8 @@ pub mod glv; pub mod wnaf; -pub mod variable_base; pub mod fixed_base; +pub mod variable_base; use crate::{ short_weierstrass::{Affine, Projective, SWCurveConfig}, diff --git a/ff/Cargo.toml b/ff/Cargo.toml index dc51b5deb..3d5f0b156 100644 --- a/ff/Cargo.toml +++ b/ff/Cargo.toml @@ -29,7 +29,7 @@ zeroize = { workspace = true, features = ["zeroize_derive"] } num-bigint.workspace = true digest = { workspace = true, features = ["alloc"] } itertools.workspace = true -allocative = { version = "0.3.4", optional = true } +allocative = "0.3.4" [dev-dependencies] ark-test-curves = { workspace = true, features = [ @@ -53,4 +53,3 @@ default = [] std = ["ark-std/std", "ark-serialize/std", "itertools/use_std"] parallel = ["std", "rayon", "ark-std/parallel", "ark-serialize/parallel"] asm = [] -allocative = ["dep:allocative"] diff --git a/ff/src/biginteger/arithmetic.rs b/ff/src/biginteger/arithmetic.rs index 493758ae7..6f6a67f26 100644 --- a/ff/src/biginteger/arithmetic.rs +++ b/ff/src/biginteger/arithmetic.rs @@ -135,7 +135,9 @@ pub fn add_limbs_shifted_inplace( let mut i = 0usize; while i < limbs.len() { let idx = lane_offset + i; - if idx >= N { break; } + if idx >= N { + break; + } let tmp = (acc[idx] as u128) + (limbs[i] as u128) + (carry as u128); acc[idx] = tmp as u64; carry = (tmp >> 64) as u64; diff --git a/ff/src/biginteger/i8_or_i96.rs b/ff/src/biginteger/i8_or_i96.rs index d1d94ecc8..bec6ee803 100644 --- a/ff/src/biginteger/i8_or_i96.rs +++ b/ff/src/biginteger/i8_or_i96.rs @@ -1,4 +1,5 @@ -use core::ops::{Add, Sub, Mul, AddAssign, SubAssign, MulAssign}; +use crate::biginteger::{S160, S224}; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; /// Compact signed integer optimized for the common `i8` case, widening to a 96-bit /// split representation when needed (low 64 bits in `large_lo`, next 32 bits in `large_hi`). @@ -22,7 +23,9 @@ use core::ops::{Add, Sub, Mul, AddAssign, SubAssign, MulAssign}; /// This is equivalent to sign-extending `large_hi` and zero-extending `large_lo`. /// /// ## Notes: -/// - Arithmetic uses exact `i128` semantics (no modular reduction, no saturation). +/// - Arithmetic uses `i128` for intermediate computation but the representation is signed 96-bit. +/// Results are canonicalized to the smallest fitting form. If a result does not fit in 96 bits, +/// it is truncated to signed 96-bit two's complement (wrapping modulo 2^96). /// - The `neg` implementation avoids `i8` overflow by widening `i8::MIN` to the wide form. /// - Conversions are total: `to_i128()` always returns the exact value. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -315,7 +318,7 @@ impl Ord for I8OrI96 { Ordering::Equal => { // If high parts are the same, the order is determined by the low parts. self_lo.cmp(&other_lo) - } + }, order => order, } } @@ -445,4 +448,155 @@ impl core::convert::TryFrom for I8OrI96 { Err(TryFromU128Error) } } -} \ No newline at end of file +} + +impl Mul for I8OrI96 { + type Output = S224; + + #[inline] + fn mul(self, rhs: S160) -> Self::Output { + // Determine sign of self + let self_is_positive = if self.is_small { + self.small_i8 >= 0 + } else { + self.large_hi >= 0 + }; + + // Extract rhs magnitude limbs + let rhs_lo = rhs.magnitude_lo(); + let b0 = rhs_lo[0]; + let b1 = rhs_lo[1]; + let b2 = rhs.magnitude_hi() as u64; // widen for math + let b1_is_zero = b1 == 0; + let b2_is_zero = b2 == 0; + let rhs_is_zero = (b0 | b1 | b2) == 0; + + if rhs_is_zero { + return S224::zero(); + } + + // Compute absolute magnitude of self as 96-bit split (x0: u64, x1: u32) + let (x0, x1_u32) = if self.is_small { + let v = self.small_i8; + let k = if v < 0 { + (-(v as i16)) as u64 + } else { + v as u64 + }; + (k, 0u32) + } else { + let hi = self.large_hi; + if hi >= 0 { + (self.large_lo, hi as u32) + } else { + // Two's-complement absolute: (~lo, ~hi) + 1 + let inv_lo = !self.large_lo; + let inv_hi = !(hi as u32); + let sum_lo = inv_lo.wrapping_add(1); + let carry = (sum_lo == 0) as u32; + let sum_hi = inv_hi.wrapping_add(carry); + (sum_lo, sum_hi) + } + }; + + // Compute magnitude product truncated to 224 bits: [r0,r1,r2] + hi32 + let (r0, r1, r2, hi32) = if x1_u32 == 0 { + // Fast path: scalar (<= 8-bit) * 160-bit + let k = x0; + let mut c0 = 0u64; + let r0 = mac_with_carry!(0u64, b0, k, &mut c0); + + if b1_is_zero { + if b2_is_zero { + // Only 64-bit rhs + let r1 = c0; + (r0, r1, 0u64, 0u32) + } else { + // 128-bit rhs via b2 only + let r1 = c0; + let mut hi = 0u64; + let r2 = mac_with_carry!(0u64, b2, k, &mut hi); + let hi32 = hi as u32; + (r0, r1, r2, hi32) + } + } else if b2_is_zero { + // 128-bit rhs via b1 only + let mut c1 = c0; + let r1p = mac_with_carry!(0u64, b1, k, &mut c1); + let r1 = adc!(r1p, 0u64, &mut c1); + let r2 = c1; + (r0, r1, r2, 0u32) + } else { + // Full 160-bit rhs + let mut c1 = c0; + let r1 = mac_with_carry!(0u64, b1, k, &mut c1); + + let mut c2 = 0u64; + let mut r2 = mac_with_carry!(0u64, b2, k, &mut c2); + r2 = adc!(r2, c1, &mut c2); + let hi32 = c2 as u32; + (r0, r1, r2, hi32) + } + } else { + // General 96-bit (2 limbs: x0, x1_u32) times 160-bit (3 limbs: b0, b1, b2) + let x1 = x1_u32 as u64; + + let mut c0 = 0u64; + let r0 = mac_with_carry!(0u64, x0, b0, &mut c0); + + if b1_is_zero { + if b2_is_zero { + // Only 64-bit rhs + let mut c1 = c0; + let r1 = mac_with_carry!(0u64, x1, b0, &mut c1); + let r2 = c1; + (r0, r1, r2, 0u32) + } else { + // No b1, but have b2 + let mut c1 = c0; + let r1 = mac_with_carry!(0u64, x1, b0, &mut c1); + + let mut c2 = c1; + let r2 = mac_with_carry!(0u64, x0, b2, &mut c2); + + let mut carry_hi = c2; + crate::biginteger::arithmetic::mac_discard(carry_hi, x1, b2, &mut carry_hi); + let hi32 = carry_hi as u32; + (r0, r1, r2, hi32) + } + } else if b2_is_zero { + // No b2, but have b1 + let mut c1 = c0; + let mut r1 = mac_with_carry!(0u64, x0, b1, &mut c1); + r1 = mac_with_carry!(r1, x1, b0, &mut c1); + + let mut c2 = c1; + let r2 = mac_with_carry!(0u64, x1, b1, &mut c2); + let hi32 = c2 as u32; + (r0, r1, r2, hi32) + } else { + // Full 160-bit rhs + let mut c1 = c0; + let mut r1 = mac_with_carry!(0u64, x0, b1, &mut c1); + r1 = mac_with_carry!(r1, x1, b0, &mut c1); + + let mut c2 = c1; + let mut r2 = mac_with_carry!(0u64, x0, b2, &mut c2); + r2 = mac_with_carry!(r2, x1, b1, &mut c2); + + let mut carry_hi = c2; + crate::biginteger::arithmetic::mac_discard(carry_hi, x1, b2, &mut carry_hi); + let hi32 = carry_hi as u32; + (r0, r1, r2, hi32) + } + }; + + // Combine sign; canonicalize zero to positive + let mut is_positive = !(self_is_positive ^ rhs.is_positive()); + if (r0 | r1 | r2) == 0 && hi32 == 0 { + is_positive = true; + } + + S224::new([r0, r1, r2], hi32, is_positive) + } +} diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 151b47232..55bbb55f2 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -2,6 +2,7 @@ use crate::{ bits::{BitIteratorBE, BitIteratorLE}, const_for, UniformRand, }; +use allocative::Allocative; #[allow(unused)] use ark_ff_macros::unroll_for_loops; use ark_serialize::{ @@ -30,15 +31,15 @@ use zeroize::Zeroize; pub mod arithmetic; pub mod signed; -pub use signed::{SignedBigInt, S64, S128, S196, S256}; +pub use signed::{SignedBigInt, S128, S196, S256, S64}; pub mod signed_hi_32; -pub use signed_hi_32::{SignedBigIntHi32, S96, S160, S224}; +pub use signed_hi_32::{SignedBigIntHi32, S160, S224, S96}; pub mod i8_or_i96; pub use i8_or_i96::I8OrI96; -#[derive(Copy, Clone, PartialEq, Eq, Hash, Zeroize)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, Zeroize, Allocative)] pub struct BigInt(pub [u64; N]); impl Default for BigInt { @@ -388,7 +389,11 @@ impl BigInt { /// Fused multiply-add with truncation: acc += self * other, fitting into P limbs; overflow is ignored. /// This is a generic version for arbitrary limb widths of `self` and `other`. #[inline] - pub fn fmadd_trunc(&self, other: &BigInt, acc: &mut BigInt

) { + pub fn fmadd_trunc( + &self, + other: &BigInt, + acc: &mut BigInt

, + ) { let i_limit = core::cmp::min(N, P); for i in 0..i_limit { let mut carry = 0u64; @@ -621,11 +626,7 @@ impl BigInteger for BigInt { #[inline] #[unroll_for_loops(8)] - fn fmu64a_carry_propagating( - &self, - other: u64, - acc: &mut BigInt, - ) { + fn fmu64a_carry_propagating(&self, other: u64, acc: &mut BigInt) { // ensure NPLUS2 is the correct size (N + 2 limbs) debug_assert!(NPLUS2 == N + 2); if other == 0 || self.is_zero() { @@ -1593,11 +1594,7 @@ pub trait BigInteger: /// NEW! Fused multiply-accumulate with a u64 multiplier and explicit overflow propagation. /// Accumulates `self * other` into `acc`, which must have two extra limbs (N + 2). /// Any overflow from limb N is carried into limb N+1 instead of wrapping. - fn fmu64a_carry_propagating( - &self, - other: u64, - acc: &mut BigInt, - ); + fn fmu64a_carry_propagating(&self, other: u64, acc: &mut BigInt); /// NEW! Multiplies self by a u128, returning a bigint with two extra limbs to hold overflow. fn mul_u128_w_carry( diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs index fd3a0d675..62d540ddc 100644 --- a/ff/src/biginteger/signed.rs +++ b/ff/src/biginteger/signed.rs @@ -1,12 +1,14 @@ use crate::biginteger::{BigInt, BigInteger}; +use allocative::Allocative; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; use core::cmp::Ordering; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -#[cfg(feature = "allocative")] -use allocative::Allocative; /// A signed big integer using arkworks BigInt for magnitude and a sign bit -#[cfg_attr(feature = "allocative", derive(Allocative))] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] pub struct SignedBigInt { pub magnitude: BigInt, pub is_positive: bool, @@ -26,8 +28,12 @@ impl SignedBigInt { let idx = i - 1; let a = if idx < N { self.magnitude.0[idx] } else { 0u64 }; let b = if idx < M { rhs.magnitude.0[idx] } else { 0u64 }; - if a > b { return Ordering::Greater; } - if a < b { return Ordering::Less; } + if a > b { + return Ordering::Greater; + } + if a < b { + return Ordering::Less; + } i -= 1; } Ordering::Equal @@ -44,19 +50,28 @@ impl SignedBigInt { /// Construct from an existing BigInt magnitude and sign. #[inline] pub fn from_bigint(magnitude: BigInt, is_positive: bool) -> Self { - Self { magnitude, is_positive } + Self { + magnitude, + is_positive, + } } /// Zero value with a positive sign (negative zero allowed elsewhere). #[inline] pub fn zero() -> Self { - Self { magnitude: BigInt::from(0u64), is_positive: true } + Self { + magnitude: BigInt::from(0u64), + is_positive: true, + } } /// One with a positive sign. #[inline] pub fn one() -> Self { - Self { magnitude: BigInt::from(1u64), is_positive: true } + Self { + magnitude: BigInt::from(1u64), + is_positive: true, + } } /// Return true if magnitude is zero (sign is not considered). @@ -67,15 +82,21 @@ impl SignedBigInt { /// Borrow the magnitude (absolute value). #[inline] - pub fn as_magnitude(&self) -> &BigInt { &self.magnitude } + pub fn as_magnitude(&self) -> &BigInt { + &self.magnitude + } /// Return the magnitude limbs by value (copy). #[inline] - pub fn magnitude_limbs(&self) -> [u64; N] { self.magnitude.0 } + pub fn magnitude_limbs(&self) -> [u64; N] { + self.magnitude.0 + } /// Borrow the magnitude limbs as a slice (avoids copying the array). #[inline] - pub fn magnitude_slice(&self) -> &[u64] { self.magnitude.as_ref() } + pub fn magnitude_slice(&self) -> &[u64] { + self.magnitude.as_ref() + } /// Return true iff the value is non-negative. #[inline] @@ -85,15 +106,24 @@ impl SignedBigInt { /// Compute self + other modulo 2^(64*N); carry beyond N limbs is dropped. #[inline] - pub fn add(mut self, other: Self) -> Self { self += other; self } + pub fn add(mut self, other: Self) -> Self { + self += other; + self + } /// Compute self - other modulo 2^(64*N); borrow beyond N limbs is dropped. #[inline] - pub fn sub(mut self, other: Self) -> Self { self -= other; self } + pub fn sub(mut self, other: Self) -> Self { + self -= other; + self + } /// Compute self * other and keep only the low N limbs; high limbs are discarded. #[inline] - pub fn mul(mut self, other: Self) -> Self { self *= other; self } + pub fn mul(mut self, other: Self) -> Self { + self *= other; + self + } /// Flip the sign; zero is not canonicalized (negative zero may occur). #[inline] @@ -112,13 +142,13 @@ impl SignedBigInt { match self.magnitude.cmp(&rhs.magnitude) { Ordering::Greater | Ordering::Equal => { let _borrow = self.magnitude.sub_with_borrow(&rhs.magnitude); - } + }, Ordering::Less => { // Minimize copies: move rhs magnitude into place and subtract old self let old = core::mem::replace(&mut self.magnitude, rhs.magnitude); let _borrow = self.magnitude.sub_with_borrow(&old); self.is_positive = rhs.is_positive; - } + }, } } } @@ -135,13 +165,13 @@ impl SignedBigInt { Ordering::Greater | Ordering::Equal => { let _borrow = self.magnitude.sub_with_borrow(&rhs.magnitude); // sign stays the same - } + }, Ordering::Less => { // Result takes rhs magnitude minus self magnitude, sign flips let old = core::mem::replace(&mut self.magnitude, rhs.magnitude); let _borrow = self.magnitude.sub_with_borrow(&old); self.is_positive = !self.is_positive; - } + }, } } } @@ -176,7 +206,10 @@ impl SignedBigInt { if lim < M { res.0[lim] = carry as u64; } - SignedBigInt:: { magnitude: res, is_positive: self.is_positive } + SignedBigInt:: { + magnitude: res, + is_positive: self.is_positive, + } } else { // Different signs -> subtract smaller magnitude from larger match self.magnitude.cmp(&rhs.magnitude) { @@ -195,8 +228,11 @@ impl SignedBigInt { borrow = b1; } } - SignedBigInt:: { magnitude: res, is_positive: self.is_positive } - } + SignedBigInt:: { + magnitude: res, + is_positive: self.is_positive, + } + }, Ordering::Less => { let mut res = BigInt::::zero(); let lim = core::cmp::min(N, M); @@ -212,8 +248,11 @@ impl SignedBigInt { borrow = b1; } } - SignedBigInt:: { magnitude: res, is_positive: rhs.is_positive } - } + SignedBigInt:: { + magnitude: res, + is_positive: rhs.is_positive, + } + }, } } } @@ -235,7 +274,10 @@ impl SignedBigInt { if lim < M { res.0[lim] = carry as u64; } - SignedBigInt:: { magnitude: res, is_positive: self.is_positive } + SignedBigInt:: { + magnitude: res, + is_positive: self.is_positive, + } } else { // different signs wrt subtraction => subtract magnitudes match self.magnitude.cmp(&rhs.magnitude) { @@ -254,8 +296,11 @@ impl SignedBigInt { borrow = b1; } } - SignedBigInt:: { magnitude: res, is_positive: self.is_positive } - } + SignedBigInt:: { + magnitude: res, + is_positive: self.is_positive, + } + }, Ordering::Less => { let mut res = BigInt::::zero(); let lim = core::cmp::min(N, M); @@ -271,8 +316,11 @@ impl SignedBigInt { borrow = b1; } } - SignedBigInt:: { magnitude: res, is_positive: !self.is_positive } - } + SignedBigInt:: { + magnitude: res, + is_positive: !self.is_positive, + } + }, } } } @@ -280,7 +328,10 @@ impl SignedBigInt { /// Truncated mixed-width addition: compute (self + rhs) where rhs can have a /// different limb count, and fit into P limbs; overflow is ignored. #[inline] - pub fn add_trunc_mixed(&self, rhs: &SignedBigInt) -> SignedBigInt

{ + pub fn add_trunc_mixed( + &self, + rhs: &SignedBigInt, + ) -> SignedBigInt

{ // Case 1: same signs => add magnitudes, sign = self.is_positive if self.is_positive == rhs.is_positive { let mut res = BigInt::

::zero(); @@ -310,8 +361,13 @@ impl SignedBigInt { k += 1; } } - if k < P { res.0[k] = carry as u64; } - return SignedBigInt::

{ magnitude: res, is_positive: self.is_positive }; + if k < P { + res.0[k] = carry as u64; + } + return SignedBigInt::

{ + magnitude: res, + is_positive: self.is_positive, + }; } // Case 2: different signs => subtract smaller magnitude from larger @@ -348,8 +404,11 @@ impl SignedBigInt { k += 1; } } - SignedBigInt::

{ magnitude: res, is_positive: self.is_positive } - } + SignedBigInt::

{ + magnitude: res, + is_positive: self.is_positive, + } + }, Ordering::Less => { // res_mag = rhs.mag - self.mag; sign = rhs.is_positive let mut res = BigInt::

::zero(); @@ -380,22 +439,35 @@ impl SignedBigInt { k += 1; } } - SignedBigInt::

{ magnitude: res, is_positive: rhs.is_positive } - } + SignedBigInt::

{ + magnitude: res, + is_positive: rhs.is_positive, + } + }, } } /// Truncated mul: compute self * rhs and fit into P limbs; no assumption on P; overflow ignored. #[inline] - pub fn mul_trunc(&self, rhs: &SignedBigInt) -> SignedBigInt

{ + pub fn mul_trunc( + &self, + rhs: &SignedBigInt, + ) -> SignedBigInt

{ let mag = self.magnitude.mul_trunc::(&rhs.magnitude); let sign = self.is_positive == rhs.is_positive; - SignedBigInt::

{ magnitude: mag, is_positive: sign } + SignedBigInt::

{ + magnitude: mag, + is_positive: sign, + } } /// Fused multiply-add: acc += self * rhs, fitted into P limbs; overflow is ignored. #[inline] - pub fn fmadd_trunc(&self, rhs: &SignedBigInt, acc: &mut SignedBigInt

) { + pub fn fmadd_trunc( + &self, + rhs: &SignedBigInt, + acc: &mut SignedBigInt

, + ) { let prod_mag = self.magnitude.mul_trunc::(&rhs.magnitude); let prod_sign = self.is_positive == rhs.is_positive; if acc.is_positive == prod_sign { @@ -404,12 +476,12 @@ impl SignedBigInt { match acc.magnitude.cmp(&prod_mag) { Ordering::Greater | Ordering::Equal => { let _ = acc.magnitude.sub_with_borrow(&prod_mag); - } + }, Ordering::Less => { let old = core::mem::replace(&mut acc.magnitude, prod_mag); let _ = acc.magnitude.sub_with_borrow(&old); acc.is_positive = prod_sign; - } + }, } } } @@ -463,7 +535,10 @@ impl SignedBigInt { /// Truncated mixed-width subtraction: compute (self - rhs) where rhs can have a /// different limb count, and fit into P limbs; overflow is ignored. #[inline] - pub fn sub_trunc_mixed(&self, rhs: &SignedBigInt) -> SignedBigInt

{ + pub fn sub_trunc_mixed( + &self, + rhs: &SignedBigInt, + ) -> SignedBigInt

{ // Case 1: different signs => addition of magnitudes, sign = self.is_positive if self.is_positive != rhs.is_positive { let mut res = BigInt::

::zero(); @@ -476,7 +551,10 @@ impl SignedBigInt { res.0[i] = s2; carry = (c1 as u8) | (c2 as u8); } - return SignedBigInt::

{ magnitude: res, is_positive: self.is_positive }; + return SignedBigInt::

{ + magnitude: res, + is_positive: self.is_positive, + }; } // Case 2: same signs => subtract smaller magnitude from larger; sign accordingly @@ -490,8 +568,14 @@ impl SignedBigInt { let idx = i - 1; let a = if idx < N { self.magnitude.0[idx] } else { 0u64 }; let b = if idx < M { rhs.magnitude.0[idx] } else { 0u64 }; - if a > b { ordering = Ordering::Greater; break; } - if a < b { ordering = Ordering::Less; break; } + if a > b { + ordering = Ordering::Greater; + break; + } + if a < b { + ordering = Ordering::Less; + break; + } i -= 1; } ordering @@ -515,8 +599,11 @@ impl SignedBigInt { borrow = b1; } } - SignedBigInt::

{ magnitude: res, is_positive: self.is_positive } - } + SignedBigInt::

{ + magnitude: res, + is_positive: self.is_positive, + } + }, Ordering::Less => { // res_mag = rhs.mag - self.mag; sign = !self.is_positive let mut res = BigInt::

::zero(); @@ -534,8 +621,11 @@ impl SignedBigInt { borrow = b1; } } - SignedBigInt::

{ magnitude: res, is_positive: !self.is_positive } - } + SignedBigInt::

{ + magnitude: res, + is_positive: !self.is_positive, + } + }, } } } @@ -583,16 +673,26 @@ impl From for SignedBigInt { } // Specializations for common sizes -impl SignedBigInt<1> { +impl S64 { /// Convert to i128; any u64 magnitude fits for both signs. #[inline] pub fn to_i128(&self) -> i128 { let magnitude = self.magnitude.0[0]; - if self.is_positive { magnitude as i128 } else { -(magnitude as i128) } + if self.is_positive { + magnitude as i128 + } else { + -(magnitude as i128) + } + } + + /// Return the magnitude as u64 + #[inline] + pub fn magnitude_as_u64(&self) -> u64 { + self.magnitude.0[0] } } -impl SignedBigInt<2> { +impl S128 { /// Convert to i128 using 2^127 bounds: positive requires mag <= i128::MAX; negative allows mag == 2^127. #[inline] pub fn to_i128(&self) -> Option { @@ -600,7 +700,9 @@ impl SignedBigInt<2> { let lo = self.magnitude.0[0]; let hi_top_bit = hi >> 63; // bit 127 if self.is_positive { - if hi_top_bit != 0 { return None; } + if hi_top_bit != 0 { + return None; + } let mag = ((hi as u128) << 64) | (lo as u128); Some(mag as i128) } else { @@ -620,6 +722,40 @@ impl SignedBigInt<2> { pub fn magnitude_as_u128(&self) -> u128 { (self.magnitude.0[1] as u128) << 64 | (self.magnitude.0[0] as u128) } + + /// Construct from u128 and sign + #[inline] + pub fn from_u128_and_sign(value: u128, is_positive: bool) -> Self { + Self::new([value as u64, (value >> 64) as u64], is_positive) + } + + /// Exact product of u64 and i64 into S128 (u64 × s64 -> s128) + #[inline] + pub fn from_u64_mul_i64(u: u64, s: i64) -> Self { + let mag = (u as u128) * (s.unsigned_abs() as u128); + Self::from_u128_and_sign(mag, s >= 0) + } + + /// Exact product of i64 and u64 into S128 (s64 × u64 -> s128) + #[inline] + pub fn from_i64_mul_u64(s: i64, u: u64) -> Self { + Self::from_u64_mul_i64(u, s) + } + + /// Exact product of two u64 into S128 (u64 × u64 -> s128, non-negative) + #[inline] + pub fn from_u64_mul_u64(a: u64, b: u64) -> Self { + let mag = (a as u128) * (b as u128); + Self::from_u128_and_sign(mag, true) + } + + /// Exact product of two i64 into S128 (s64 × s64 -> s128) + #[inline] + pub fn from_i64_mul_i64(a: i64, b: i64) -> Self { + let mag = (a.unsigned_abs() as u128) * (b.unsigned_abs() as u128); + let is_positive = (a >= 0) == (b >= 0); + Self::from_u128_and_sign(mag, is_positive) + } } /// Helper function for single u64 signed arithmetic @@ -778,3 +914,74 @@ impl core::ops::Mul for &SignedBigInt { out } } + +// =============================================== +// Ordering and canonical serialization +// =============================================== + +impl core::cmp::PartialOrd for SignedBigInt { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl core::cmp::Ord for SignedBigInt { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + match (self.is_positive, other.is_positive) { + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + _ => { + let ord = self.magnitude.cmp(&other.magnitude); + if self.is_positive { + ord + } else { + ord.reverse() + } + }, + } + } +} + +impl CanonicalSerialize for SignedBigInt { + #[inline] + fn serialize_with_mode( + &self, + mut w: W, + compress: Compress, + ) -> Result<(), SerializationError> { + // encode sign as a single byte then magnitude + (self.is_positive as u8).serialize_with_mode(&mut w, compress)?; + self.magnitude.serialize_with_mode(w, compress) + } + + #[inline] + fn serialized_size(&self, compress: Compress) -> usize { + (self.is_positive as u8).serialized_size(compress) + + self.magnitude.serialized_size(compress) + } +} + +impl CanonicalDeserialize for SignedBigInt { + #[inline] + fn deserialize_with_mode( + mut r: R, + compress: Compress, + validate: Validate, + ) -> Result { + let sign_u8 = u8::deserialize_with_mode(&mut r, compress, validate)?; + let mag = BigInt::::deserialize_with_mode(r, compress, validate)?; + Ok(SignedBigInt { + magnitude: mag, + is_positive: sign_u8 != 0, + }) + } +} + +impl Valid for SignedBigInt { + #[inline] + fn check(&self) -> Result<(), SerializationError> { + self.magnitude.check() + } +} diff --git a/ff/src/biginteger/signed_hi_32.rs b/ff/src/biginteger/signed_hi_32.rs index 7a777a7cc..da6f98e1d 100644 --- a/ff/src/biginteger/signed_hi_32.rs +++ b/ff/src/biginteger/signed_hi_32.rs @@ -1,18 +1,24 @@ -use core::ops::{Add, Sub, Mul, Neg, AddAssign, SubAssign, MulAssign}; +use allocative::Allocative; use ark_std::cmp::Ordering; use ark_std::vec::Vec; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -/// Compact signed big-integer parameterized by limb count `N`, with top limb being u32 +/// Compact signed big-integer parameterized by limb count `N` (total width = `N*64 + 32` bits). /// -/// Representation: -/// - `magnitude_lo: [u64; N]` stores low limbs in little-endian order (index 0 is least significant). -/// - `magnitude_hi: i32` is the high 32-bit tail -/// - `is_positive: bool` is the sign +/// Representation (sign-magnitude): +/// - `magnitude_lo: [u64; N]` holds the low limbs in little-endian order (index 0 is least significant). +/// - `magnitude_hi: u32` holds the high 32-bit tail of the magnitude. +/// - `is_positive: bool` is the sign flag. The magnitude stores the absolute value. +/// +/// Arithmetic semantics: +/// - Addition, subtraction, and multiplication operate on magnitudes modulo `2^(64*N + 32)` +/// and then set the sign via standard sign rules. +/// - Zero is not normalized: a zero magnitude can be paired with either sign. Equality is structural, +/// so `+0 != -0`. Callers that require canonical zero should normalize externally. /// /// Notes: -/// - For most applications, `N` is typically ≤ 3, but the API supports larger `N`. -#[cfg_attr(feature = "allocative", derive(Allocative))] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +/// - Specialized fast paths exist for `N ∈ {0,1,2}`; larger `N` uses a generic path. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] pub struct SignedBigIntHi32 { /// Little-endian low limbs: limb 0 = low 64 bits, limb 1 = next 64 bits, and so on magnitude_lo: [u64; N], @@ -84,7 +90,8 @@ impl SignedBigIntHi32 { self.magnitude_hi } - /// Returns `true` if the number is non-negative. + /// Returns the sign flag (`true` for a positive sign). + /// Note: zero is not canonicalized; a zero magnitude can have either sign. pub const fn is_positive(&self) -> bool { self.is_positive } @@ -130,13 +137,13 @@ impl SignedBigIntHi32 { let (lo, hi, _borrow) = self.sub_magnitudes_with_borrow(rhs); self.magnitude_lo = lo; self.magnitude_hi = hi; - } + }, Ordering::Less => { let (lo, hi, _borrow) = rhs.sub_magnitudes_with_borrow(self); self.magnitude_lo = lo; self.magnitude_hi = hi; self.is_positive = rhs.is_positive; - } + }, } } } @@ -166,9 +173,7 @@ impl SignedBigIntHi32 { let t0 = (a0 as u128) * (b0 as u128); let lo0 = t0 as u64; - let cross = (t0 >> 64) - + (a0 as u128) * (b1 as u128) - + (a1 as u128) * (b0 as u128); + let cross = (t0 >> 64) + (a0 as u128) * (b1 as u128) + (a1 as u128) * (b0 as u128); let hi = (cross as u64 & 0xFFFF_FFFF) as u32; let mut lo = [0u64; N]; @@ -230,8 +235,7 @@ impl SignedBigIntHi32 { let mut carry: u128 = 0; for j in 0..other_limbs.len() { let idx = i + j; - let p = (self_limbs[i] as u128) - * (other_limbs[j] as u128) + let p = (self_limbs[i] as u128) * (other_limbs[j] as u128) + (prod[idx] as u128) + carry; prod[idx] = p as u64; @@ -262,8 +266,7 @@ impl SignedBigIntHi32 { let mut carry: u128 = 0; for i in 0..N { - let sum = - (self.magnitude_lo[i] as u128) + (other.magnitude_lo[i] as u128) + carry; + let sum = (self.magnitude_lo[i] as u128) + (other.magnitude_lo[i] as u128) + carry; magnitude_lo[i] = sum as u64; carry = sum >> 64; } @@ -427,6 +430,42 @@ impl<'a, const N: usize> Mul for &'a SignedBigIntHi32 { } } +// ------------------------------------------------------------------------------------------------ +// Symmetric mul: S160 * I8OrI96 -> S224 (for ergonomics) +// ------------------------------------------------------------------------------------------------ + +impl core::ops::Mul for S160 { + type Output = S224; + #[inline] + fn mul(self, rhs: crate::biginteger::I8OrI96) -> Self::Output { + rhs * self + } +} + +impl core::ops::Mul<&crate::biginteger::I8OrI96> for S160 { + type Output = S224; + #[inline] + fn mul(self, rhs: &crate::biginteger::I8OrI96) -> Self::Output { + (*rhs) * self + } +} + +impl core::ops::Mul for &S160 { + type Output = S224; + #[inline] + fn mul(self, rhs: crate::biginteger::I8OrI96) -> Self::Output { + rhs * *self + } +} + +impl core::ops::Mul<&crate::biginteger::I8OrI96> for &S160 { + type Output = S224; + #[inline] + fn mul(self, rhs: &crate::biginteger::I8OrI96) -> Self::Output { + (*rhs) * *self + } +} + // ------------------------------------------------------------------------------------------------ // From traits // ------------------------------------------------------------------------------------------------ @@ -460,3 +499,12 @@ impl From for S160 { Self::new([lo, hi], 0, true) } } + +impl From for crate::biginteger::BigInt<4> { + #[inline] + fn from(val: S224) -> Self { + let lo = val.magnitude_lo(); + let hi = val.magnitude_hi() as u64; + crate::biginteger::BigInt::<4>([lo[0], lo[1], lo[2], hi]) + } +} diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 550f6de0c..29cfbf905 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -1,1025 +1,1050 @@ #[cfg(test)] pub mod tests { -use crate::{biginteger::{BigInteger, SignedBigInt}, UniformRand}; -use num_bigint::BigUint; - -// Test elementary math operations for BigInteger. -fn biginteger_arithmetic_test(a: B, b: B, zero: B, max: B) { - // zero == zero - assert_eq!(zero, zero); - - // zero.is_zero() == true - assert_eq!(zero.is_zero(), true); - - // a == a - assert_eq!(a, a); - - // a + 0 = a - let mut a0_add = a; - let carry = a0_add.add_with_carry(&zero); - assert_eq!(a0_add, a); - assert_eq!(carry, false); - - // a - 0 = a - let mut a0_sub = a; - let borrow = a0_sub.sub_with_borrow(&zero); - assert_eq!(a0_sub, a); - assert_eq!(borrow, false); - - // a - a = 0 - let mut aa_sub = a; - let borrow = aa_sub.sub_with_borrow(&a); - assert_eq!(aa_sub, zero); - assert_eq!(borrow, false); - - // a + b = b + a - let mut ab_add = a; - let ab_carry = ab_add.add_with_carry(&b); - let mut ba_add = b; - let ba_carry = ba_add.add_with_carry(&a); - assert_eq!(ab_add, ba_add); - assert_eq!(ab_carry, ba_carry); - - // a * 1 = a - let mut a_mul1 = a; - a_mul1 <<= 0; - assert_eq!(a_mul1, a); - - // a * 2 = a + a - let mut a_mul2 = a; - a_mul2.mul2(); - let mut a_plus_a = a; - let carry_a_plus_a = a_plus_a.add_with_carry(&a); // Won't assert anything about carry bit. - assert_eq!(a_mul2, a_plus_a); - - // a * 1 = a - assert_eq!(a.mul_low(&B::from(1u64)), a); - - // a * 2 = a - assert_eq!(a.mul_low(&B::from(2u64)), a_plus_a); - - // a * b = b * a - assert_eq!(a.mul_low(&b), b.mul_low(&a)); - - // a * 2 * b * 0 = 0 - assert!(a.mul_low(&zero).is_zero()); - - // a * 2 * ... * 2 = a * 2^n - let mut a_mul_n = a; - for _ in 0..20 { - a_mul_n = a_mul_n.mul_low(&B::from(2u64)); + use crate::{ + biginteger::{BigInteger, SignedBigInt}, + UniformRand, + }; + use num_bigint::BigUint; + + // Test elementary math operations for BigInteger. + fn biginteger_arithmetic_test(a: B, b: B, zero: B, max: B) { + // zero == zero + assert_eq!(zero, zero); + + // zero.is_zero() == true + assert_eq!(zero.is_zero(), true); + + // a == a + assert_eq!(a, a); + + // a + 0 = a + let mut a0_add = a; + let carry = a0_add.add_with_carry(&zero); + assert_eq!(a0_add, a); + assert_eq!(carry, false); + + // a - 0 = a + let mut a0_sub = a; + let borrow = a0_sub.sub_with_borrow(&zero); + assert_eq!(a0_sub, a); + assert_eq!(borrow, false); + + // a - a = 0 + let mut aa_sub = a; + let borrow = aa_sub.sub_with_borrow(&a); + assert_eq!(aa_sub, zero); + assert_eq!(borrow, false); + + // a + b = b + a + let mut ab_add = a; + let ab_carry = ab_add.add_with_carry(&b); + let mut ba_add = b; + let ba_carry = ba_add.add_with_carry(&a); + assert_eq!(ab_add, ba_add); + assert_eq!(ab_carry, ba_carry); + + // a * 1 = a + let mut a_mul1 = a; + a_mul1 <<= 0; + assert_eq!(a_mul1, a); + + // a * 2 = a + a + let mut a_mul2 = a; + a_mul2.mul2(); + let mut a_plus_a = a; + let carry_a_plus_a = a_plus_a.add_with_carry(&a); // Won't assert anything about carry bit. + assert_eq!(a_mul2, a_plus_a); + + // a * 1 = a + assert_eq!(a.mul_low(&B::from(1u64)), a); + + // a * 2 = a + assert_eq!(a.mul_low(&B::from(2u64)), a_plus_a); + + // a * b = b * a + assert_eq!(a.mul_low(&b), b.mul_low(&a)); + + // a * 2 * b * 0 = 0 + assert!(a.mul_low(&zero).is_zero()); + + // a * 2 * ... * 2 = a * 2^n + let mut a_mul_n = a; + for _ in 0..20 { + a_mul_n = a_mul_n.mul_low(&B::from(2u64)); + } + assert_eq!(a_mul_n, a << 20); + + // a * 0 = (0, 0) + assert_eq!(a.mul(&zero), (zero, zero)); + + // a * 1 = (a, 0) + assert_eq!(a.mul(&B::from(1u64)), (a, zero)); + + // a * 1 = 0 (high part of the result) + assert_eq!(a.mul_high(&B::from(1u64)), (zero)); + + // a * 0 = 0 (high part of the result) + assert!(a.mul_high(&zero).is_zero()); + + // If a + a has a carry + if carry_a_plus_a { + // a + a has a carry: high part of a * 2 is not zero + assert_ne!(a.mul_high(&B::from(2u64)), zero); + } else { + // a + a has no carry: high part of a * 2 is zero + assert_eq!(a.mul_high(&B::from(2u64)), zero); + } + + // max + max = max * 2 + let mut max_plus_max = max; + max_plus_max.add_with_carry(&max); + assert_eq!(max.mul(&B::from(2u64)), (max_plus_max, B::from(1u64))); + assert_eq!(max.mul_high(&B::from(2u64)), B::from(1u64)); } - assert_eq!(a_mul_n, a << 20); - // a * 0 = (0, 0) - assert_eq!(a.mul(&zero), (zero, zero)); - - // a * 1 = (a, 0) - assert_eq!(a.mul(&B::from(1u64)), (a, zero)); - - // a * 1 = 0 (high part of the result) - assert_eq!(a.mul_high(&B::from(1u64)), (zero)); + fn biginteger_shr() { + let mut rng = ark_std::test_rng(); + let a = B::rand(&mut rng); + assert_eq!(a >> 0, a); + + // Binary simple test + let a = B::from(256u64); + assert_eq!(a >> 2, B::from(64u64)); + + // Test saturated underflow + let a = B::from(1u64); + assert_eq!(a >> 5, B::from(0u64)); + + // Test null bits + let a = B::rand(&mut rng); + let b = a >> 3; + assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 1), false); + assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 2), false); + assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 3), false); + } - // a * 0 = 0 (high part of the result) - assert!(a.mul_high(&zero).is_zero()); + fn biginteger_shl() { + let mut rng = ark_std::test_rng(); + let a = B::rand(&mut rng); + assert_eq!(a << 0, a); + + // Binary simple test + let a = B::from(64u64); + assert_eq!(a << 2, B::from(256u64)); + + // Testing saturated overflow + let a = B::rand(&mut rng); + assert_eq!(a << ((B::NUM_LIMBS as u32) * 64), B::from(0u64)); + + // Test null bits + let a = B::rand(&mut rng); + let b = a << 3; + assert_eq!(b.get_bit(0), false); + assert_eq!(b.get_bit(1), false); + assert_eq!(b.get_bit(2), false); + } - // If a + a has a carry - if carry_a_plus_a { - // a + a has a carry: high part of a * 2 is not zero - assert_ne!(a.mul_high(&B::from(2u64)), zero); - } else { - // a + a has no carry: high part of a * 2 is zero - assert_eq!(a.mul_high(&B::from(2u64)), zero); + // Test for BigInt's bitwise operations + fn biginteger_bitwise_ops_test() { + let mut rng = ark_std::test_rng(); + + // Test XOR + // a xor a = 0 + let a = B::rand(&mut rng); + assert_eq!(a ^ &a, B::from(0_u64)); + + // Testing a xor b xor b + let a = B::rand(&mut rng); + let b = B::rand(&mut rng); + let xor_ab = a ^ b; + assert_eq!(xor_ab ^ b, a); + + // Test OR + // a or a = a + let a = B::rand(&mut rng); + assert_eq!(a | &a, a); + + // Testing a or b or b + let a = B::rand(&mut rng); + let b = B::rand(&mut rng); + let or_ab = a | b; + assert_eq!(or_ab | &b, a | b); + + // Test AND + // a and a = a + let a = B::rand(&mut rng); + assert_eq!(a & (&a), a); + + // Testing a and a and b. + let a = B::rand(&mut rng); + let b = B::rand(&mut rng); + let b_clone = b.clone(); + let and_ab = a & b; + assert_eq!(and_ab & b_clone, a & b); + + // Testing De Morgan's law + let a = 0x1234567890abcdef_u64; + let b = 0xfedcba0987654321_u64; + let de_morgan_lhs = B::from(!(a | b)); + let de_morgan_rhs = B::from(!a) & B::from(!b); + assert_eq!(de_morgan_lhs, de_morgan_rhs); } - // max + max = max * 2 - let mut max_plus_max = max; - max_plus_max.add_with_carry(&max); - assert_eq!(max.mul(&B::from(2u64)), (max_plus_max, B::from(1u64))); - assert_eq!(max.mul_high(&B::from(2u64)), B::from(1u64)); -} + // Test correctness of BigInteger's bit values + fn biginteger_bits_test() { + let mut one = B::from(1u64); + // 0th bit of BigInteger representing 1 is 1 + assert!(one.get_bit(0)); + // 1st bit of BigInteger representing 1 is not 1 + assert!(!one.get_bit(1)); + one <<= 5; + let thirty_two = one; + // 0th bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(0)); + // 1st bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(1)); + // 2nd bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(2)); + // 3rd bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(3)); + // 4th bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(4)); + // 5th bit of BigInteger representing 32 is 1 + assert!(thirty_two.get_bit(5), "{:?}", thirty_two); + + // Generates a random BigInteger and tests bit construction methods. + let mut rng = ark_std::test_rng(); + let a: B = UniformRand::rand(&mut rng); + assert_eq!(B::from_bits_be(&a.to_bits_be()), a); + assert_eq!(B::from_bits_le(&a.to_bits_le()), a); + } -fn biginteger_shr() { - let mut rng = ark_std::test_rng(); - let a = B::rand(&mut rng); - assert_eq!(a >> 0, a); - - // Binary simple test - let a = B::from(256u64); - assert_eq!(a >> 2, B::from(64u64)); - - // Test saturated underflow - let a = B::from(1u64); - assert_eq!(a >> 5, B::from(0u64)); - - // Test null bits - let a = B::rand(&mut rng); - let b = a >> 3; - assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 1), false); - assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 2), false); - assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 3), false); -} + // Test conversion from BigInteger to BigUint + fn biginteger_conversion_test() { + let mut rng = ark_std::test_rng(); -fn biginteger_shl() { - let mut rng = ark_std::test_rng(); - let a = B::rand(&mut rng); - assert_eq!(a << 0, a); - - // Binary simple test - let a = B::from(64u64); - assert_eq!(a << 2, B::from(256u64)); - - // Testing saturated overflow - let a = B::rand(&mut rng); - assert_eq!(a << ((B::NUM_LIMBS as u32) * 64), B::from(0u64)); - - // Test null bits - let a = B::rand(&mut rng); - let b = a << 3; - assert_eq!(b.get_bit(0), false); - assert_eq!(b.get_bit(1), false); - assert_eq!(b.get_bit(2), false); -} + let x: B = UniformRand::rand(&mut rng); + let x_bigint: BigUint = x.into(); + let x_recovered = B::try_from(x_bigint).ok().unwrap(); -// Test for BigInt's bitwise operations -fn biginteger_bitwise_ops_test() { - let mut rng = ark_std::test_rng(); - - // Test XOR - // a xor a = 0 - let a = B::rand(&mut rng); - assert_eq!(a ^ &a, B::from(0_u64)); - - // Testing a xor b xor b - let a = B::rand(&mut rng); - let b = B::rand(&mut rng); - let xor_ab = a ^ b; - assert_eq!(xor_ab ^ b, a); - - // Test OR - // a or a = a - let a = B::rand(&mut rng); - assert_eq!(a | &a, a); - - // Testing a or b or b - let a = B::rand(&mut rng); - let b = B::rand(&mut rng); - let or_ab = a | b; - assert_eq!(or_ab | &b, a | b); - - // Test AND - // a and a = a - let a = B::rand(&mut rng); - assert_eq!(a & (&a), a); - - // Testing a and a and b. - let a = B::rand(&mut rng); - let b = B::rand(&mut rng); - let b_clone = b.clone(); - let and_ab = a & b; - assert_eq!(and_ab & b_clone, a & b); - - // Testing De Morgan's law - let a = 0x1234567890abcdef_u64; - let b = 0xfedcba0987654321_u64; - let de_morgan_lhs = B::from(!(a | b)); - let de_morgan_rhs = B::from(!a) & B::from(!b); - assert_eq!(de_morgan_lhs, de_morgan_rhs); -} + assert_eq!(x, x_recovered); + } -// Test correctness of BigInteger's bit values -fn biginteger_bits_test() { - let mut one = B::from(1u64); - // 0th bit of BigInteger representing 1 is 1 - assert!(one.get_bit(0)); - // 1st bit of BigInteger representing 1 is not 1 - assert!(!one.get_bit(1)); - one <<= 5; - let thirty_two = one; - // 0th bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(0)); - // 1st bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(1)); - // 2nd bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(2)); - // 3rd bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(3)); - // 4th bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(4)); - // 5th bit of BigInteger representing 32 is 1 - assert!(thirty_two.get_bit(5), "{:?}", thirty_two); - - // Generates a random BigInteger and tests bit construction methods. - let mut rng = ark_std::test_rng(); - let a: B = UniformRand::rand(&mut rng); - assert_eq!(B::from_bits_be(&a.to_bits_be()), a); - assert_eq!(B::from_bits_le(&a.to_bits_le()), a); -} + // Wrapper test function for BigInteger + fn test_biginteger(max: B, zero: B) { + let mut rng = ark_std::test_rng(); + let a: B = UniformRand::rand(&mut rng); + let b: B = UniformRand::rand(&mut rng); + biginteger_arithmetic_test(a, b, zero, max); + biginteger_bits_test::(); + biginteger_conversion_test::(); + biginteger_bitwise_ops_test::(); + biginteger_shr::(); + biginteger_shl::(); + } -// Test conversion from BigInteger to BigUint -fn biginteger_conversion_test() { - let mut rng = ark_std::test_rng(); + #[test] + fn test_biginteger64() { + use crate::biginteger::BigInteger64 as B; + test_biginteger(B::new([u64::MAX; 1]), B::new([0u64; 1])); + } - let x: B = UniformRand::rand(&mut rng); - let x_bigint: BigUint = x.into(); - let x_recovered = B::try_from(x_bigint).ok().unwrap(); + #[test] + fn test_biginteger128() { + use crate::biginteger::BigInteger128 as B; + test_biginteger(B::new([u64::MAX; 2]), B::new([0u64; 2])); + } - assert_eq!(x, x_recovered); -} + #[test] + fn test_biginteger256() { + use crate::biginteger::BigInteger256 as B; + test_biginteger(B::new([u64::MAX; 4]), B::new([0u64; 4])); + } -// Wrapper test function for BigInteger -fn test_biginteger(max: B, zero: B) { - let mut rng = ark_std::test_rng(); - let a: B = UniformRand::rand(&mut rng); - let b: B = UniformRand::rand(&mut rng); - biginteger_arithmetic_test(a, b, zero, max); - biginteger_bits_test::(); - biginteger_conversion_test::(); - biginteger_bitwise_ops_test::(); - biginteger_shr::(); - biginteger_shl::(); -} + #[test] + fn test_biginteger384() { + use crate::biginteger::BigInteger384 as B; + test_biginteger(B::new([u64::MAX; 6]), B::new([0u64; 6])); + } -#[test] -fn test_biginteger64() { - use crate::biginteger::BigInteger64 as B; - test_biginteger(B::new([u64::MAX; 1]), B::new([0u64; 1])); -} + #[test] + fn test_biginteger448() { + use crate::biginteger::BigInteger448 as B; + test_biginteger(B::new([u64::MAX; 7]), B::new([0u64; 7])); + } -#[test] -fn test_biginteger128() { - use crate::biginteger::BigInteger128 as B; - test_biginteger(B::new([u64::MAX; 2]), B::new([0u64; 2])); -} + #[test] + fn test_biginteger768() { + use crate::biginteger::BigInteger768 as B; + test_biginteger(B::new([u64::MAX; 12]), B::new([0u64; 12])); + } -#[test] -fn test_biginteger256() { - use crate::biginteger::BigInteger256 as B; - test_biginteger(B::new([u64::MAX; 4]), B::new([0u64; 4])); -} + #[test] + fn test_biginteger832() { + use crate::biginteger::BigInteger832 as B; + test_biginteger(B::new([u64::MAX; 13]), B::new([0u64; 13])); + } -#[test] -fn test_biginteger384() { - use crate::biginteger::BigInteger384 as B; - test_biginteger(B::new([u64::MAX; 6]), B::new([0u64; 6])); -} + // Tests for NEW functions + use crate::biginteger::BigInteger256; + + #[test] + fn test_mul_u64_in_place() { + let mut a = BigInteger256::from(0x123456789ABCDEFu64); + let b = 0x987654321u64; + + // Test against reference implementation + let expected = BigUint::from(0x123456789ABCDEFu64) * BigUint::from(b); + a.mul_u64_in_place(b); + assert_eq!(BigUint::from(a), expected); + + // Test zero multiplication + let mut zero = BigInteger256::zero(); + zero.mul_u64_in_place(12345); + assert!(zero.is_zero()); + + // Test multiplication by zero + let mut a = BigInteger256::from(12345u64); + a.mul_u64_in_place(0); + assert!(a.is_zero()); + + // Test multiplication by one + let orig = BigInteger256::from(0xDEADBEEFu64); + let mut a = orig; + a.mul_u64_in_place(1); + assert_eq!(a, orig); + } -#[test] -fn test_biginteger448() { - use crate::biginteger::BigInteger448 as B; - test_biginteger(B::new([u64::MAX; 7]), B::new([0u64; 7])); -} + #[test] + fn test_mul_u64_w_carry() { + let a = BigInteger256::from(u64::MAX); + let b = u64::MAX; + + // Test against reference implementation + let expected = BigUint::from(u64::MAX) * BigUint::from(u64::MAX); + let result = a.mul_u64_w_carry::<5>(b); + assert_eq!(BigUint::from(result), expected); + + // Test with small numbers + let a = BigInteger256::from(12345u64); + let b = 67890u64; + let expected = BigUint::from(12345u64) * BigUint::from(67890u64); + let result = a.mul_u64_w_carry::<5>(b); + assert_eq!(BigUint::from(result), expected); + + // Test zero cases + let zero = BigInteger256::zero(); + let result = zero.mul_u64_w_carry::<5>(12345); + assert!(result.is_zero()); + + let a = BigInteger256::from(12345u64); + let result = a.mul_u64_w_carry::<5>(0); + assert!(result.is_zero()); + + // Test multiplication by one + let a = BigInteger256::from(0xDEADBEEFu64); + let result = a.mul_u64_w_carry::<5>(1); + let expected_bytes = a.to_bytes_le(); + let result_bytes = result.to_bytes_le(); + assert_eq!(&result_bytes[..expected_bytes.len()], &expected_bytes[..]); + } -#[test] -fn test_biginteger768() { - use crate::biginteger::BigInteger768 as B; - test_biginteger(B::new([u64::MAX; 12]), B::new([0u64; 12])); -} + #[test] + fn test_fmu64a() { + let a = BigInteger256::from(12345u64); + let b = 67890u64; + let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); + + // Perform fused multiply-accumulate + a.fmu64a(b, &mut acc); + + // Compare against separate multiply and add + let expected_mul = BigUint::from(12345u64) * BigUint::from(67890u64); + let expected_total = expected_mul + BigUint::from(11111u64); + assert_eq!(BigUint::from(acc), expected_total); + + // Test zero cases + let zero = BigInteger256::zero(); + let mut acc = BigInteger256::from(12345u64).mul_u64_w_carry::<5>(1); + let acc_copy = acc; + zero.fmu64a(67890, &mut acc); + assert_eq!(acc, acc_copy); // Should be unchanged + + // Test multiplication by zero + let a = BigInteger256::from(12345u64); + let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); + let acc_copy = acc; + a.fmu64a(0, &mut acc); + assert_eq!(acc, acc_copy); // Should be unchanged + + // Test multiplication by one (should be just addition) + let a = BigInteger256::from(12345u64); + let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); + a.fmu64a(1, &mut acc); + let expected = BigUint::from(12345u64) + BigUint::from(11111u64); + assert_eq!(BigUint::from(acc), expected); + } -#[test] -fn test_biginteger832() { - use crate::biginteger::BigInteger832 as B; - test_biginteger(B::new([u64::MAX; 13]), B::new([0u64; 13])); -} + #[test] + fn test_mul_u128_w_carry() { + let a = BigInteger256::from(0x123456789ABCDEFu64); + let b = 0x987654321DEADBEEFu128; + + // Test against reference implementation + let expected = BigUint::from(0x123456789ABCDEFu64) * BigUint::from(0x987654321DEADBEEFu128); + let result = a.mul_u128_w_carry::<5, 6>(b); + assert_eq!(BigUint::from(result), expected); + + // Test with u64 value (should be same as mul_u64_w_carry) + let b_u64 = 0x987654321u64; + let result_u128 = a.mul_u128_w_carry::<5, 6>(b_u64 as u128); + let result_u64 = a.mul_u64_w_carry::<5>(b_u64); + + // Compare first 5 limbs (u64 result size) + for i in 0..5 { + assert_eq!(result_u128.0[i], result_u64.0[i]); + } + assert_eq!(result_u128.0[5], 0); // Extra limb should be zero + + // Test zero cases + let zero = BigInteger256::zero(); + let result = zero.mul_u128_w_carry::<5, 6>(12345); + assert!(result.is_zero()); + + let a = BigInteger256::from(12345u64); + let result = a.mul_u128_w_carry::<5, 6>(0); + assert!(result.is_zero()); + + // Test multiplication by one + let a = BigInteger256::from(0xDEADBEEFu64); + let result = a.mul_u128_w_carry::<5, 6>(1); + let expected_bytes = a.to_bytes_le(); + let result_bytes = result.to_bytes_le(); + assert_eq!(&result_bytes[..expected_bytes.len()], &expected_bytes[..]); + } -// Tests for NEW functions -use crate::biginteger::BigInteger256; - -#[test] -fn test_mul_u64_in_place() { - let mut a = BigInteger256::from(0x123456789ABCDEFu64); - let b = 0x987654321u64; - - // Test against reference implementation - let expected = BigUint::from(0x123456789ABCDEFu64) * BigUint::from(b); - a.mul_u64_in_place(b); - assert_eq!(BigUint::from(a), expected); - - // Test zero multiplication - let mut zero = BigInteger256::zero(); - zero.mul_u64_in_place(12345); - assert!(zero.is_zero()); - - // Test multiplication by zero - let mut a = BigInteger256::from(12345u64); - a.mul_u64_in_place(0); - assert!(a.is_zero()); - - // Test multiplication by one - let orig = BigInteger256::from(0xDEADBEEFu64); - let mut a = orig; - a.mul_u64_in_place(1); - assert_eq!(a, orig); -} + #[test] + fn test_fm128a_basic_and_edges() { + use crate::biginteger::BigInteger256 as B; + // Basic reference check against BigUint + let a = B::from(0x123456789ABCDEFu64); + let b = 0x987654321DEADBEEFu128; + let mut acc = B::zero().mul_u128_w_carry::<5, 6>(1); // zero-extended accumulator (6 limbs) + a.fm128a::<6>(b, &mut acc); + let expected = num_bigint::BigUint::from(0x123456789ABCDEFu64) + * num_bigint::BigUint::from(0x987654321DEADBEEFu128); + assert_eq!(num_bigint::BigUint::from(acc), expected); + + // Zero multiplier: no change + let a = B::from(12345u64); + let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); + let acc_copy = acc; + a.fm128a::<6>(0, &mut acc); + assert_eq!(acc, acc_copy); + + // One multiplier: reduces to addition + let a = B::from(12345u64); + let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); + a.fm128a::<6>(1, &mut acc); + let expected = num_bigint::BigUint::from(12345u64) + num_bigint::BigUint::from(11111u64); + assert_eq!(num_bigint::BigUint::from(acc), expected); + + // Overflow propagation from limb N into highest limb + let a = B::new([u64::MAX; 4]); + let mut acc = B::zero().mul_u128_w_carry::<5, 6>(1); + // Pre-fill limb N to force overflow when adding the final carry from low pass + acc.0[4] = u64::MAX; // limb N + acc.0[5] = 0; // highest limb + // cause carry=1 from low pass (a * 2) + a.fm128a::<6>(2, &mut acc); + // Expect highest limb incremented by 1 due to overflow from limb N + assert_eq!(acc.0[5], 1); + } -#[test] -fn test_mul_u64_w_carry() { - let a = BigInteger256::from(u64::MAX); - let b = u64::MAX; - - // Test against reference implementation - let expected = BigUint::from(u64::MAX) * BigUint::from(u64::MAX); - let result = a.mul_u64_w_carry::<5>(b); - assert_eq!(BigUint::from(result), expected); - - // Test with small numbers - let a = BigInteger256::from(12345u64); - let b = 67890u64; - let expected = BigUint::from(12345u64) * BigUint::from(67890u64); - let result = a.mul_u64_w_carry::<5>(b); - assert_eq!(BigUint::from(result), expected); - - // Test zero cases - let zero = BigInteger256::zero(); - let result = zero.mul_u64_w_carry::<5>(12345); - assert!(result.is_zero()); - - let a = BigInteger256::from(12345u64); - let result = a.mul_u64_w_carry::<5>(0); - assert!(result.is_zero()); - - // Test multiplication by one - let a = BigInteger256::from(0xDEADBEEFu64); - let result = a.mul_u64_w_carry::<5>(1); - let expected_bytes = a.to_bytes_le(); - let result_bytes = result.to_bytes_le(); - assert_eq!(&result_bytes[..expected_bytes.len()], &expected_bytes[..]); -} + #[test] + fn test_overflow_behavior_fmu64a() { + // Test that overflow in the highest limb wraps around as documented + let a = BigInteger256::new([u64::MAX; 4]); + let mut acc = BigInteger256::new([0, 0, 0, 0]).mul_u64_w_carry::<5>(1); + acc.0[4] = u64::MAX; // Set highest limb to max -#[test] -fn test_fmu64a() { - let a = BigInteger256::from(12345u64); - let b = 67890u64; - let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); - - // Perform fused multiply-accumulate - a.fmu64a(b, &mut acc); - - // Compare against separate multiply and add - let expected_mul = BigUint::from(12345u64) * BigUint::from(67890u64); - let expected_total = expected_mul + BigUint::from(11111u64); - assert_eq!(BigUint::from(acc), expected_total); - - // Test zero cases - let zero = BigInteger256::zero(); - let mut acc = BigInteger256::from(12345u64).mul_u64_w_carry::<5>(1); - let acc_copy = acc; - zero.fmu64a(67890, &mut acc); - assert_eq!(acc, acc_copy); // Should be unchanged - - // Test multiplication by zero - let a = BigInteger256::from(12345u64); - let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); - let acc_copy = acc; - a.fmu64a(0, &mut acc); - assert_eq!(acc, acc_copy); // Should be unchanged - - // Test multiplication by one (should be just addition) - let a = BigInteger256::from(12345u64); - let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); - a.fmu64a(1, &mut acc); - let expected = BigUint::from(12345u64) + BigUint::from(11111u64); - assert_eq!(BigUint::from(acc), expected); -} + // This should cause overflow in the highest limb + a.fmu64a(2, &mut acc); -#[test] -fn test_mul_u128_w_carry() { - let a = BigInteger256::from(0x123456789ABCDEFu64); - let b = 0x987654321DEADBEEFu128; - - // Test against reference implementation - let expected = BigUint::from(0x123456789ABCDEFu64) * BigUint::from(0x987654321DEADBEEFu128); - let result = a.mul_u128_w_carry::<5, 6>(b); - assert_eq!(BigUint::from(result), expected); - - // Test with u64 value (should be same as mul_u64_w_carry) - let b_u64 = 0x987654321u64; - let result_u128 = a.mul_u128_w_carry::<5, 6>(b_u64 as u128); - let result_u64 = a.mul_u64_w_carry::<5>(b_u64); - - // Compare first 5 limbs (u64 result size) - for i in 0..5 { - assert_eq!(result_u128.0[i], result_u64.0[i]); + // The overflow should wrap around + // u64::MAX * 2 = 2^65 - 2, which when added to u64::MAX = 2^65 + u64::MAX - 2 + // This wraps to u64::MAX - 2 with a carry of 1 that itself wraps + assert_eq!(acc.0[4], u64::MAX.wrapping_add(1)); // Wrapped result } - assert_eq!(result_u128.0[5], 0); // Extra limb should be zero - - // Test zero cases - let zero = BigInteger256::zero(); - let result = zero.mul_u128_w_carry::<5, 6>(12345); - assert!(result.is_zero()); - - let a = BigInteger256::from(12345u64); - let result = a.mul_u128_w_carry::<5, 6>(0); - assert!(result.is_zero()); - - // Test multiplication by one - let a = BigInteger256::from(0xDEADBEEFu64); - let result = a.mul_u128_w_carry::<5, 6>(1); - let expected_bytes = a.to_bytes_le(); - let result_bytes = result.to_bytes_le(); - assert_eq!(&result_bytes[..expected_bytes.len()], &expected_bytes[..]); -} -#[test] -fn test_fm128a_basic_and_edges() { - use crate::biginteger::BigInteger256 as B; - // Basic reference check against BigUint - let a = B::from(0x123456789ABCDEFu64); - let b = 0x987654321DEADBEEFu128; - let mut acc = B::zero().mul_u128_w_carry::<5, 6>(1); // zero-extended accumulator (6 limbs) - a.fm128a::<6>(b, &mut acc); - let expected = num_bigint::BigUint::from(0x123456789ABCDEFu64) - * num_bigint::BigUint::from(0x987654321DEADBEEFu128); - assert_eq!(num_bigint::BigUint::from(acc), expected); - - // Zero multiplier: no change - let a = B::from(12345u64); - let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); - let acc_copy = acc; - a.fm128a::<6>(0, &mut acc); - assert_eq!(acc, acc_copy); - - // One multiplier: reduces to addition - let a = B::from(12345u64); - let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); - a.fm128a::<6>(1, &mut acc); - let expected = num_bigint::BigUint::from(12345u64) + num_bigint::BigUint::from(11111u64); - assert_eq!(num_bigint::BigUint::from(acc), expected); - - // Overflow propagation from limb N into highest limb - let a = B::new([u64::MAX; 4]); - let mut acc = B::zero().mul_u128_w_carry::<5, 6>(1); - // Pre-fill limb N to force overflow when adding the final carry from low pass - acc.0[4] = u64::MAX; // limb N - acc.0[5] = 0; // highest limb - // cause carry=1 from low pass (a * 2) - a.fm128a::<6>(2, &mut acc); - // Expect highest limb incremented by 1 due to overflow from limb N - assert_eq!(acc.0[5], 1); -} + #[test] + fn test_edge_cases_large_numbers() { + // Test with maximum values + let max_bi = BigInteger256::new([u64::MAX; 4]); -#[test] -fn test_overflow_behavior_fmu64a() { - // Test that overflow in the highest limb wraps around as documented - let a = BigInteger256::new([u64::MAX; 4]); - let mut acc = BigInteger256::new([0, 0, 0, 0]).mul_u64_w_carry::<5>(1); - acc.0[4] = u64::MAX; // Set highest limb to max - - // This should cause overflow in the highest limb - a.fmu64a(2, &mut acc); - - // The overflow should wrap around - // u64::MAX * 2 = 2^65 - 2, which when added to u64::MAX = 2^65 + u64::MAX - 2 - // This wraps to u64::MAX - 2 with a carry of 1 that itself wraps - assert_eq!(acc.0[4], u64::MAX.wrapping_add(1)); // Wrapped result -} + // mul_u64_w_carry with max values + let result = max_bi.mul_u64_w_carry::<5>(u64::MAX); + let expected = BigUint::from(max_bi) * BigUint::from(u64::MAX); + assert_eq!(BigUint::from(result), expected); -#[test] -fn test_edge_cases_large_numbers() { - // Test with maximum values - let max_bi = BigInteger256::new([u64::MAX; 4]); - - // mul_u64_w_carry with max values - let result = max_bi.mul_u64_w_carry::<5>(u64::MAX); - let expected = BigUint::from(max_bi) * BigUint::from(u64::MAX); - assert_eq!(BigUint::from(result), expected); - - // mul_u128_w_carry with max values - let result = max_bi.mul_u128_w_carry::<5, 6>(u128::MAX); - let expected = BigUint::from(max_bi) * BigUint::from(u128::MAX); - assert_eq!(BigUint::from(result), expected); -} + // mul_u128_w_carry with max values + let result = max_bi.mul_u128_w_carry::<5, 6>(u128::MAX); + let expected = BigUint::from(max_bi) * BigUint::from(u128::MAX); + assert_eq!(BigUint::from(result), expected); + } -#[test] -fn test_fmu64a_into_nplus4_correctness_and_edges() { - use crate::biginteger::{BigInt, BigInteger256 as B}; - let a = B::from(0xDEADBEEFCAFEBABEu64); - let other = 0xFEDCBA9876543210u64; - let mut acc = BigInt::<8>::zero(); // N+4 accumulator for N=4 - - // Reference: (a * other + acc_before) mod 2^(64*(N+4)) - let before = BigUint::from(acc.clone()); - a.fmu64a_into_nplus4::<8>(other, &mut acc); - let mut expected = BigUint::from(a); - expected *= BigUint::from(other); - expected += before; - let modulus = BigUint::from(1u8) << (64 * 8); - expected %= &modulus; - assert_eq!(BigUint::from(acc.clone()), expected); - - // Zero multiplier is no-op - let mut acc2 = acc.clone(); - a.fmu64a_into_nplus4::<8>(0, &mut acc2); - assert_eq!(acc2, acc); - - // One multiplier reduces to addition - let mut acc3 = BigInt::<8>::zero(); - acc3.0[0] = 11111; - let before3 = BigUint::from(acc3.clone()); - a.fmu64a_into_nplus4::<8>(1, &mut acc3); - let mut expected3 = BigUint::from(a); - expected3 += before3; - expected3 %= &modulus; - assert_eq!(BigUint::from(acc3), expected3); - - // Force cascading carry across N..=N+3 - let a = B::new([u64::MAX; 4]); - let mut acc4 = BigInt::<8>::zero(); - acc4.0[4] = u64::MAX; // limb N - acc4.0[5] = u64::MAX; // limb N+1 - acc4.0[6] = u64::MAX; // limb N+2 - acc4.0[7] = 0; // limb N+3 (top) - // Use multiplier 2 so the low pass produces a carry=1 - a.fmu64a_into_nplus4::<8>(2, &mut acc4); - assert_eq!(acc4.0[7], 1); -} + #[test] + fn test_fmu64a_into_nplus4_correctness_and_edges() { + use crate::biginteger::{BigInt, BigInteger256 as B}; + let a = B::from(0xDEADBEEFCAFEBABEu64); + let other = 0xFEDCBA9876543210u64; + let mut acc = BigInt::<8>::zero(); // N+4 accumulator for N=4 + + // Reference: (a * other + acc_before) mod 2^(64*(N+4)) + let before = BigUint::from(acc.clone()); + a.fmu64a_into_nplus4::<8>(other, &mut acc); + let mut expected = BigUint::from(a); + expected *= BigUint::from(other); + expected += before; + let modulus = BigUint::from(1u8) << (64 * 8); + expected %= &modulus; + assert_eq!(BigUint::from(acc.clone()), expected); + + // Zero multiplier is no-op + let mut acc2 = acc.clone(); + a.fmu64a_into_nplus4::<8>(0, &mut acc2); + assert_eq!(acc2, acc); + + // One multiplier reduces to addition + let mut acc3 = BigInt::<8>::zero(); + acc3.0[0] = 11111; + let before3 = BigUint::from(acc3.clone()); + a.fmu64a_into_nplus4::<8>(1, &mut acc3); + let mut expected3 = BigUint::from(a); + expected3 += before3; + expected3 %= &modulus; + assert_eq!(BigUint::from(acc3), expected3); + + // Force cascading carry across N..=N+3 + let a = B::new([u64::MAX; 4]); + let mut acc4 = BigInt::<8>::zero(); + acc4.0[4] = u64::MAX; // limb N + acc4.0[5] = u64::MAX; // limb N+1 + acc4.0[6] = u64::MAX; // limb N+2 + acc4.0[7] = 0; // limb N+3 (top) + // Use multiplier 2 so the low pass produces a carry=1 + a.fmu64a_into_nplus4::<8>(2, &mut acc4); + assert_eq!(acc4.0[7], 1); + } -#[test] -fn test_fm2x64a_into_nplus4_correctness() { - use crate::biginteger::{BigInt, BigInteger256 as B}; - let a = B::from(0x1234567890ABCDEFu64); - let other = [0x0FEDCBA987654321u64, 0x0011223344556677u64]; - let mut acc = BigInt::<8>::zero(); - - let before = BigUint::from(acc.clone()); - a.fm2x64a_into_nplus4::<8>(other, &mut acc); - - // Expected: a * (lo + (hi << 64)) + acc_before mod 2^(64*8) - let hi = BigUint::from(other[1]); - let lo = BigUint::from(other[0]); - let factor = (hi << 64) + lo; - let mut expected = BigUint::from(a); - expected *= factor; - expected += before; - let modulus = BigUint::from(1u8) << (64 * 8); - expected %= &modulus; - assert_eq!(BigUint::from(acc.clone()), expected); - - // Zero limbs are no-op - let mut acc2 = acc.clone(); - a.fm2x64a_into_nplus4::<8>([0, 0], &mut acc2); - assert_eq!(acc2, acc); -} + #[test] + fn test_fm2x64a_into_nplus4_correctness() { + use crate::biginteger::{BigInt, BigInteger256 as B}; + let a = B::from(0x1234567890ABCDEFu64); + let other = [0x0FEDCBA987654321u64, 0x0011223344556677u64]; + let mut acc = BigInt::<8>::zero(); + + let before = BigUint::from(acc.clone()); + a.fm2x64a_into_nplus4::<8>(other, &mut acc); + + // Expected: a * (lo + (hi << 64)) + acc_before mod 2^(64*8) + let hi = BigUint::from(other[1]); + let lo = BigUint::from(other[0]); + let factor = (hi << 64) + lo; + let mut expected = BigUint::from(a); + expected *= factor; + expected += before; + let modulus = BigUint::from(1u8) << (64 * 8); + expected %= &modulus; + assert_eq!(BigUint::from(acc.clone()), expected); + + // Zero limbs are no-op + let mut acc2 = acc.clone(); + a.fm2x64a_into_nplus4::<8>([0, 0], &mut acc2); + assert_eq!(acc2, acc); + } -#[test] -fn test_fm3x64a_into_nplus4_correctness() { - use crate::biginteger::{BigInt, BigInteger256 as B}; - let a = B::from(0x0F0E0D0C0B0A0908u64); - let other = [0x89ABCDEF01234567u64, 0x76543210FEDCBA98u64, 0x1122334455667788u64]; - let mut acc = BigInt::<8>::zero(); - - let before = BigUint::from(acc.clone()); - a.fm3x64a_into_nplus4::<8>(other, &mut acc); - - // Expected: a * (o0 + (o1<<64) + (o2<<128)) + acc_before mod 2^(64*8) - let term0 = BigUint::from(other[0]); - let term1 = BigUint::from(other[1]) << 64; - let term2 = BigUint::from(other[2]) << 128; - let factor = term0 + term1 + term2; - let mut expected = BigUint::from(a); - expected *= factor; - expected += before; - let modulus = BigUint::from(1u8) << (64 * 8); - expected %= &modulus; - assert_eq!(BigUint::from(acc.clone()), expected); - - // Edge: ensure offset accumulation lands in correct limbs - // Fill acc with a pattern, then accumulate using only the highest limb to ensure writes start at index 2 - let a = B::from(3u64); - let mut acc2 = BigInt::<8>::zero(); - acc2.0[0] = 5; - acc2.0[1] = 7; - let other2 = [0, 0, 2]; // Only offset by 2 limbs - let before2 = BigUint::from(acc2.clone()); - a.fm3x64a_into_nplus4::<8>(other2, &mut acc2); - let mut expected2 = BigUint::from(a); - expected2 *= BigUint::from(2u64) << 128; - expected2 += before2; - let modulus = BigUint::from(1u8) << (64 * 8); - expected2 %= &modulus; - assert_eq!(BigUint::from(acc2), expected2); -} + #[test] + fn test_fm3x64a_into_nplus4_correctness() { + use crate::biginteger::{BigInt, BigInteger256 as B}; + let a = B::from(0x0F0E0D0C0B0A0908u64); + let other = [ + 0x89ABCDEF01234567u64, + 0x76543210FEDCBA98u64, + 0x1122334455667788u64, + ]; + let mut acc = BigInt::<8>::zero(); + + let before = BigUint::from(acc.clone()); + a.fm3x64a_into_nplus4::<8>(other, &mut acc); + + // Expected: a * (o0 + (o1<<64) + (o2<<128)) + acc_before mod 2^(64*8) + let term0 = BigUint::from(other[0]); + let term1 = BigUint::from(other[1]) << 64; + let term2 = BigUint::from(other[2]) << 128; + let factor = term0 + term1 + term2; + let mut expected = BigUint::from(a); + expected *= factor; + expected += before; + let modulus = BigUint::from(1u8) << (64 * 8); + expected %= &modulus; + assert_eq!(BigUint::from(acc.clone()), expected); + + // Edge: ensure offset accumulation lands in correct limbs + // Fill acc with a pattern, then accumulate using only the highest limb to ensure writes start at index 2 + let a = B::from(3u64); + let mut acc2 = BigInt::<8>::zero(); + acc2.0[0] = 5; + acc2.0[1] = 7; + let other2 = [0, 0, 2]; // Only offset by 2 limbs + let before2 = BigUint::from(acc2.clone()); + a.fm3x64a_into_nplus4::<8>(other2, &mut acc2); + let mut expected2 = BigUint::from(a); + expected2 *= BigUint::from(2u64) << 128; + expected2 += before2; + let modulus = BigUint::from(1u8) << (64 * 8); + expected2 %= &modulus; + assert_eq!(BigUint::from(acc2), expected2); + } -// ============================== -// SignedBigInt tests -// ============================== - -#[test] -fn test_signed_construction() { - // zero and one - let z = SignedBigInt::<1>::zero(); - assert!(z.is_zero()); - assert!(z.is_positive); - let o = SignedBigInt::<1>::one(); - assert!(!o.is_zero()); - assert!(o.is_positive); - - // from_u64 - let p = SignedBigInt::<1>::from_u64(42); - assert_eq!(p.magnitude.0[0], 42); - assert!(p.is_positive); - let n = SignedBigInt::<1>::from((42u64, false)); - assert_eq!(n.magnitude.0[0], 42); - assert!(!n.is_positive); -} + // ============================== + // SignedBigInt tests + // ============================== + + #[test] + fn test_signed_construction() { + // zero and one + let z = SignedBigInt::<1>::zero(); + assert!(z.is_zero()); + assert!(z.is_positive); + let o = SignedBigInt::<1>::one(); + assert!(!o.is_zero()); + assert!(o.is_positive); + + // from_u64 + let p = SignedBigInt::<1>::from_u64(42); + assert_eq!(p.magnitude.0[0], 42); + assert!(p.is_positive); + let n = SignedBigInt::<1>::from((42u64, false)); + assert_eq!(n.magnitude.0[0], 42); + assert!(!n.is_positive); + } -#[test] -fn test_signed_add_sub_mul_neg() { - let a = SignedBigInt::<1>::from_u64(10); - let b = SignedBigInt::<1>::from_u64(5); - assert_eq!((a + b).magnitude.0[0], 15); - assert_eq!((a - b).magnitude.0[0], 5); - assert_eq!((a * b).magnitude.0[0], 50); - let neg = -a; - assert_eq!(neg.magnitude.0[0], 10); - assert!(!neg.is_positive); - - // opposite signs - let x = SignedBigInt::<1>::from_u64(30); - let y = SignedBigInt::<1>::from((20u64, false)); - let r = x + y; // 30 - 20 - assert!(r.is_positive); - assert_eq!(r.magnitude.0[0], 10); - - let x2 = SignedBigInt::<1>::from((20u64, false)); - let y2 = SignedBigInt::<1>::from_u64(30); - let r2 = x2 + y2; // -20 + 30 - assert!(r2.is_positive); - assert_eq!(r2.magnitude.0[0], 10); -} + #[test] + fn test_signed_add_sub_mul_neg() { + let a = SignedBigInt::<1>::from_u64(10); + let b = SignedBigInt::<1>::from_u64(5); + assert_eq!((a + b).magnitude.0[0], 15); + assert_eq!((a - b).magnitude.0[0], 5); + assert_eq!((a * b).magnitude.0[0], 50); + let neg = -a; + assert_eq!(neg.magnitude.0[0], 10); + assert!(!neg.is_positive); + + // opposite signs + let x = SignedBigInt::<1>::from_u64(30); + let y = SignedBigInt::<1>::from((20u64, false)); + let r = x + y; // 30 - 20 + assert!(r.is_positive); + assert_eq!(r.magnitude.0[0], 10); + + let x2 = SignedBigInt::<1>::from((20u64, false)); + let y2 = SignedBigInt::<1>::from_u64(30); + let r2 = x2 + y2; // -20 + 30 + assert!(r2.is_positive); + assert_eq!(r2.magnitude.0[0], 10); + } -#[test] -fn test_signed_to_i128_and_mag_helpers() { - let p = SignedBigInt::<1>::from_u64(100); - assert_eq!(p.to_i128(), 100); - let n = SignedBigInt::<1>::from((100u64, false)); - assert_eq!(n.to_i128(), -100); - - let d = SignedBigInt::<2>::from_u128(0x1234_5678_9abc_def0_1111_2222_3333_4444u128); - assert_eq!(d.magnitude.0[0], 0x1111_2222_3333_4444); - assert_eq!(d.magnitude.0[1], 0x1234_5678_9abc_def0); - // Positive below 2^127 should convert - let expected_i128 = 0x1234_5678_9abc_def0_1111_2222_3333_4444u128 as i128; - assert_eq!(d.to_i128(), Some(expected_i128)); - - // Positive at 2^127 should fail - let too_big_pos = SignedBigInt::<2>::from_u128(1u128 << 127); - assert_eq!(too_big_pos.to_i128(), None); - - let small = SignedBigInt::<2>::new([100, 0], true); - assert_eq!(small.to_i128(), Some(100)); - assert_eq!(small.magnitude_as_u128(), 100u128); -} + #[test] + fn test_signed_to_i128_and_mag_helpers() { + let p = SignedBigInt::<1>::from_u64(100); + assert_eq!(p.to_i128(), 100); + let n = SignedBigInt::<1>::from((100u64, false)); + assert_eq!(n.to_i128(), -100); + + let d = SignedBigInt::<2>::from_u128(0x1234_5678_9abc_def0_1111_2222_3333_4444u128); + assert_eq!(d.magnitude.0[0], 0x1111_2222_3333_4444); + assert_eq!(d.magnitude.0[1], 0x1234_5678_9abc_def0); + // Positive below 2^127 should convert + let expected_i128 = 0x1234_5678_9abc_def0_1111_2222_3333_4444u128 as i128; + assert_eq!(d.to_i128(), Some(expected_i128)); + + // Positive at 2^127 should fail + let too_big_pos = SignedBigInt::<2>::from_u128(1u128 << 127); + assert_eq!(too_big_pos.to_i128(), None); + + let small = SignedBigInt::<2>::new([100, 0], true); + assert_eq!(small.to_i128(), Some(100)); + assert_eq!(small.magnitude_as_u128(), 100u128); + } -#[test] -fn test_add_with_sign_u64_helper() { - let (mag, sign) = crate::biginteger::signed::add_with_sign_u64(10, true, 5, true); - assert_eq!(mag, 15); - assert!(sign); - let (mag2, sign2) = crate::biginteger::signed::add_with_sign_u64(10, true, 5, false); - assert_eq!(mag2, 5); - assert!(sign2); - let (mag3, sign3) = crate::biginteger::signed::add_with_sign_u64(5, true, 10, false); - assert_eq!(mag3, 5); - assert!(!sign3); -} + #[test] + fn test_add_with_sign_u64_helper() { + let (mag, sign) = crate::biginteger::signed::add_with_sign_u64(10, true, 5, true); + assert_eq!(mag, 15); + assert!(sign); + let (mag2, sign2) = crate::biginteger::signed::add_with_sign_u64(10, true, 5, false); + assert_eq!(mag2, 5); + assert!(sign2); + let (mag3, sign3) = crate::biginteger::signed::add_with_sign_u64(5, true, 10, false); + assert_eq!(mag3, 5); + assert!(!sign3); + } -#[test] -fn test_signed_truncated_add_sub() { - use crate::biginteger::SignedBigInt as S; - let a = S::<2>::from_u128(0x0000_0000_0000_0001_ffff_ffff_ffff_ffff); - let b = S::<2>::from_u128(0x0000_0000_0000_0001_0000_0000_0000_0001); - // Add and truncate to 1 limb - let r1 = a.add_trunc::<1>(&b); - // expected low limb wrap of the low words, ignoring carry to limb1 - let expected_low = (0xffff_ffff_ffff_ffffu64).wrapping_add(0x0000_0000_0000_0001u64); - assert_eq!(r1.magnitude.0[0], expected_low); - assert!(r1.is_positive); - - // Different signs: subtraction path - let a = S::<2>::from_u128(0x2); - let b = S::<2>::from(-3i128); // -3 - let r2 = a.add_trunc::<1>(&b); // 2 + (-3) = -1, truncated to 64-bit - assert_eq!(r2.magnitude.0[0], 1); - assert!(!r2.is_positive); - - // sub_trunc uses add_trunc internally - let x = S::<1>::from_u64(10); - let y = S::<1>::from_u64(7); - let r3 = x.sub_trunc::<1>(&y); - assert_eq!(r3.magnitude.0[0], 3); - assert!(r3.is_positive); -} + #[test] + fn test_signed_truncated_add_sub() { + use crate::biginteger::SignedBigInt as S; + let a = S::<2>::from_u128(0x0000_0000_0000_0001_ffff_ffff_ffff_ffff); + let b = S::<2>::from_u128(0x0000_0000_0000_0001_0000_0000_0000_0001); + // Add and truncate to 1 limb + let r1 = a.add_trunc::<1>(&b); + // expected low limb wrap of the low words, ignoring carry to limb1 + let expected_low = (0xffff_ffff_ffff_ffffu64).wrapping_add(0x0000_0000_0000_0001u64); + assert_eq!(r1.magnitude.0[0], expected_low); + assert!(r1.is_positive); + + // Different signs: subtraction path + let a = S::<2>::from_u128(0x2); + let b = S::<2>::from(-3i128); // -3 + let r2 = a.add_trunc::<1>(&b); // 2 + (-3) = -1, truncated to 64-bit + assert_eq!(r2.magnitude.0[0], 1); + assert!(!r2.is_positive); + + // sub_trunc uses add_trunc internally + let x = S::<1>::from_u64(10); + let y = S::<1>::from_u64(7); + let r3 = x.sub_trunc::<1>(&y); + assert_eq!(r3.magnitude.0[0], 3); + assert!(r3.is_positive); + } -#[test] -fn test_signed_truncated_mul_and_fmadd() { - use crate::biginteger::SignedBigInt as S; - // 128-bit x 64-bit -> truncated to 2 limbs (128-bit) - let a = S::<2>::from_u128(0x0000_0000_0000_0001_FFFF_FFFF_FFFF_FFFFu128); - let b = S::<1>::from_u64(0x2); - let p = a.mul_trunc::<1, 2>(&b); - // Expected low 128 bits of the product - let expected = num_bigint::BigUint::from(0x0000_0000_0000_0001_FFFF_FFFF_FFFF_FFFFu128) - * num_bigint::BigUint::from(2u64); - let got = num_bigint::BigUint::from(p.magnitude); - assert_eq!(got, expected & ((num_bigint::BigUint::from(1u8) << 128) - 1u8)); - assert!(p.is_positive); - - // fmadd into 1-limb accumulator (truncate to 64 bits) - let a = S::<1>::from_u64(0xFFFF_FFFF_FFFF_FFFF); - let b = S::<1>::from_u64(0x2); - let mut acc = S::<1>::from_u64(1); - a.fmadd_trunc::<1, 1>(&b, &mut acc); // acc = 1 + (a*b) mod 2^64 with sign + - // a*b = (2^64 - 1)*2 = 2^65 - 2 => low 64 = (2^64 - 2) - let expected_low = (u64::MAX).wrapping_sub(1); - assert_eq!(acc.magnitude.0[0], expected_low.wrapping_add(1)); -} + #[test] + fn test_signed_truncated_mul_and_fmadd() { + use crate::biginteger::SignedBigInt as S; + // 128-bit x 64-bit -> truncated to 2 limbs (128-bit) + let a = S::<2>::from_u128(0x0000_0000_0000_0001_FFFF_FFFF_FFFF_FFFFu128); + let b = S::<1>::from_u64(0x2); + let p = a.mul_trunc::<1, 2>(&b); + // Expected low 128 bits of the product + let expected = num_bigint::BigUint::from(0x0000_0000_0000_0001_FFFF_FFFF_FFFF_FFFFu128) + * num_bigint::BigUint::from(2u64); + let got = num_bigint::BigUint::from(p.magnitude); + assert_eq!( + got, + expected & ((num_bigint::BigUint::from(1u8) << 128) - 1u8) + ); + assert!(p.is_positive); + + // fmadd into 1-limb accumulator (truncate to 64 bits) + let a = S::<1>::from_u64(0xFFFF_FFFF_FFFF_FFFF); + let b = S::<1>::from_u64(0x2); + let mut acc = S::<1>::from_u64(1); + a.fmadd_trunc::<1, 1>(&b, &mut acc); // acc = 1 + (a*b) mod 2^64 with sign + + // a*b = (2^64 - 1)*2 = 2^65 - 2 => low 64 = (2^64 - 2) + let expected_low = (u64::MAX).wrapping_sub(1); + assert_eq!(acc.magnitude.0[0], expected_low.wrapping_add(1)); + } -#[test] -fn test_signed_truncated_add_sub_mixed() { - use crate::biginteger::SignedBigInt as S; - // Same sign, different widths, ensure carry handling and sign preservation - let a = S::<2>::from_u128(0x0000_0000_0000_0002_FFFF_FFFF_FFFF_FFFF); - let b = S::<1>::from_u64(0x0000_0000_0000_0002); - let r = a.add_trunc_mixed::<1, 2>(&b); // 128-bit result - let expected = num_bigint::BigUint::from(0x0000_0000_0000_0002_FFFF_FFFF_FFFF_FFFFu128) - + num_bigint::BigUint::from(2u64); - assert_eq!(num_bigint::BigUint::from(r.magnitude), expected); - assert!(r.is_positive); - - // Different signs, |a| > |b|: result sign should be sign(a) - let a2 = S::<2>::from_u128(5000); - let b2 = S::<1>::from((3000u64, false)); // -3000 - let r2 = a2.add_trunc_mixed::<1, 2>(&b2); - assert!(r2.is_positive); - assert_eq!(r2.magnitude.0[0], 2000); - - // Different signs, |b| > |a|: result sign should be sign(b) - let a3 = S::<2>::from_u128(1000); - let b3 = S::<1>::from((3000u64, false)); // -3000 - let r3 = a3.add_trunc_mixed::<1, 2>(&b3); - assert!(!r3.is_positive); - assert_eq!(r3.magnitude.0[0], 2000); - - // sub_trunc_mixed basic checks - let a4 = S::<2>::from_u128(10000); - let b4 = S::<1>::from_u64(9999); - let r4 = a4.sub_trunc_mixed::<1, 2>(&b4); - assert!(r4.is_positive); - assert_eq!(r4.magnitude.0[0], 1); - - let a5 = S::<2>::from_u128(1000); - let b5 = S::<1>::from_u64(5000); - let r5 = a5.sub_trunc_mixed::<1, 2>(&b5); - assert!(!r5.is_positive); - assert_eq!(r5.magnitude.0[0], 4000); -} + #[test] + fn test_signed_truncated_add_sub_mixed() { + use crate::biginteger::SignedBigInt as S; + // Same sign, different widths, ensure carry handling and sign preservation + let a = S::<2>::from_u128(0x0000_0000_0000_0002_FFFF_FFFF_FFFF_FFFF); + let b = S::<1>::from_u64(0x0000_0000_0000_0002); + let r = a.add_trunc_mixed::<1, 2>(&b); // 128-bit result + let expected = num_bigint::BigUint::from(0x0000_0000_0000_0002_FFFF_FFFF_FFFF_FFFFu128) + + num_bigint::BigUint::from(2u64); + assert_eq!(num_bigint::BigUint::from(r.magnitude), expected); + assert!(r.is_positive); + + // Different signs, |a| > |b|: result sign should be sign(a) + let a2 = S::<2>::from_u128(5000); + let b2 = S::<1>::from((3000u64, false)); // -3000 + let r2 = a2.add_trunc_mixed::<1, 2>(&b2); + assert!(r2.is_positive); + assert_eq!(r2.magnitude.0[0], 2000); + + // Different signs, |b| > |a|: result sign should be sign(b) + let a3 = S::<2>::from_u128(1000); + let b3 = S::<1>::from((3000u64, false)); // -3000 + let r3 = a3.add_trunc_mixed::<1, 2>(&b3); + assert!(!r3.is_positive); + assert_eq!(r3.magnitude.0[0], 2000); + + // sub_trunc_mixed basic checks + let a4 = S::<2>::from_u128(10000); + let b4 = S::<1>::from_u64(9999); + let r4 = a4.sub_trunc_mixed::<1, 2>(&b4); + assert!(r4.is_positive); + assert_eq!(r4.magnitude.0[0], 1); + + let a5 = S::<2>::from_u128(1000); + let b5 = S::<1>::from_u64(5000); + let r5 = a5.sub_trunc_mixed::<1, 2>(&b5); + assert!(!r5.is_positive); + assert_eq!(r5.magnitude.0[0], 4000); + } -#[test] -fn test_signed_fmadd_trunc_mixed_width_and_signs() { - use crate::biginteger::SignedBigInt as S; - // Case 1: same sign => pure addition of magnitudes - let a = S::<2>::from_u128(30000); - let b = S::<1>::from_u64(7); - let mut acc = S::<2>::from_u128(1000000); - a.fmadd_trunc::<1, 2>(&b, &mut acc); // acc += 210000 - assert!(acc.is_positive); - assert_eq!(acc.magnitude.0[0] as u128 + ((acc.magnitude.0[1] as u128) << 64), 1210000u128); - - // Case 2: different sign, |prod| < |acc| => sign preserved - let a2 = S::<2>::from_u128(30000); - let b2 = S::<1>::from((7u64, false)); // -7 - let mut acc2 = S::<2>::from_u128(1000000); - a2.fmadd_trunc::<1, 2>(&b2, &mut acc2); // acc2 -= 210000 => 790000 - assert!(acc2.is_positive); - assert_eq!(acc2.magnitude.0[0] as u128 + ((acc2.magnitude.0[1] as u128) << 64), 790000u128); - - // Case 3: different sign, |prod| > |acc| => sign flips to prod_sign - let a3 = S::<2>::from_u128(300); - let b3 = S::<1>::from((7u64, false)); // -7 => prod = -2100 - let mut acc3 = S::<2>::from_u128(1000); - a3.fmadd_trunc::<1, 2>(&b3, &mut acc3); // 1000 - 2100 = -1100 - assert!(!acc3.is_positive); - assert_eq!(acc3.magnitude.0[0], 1100); -} + #[test] + fn test_signed_fmadd_trunc_mixed_width_and_signs() { + use crate::biginteger::SignedBigInt as S; + // Case 1: same sign => pure addition of magnitudes + let a = S::<2>::from_u128(30000); + let b = S::<1>::from_u64(7); + let mut acc = S::<2>::from_u128(1000000); + a.fmadd_trunc::<1, 2>(&b, &mut acc); // acc += 210000 + assert!(acc.is_positive); + assert_eq!( + acc.magnitude.0[0] as u128 + ((acc.magnitude.0[1] as u128) << 64), + 1210000u128 + ); + + // Case 2: different sign, |prod| < |acc| => sign preserved + let a2 = S::<2>::from_u128(30000); + let b2 = S::<1>::from((7u64, false)); // -7 + let mut acc2 = S::<2>::from_u128(1000000); + a2.fmadd_trunc::<1, 2>(&b2, &mut acc2); // acc2 -= 210000 => 790000 + assert!(acc2.is_positive); + assert_eq!( + acc2.magnitude.0[0] as u128 + ((acc2.magnitude.0[1] as u128) << 64), + 790000u128 + ); + + // Case 3: different sign, |prod| > |acc| => sign flips to prod_sign + let a3 = S::<2>::from_u128(300); + let b3 = S::<1>::from((7u64, false)); // -7 => prod = -2100 + let mut acc3 = S::<2>::from_u128(1000); + a3.fmadd_trunc::<1, 2>(&b3, &mut acc3); // 1000 - 2100 = -1100 + assert!(!acc3.is_positive); + assert_eq!(acc3.magnitude.0[0], 1100); + } -#[test] -fn test_prop_add_sub_trunc_mixed_random() { - use crate::biginteger::SignedBigInt as S; - use ark_std::rand::Rng; - let mut rng = ark_std::test_rng(); - - // Helper to validate a single pair for given consts - macro_rules! run_case { - ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ - for _ in 0..$iters { - let a_mag: crate::biginteger::BigInt<$n> = UniformRand::rand(&mut rng); - let b_mag: crate::biginteger::BigInt<$m> = UniformRand::rand(&mut rng); - let a_pos = (rng.gen::() & 1) == 1; - let b_pos = (rng.gen::() & 1) == 1; - let a = S::<$n>::from_bigint(a_mag, a_pos); - let b = S::<$m>::from_bigint(b_mag, b_pos); - - // add_trunc_mixed - let r_add = a.add_trunc_mixed::<$m, $p>(&b); - let a_bu = num_bigint::BigUint::from(a.magnitude); - let b_bu = num_bigint::BigUint::from(b.magnitude); - let (exp_add_mag, exp_add_pos) = if a_pos == b_pos { - (&a_bu + &b_bu, a_pos) - } else if a_bu >= b_bu { - (&a_bu - &b_bu, a_pos) - } else { - (&b_bu - &a_bu, b_pos) - }; - let modulus = num_bigint::BigUint::from(1u8) << (64 * $p); - let exp_add_mag_mod = exp_add_mag % &modulus; - assert_eq!(num_bigint::BigUint::from(r_add.magnitude), exp_add_mag_mod); - if exp_add_mag_mod != num_bigint::BigUint::from(0u8) { - assert_eq!(r_add.is_positive, exp_add_pos); + #[test] + fn test_prop_add_sub_trunc_mixed_random() { + use crate::biginteger::SignedBigInt as S; + use ark_std::rand::Rng; + let mut rng = ark_std::test_rng(); + + // Helper to validate a single pair for given consts + macro_rules! run_case { + ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ + for _ in 0..$iters { + let a_mag: crate::biginteger::BigInt<$n> = UniformRand::rand(&mut rng); + let b_mag: crate::biginteger::BigInt<$m> = UniformRand::rand(&mut rng); + let a_pos = (rng.gen::() & 1) == 1; + let b_pos = (rng.gen::() & 1) == 1; + let a = S::<$n>::from_bigint(a_mag, a_pos); + let b = S::<$m>::from_bigint(b_mag, b_pos); + + // add_trunc_mixed + let r_add = a.add_trunc_mixed::<$m, $p>(&b); + let a_bu = num_bigint::BigUint::from(a.magnitude); + let b_bu = num_bigint::BigUint::from(b.magnitude); + let (exp_add_mag, exp_add_pos) = if a_pos == b_pos { + (&a_bu + &b_bu, a_pos) + } else if a_bu >= b_bu { + (&a_bu - &b_bu, a_pos) + } else { + (&b_bu - &a_bu, b_pos) + }; + let modulus = num_bigint::BigUint::from(1u8) << (64 * $p); + let exp_add_mag_mod = exp_add_mag % &modulus; + assert_eq!(num_bigint::BigUint::from(r_add.magnitude), exp_add_mag_mod); + if exp_add_mag_mod != num_bigint::BigUint::from(0u8) { + assert_eq!(r_add.is_positive, exp_add_pos); + } + + // sub_trunc_mixed: a - b + let r_sub = a.sub_trunc_mixed::<$m, $p>(&b); + let (exp_sub_mag, exp_sub_pos) = if a_pos != b_pos { + (&a_bu + &b_bu, a_pos) + } else if a_bu >= b_bu { + (&a_bu - &b_bu, a_pos) + } else { + (&b_bu - &a_bu, !a_pos) + }; + let exp_sub_mag_mod = exp_sub_mag % &modulus; + assert_eq!(num_bigint::BigUint::from(r_sub.magnitude), exp_sub_mag_mod); + if exp_sub_mag_mod != num_bigint::BigUint::from(0u8) { + assert_eq!(r_sub.is_positive, exp_sub_pos); + } } + }}; + } - // sub_trunc_mixed: a - b - let r_sub = a.sub_trunc_mixed::<$m, $p>(&b); - let (exp_sub_mag, exp_sub_pos) = if a_pos != b_pos { - (&a_bu + &b_bu, a_pos) - } else if a_bu >= b_bu { - (&a_bu - &b_bu, a_pos) - } else { - (&b_bu - &a_bu, !a_pos) - }; - let exp_sub_mag_mod = exp_sub_mag % &modulus; - assert_eq!(num_bigint::BigUint::from(r_sub.magnitude), exp_sub_mag_mod); - if exp_sub_mag_mod != num_bigint::BigUint::from(0u8) { - assert_eq!(r_sub.is_positive, exp_sub_pos); - } - } - }}; + run_case!(2, 3, 2, 200); + run_case!(3, 1, 2, 200); + run_case!(1, 2, 1, 200); } - run_case!(2, 3, 2, 200); - run_case!(3, 1, 2, 200); - run_case!(1, 2, 1, 200); -} - -#[test] -fn test_prop_fmadd_trunc_random() { - use crate::biginteger::SignedBigInt as S; - use ark_std::rand::Rng; - let mut rng = ark_std::test_rng(); - - macro_rules! run_case { - ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ - for _ in 0..$iters { - let a_mag: crate::biginteger::BigInt<$n> = UniformRand::rand(&mut rng); - let b_mag: crate::biginteger::BigInt<$m> = UniformRand::rand(&mut rng); - let acc_mag: crate::biginteger::BigInt<$p> = UniformRand::rand(&mut rng); - let a_pos = (rng.gen::() & 1) == 1; - let b_pos = (rng.gen::() & 1) == 1; - let acc_pos = (rng.gen::() & 1) == 1; - let a = S::<$n>::from_bigint(a_mag, a_pos); - let b = S::<$m>::from_bigint(b_mag, b_pos); - let mut acc = S::<$p>::from_bigint(acc_mag, acc_pos); - - // expected via BigUint with truncation of the product BEFORE combining signs - let a_bu = num_bigint::BigUint::from(a.magnitude); - let b_bu = num_bigint::BigUint::from(b.magnitude); - let acc_bu = num_bigint::BigUint::from(acc.magnitude); - let modulus = num_bigint::BigUint::from(1u8) << (64 * $p); - let prod_mod = (&a_bu * &b_bu) % &modulus; - let prod_pos = a_pos == b_pos; - let (exp_mag_mod, exp_pos) = if acc_pos == prod_pos { - ((acc_bu + &prod_mod) % &modulus, acc_pos) - } else if acc_bu >= prod_mod { - (acc_bu - &prod_mod, acc_pos) - } else { - (prod_mod - &acc_bu, prod_pos) - }; - - a.fmadd_trunc::<$m, $p>(&b, &mut acc); - - assert_eq!(num_bigint::BigUint::from(acc.magnitude), exp_mag_mod); - if exp_mag_mod != num_bigint::BigUint::from(0u8) { - assert_eq!(acc.is_positive, exp_pos); + #[test] + fn test_prop_fmadd_trunc_random() { + use crate::biginteger::SignedBigInt as S; + use ark_std::rand::Rng; + let mut rng = ark_std::test_rng(); + + macro_rules! run_case { + ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ + for _ in 0..$iters { + let a_mag: crate::biginteger::BigInt<$n> = UniformRand::rand(&mut rng); + let b_mag: crate::biginteger::BigInt<$m> = UniformRand::rand(&mut rng); + let acc_mag: crate::biginteger::BigInt<$p> = UniformRand::rand(&mut rng); + let a_pos = (rng.gen::() & 1) == 1; + let b_pos = (rng.gen::() & 1) == 1; + let acc_pos = (rng.gen::() & 1) == 1; + let a = S::<$n>::from_bigint(a_mag, a_pos); + let b = S::<$m>::from_bigint(b_mag, b_pos); + let mut acc = S::<$p>::from_bigint(acc_mag, acc_pos); + + // expected via BigUint with truncation of the product BEFORE combining signs + let a_bu = num_bigint::BigUint::from(a.magnitude); + let b_bu = num_bigint::BigUint::from(b.magnitude); + let acc_bu = num_bigint::BigUint::from(acc.magnitude); + let modulus = num_bigint::BigUint::from(1u8) << (64 * $p); + let prod_mod = (&a_bu * &b_bu) % &modulus; + let prod_pos = a_pos == b_pos; + let (exp_mag_mod, exp_pos) = if acc_pos == prod_pos { + ((acc_bu + &prod_mod) % &modulus, acc_pos) + } else if acc_bu >= prod_mod { + (acc_bu - &prod_mod, acc_pos) + } else { + (prod_mod - &acc_bu, prod_pos) + }; + + a.fmadd_trunc::<$m, $p>(&b, &mut acc); + + assert_eq!(num_bigint::BigUint::from(acc.magnitude), exp_mag_mod); + if exp_mag_mod != num_bigint::BigUint::from(0u8) { + assert_eq!(acc.is_positive, exp_pos); + } } - } - }}; + }}; + } + + run_case!(2, 1, 2, 200); + run_case!(3, 2, 2, 200); } - run_case!(2, 1, 2, 200); - run_case!(3, 2, 2, 200); -} + // ============================== + // Tests for add_trunc and add_assign_trunc (unsigned BigInt) + // ============================== -// ============================== -// Tests for add_trunc and add_assign_trunc (unsigned BigInt) -// ============================== + #[test] + fn test_add_trunc_correctness_random() { + use crate::biginteger::BigInt; + let mut rng = ark_std::test_rng(); -#[test] -fn test_add_trunc_correctness_random() { - use crate::biginteger::BigInt; - let mut rng = ark_std::test_rng(); + macro_rules! run_case { + ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ + for _ in 0..$iters { + let a: BigInt<$n> = UniformRand::rand(&mut rng); + let b: BigInt<$m> = UniformRand::rand(&mut rng); - macro_rules! run_case { - ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ - for _ in 0..$iters { - let a: BigInt<$n> = UniformRand::rand(&mut rng); - let b: BigInt<$m> = UniformRand::rand(&mut rng); + let res = a.add_trunc::<$m, $p>(&b); - let res = a.add_trunc::<$m, $p>(&b); + let a_bu = BigUint::from(a); + let b_bu = BigUint::from(b); + let modulus = BigUint::from(1u8) << (64 * $p); + let expected = (a_bu + b_bu) % &modulus; + assert_eq!(BigUint::from(res), expected); + } + }}; + } + + // Same-width, truncated equal width + run_case!(4, 4, 4, 200); + // Same-width, truncate to fewer limbs + run_case!(4, 4, 3, 200); + // Mixed widths, truncate to min and to max + run_case!(4, 2, 3, 200); + run_case!(2, 4, 2, 200); + } - let a_bu = BigUint::from(a); - let b_bu = BigUint::from(b); - let modulus = BigUint::from(1u8) << (64 * $p); - let expected = (a_bu + b_bu) % &modulus; - assert_eq!(BigUint::from(res), expected); + #[test] + fn test_add_assign_trunc_correctness_and_zeroing() { + use crate::biginteger::BigInt; + let mut rng = ark_std::test_rng(); + + // Case 1: N = 4, M = 4, P = 4 (no truncation); compare against add_trunc and add_with_carry + for _ in 0..200 { + let a: BigInt<4> = UniformRand::rand(&mut rng); + let b: BigInt<4> = UniformRand::rand(&mut rng); + let r_trunc = a.add_trunc::<4, 4>(&b); + let mut a2 = a; + a2.add_assign_trunc::<4, 4>(&b); + assert_eq!(a2, r_trunc); + + // Regular add_with_carry should match lower 4 limbs modulo 2^(256) + let mut a3 = a; + a3.add_with_carry(&b); + assert_eq!(a3, r_trunc); + } + + // Case 2: N = 4, M = 4, P = 3 (truncation) -> self's limb 3 must be zeroed + for _ in 0..200 { + let a: BigInt<4> = UniformRand::rand(&mut rng); + let b: BigInt<4> = UniformRand::rand(&mut rng); + let r_trunc = a.add_trunc::<4, 3>(&b); + let mut a2 = a; + a2.add_assign_trunc::<4, 3>(&b); + // Low 3 limbs match result + for i in 0..3 { + assert_eq!(a2.0[i], r_trunc.0[i]); + } + // Higher limbs of self must be zero + for i in 3..4 { + assert_eq!(a2.0[i], 0); + } + } + + // Case 3: Mixed widths N = 4, M = 2, P = 3 + for _ in 0..200 { + let a: BigInt<4> = UniformRand::rand(&mut rng); + let b: BigInt<2> = UniformRand::rand(&mut rng); + let r_trunc = a.add_trunc::<2, 3>(&b); + let mut a2 = a; + a2.add_assign_trunc::<2, 3>(&b); + for i in 0..3 { + assert_eq!(a2.0[i], r_trunc.0[i]); } - }}; + // Truncated limb 3.. must be zero + for i in 3..4 { + assert_eq!(a2.0[i], 0); + } + } + + // Case 4: Mixed widths N = 2, M = 4, P = 2 (limit is N so no zeroing beyond N) + for _ in 0..200 { + let a: BigInt<2> = UniformRand::rand(&mut rng); + let b: BigInt<4> = UniformRand::rand(&mut rng); + let r_trunc = a.add_trunc::<4, 2>(&b); + let mut a2 = a; + a2.add_assign_trunc::<4, 2>(&b); + assert_eq!(a2, r_trunc); + } } - // Same-width, truncated equal width - run_case!(4, 4, 4, 200); - // Same-width, truncate to fewer limbs - run_case!(4, 4, 3, 200); - // Mixed widths, truncate to min and to max - run_case!(4, 2, 3, 200); - run_case!(2, 4, 2, 200); -} - -#[test] -fn test_add_assign_trunc_correctness_and_zeroing() { - use crate::biginteger::BigInt; - let mut rng = ark_std::test_rng(); + #[test] + fn test_add_trunc_and_add_assign_trunc_overflow_edges() { + use crate::biginteger::BigInt; - // Case 1: N = 4, M = 4, P = 4 (no truncation); compare against add_trunc and add_with_carry - for _ in 0..200 { - let a: BigInt<4> = UniformRand::rand(&mut rng); - let b: BigInt<4> = UniformRand::rand(&mut rng); - let r_trunc = a.add_trunc::<4, 4>(&b); + // All-ones + all-ones with truncation + let a = BigInt::<4>::new([u64::MAX; 4]); + let b = BigInt::<4>::new([u64::MAX; 4]); + // P = 4: result should be wrapping add modulo 2^256 + let r = a.add_trunc::<4, 4>(&b); let mut a2 = a; a2.add_assign_trunc::<4, 4>(&b); - assert_eq!(a2, r_trunc); + assert_eq!(a2, r); - // Regular add_with_carry should match lower 4 limbs modulo 2^(256) + // P = 3: ensure high limb is zeroed in mutating version + let r3 = a.add_trunc::<4, 3>(&b); let mut a3 = a; - a3.add_with_carry(&b); - assert_eq!(a3, r_trunc); - } - - // Case 2: N = 4, M = 4, P = 3 (truncation) -> self's limb 3 must be zeroed - for _ in 0..200 { - let a: BigInt<4> = UniformRand::rand(&mut rng); - let b: BigInt<4> = UniformRand::rand(&mut rng); - let r_trunc = a.add_trunc::<4, 3>(&b); - let mut a2 = a; - a2.add_assign_trunc::<4, 3>(&b); - // Low 3 limbs match result - for i in 0..3 { assert_eq!(a2.0[i], r_trunc.0[i]); } - // Higher limbs of self must be zero - for i in 3..4 { assert_eq!(a2.0[i], 0); } - } - - // Case 3: Mixed widths N = 4, M = 2, P = 3 - for _ in 0..200 { - let a: BigInt<4> = UniformRand::rand(&mut rng); - let b: BigInt<2> = UniformRand::rand(&mut rng); - let r_trunc = a.add_trunc::<2, 3>(&b); - let mut a2 = a; - a2.add_assign_trunc::<2, 3>(&b); - for i in 0..3 { assert_eq!(a2.0[i], r_trunc.0[i]); } - // Truncated limb 3.. must be zero - for i in 3..4 { assert_eq!(a2.0[i], 0); } + a3.add_assign_trunc::<4, 3>(&b); + for i in 0..3 { + assert_eq!(a3.0[i], r3.0[i]); + } + assert_eq!(a3.0[3], 0); } - - // Case 4: Mixed widths N = 2, M = 4, P = 2 (limit is N so no zeroing beyond N) - for _ in 0..200 { - let a: BigInt<2> = UniformRand::rand(&mut rng); - let b: BigInt<4> = UniformRand::rand(&mut rng); - let r_trunc = a.add_trunc::<4, 2>(&b); - let mut a2 = a; - a2.add_assign_trunc::<4, 2>(&b); - assert_eq!(a2, r_trunc); - } -} - -#[test] -fn test_add_trunc_and_add_assign_trunc_overflow_edges() { - use crate::biginteger::BigInt; - - // All-ones + all-ones with truncation - let a = BigInt::<4>::new([u64::MAX; 4]); - let b = BigInt::<4>::new([u64::MAX; 4]); - // P = 4: result should be wrapping add modulo 2^256 - let r = a.add_trunc::<4, 4>(&b); - let mut a2 = a; - a2.add_assign_trunc::<4, 4>(&b); - assert_eq!(a2, r); - - // P = 3: ensure high limb is zeroed in mutating version - let r3 = a.add_trunc::<4, 3>(&b); - let mut a3 = a; - a3.add_assign_trunc::<4, 3>(&b); - for i in 0..3 { assert_eq!(a3.0[i], r3.0[i]); } - assert_eq!(a3.0[3], 0); -} - } diff --git a/ff/src/fields/models/fp/mod.rs b/ff/src/fields/models/fp/mod.rs index 3fb2110a3..30af6f46f 100644 --- a/ff/src/fields/models/fp/mod.rs +++ b/ff/src/fields/models/fp/mod.rs @@ -2,7 +2,6 @@ use crate::{ AdditiveGroup, BigInt, BigInteger, FftField, Field, LegendreSymbol, One, PrimeField, SqrtPrecomputation, Zero, }; -#[cfg(feature = "allocative")] use allocative::Allocative; use ark_serialize::{ buffer_byte_size, CanonicalDeserialize, CanonicalDeserializeWithFlags, CanonicalSerialize, @@ -120,7 +119,6 @@ pub struct Fp, const N: usize>( #[doc(hidden)] pub PhantomData

, ); -#[cfg(feature = "allocative")] impl, const N: usize> Allocative for Fp { fn visit<'a, 'b: 'a>(&self, _visitor: &'a mut allocative::Visitor<'b>) {} } diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index 405a11797..bdddd4432 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -1,4 +1,3 @@ - use super::{Fp, FpConfig}; use crate::{ biginteger::arithmetic as fa, BigInt, BigInteger, PrimeField, SqrtPrecomputation, Zero, @@ -1064,7 +1063,10 @@ impl, const N: usize> Fp, N> { /// Two-phase (schoolbook+REDC) multiply with a RHS whose highest K limbs are provided /// in `rhs_hi` and lower limbs are zero. #[inline] - const fn mul_without_cond_subtract_rhs_hi(mut self, rhs_hi: &crate::BigInt) -> (bool, Self) { + const fn mul_without_cond_subtract_rhs_hi( + mut self, + rhs_hi: &crate::BigInt, + ) -> (bool, Self) { let (mut lo, mut hi) = ([0u64; N], [0u64; N]); // Schoolbook: only columns j in [N-K, N) crate::const_for!((i in 0..N) { @@ -1247,19 +1249,15 @@ impl, const N: usize> Fp, N> { /// This avoids creating temporary BigInt objects. #[inline(always)] #[unroll_for_loops(8)] - fn mul_u64_accumulate( - acc: &mut BigInt, - a: &BigInt, - b: u64 - ) { + fn mul_u64_accumulate(acc: &mut BigInt, a: &BigInt, b: u64) { debug_assert!(NPLUS1 == N + 1); use crate::biginteger::arithmetic as fa; - + let mut carry = 0u64; for i in 0..N { acc.0[i] = fa::mac_with_carry(acc.0[i], a.0[i], b, &mut carry); } - + // Add final carry to the high limb let final_carry = fa::adc(&mut acc.0[N], carry, 0); debug_assert!(final_carry == 0, "overflow in mul_u64_accumulate"); @@ -1269,20 +1267,21 @@ impl, const N: usize> Fp, N> { /// Performs unreduced accumulation in BigInt, then one final reduction. /// This is more efficient than individual multiplications and additions. #[inline(always)] - pub fn linear_combination_u64( - pairs: &[(Self, u64)] - ) -> Self { + pub fn linear_combination_u64(pairs: &[(Self, u64)]) -> Self { debug_assert!(NPLUS1 == N + 1); - debug_assert!(!pairs.is_empty(), "linear_combination_u64 requires at least one pair"); - + debug_assert!( + !pairs.is_empty(), + "linear_combination_u64 requires at least one pair" + ); + // Start with first term - let mut acc = pairs[0].0.0.mul_u64_w_carry::(pairs[0].1); - + let mut acc = pairs[0].0 .0.mul_u64_w_carry::(pairs[0].1); + // Accumulate remaining terms using multiply-accumulate to avoid temporaries for (a, b) in &pairs[1..] { Self::mul_u64_accumulate::(&mut acc, &a.0, *b); } - + Self::from_unchecked_nplus1::(acc) } @@ -1291,37 +1290,43 @@ impl, const N: usize> Fp, N> { /// sums are computed separately and subtracted. One final reduction is performed. #[inline(always)] pub fn linear_combination_i64( - pos: &[(Self, u64)], - neg: &[(Self, u64)] + pos: &[(Self, u64)], + neg: &[(Self, u64)], ) -> Self { debug_assert!(NPLUS1 == N + 1); - debug_assert!(!pos.is_empty(), "linear_combination_i64 requires at least one positive term"); - debug_assert!(!neg.is_empty(), "linear_combination_i64 requires at least one negative term"); - + debug_assert!( + !pos.is_empty(), + "linear_combination_i64 requires at least one positive term" + ); + debug_assert!( + !neg.is_empty(), + "linear_combination_i64 requires at least one negative term" + ); + // Compute unreduced positive sum - let mut pos_lc = pos[0].0.0.mul_u64_w_carry::(pos[0].1); + let mut pos_lc = pos[0].0 .0.mul_u64_w_carry::(pos[0].1); for (a, b) in &pos[1..] { Self::mul_u64_accumulate::(&mut pos_lc, &a.0, *b); } - + // Compute unreduced negative sum - let mut neg_lc = neg[0].0.0.mul_u64_w_carry::(neg[0].1); + let mut neg_lc = neg[0].0 .0.mul_u64_w_carry::(neg[0].1); for (a, b) in &neg[1..] { Self::mul_u64_accumulate::(&mut neg_lc, &a.0, *b); } - + // Subtract and reduce once match pos_lc.cmp(&neg_lc) { core::cmp::Ordering::Greater => { let borrow = pos_lc.sub_with_borrow(&neg_lc); debug_assert!(!borrow, "borrow in linear_combination_i64"); Self::from_unchecked_nplus1::(pos_lc) - } + }, core::cmp::Ordering::Less => { let borrow = neg_lc.sub_with_borrow(&pos_lc); debug_assert!(!borrow, "borrow in linear_combination_i64"); -Self::from_unchecked_nplus1::(neg_lc) - } + }, core::cmp::Ordering::Equal => Self::zero(), } } @@ -1330,11 +1335,13 @@ impl, const N: usize> Fp, N> { /// Avoids slice overhead and loop setup costs. #[inline(always)] pub fn linear_combination_u64_2( - a1: &Self, b1: u64, - a2: &Self, b2: u64 + a1: &Self, + b1: u64, + a2: &Self, + b2: u64, ) -> Self { debug_assert!(NPLUS1 == N + 1); - + let mut acc = a1.0.mul_u64_w_carry::(b1); Self::mul_u64_accumulate::(&mut acc, &a2.0, b2); Self::from_unchecked_nplus1::(acc) @@ -1343,12 +1350,15 @@ impl, const N: usize> Fp, N> { /// Optimized version for exactly 3 terms: a₁×b₁ + a₂×b₂ + a₃×b₃ #[inline(always)] pub fn linear_combination_u64_3( - a1: &Self, b1: u64, - a2: &Self, b2: u64, - a3: &Self, b3: u64 + a1: &Self, + b1: u64, + a2: &Self, + b2: u64, + a3: &Self, + b3: u64, ) -> Self { debug_assert!(NPLUS1 == N + 1); - + let mut acc = a1.0.mul_u64_w_carry::(b1); Self::mul_u64_accumulate::(&mut acc, &a2.0, b2); Self::mul_u64_accumulate::(&mut acc, &a3.0, b3); @@ -1594,9 +1604,10 @@ fn barrett_reduce_nplus1_to_n, const N: usize, const NPLUS1: us // Compute r_tmp = c - m * 2p (result is ([u64; N], u64)) let m_times_2p = ( m2p.0[0..N].try_into().unwrap(), // Convert to ([u64; N], u64) - m2p.0[N] // High limb remains as u64 + m2p.0[N], // High limb remains as u64 ); - let (r_tmp, r_tmp_borrow) = sub_bigint_plus_one_prime((c.0[0], c.0[1..N+1].try_into().unwrap()), m_times_2p); + let (r_tmp, r_tmp_borrow) = + sub_bigint_plus_one_prime((c.0[0], c.0[1..N + 1].try_into().unwrap()), m_times_2p); // A borrow here implies c was smaller than m*2p, which shouldn't happen with correct m. debug_assert!(!r_tmp_borrow, "Borrow occurred calculating c - m*2p"); // Change formats again! @@ -1604,7 +1615,7 @@ fn barrett_reduce_nplus1_to_n, const N: usize, const NPLUS1: us // Alternative simple BigInt subtraction (much slower for some reason): /*let (r_tmp_bigint, r_borrow) = c.const_sub_with_borrow(&m2p); debug_assert!(!r_borrow, "Borrow occurred calculating c - m*2p");*/ - + // Use the optimized conditional subtraction to go from N+1 limbs to N limbs. barrett_cond_subtract::(r_tmp_bigint) } diff --git a/test-curves/benches/bigint.rs b/test-curves/benches/bigint.rs index 66319b37a..f0d7ab1f9 100644 --- a/test-curves/benches/bigint.rs +++ b/test-curves/benches/bigint.rs @@ -1,6 +1,6 @@ // Benchmark for BigInt operations #[cfg(feature = "bn254")] -use ark_ff::{BigInteger, BigInt}; +use ark_ff::{BigInt, BigInteger}; #[cfg(feature = "bn254")] use ark_std::rand::{rngs::StdRng, Rng, SeedableRng}; #[cfg(feature = "bn254")] @@ -14,21 +14,25 @@ fn bigint_add_bench(c: &mut Criterion) { // Generate random BigInt<4> instances for benchmarking let a_bigints = (0..SAMPLES) - .map(|_| BigInt::<4>([ - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - ])) + .map(|_| { + BigInt::<4>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); let b_bigints = (0..SAMPLES) - .map(|_| BigInt::<4>([ - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - ])) + .map(|_| { + BigInt::<4>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); let mut group = c.benchmark_group("BigInt<4> Addition Comparison"); @@ -117,9 +121,7 @@ fn bigint_add_bench(c: &mut Criterion) { // Test case: addition that would overflow to compare truncation behavior let max_bigints = (0..SAMPLES) - .map(|_| BigInt::<4>([ - u64::MAX, u64::MAX, u64::MAX, u64::MAX, - ])) + .map(|_| BigInt::<4>([u64::MAX, u64::MAX, u64::MAX, u64::MAX])) .collect::>(); group.bench_function("add_trunc overflow case", |bench| { diff --git a/test-curves/benches/small_mul.rs b/test-curves/benches/small_mul.rs index eff9e544c..965a34035 100644 --- a/test-curves/benches/small_mul.rs +++ b/test-curves/benches/small_mul.rs @@ -1,6 +1,6 @@ // This bench prefers bn254; if not enabled, provide a no-op main #[cfg(feature = "bn254")] -use ark_ff::{UniformRand, BigInteger}; +use ark_ff::{BigInteger, UniformRand}; #[cfg(feature = "bn254")] use ark_std::rand::{rngs::StdRng, Rng, SeedableRng}; #[cfg(feature = "bn254")] @@ -16,68 +16,83 @@ fn mul_small_bench(c: &mut Criterion) { // Use a fixed seed for reproducibility let mut rng = StdRng::seed_from_u64(0u64); - let a_s = (0..SAMPLES) - .map(|_| Fr::rand(&mut rng)) - .collect::>(); + let a_s = (0..SAMPLES).map(|_| Fr::rand(&mut rng)).collect::>(); // let a_limbs_s = a_s.iter().map(|a| a.0.0).collect::>(); - let b_u64_s = (0..SAMPLES) - .map(|_| rng.gen::()) - .collect::>(); + let b_u64_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); // Convert u64 to Fr for standard multiplication benchmark let b_fr_s = b_u64_s.iter().map(|&b| Fr::from(b)).collect::>(); let b_u64_as_u128_s = b_u64_s.iter().map(|&b| b as u128).collect::>(); - let b_i64_s = (0..SAMPLES) - .map(|_| rng.gen::()) - .collect::>(); + let b_i64_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); - let b_u128_s = (0..SAMPLES) - .map(|_| rng.gen::()) - .collect::>(); + let b_u128_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); - let b_i128_s = (0..SAMPLES) - .map(|_| rng.gen::()) - .collect::>(); + let b_i128_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); // Generate another set of random Fr elements for addition - let c_s = (0..SAMPLES) - .map(|_| Fr::rand(&mut rng)) - .collect::>(); + let c_s = (0..SAMPLES).map(|_| Fr::rand(&mut rng)).collect::>(); // Generate test data for reduction benchmarks use ark_ff::BigInt; // Extract BigInt<4> from Fr elements for mul_u64_w_carry benchmark let a_bigints = a_s.iter().map(|a| a.0).collect::>(); - + // For Montgomery reduction: 2N-limb inputs (N=4 for bn254, so 2N=8) let bigint_2n_s = (0..SAMPLES) - .map(|_| BigInt::<8>([ - rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), - rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), - ])) + .map(|_| { + BigInt::<8>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); - + // For Barrett reductions: N+1, N+2, N+3 limb inputs let bigint_nplus1_s = (0..SAMPLES) - .map(|_| BigInt::<5>([ - rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), - ])) + .map(|_| { + BigInt::<5>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); - + let bigint_nplus2_s = (0..SAMPLES) - .map(|_| BigInt::<6>([ - rng.gen::(), rng.gen::(), rng.gen::(), - rng.gen::(), rng.gen::(), rng.gen::(), - ])) + .map(|_| { + BigInt::<6>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); - + let bigint_nplus3_s = (0..SAMPLES) - .map(|_| BigInt::<7>([ - rng.gen::(), rng.gen::(), rng.gen::(), rng.gen::(), - rng.gen::(), rng.gen::(), rng.gen::(), - ])) + .map(|_| { + BigInt::<7>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); let mut group = c.benchmark_group("Fr Arithmetic Comparison"); @@ -259,8 +274,10 @@ fn mul_small_bench(c: &mut Criterion) { bench.iter(|| { i = (i + 1) % SAMPLES; criterion::black_box(Fr::linear_combination_u64_2::<5>( - &a_s[i], b_u64_s[i], - &c_s[i], b_u64_s[(i + 1) % SAMPLES] + &a_s[i], + b_u64_s[i], + &c_s[i], + b_u64_s[(i + 1) % SAMPLES], )) }) }); @@ -270,7 +287,7 @@ fn mul_small_bench(c: &mut Criterion) { bench.iter(|| { i = (i + 1) % SAMPLES; let pairs = [ - (a_s[i], b_u64_s[i]), + (a_s[i], b_u64_s[i]), (c_s[i], b_u64_s[(i + 1) % SAMPLES]), (a_s[(i + 2) % SAMPLES], b_u64_s[(i + 2) % SAMPLES]), (c_s[(i + 3) % SAMPLES], b_u64_s[(i + 3) % SAMPLES]), @@ -284,9 +301,12 @@ fn mul_small_bench(c: &mut Criterion) { bench.iter(|| { i = (i + 1) % SAMPLES; criterion::black_box(Fr::linear_combination_u64_3::<5>( - &a_s[i], b_u64_s[i], - &c_s[i], b_u64_s[(i + 1) % SAMPLES], - &a_s[(i + 2) % SAMPLES], b_u64_s[(i + 2) % SAMPLES] + &a_s[i], + b_u64_s[i], + &c_s[i], + b_u64_s[(i + 1) % SAMPLES], + &a_s[(i + 2) % SAMPLES], + b_u64_s[(i + 2) % SAMPLES], )) }) }); @@ -296,8 +316,10 @@ fn mul_small_bench(c: &mut Criterion) { bench.iter(|| { i = (i + 1) % SAMPLES; let pos = [(a_s[i], b_u64_s[i]), (c_s[i], b_u64_s[(i + 1) % SAMPLES])]; - let neg = [(a_s[(i + 2) % SAMPLES], b_u64_s[(i + 2) % SAMPLES]), - (c_s[(i + 3) % SAMPLES], b_u64_s[(i + 3) % SAMPLES])]; + let neg = [ + (a_s[(i + 2) % SAMPLES], b_u64_s[(i + 2) % SAMPLES]), + (c_s[(i + 3) % SAMPLES], b_u64_s[(i + 3) % SAMPLES]), + ]; criterion::black_box(Fr::linear_combination_i64::<5>(&pos, &neg)) }) }); @@ -334,4 +356,4 @@ criterion_group!(benches, mul_small_bench); criterion_main!(benches); #[cfg(not(feature = "bn254"))] -fn main() {} \ No newline at end of file +fn main() {} diff --git a/test-curves/src/bn254/fq.rs b/test-curves/src/bn254/fq.rs index 001d94836..6bddf9bc0 100644 --- a/test-curves/src/bn254/fq.rs +++ b/test-curves/src/bn254/fq.rs @@ -10,4 +10,4 @@ pub struct FqConfig; pub type Fq = Fp256>; pub const FQ_ONE: Fq = ark_ff::MontFp!("1"); -pub const FQ_ZERO: Fq = ark_ff::MontFp!("0"); \ No newline at end of file +pub const FQ_ZERO: Fq = ark_ff::MontFp!("0"); diff --git a/test-curves/src/bn254/fr.rs b/test-curves/src/bn254/fr.rs index 4caef8e7c..4de077431 100644 --- a/test-curves/src/bn254/fr.rs +++ b/test-curves/src/bn254/fr.rs @@ -14,4 +14,4 @@ pub struct FrConfig; pub type Fr = Fp256>; pub const FR_ONE: Fr = ark_ff::MontFp!("1"); -pub const FR_ZERO: Fr = ark_ff::MontFp!("0"); \ No newline at end of file +pub const FR_ZERO: Fr = ark_ff::MontFp!("0"); diff --git a/test-curves/src/bn254/g1.rs b/test-curves/src/bn254/g1.rs index 2b3c5a0c5..608278db8 100644 --- a/test-curves/src/bn254/g1.rs +++ b/test-curves/src/bn254/g1.rs @@ -1,8 +1,6 @@ -use ark_ec::models::short_weierstrass::{ - Affine, Projective, SWCurveConfig, -}; +use ark_ec::models::short_weierstrass::{Affine, Projective, SWCurveConfig}; use ark_ec::CurveConfig; -use ark_ff::{Field, MontFp, Zero, AdditiveGroup}; +use ark_ff::{AdditiveGroup, Field, MontFp, Zero}; use crate::bn254::{Fq, Fr}; // Assuming Fq is defined in fq.rs diff --git a/test-curves/src/bn254/test.rs b/test-curves/src/bn254/test.rs index 51a9c691e..176467a7a 100644 --- a/test-curves/src/bn254/test.rs +++ b/test-curves/src/bn254/test.rs @@ -2,7 +2,9 @@ use ark_ec::{ models::short_weierstrass::SWCurveConfig, // Keep this as G1 is SW pairing::Pairing, - AffineRepr, CurveGroup, PrimeGroup, + AffineRepr, + CurveGroup, + PrimeGroup, }; use ark_ff::{Field, One, UniformRand, Zero}; use ark_std::{rand::Rng, test_rng}; From 7fac06549a4c301e775e0ea548c81f1fca261a21 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 16 Sep 2025 12:01:22 -0400 Subject: [PATCH 20/38] (de)serialization & allocative for i8 or i96 --- ff/src/biginteger/i8_or_i96.rs | 63 +++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/ff/src/biginteger/i8_or_i96.rs b/ff/src/biginteger/i8_or_i96.rs index bec6ee803..a84392af6 100644 --- a/ff/src/biginteger/i8_or_i96.rs +++ b/ff/src/biginteger/i8_or_i96.rs @@ -1,5 +1,10 @@ use crate::biginteger::{S160, S224}; use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; +use allocative::Allocative; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; /// Compact signed integer optimized for the common `i8` case, widening to a 96-bit /// split representation when needed (low 64 bits in `large_lo`, next 32 bits in `large_hi`). @@ -28,7 +33,7 @@ use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; /// it is truncated to signed 96-bit two's complement (wrapping modulo 2^96). /// - The `neg` implementation avoids `i8` overflow by widening `i8::MIN` to the wide form. /// - Conversions are total: `to_i128()` always returns the exact value. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] pub struct I8OrI96 { /// The lower 64 bits of the constant value. large_lo: u64, @@ -600,3 +605,59 @@ impl Mul for I8OrI96 { S224::new([r0, r1, r2], hi32, is_positive) } } + +// ------------------------------------------------------------------------------------------------ +// Canonical serialization +// ------------------------------------------------------------------------------------------------ + +impl CanonicalSerialize for I8OrI96 { + #[inline] + fn serialize_with_mode( + &self, + mut w: W, + compress: Compress, + ) -> Result<(), SerializationError> { + // Print whether it is small (computed canonically from value), and the numeric value + // as 96-bit two's complement represented by (hi: i32, lo: u64), derived from to_i128(). + let v = self.to_i128(); + let is_small_value = v >= i8::MIN as i128 && v <= i8::MAX as i128; + let is_small_u8: u8 = if is_small_value { 1 } else { 0 }; + is_small_u8.serialize_with_mode(&mut w, compress)?; + + let hi: i32 = (v >> 64) as i32; + let lo: u64 = v as u64; + hi.serialize_with_mode(&mut w, compress)?; + lo.serialize_with_mode(w, compress) + } + + #[inline] + fn serialized_size(&self, compress: Compress) -> usize { + (self.is_small as u8).serialized_size(compress) + + (0i32).serialized_size(compress) + + (0u64).serialized_size(compress) + } +} + +impl CanonicalDeserialize for I8OrI96 { + #[inline] + fn deserialize_with_mode( + mut r: R, + compress: Compress, + validate: Validate, + ) -> Result { + let _is_small = u8::deserialize_with_mode(&mut r, compress, validate)?; + let hi = i32::deserialize_with_mode(&mut r, compress, validate)?; + let lo = u64::deserialize_with_mode(r, compress, validate)?; + let v: i128 = ((hi as i128) << 64) | (lo as i128); + Ok(I8OrI96::from_i128(v)) + } +} + +impl Valid for I8OrI96 { + #[inline] + fn check(&self) -> Result<(), SerializationError> { + // All bit patterns of the struct represent either a small i8 or a 96-bit two's complement value. + // No additional invariants beyond that; always valid. + Ok(()) + } +} From 93250cc8e7d0288d559e7722243c9f18bb50d76f Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 16 Sep 2025 17:00:29 -0400 Subject: [PATCH 21/38] add conversion from (signed) bigint to field elts --- ff/src/biginteger/mod.rs | 12 ++++ ff/src/biginteger/signed_hi_32.rs | 15 +++++ ff/src/fields/models/fp/montgomery_backend.rs | 65 ++++++++++++++++++- 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 55bbb55f2..3684e0764 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -464,6 +464,18 @@ impl BigInt { crate::const_helpers::R2Buffer::([0u64; N], [0u64; N], 1); const_modulo!(two_pow_n_times_64_square, self) } + + /// Zero-extend a smaller BigInt into BigInt (little-endian limbs). + /// Debug-asserts that M <= N. + #[inline] + pub fn zero_extend_from(smaller: &BigInt) -> BigInt { + debug_assert!(M <= N, "cannot zero-extend: source has more limbs than destination"); + let mut limbs = [0u64; N]; + let copy_len = if M < N { M } else { N }; + limbs[..copy_len].copy_from_slice(&smaller.0[..copy_len]); + BigInt::(limbs) + } + } impl BigInteger for BigInt { diff --git a/ff/src/biginteger/signed_hi_32.rs b/ff/src/biginteger/signed_hi_32.rs index da6f98e1d..b904df0cf 100644 --- a/ff/src/biginteger/signed_hi_32.rs +++ b/ff/src/biginteger/signed_hi_32.rs @@ -2,6 +2,7 @@ use allocative::Allocative; use ark_std::cmp::Ordering; use ark_std::vec::Vec; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use crate::biginteger::BigInt; /// Compact signed big-integer parameterized by limb count `N` (total width = `N*64 + 32` bits). /// @@ -296,6 +297,20 @@ impl SignedBigIntHi32 { (magnitude_lo, hi2, final_borrow) } + + /// Return the unsigned magnitude as a BigInt with N+1 limbs (little-endian), + /// packing `magnitude_lo` followed by `magnitude_hi` (widened to u64). + /// This ignores the sign; pair with `is_positive()` if you need a signed value. + #[inline] + pub fn magnitude_as_bigint_nplus1(&self) -> BigInt { + debug_assert!(NPLUS1 == N + 1, "NPLUS1 must be N+1 for SignedBigIntHi32 magnitude pack"); + let mut limbs = [0u64; NPLUS1]; + if N > 0 { + limbs[..N].copy_from_slice(&self.magnitude_lo); + } + limbs[N] = self.magnitude_hi as u64; + BigInt::(limbs) + } } // ------------------------------------------------------------------------------------------------ diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index bdddd4432..b95ed0213 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -474,6 +474,44 @@ pub trait MontConfig: 'static + Sync + Send + Sized { } } + /// Construct from a smaller-width BigInt by zero-extending into N limbs. + /// Returns None if the resulting N-limb value is >= modulus. + #[inline] + fn from_bigint_mixed(r: BigInt) -> Fp, N> { + debug_assert!(M <= N, "from_bigint_mixed requires M <= N"); + let r_n = BigInt::::zero_extend_from::(&r); + Self::from_bigint(r_n).expect("from_bigint_mixed: value >= modulus") + } + + /// Construct from a signed big integer with M 64-bit limbs (sign-magnitude). + /// Returns None if |x| >= modulus. + #[inline] + fn from_signed_bigint( + x: crate::biginteger::SignedBigInt, + ) -> Fp, N> { + // if x.is_zero() { + // return Fp::zero(); + // } + let fe = Self::from_bigint_mixed::(x.magnitude); + if x.is_positive { fe } else { -fe } + } + + /// Construct from a signed big integer with high 32-bit tail and K low 64-bit limbs. + /// KPLUS1 must be K+1; the magnitude packs as [lo[0..K], hi32 as u64]. + /// Returns None if |x| >= modulus. + #[inline] + fn from_signed_bigint_hi32( + x: crate::biginteger::SignedBigIntHi32, + ) -> Fp, N> { + debug_assert!(KPLUS1 == K + 1, "from_signed_bigint_hi32 requires KPLUS1 = K + 1"); + // if x.is_zero() { + // return Fp::zero(); + // } + let mag = x.magnitude_as_bigint_nplus1::(); + let fe = Self::from_bigint_mixed::(mag); + if x.is_positive() { fe } else { -fe } + } + #[inline] #[cfg_attr(not(target_family = "wasm"), unroll_for_loops(12))] #[cfg_attr(target_family = "wasm", unroll_for_loops(6))] @@ -840,7 +878,7 @@ impl, const N: usize> Fp, N> { Self(element, PhantomData) } - /// NEW! Construct a new field element from a BigInt + /// Construct a new field element from a BigInt /// which is in montgomery form and just needs to be reduced /// via a barrett reduction. #[inline(always)] @@ -850,7 +888,7 @@ impl, const N: usize> Fp, N> { Self::new_unchecked(r) } - /// NEW! Construct a new field element from a BigInt + /// Construct a new field element from a BigInt /// which is in montgomery form and just needs to be reduced /// via a barrett reduction. #[inline] @@ -867,6 +905,29 @@ impl, const N: usize> Fp, N> { Self::new_unchecked(r2) } + /// Construct from a smaller-width BigInt by zero-extending into N limbs. + /// Panics if the resulting value is >= modulus. + #[inline] + pub fn from_bigint_mixed(r: BigInt) -> Self { + T::from_bigint_mixed::(r) + } + + /// Construct from a signed big integer (sign-magnitude with M limbs). + /// Panics if |x| >= modulus. + #[inline] + pub fn from_signed_bigint(x: crate::biginteger::SignedBigInt) -> Self { + T::from_signed_bigint::(x) + } + + /// Construct from a signed big integer with high 32-bit tail and K low 64-bit limbs. + /// KPLUS1 must be K+1. Panics if |x| >= modulus. + #[inline] + pub fn from_signed_bigint_hi32( + x: crate::biginteger::SignedBigIntHi32, + ) -> Self { + T::from_signed_bigint_hi32::(x) + } + const fn const_is_zero(&self) -> bool { self.0.const_is_zero() } From 1e03fd6dfdba44daf16c47dce8e51958e2902e47 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 16 Sep 2025 18:26:55 -0400 Subject: [PATCH 22/38] starting refactor / unification of new bigint ops --- ff/src/biginteger/mod.rs | 323 +++++++-------------------- ff/src/biginteger/signed.rs | 360 ++++-------------------------- ff/src/biginteger/signed_hi_32.rs | 183 ++++++++++++++- 3 files changed, 314 insertions(+), 552 deletions(-) diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 3684e0764..754b6e9fc 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -31,7 +31,7 @@ use zeroize::Zeroize; pub mod arithmetic; pub mod signed; -pub use signed::{SignedBigInt, S128, S196, S256, S64}; +pub use signed::{SignedBigInt, S128, S192, S256, S64}; pub mod signed_hi_32; pub use signed_hi_32::{SignedBigIntHi32, S160, S224, S96}; @@ -350,6 +350,29 @@ impl BigInt { res } + /// Truncated-width subtraction: compute self - other and fit into P limbs; borrow is ignored beyond P limbs. + #[inline] + pub fn sub_trunc(&self, other: &BigInt) -> BigInt

{ + let mut res = BigInt::

::zero(); + let mut borrow = false; + + for i in 0..P { + let a = if i < N { self.0[i] } else { 0u64 }; + let b = if i < M { other.0[i] } else { 0u64 }; + let (d1, b1) = a.overflowing_sub(b); + if borrow { + let (d2, b2) = d1.overflowing_sub(1); + res.0[i] = d2; + borrow = b1 || b2; + } else { + res.0[i] = d1; + borrow = b1; + } + } + + res + } + /// Truncated-width addition that mutates self: self += other and fit result into P limbs; overflow is ignored. #[inline] pub fn add_assign_trunc(&mut self, other: &BigInt) { @@ -408,6 +431,58 @@ impl BigInt { } } + /// Internal core engine: accumulate self * other_limbs into acc starting at lane_offset. + /// If carry_propagate is true, propagate spill from the highest updated limb forward within P; + /// otherwise, wrap in-place (discard further carry), matching existing wrapper semantics. + #[inline] + #[unroll_for_loops(6)] + pub(crate) fn fm_limbs_into( + &self, + other_limbs: &[u64], + acc: &mut BigInt

, + lane_offset: usize, + carry_propagate: bool, + ) { + if self.is_zero() { + return; + } + for (j, &mul_limb) in other_limbs.iter().enumerate() { + if mul_limb == 0 { + continue; + } + let base = lane_offset + j; + let mut carry = 0u64; + // Accumulate across self's limbs + for i in 0..N { + let idx = base + i; + if idx >= P { + // Out of truncation range; compute carry but discard writes + // We still need to advance carry for correctness within truncated semantics? No: any + // contribution beyond P is dropped modulo 2^(64*P), so we can break. + break; + } + acc.0[idx] = mac_with_carry!(acc.0[idx], self.0[i], mul_limb, &mut carry); + } + // Add remaining carry into next limb if within width + let next = base + N; + if next < P { + let (v, mut of) = acc.0[next].overflowing_add(carry); + acc.0[next] = v; + if carry_propagate && of { + // propagate into higher limbs until carry consumed or width exhausted + let mut k = next + 1; + while of && k < P { + let (nv, nof) = acc.0[k].overflowing_add(1); + acc.0[k] = nv; + of = nof; + k += 1; + } + } + } + // else: spill beyond P is dropped by truncation + } + } + #[inline] pub(crate) const fn const_sub_with_borrow(mut self, other: &Self) -> (Self, bool) { let mut borrow = 0; @@ -611,277 +686,45 @@ impl BigInteger for BigInt { } #[inline] - #[unroll_for_loops(8)] fn fmu64a(&self, other: u64, acc: &mut BigInt) { - // ensure NPLUS1 is the correct size debug_assert!(NPLUS1 == N + 1); - // special cases for 0 and 1 - if other == 0 || self.is_zero() { - // idempotent - return; - } else if other == 1 { - // just addition - let mut carry = 0; - for i in 0..N { - carry = arithmetic::adc_for_add_with_carry(&mut acc.0[i], self.0[i], carry); - } - acc.0[N] = acc.0[N].wrapping_add(carry as u64); - return; - } - // otherwise fma - let mut carry = 0; - for i in 0..N { - acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], other, &mut carry); - } - acc.0[N] = acc.0[N].wrapping_add(carry as u64); + self.fm_limbs_into::(&[other], acc, 0, false); } #[inline] #[unroll_for_loops(8)] fn fmu64a_carry_propagating(&self, other: u64, acc: &mut BigInt) { - // ensure NPLUS2 is the correct size (N + 2 limbs) debug_assert!(NPLUS2 == N + 2); - if other == 0 || self.is_zero() { - return; - } - if other == 1 { - let mut carry: u8 = 0; - for i in 0..N { - carry = arithmetic::adc_for_add_with_carry(&mut acc.0[i], self.0[i], carry); - } - let (new_n, of1) = acc.0[N].overflowing_add(carry as u64); - acc.0[N] = new_n; - if of1 { - acc.0[N + 1] = acc.0[N + 1].wrapping_add(1); - } - return; - } - let mut carry = 0u64; - for i in 0..N { - acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], other, &mut carry); - } - let (new_n, of1) = acc.0[N].overflowing_add(carry); - acc.0[N] = new_n; - if of1 { - acc.0[N + 1] = acc.0[N + 1].wrapping_add(1); - } + self.fm_limbs_into::(&[other], acc, 0, true); } #[inline] #[unroll_for_loops(8)] fn fm128a(&self, other: u128, acc: &mut BigInt) { - // ensure NPLUS2 is the correct size (N + 2 limbs) debug_assert!(NPLUS2 == N + 2); - // special cases for 0 and 1 - // if other == 0 || self.is_zero() { - // // idempotent - // return; - // } else if other == 1 { - // // just addition into lower N limbs; propagate final carry into acc[N] - // let mut carry = 0; - // for i in 0..N { - // carry = arithmetic::adc_for_add_with_carry(&mut acc.0[i], self.0[i], carry); - // } - // // carry is at most 1; fold into limb N (wrapping into highest limb if needed later) - // acc.0[N] = acc.0[N].wrapping_add(carry as u64); - // return; - // } - - let other_lo = other as u64; - let other_hi = (other >> 64) as u64; - - // Accumulate self * other_lo into acc[0..=N] - let mut carry = 0u64; - for i in 0..N { - acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], other_lo, &mut carry); - } - // Add final carry into limb N, propagating into highest limb if it overflows - let (new_n, of1) = acc.0[N].overflowing_add(carry); - acc.0[N] = new_n; - if of1 { - acc.0[N + 1] = acc.0[N + 1].wrapping_add(1); - } - - // Accumulate self * other_hi into acc[1..=N+1] - let mut carry2 = 0u64; - for i in 0..N { - acc.0[i + 1] = mac_with_carry!(acc.0[i + 1], self.0[i], other_hi, &mut carry2); - } - acc.0[N + 1] = acc.0[N + 1].wrapping_add(carry2); + let limbs = [other as u64, (other >> 64) as u64]; + self.fm_limbs_into::(&limbs, acc, 0, true); } #[inline] #[unroll_for_loops(8)] fn fmu64a_into_nplus4(&self, other: u64, acc: &mut BigInt) { debug_assert!(NPLUS4 == N + 4); - if other == 0 || self.is_zero() { - return; - } - if other == 1 { - let mut carry: u8 = 0; - for i in 0..N { - carry = arithmetic::adc_for_add_with_carry(&mut acc.0[i], self.0[i], carry); - } - if carry != 0 { - let (n0, of0) = acc.0[N].overflowing_add(1); - acc.0[N] = n0; - if of0 { - let (n1, of1) = acc.0[N + 1].overflowing_add(1); - acc.0[N + 1] = n1; - if of1 { - let (n2, of2) = acc.0[N + 2].overflowing_add(1); - acc.0[N + 2] = n2; - if of2 { - let (n3, _of3) = acc.0[N + 3].overflowing_add(1); - acc.0[N + 3] = n3; - } - } - } - } - return; - } - let mut carry0 = 0u64; - for i in 0..N { - acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], other, &mut carry0); - } - if carry0 != 0 { - let (n0, of0) = acc.0[N].overflowing_add(carry0); - acc.0[N] = n0; - if of0 { - let (n1, of1) = acc.0[N + 1].overflowing_add(1); - acc.0[N + 1] = n1; - if of1 { - let (n2, of2) = acc.0[N + 2].overflowing_add(1); - acc.0[N + 2] = n2; - if of2 { - let (n3, _of3) = acc.0[N + 3].overflowing_add(1); - acc.0[N + 3] = n3; - } - } - } - } + self.fm_limbs_into::(&[other], acc, 0, true); } #[inline] #[unroll_for_loops(8)] fn fm2x64a_into_nplus4(&self, other: [u64; 2], acc: &mut BigInt) { debug_assert!(NPLUS4 == N + 4); - let lo = other[0]; - let hi = other[1]; - if (lo | hi) == 0 || self.is_zero() { - return; - } - - if lo != 0 { - let mut carry0 = 0u64; - for i in 0..N { - acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], lo, &mut carry0); - } - if carry0 != 0 { - let (n0, of0) = acc.0[N].overflowing_add(carry0); - acc.0[N] = n0; - if of0 { - let (n1, of1) = acc.0[N + 1].overflowing_add(1); - acc.0[N + 1] = n1; - if of1 { - let (n2, of2) = acc.0[N + 2].overflowing_add(1); - acc.0[N + 2] = n2; - if of2 { - let (n3, _of3) = acc.0[N + 3].overflowing_add(1); - acc.0[N + 3] = n3; - } - } - } - } - } - - if hi != 0 { - let mut carry1 = 0u64; - for i in 0..N { - acc.0[i + 1] = mac_with_carry!(acc.0[i + 1], self.0[i], hi, &mut carry1); - } - if carry1 != 0 { - let (n1, of1) = acc.0[N + 1].overflowing_add(carry1); - acc.0[N + 1] = n1; - if of1 { - let (n2, of2) = acc.0[N + 2].overflowing_add(1); - acc.0[N + 2] = n2; - if of2 { - let (n3, _of3) = acc.0[N + 3].overflowing_add(1); - acc.0[N + 3] = n3; - } - } - } - } + self.fm_limbs_into::(&other, acc, 0, true); } #[inline] #[unroll_for_loops(8)] fn fm3x64a_into_nplus4(&self, other: [u64; 3], acc: &mut BigInt) { debug_assert!(NPLUS4 == N + 4); - let o0 = other[0]; - let o1 = other[1]; - let o2 = other[2]; - if (o0 | o1 | o2) == 0 || self.is_zero() { - return; - } - - if o0 != 0 { - let mut carry0 = 0u64; - for i in 0..N { - acc.0[i] = mac_with_carry!(acc.0[i], self.0[i], o0, &mut carry0); - } - if carry0 != 0 { - let (n0, of0) = acc.0[N].overflowing_add(carry0); - acc.0[N] = n0; - if of0 { - let (n1, of1) = acc.0[N + 1].overflowing_add(1); - acc.0[N + 1] = n1; - if of1 { - let (n2, of2) = acc.0[N + 2].overflowing_add(1); - acc.0[N + 2] = n2; - if of2 { - let (n3, _of3) = acc.0[N + 3].overflowing_add(1); - acc.0[N + 3] = n3; - } - } - } - } - } - - if o1 != 0 { - let mut carry1 = 0u64; - for i in 0..N { - acc.0[i + 1] = mac_with_carry!(acc.0[i + 1], self.0[i], o1, &mut carry1); - } - if carry1 != 0 { - let (n1, of1) = acc.0[N + 1].overflowing_add(carry1); - acc.0[N + 1] = n1; - if of1 { - let (n2, of2) = acc.0[N + 2].overflowing_add(1); - acc.0[N + 2] = n2; - if of2 { - let (n3, _of3) = acc.0[N + 3].overflowing_add(1); - acc.0[N + 3] = n3; - } - } - } - } - - if o2 != 0 { - let mut carry2 = 0u64; - for i in 0..N { - acc.0[i + 2] = mac_with_carry!(acc.0[i + 2], self.0[i], o2, &mut carry2); - } - if carry2 != 0 { - let (n2, of2) = acc.0[N + 2].overflowing_add(carry2); - acc.0[N + 2] = n2; - if of2 { - let (n3, _of3) = acc.0[N + 3].overflowing_add(1); - acc.0[N + 3] = n3; - } - } - } + self.fm_limbs_into::(&other, acc, 0, true); } #[inline] diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs index 62d540ddc..c2335294a 100644 --- a/ff/src/biginteger/signed.rs +++ b/ff/src/biginteger/signed.rs @@ -16,7 +16,7 @@ pub struct SignedBigInt { pub type S64 = SignedBigInt<1>; pub type S128 = SignedBigInt<2>; -pub type S196 = SignedBigInt<3>; +pub type S192 = SignedBigInt<3>; pub type S256 = SignedBigInt<4>; impl SignedBigInt { @@ -183,6 +183,16 @@ impl SignedBigInt { self.magnitude = low; self.is_positive = self.is_positive == rhs.is_positive; } + + /// Zero-extend a smaller-width signed big integer into N limbs (little-endian). + /// Preserves the sign bit; only the magnitude is widened by zero-extension. + /// Debug-asserts that M <= N. + #[inline] + pub fn zero_extend_from(smaller: &SignedBigInt) -> SignedBigInt { + debug_assert!(M <= N, "cannot zero-extend: source has more limbs than destination"); + let widened_mag = BigInt::::zero_extend_from::(&smaller.magnitude); + SignedBigInt::from_bigint(widened_mag, smaller.is_positive) + } } impl SignedBigInt { @@ -192,68 +202,18 @@ impl SignedBigInt { #[inline] pub fn add_trunc(&self, rhs: &SignedBigInt) -> SignedBigInt { if self.is_positive == rhs.is_positive { - // Same sign -> truncate limbwise sum - let mut res = BigInt::::zero(); - let mut carry: u8 = 0; - let lim = core::cmp::min(N, M); - for i in 0..lim { - let (s1, c1) = self.magnitude.0[i].overflowing_add(rhs.magnitude.0[i]); - let (s2, c2) = s1.overflowing_add(carry as u64); - res.0[i] = s2; - carry = (c1 as u8) | (c2 as u8); - } - // propagate carry into next limb if within M, else drop - if lim < M { - res.0[lim] = carry as u64; - } - SignedBigInt:: { - magnitude: res, - is_positive: self.is_positive, - } - } else { - // Different signs -> subtract smaller magnitude from larger - match self.magnitude.cmp(&rhs.magnitude) { - Ordering::Greater | Ordering::Equal => { - let mut res = BigInt::::zero(); - let lim = core::cmp::min(N, M); - let mut borrow: bool = false; - for i in 0..lim { - let (d1, b1) = self.magnitude.0[i].overflowing_sub(rhs.magnitude.0[i]); - if borrow { - let (d2, b2) = d1.overflowing_sub(1); - res.0[i] = d2; - borrow = b1 || b2; - } else { - res.0[i] = d1; - borrow = b1; - } - } - SignedBigInt:: { - magnitude: res, - is_positive: self.is_positive, - } - }, - Ordering::Less => { - let mut res = BigInt::::zero(); - let lim = core::cmp::min(N, M); - let mut borrow: bool = false; - for i in 0..lim { - let (d1, b1) = rhs.magnitude.0[i].overflowing_sub(self.magnitude.0[i]); - if borrow { - let (d2, b2) = d1.overflowing_sub(1); - res.0[i] = d2; - borrow = b1 || b2; - } else { - res.0[i] = d1; - borrow = b1; - } - } - SignedBigInt:: { - magnitude: res, - is_positive: rhs.is_positive, - } - }, - } + let mag = self.magnitude.add_trunc::(&rhs.magnitude); + return SignedBigInt:: { magnitude: mag, is_positive: self.is_positive }; + } + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let mag = self.magnitude.sub_trunc::(&rhs.magnitude); + SignedBigInt:: { magnitude: mag, is_positive: self.is_positive } + }, + Ordering::Less => { + let mag = rhs.magnitude.sub_trunc::(&self.magnitude); + SignedBigInt:: { magnitude: mag, is_positive: rhs.is_positive } + }, } } @@ -261,67 +221,18 @@ impl SignedBigInt { #[inline] pub fn sub_trunc(&self, rhs: &SignedBigInt) -> SignedBigInt { if self.is_positive != rhs.is_positive { - // same as addition path - let mut res = BigInt::::zero(); - let mut carry: u8 = 0; - let lim = core::cmp::min(N, M); - for i in 0..lim { - let (s1, c1) = self.magnitude.0[i].overflowing_add(rhs.magnitude.0[i]); - let (s2, c2) = s1.overflowing_add(carry as u64); - res.0[i] = s2; - carry = (c1 as u8) | (c2 as u8); - } - if lim < M { - res.0[lim] = carry as u64; - } - SignedBigInt:: { - magnitude: res, - is_positive: self.is_positive, - } - } else { - // different signs wrt subtraction => subtract magnitudes - match self.magnitude.cmp(&rhs.magnitude) { - Ordering::Greater | Ordering::Equal => { - let mut res = BigInt::::zero(); - let lim = core::cmp::min(N, M); - let mut borrow: bool = false; - for i in 0..lim { - let (d1, b1) = self.magnitude.0[i].overflowing_sub(rhs.magnitude.0[i]); - if borrow { - let (d2, b2) = d1.overflowing_sub(1); - res.0[i] = d2; - borrow = b1 || b2; - } else { - res.0[i] = d1; - borrow = b1; - } - } - SignedBigInt:: { - magnitude: res, - is_positive: self.is_positive, - } - }, - Ordering::Less => { - let mut res = BigInt::::zero(); - let lim = core::cmp::min(N, M); - let mut borrow: bool = false; - for i in 0..lim { - let (d1, b1) = rhs.magnitude.0[i].overflowing_sub(self.magnitude.0[i]); - if borrow { - let (d2, b2) = d1.overflowing_sub(1); - res.0[i] = d2; - borrow = b1 || b2; - } else { - res.0[i] = d1; - borrow = b1; - } - } - SignedBigInt:: { - magnitude: res, - is_positive: !self.is_positive, - } - }, - } + let mag = self.magnitude.add_trunc::(&rhs.magnitude); + return SignedBigInt:: { magnitude: mag, is_positive: self.is_positive }; + } + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let mag = self.magnitude.sub_trunc::(&rhs.magnitude); + SignedBigInt:: { magnitude: mag, is_positive: self.is_positive } + }, + Ordering::Less => { + let mag = rhs.magnitude.sub_trunc::(&self.magnitude); + SignedBigInt:: { magnitude: mag, is_positive: !self.is_positive } + }, } } @@ -332,117 +243,18 @@ impl SignedBigInt { &self, rhs: &SignedBigInt, ) -> SignedBigInt

{ - // Case 1: same signs => add magnitudes, sign = self.is_positive if self.is_positive == rhs.is_positive { - let mut res = BigInt::

::zero(); - let mut carry: u8 = 0; - let overlap = core::cmp::min(core::cmp::min(N, M), P); - for i in 0..overlap { - let (s1, c1) = self.magnitude.0[i].overflowing_add(rhs.magnitude.0[i]); - let (s2, c2) = s1.overflowing_add(carry as u64); - res.0[i] = s2; - carry = (c1 as u8) | (c2 as u8); - } - let mut k = overlap; - if N > M { - let end = core::cmp::min(N, P); - while k < end { - let (s1, c1) = self.magnitude.0[k].overflowing_add(carry as u64); - res.0[k] = s1; - carry = c1 as u8; - k += 1; - } - } else if M > N { - let end = core::cmp::min(M, P); - while k < end { - let (s1, c1) = rhs.magnitude.0[k].overflowing_add(carry as u64); - res.0[k] = s1; - carry = c1 as u8; - k += 1; - } - } - if k < P { - res.0[k] = carry as u64; - } - return SignedBigInt::

{ - magnitude: res, - is_positive: self.is_positive, - }; + let mag = self.magnitude.add_trunc::(&rhs.magnitude); + return SignedBigInt::

{ magnitude: mag, is_positive: self.is_positive }; } - - // Case 2: different signs => subtract smaller magnitude from larger - let ord = self.cmp_magnitude_mixed(rhs); - - match ord { + match self.cmp_magnitude_mixed(rhs) { Ordering::Greater | Ordering::Equal => { - // res_mag = self.mag - rhs.mag; sign = self.is_positive - let mut res = BigInt::

::zero(); - let mut borrow = false; - let overlap = core::cmp::min(core::cmp::min(N, M), P); - for i in 0..overlap { - let (d1, b1) = self.magnitude.0[i].overflowing_sub(rhs.magnitude.0[i]); - if borrow { - let (d2, b2) = d1.overflowing_sub(1); - res.0[i] = d2; - borrow = b1 || b2; - } else { - res.0[i] = d1; - borrow = b1; - } - } - let mut k = overlap; - if N > M { - let end = core::cmp::min(N, P); - while k < end { - if borrow { - let (d2, b2) = self.magnitude.0[k].overflowing_sub(1); - res.0[k] = d2; - borrow = b2; - } else { - res.0[k] = self.magnitude.0[k]; - } - k += 1; - } - } - SignedBigInt::

{ - magnitude: res, - is_positive: self.is_positive, - } + let mag = self.magnitude.sub_trunc::(&rhs.magnitude); + SignedBigInt::

{ magnitude: mag, is_positive: self.is_positive } }, Ordering::Less => { - // res_mag = rhs.mag - self.mag; sign = rhs.is_positive - let mut res = BigInt::

::zero(); - let mut borrow = false; - let overlap = core::cmp::min(core::cmp::min(N, M), P); - for i in 0..overlap { - let (d1, b1) = rhs.magnitude.0[i].overflowing_sub(self.magnitude.0[i]); - if borrow { - let (d2, b2) = d1.overflowing_sub(1); - res.0[i] = d2; - borrow = b1 || b2; - } else { - res.0[i] = d1; - borrow = b1; - } - } - let mut k = overlap; - if M > N { - let end = core::cmp::min(M, P); - while k < end { - if borrow { - let (d2, b2) = rhs.magnitude.0[k].overflowing_sub(1); - res.0[k] = d2; - borrow = b2; - } else { - res.0[k] = rhs.magnitude.0[k]; - } - k += 1; - } - } - SignedBigInt::

{ - magnitude: res, - is_positive: rhs.is_positive, - } + let mag = rhs.magnitude.sub_trunc::(&self.magnitude); + SignedBigInt::

{ magnitude: mag, is_positive: rhs.is_positive } }, } } @@ -539,92 +351,18 @@ impl SignedBigInt { &self, rhs: &SignedBigInt, ) -> SignedBigInt

{ - // Case 1: different signs => addition of magnitudes, sign = self.is_positive if self.is_positive != rhs.is_positive { - let mut res = BigInt::

::zero(); - let mut carry: u8 = 0; - for i in 0..P { - let a = if i < N { self.magnitude.0[i] } else { 0u64 }; - let b = if i < M { rhs.magnitude.0[i] } else { 0u64 }; - let (s1, c1) = a.overflowing_add(b); - let (s2, c2) = s1.overflowing_add(carry as u64); - res.0[i] = s2; - carry = (c1 as u8) | (c2 as u8); - } - return SignedBigInt::

{ - magnitude: res, - is_positive: self.is_positive, - }; + let mag = self.magnitude.add_trunc::(&rhs.magnitude); + return SignedBigInt::

{ magnitude: mag, is_positive: self.is_positive }; } - - // Case 2: same signs => subtract smaller magnitude from larger; sign accordingly - // Mixed-width magnitude comparison (zero-extended to max(N, M)) - let ord = { - // Compare from most significant limb down to 0 - let max_limbs = if N > M { N } else { M }; - let mut i = max_limbs; - let mut ordering = Ordering::Equal; - while i > 0 { - let idx = i - 1; - let a = if idx < N { self.magnitude.0[idx] } else { 0u64 }; - let b = if idx < M { rhs.magnitude.0[idx] } else { 0u64 }; - if a > b { - ordering = Ordering::Greater; - break; - } - if a < b { - ordering = Ordering::Less; - break; - } - i -= 1; - } - ordering - }; - - match ord { + match self.cmp_magnitude_mixed(rhs) { Ordering::Greater | Ordering::Equal => { - // res_mag = self.mag - rhs.mag; sign = self.is_positive - let mut res = BigInt::

::zero(); - let mut borrow = false; - for i in 0..P { - let a = if i < N { self.magnitude.0[i] } else { 0u64 }; - let b = if i < M { rhs.magnitude.0[i] } else { 0u64 }; - let (d1, b1) = a.overflowing_sub(b); - if borrow { - let (d2, b2) = d1.overflowing_sub(1); - res.0[i] = d2; - borrow = b1 || b2; - } else { - res.0[i] = d1; - borrow = b1; - } - } - SignedBigInt::

{ - magnitude: res, - is_positive: self.is_positive, - } + let mag = self.magnitude.sub_trunc::(&rhs.magnitude); + SignedBigInt::

{ magnitude: mag, is_positive: self.is_positive } }, Ordering::Less => { - // res_mag = rhs.mag - self.mag; sign = !self.is_positive - let mut res = BigInt::

::zero(); - let mut borrow = false; - for i in 0..P { - let a = if i < M { rhs.magnitude.0[i] } else { 0u64 }; - let b = if i < N { self.magnitude.0[i] } else { 0u64 }; - let (d1, b1) = a.overflowing_sub(b); - if borrow { - let (d2, b2) = d1.overflowing_sub(1); - res.0[i] = d2; - borrow = b1 || b2; - } else { - res.0[i] = d1; - borrow = b1; - } - } - SignedBigInt::

{ - magnitude: res, - is_positive: !self.is_positive, - } + let mag = rhs.magnitude.sub_trunc::(&self.magnitude); + SignedBigInt::

{ magnitude: mag, is_positive: !self.is_positive } }, } } diff --git a/ff/src/biginteger/signed_hi_32.rs b/ff/src/biginteger/signed_hi_32.rs index b904df0cf..f1c3dcbbd 100644 --- a/ff/src/biginteger/signed_hi_32.rs +++ b/ff/src/biginteger/signed_hi_32.rs @@ -2,7 +2,11 @@ use allocative::Allocative; use ark_std::cmp::Ordering; use ark_std::vec::Vec; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use crate::biginteger::BigInt; +use crate::biginteger::{BigInt, SignedBigInt}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; /// Compact signed big-integer parameterized by limb count `N` (total width = `N*64 + 32` bits). /// @@ -311,6 +315,52 @@ impl SignedBigIntHi32 { limbs[N] = self.magnitude_hi as u64; BigInt::(limbs) } + + /// Zero-extend a smaller-width SignedBigIntHi32 into width N (little-endian). + /// Moves the 32-bit head of the smaller value into the next low 64-bit limb on widen, + /// and clears the head in the widened representation to preserve the numeric value. + /// Debug-asserts that M <= N. + #[inline] + pub fn zero_extend_from(smaller: &SignedBigIntHi32) -> SignedBigIntHi32 { + debug_assert!(M <= N, "cannot zero-extend: source has more limbs than destination"); + if N == M { + return SignedBigIntHi32::::new( + // copy to avoid borrowing issues + { + let mut lo = [0u64; N]; + if N > 0 { + lo.copy_from_slice(smaller.magnitude_lo()); + } + lo + }, + smaller.magnitude_hi(), + smaller.is_positive(), + ); + } + // N > M + let mut lo = [0u64; N]; + if M > 0 { + lo[..M].copy_from_slice(smaller.magnitude_lo()); + } + // Place the 32-bit head into limb M + lo[M] = smaller.magnitude_hi() as u64; + SignedBigIntHi32::::new(lo, 0u32, smaller.is_positive()) + } + + /// Convert this hi-32 representation into a standard SignedBigInt with N+1 limbs. + /// Packs the low limbs verbatim and writes the 32-bit head into the highest limb. + /// Debug-asserts that NPLUS1 == N + 1. + #[inline] + pub fn to_signed_bigint_nplus1(&self) -> SignedBigInt { + debug_assert!(NPLUS1 == N + 1, "to_signed_bigint_nplus1 requires NPLUS1 = N + 1"); + let mut limbs = [0u64; NPLUS1]; + if N > 0 { + limbs[..N].copy_from_slice(self.magnitude_lo()); + } + limbs[N] = self.magnitude_hi() as u64; + let mag = BigInt::(limbs); + SignedBigInt::from_bigint(mag, self.is_positive()) + } } // ------------------------------------------------------------------------------------------------ @@ -445,6 +495,137 @@ impl<'a, const N: usize> Mul for &'a SignedBigIntHi32 { } } +// ------------------------------------------------------------------------------------------------ +// S160-specific inherent constructors (ergonomic helpers) +// ------------------------------------------------------------------------------------------------ + +impl S160 { + /// Construct from the signed difference of two u64 values: returns |a - b| with sign a>=b. + #[inline] + pub fn from_diff_u64(a: u64, b: u64) -> Self { + let mag = a.abs_diff(b); + let is_positive = a >= b; + S160::new([mag, 0], 0, is_positive) + } + + /// Construct from a u128 magnitude and an explicit sign. + #[inline] + pub fn from_magnitude_u128(mag: u128, is_positive: bool) -> Self { + let lo = mag as u64; + let hi = (mag >> 64) as u64; + S160::new([lo, hi], 0, is_positive) + } + + /// Construct from the signed difference of two u128 values: returns |u1 - u2| with sign u1>=u2. + #[inline] + pub fn from_diff_u128(u1: u128, u2: u128) -> Self { + if u1 >= u2 { + S160::from_magnitude_u128(u1 - u2, true) + } else { + S160::from_magnitude_u128(u2 - u1, false) + } + } + + /// Construct from the sum of two u128 values, preserving carry into the top 32-bit head. + #[inline] + pub fn from_sum_u128(u1: u128, u2: u128) -> Self { + let u1_lo = u1 as u64; + let u1_hi = (u1 >> 64) as u64; + let u2_lo = u2 as u64; + let u2_hi = (u2 >> 64) as u64; + let (sum_lo, carry0) = u1_lo.overflowing_add(u2_lo); + let (sum_hi1, carry1) = u1_hi.overflowing_add(u2_hi); + let (sum_hi, carry2) = sum_hi1.overflowing_add(if carry0 { 1 } else { 0 }); + let carry_out = (carry1 as u8 | carry2 as u8) != 0; + S160::new([sum_lo, sum_hi], if carry_out { 1 } else { 0 }, true) + } + + /// Construct from (u128 - i128) with full-width integer semantics. + #[inline] + pub fn from_u128_minus_i128(u: u128, i: i128) -> Self { + if i >= 0 { + S160::from_diff_u128(u, i as u128) + } else { + let abs_i: u128 = i.unsigned_abs(); + S160::from_sum_u128(u, abs_i) + } + } +} + +// ------------------------------------------------------------------------------------------------ +// Ordering and canonical serialization +// ------------------------------------------------------------------------------------------------ + +impl core::cmp::PartialOrd for SignedBigIntHi32 { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl core::cmp::Ord for SignedBigIntHi32 { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + match (self.is_positive, other.is_positive) { + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + _ => { + let ord = self.compare_magnitudes(other); + if self.is_positive { ord } else { ord.reverse() } + }, + } + } +} + +impl CanonicalSerialize for SignedBigIntHi32 { + #[inline] + fn serialize_with_mode( + &self, + mut w: W, + compress: Compress, + ) -> Result<(), SerializationError> { + // Encode sign, then (hi, lo) + (self.is_positive as u8).serialize_with_mode(&mut w, compress)?; + (self.magnitude_hi as i32).serialize_with_mode(&mut w, compress)?; + for i in 0..N { + self.magnitude_lo[i].serialize_with_mode(&mut w, compress)?; + } + Ok(()) + } + + #[inline] + fn serialized_size(&self, compress: Compress) -> usize { + (self.is_positive as u8).serialized_size(compress) + + (self.magnitude_hi as i32).serialized_size(compress) + + (0u64).serialized_size(compress) * N + } +} + +impl CanonicalDeserialize for SignedBigIntHi32 { + #[inline] + fn deserialize_with_mode( + mut r: R, + compress: Compress, + validate: Validate, + ) -> Result { + let sign_u8 = u8::deserialize_with_mode(&mut r, compress, validate)?; + let hi = i32::deserialize_with_mode(&mut r, compress, validate)?; + let mut lo = [0u64; N]; + for i in 0..N { + lo[i] = u64::deserialize_with_mode(&mut r, compress, validate)?; + } + Ok(SignedBigIntHi32::new(lo, hi as u32, sign_u8 != 0)) + } +} + +impl Valid for SignedBigIntHi32 { + #[inline] + fn check(&self) -> Result<(), SerializationError> { + // No additional invariants beyond structural fields + Ok(()) + } +} + // ------------------------------------------------------------------------------------------------ // Symmetric mul: S160 * I8OrI96 -> S224 (for ergonomics) // ------------------------------------------------------------------------------------------------ From 7277b2e7411710f59d902bc3e9cd7ed4f9cddf6b Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 16 Sep 2025 19:12:02 -0400 Subject: [PATCH 23/38] add msm for s64 and s128 --- ec/src/scalar_mul/variable_base/mod.rs | 97 ++++++++++++++++++++++++++ ff/src/biginteger/signed.rs | 18 +++-- ff/src/biginteger/signed_hi_32.rs | 5 ++ 3 files changed, 114 insertions(+), 6 deletions(-) diff --git a/ec/src/scalar_mul/variable_base/mod.rs b/ec/src/scalar_mul/variable_base/mod.rs index 9d83dfa53..aca9e1726 100644 --- a/ec/src/scalar_mul/variable_base/mod.rs +++ b/ec/src/scalar_mul/variable_base/mod.rs @@ -1,4 +1,5 @@ use ark_ff::prelude::*; +use ark_ff::biginteger::{S128, S64}; use ark_std::{ borrow::Borrow, cfg_chunks, cfg_into_iter, cfg_iter, @@ -626,6 +627,102 @@ pub fn msm_i128( } } +pub fn msm_s64( + mut bases: &[V::MulBase], + mut scalars: &[S64], + serial: bool, +) -> V { + let (negative_bases, non_negative_bases): (Vec, Vec) = + bases.iter().enumerate().partition_map(|(i, b)| { + if !scalars[i].sign() { + Either::Left(b) + } else { + Either::Right(b) + } + }); + let (negative_scalars, non_negative_scalars): (Vec, Vec) = scalars + .iter() + .partition_map(|s| { + let mag = s.magnitude_as_u64(); + if !s.sign() { + Either::Left(mag) + } else { + Either::Right(mag) + } + }); + if serial { + return msm_serial::(&non_negative_bases, &non_negative_scalars) + - msm_serial::(&negative_bases, &negative_scalars); + } else { + let chunk_size = match preamble(&mut bases, &mut scalars, serial) { + Some(chunk_size) => chunk_size, + None => return V::zero(), + }; + + let non_negative_msm: V = cfg_chunks!(non_negative_bases, chunk_size) + .zip(cfg_chunks!(non_negative_scalars, chunk_size)) + .map(|(non_negative_bases, non_negative_scalars)| { + msm_serial::(non_negative_bases, non_negative_scalars) + }) + .sum(); + let negative_msm: V = cfg_chunks!(negative_bases, chunk_size) + .zip(cfg_chunks!(negative_scalars, chunk_size)) + .map(|(negative_bases, negative_scalars)| { + msm_serial::(negative_bases, negative_scalars) + }) + .sum(); + non_negative_msm - negative_msm + } +} + +pub fn msm_s128( + mut bases: &[V::MulBase], + mut scalars: &[S128], + serial: bool, +) -> V { + let (negative_bases, non_negative_bases): (Vec, Vec) = + bases.iter().enumerate().partition_map(|(i, b)| { + if !scalars[i].sign() { + Either::Left(b) + } else { + Either::Right(b) + } + }); + let (negative_scalars, non_negative_scalars): (Vec, Vec) = scalars + .iter() + .partition_map(|s| { + let mag = s.magnitude_as_u128(); + if !s.sign() { + Either::Left(mag) + } else { + Either::Right(mag) + } + }); + if serial { + return msm_serial::(&non_negative_bases, &non_negative_scalars) + - msm_serial::(&negative_bases, &negative_scalars); + } else { + let chunk_size = match preamble(&mut bases, &mut scalars, serial) { + Some(chunk_size) => chunk_size, + None => return V::zero(), + }; + + let non_negative_msm: V = cfg_chunks!(non_negative_bases, chunk_size) + .zip(cfg_chunks!(non_negative_scalars, chunk_size)) + .map(|(non_negative_bases, non_negative_scalars)| { + msm_serial::(non_negative_bases, non_negative_scalars) + }) + .sum(); + let negative_msm: V = cfg_chunks!(negative_bases, chunk_size) + .zip(cfg_chunks!(negative_scalars, chunk_size)) + .map(|(negative_bases, negative_scalars)| { + msm_serial::(negative_bases, negative_scalars) + }) + .sum(); + non_negative_msm - negative_msm + } +} + pub fn msm_u128( mut bases: &[V::MulBase], mut scalars: &[u128], diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs index c2335294a..7f6927ae5 100644 --- a/ff/src/biginteger/signed.rs +++ b/ff/src/biginteger/signed.rs @@ -7,7 +7,13 @@ use ark_serialize::{ use core::cmp::Ordering; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -/// A signed big integer using arkworks BigInt for magnitude and a sign bit +/// A signed big integer using arkworks BigInt for magnitude and a sign bit. +/// +/// Notes: +/// - Zero is not canonicalized: a zero magnitude can be paired with either sign. +/// Structural equality distinguishes `+0` and `-0` (since the sign bit differs). +/// - Ordering treats `+0` and `-0` as equal: comparisons return `Ordering::Equal` when +/// both magnitudes are zero regardless of sign. #[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] pub struct SignedBigInt { pub magnitude: BigInt, @@ -667,16 +673,16 @@ impl core::cmp::PartialOrd for SignedBigInt { impl core::cmp::Ord for SignedBigInt { #[inline] fn cmp(&self, other: &Self) -> Ordering { + // Treat +0 and -0 as equal in ordering semantics + if self.magnitude.is_zero() && other.magnitude.is_zero() { + return Ordering::Equal; + } match (self.is_positive, other.is_positive) { (true, false) => Ordering::Greater, (false, true) => Ordering::Less, _ => { let ord = self.magnitude.cmp(&other.magnitude); - if self.is_positive { - ord - } else { - ord.reverse() - } + if self.is_positive { ord } else { ord.reverse() } }, } } diff --git a/ff/src/biginteger/signed_hi_32.rs b/ff/src/biginteger/signed_hi_32.rs index f1c3dcbbd..353253d8a 100644 --- a/ff/src/biginteger/signed_hi_32.rs +++ b/ff/src/biginteger/signed_hi_32.rs @@ -22,6 +22,8 @@ use ark_serialize::{ /// so `+0 != -0`. Callers that require canonical zero should normalize externally. /// /// Notes: +/// - Zero is not normalized: a zero magnitude can be positive or negative. Structural equality +/// distinguishes `+0` and `-0`, but ordering treats them as equal. /// - Specialized fast paths exist for `N ∈ {0,1,2}`; larger `N` uses a generic path. #[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] pub struct SignedBigIntHi32 { @@ -566,6 +568,9 @@ impl core::cmp::PartialOrd for SignedBigIntHi32 { impl core::cmp::Ord for SignedBigIntHi32 { #[inline] fn cmp(&self, other: &Self) -> Ordering { + if self.is_zero() && other.is_zero() { + return Ordering::Equal; + } match (self.is_positive, other.is_positive) { (true, false) => Ordering::Greater, (false, true) => Ordering::Less, From ddb9805027bd5bb6c867fbf8fb4332acd42ae4ed Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 16 Sep 2025 19:19:08 -0400 Subject: [PATCH 24/38] add conversion traits for S160 --- ff/src/biginteger/signed_hi_32.rs | 51 +++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/ff/src/biginteger/signed_hi_32.rs b/ff/src/biginteger/signed_hi_32.rs index 353253d8a..248a82f0f 100644 --- a/ff/src/biginteger/signed_hi_32.rs +++ b/ff/src/biginteger/signed_hi_32.rs @@ -2,7 +2,7 @@ use allocative::Allocative; use ark_std::cmp::Ordering; use ark_std::vec::Vec; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use crate::biginteger::{BigInt, SignedBigInt}; +use crate::biginteger::{BigInt, SignedBigInt, S64, S128}; use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, Write, @@ -672,18 +672,58 @@ impl core::ops::Mul<&crate::biginteger::I8OrI96> for &S160 { // ------------------------------------------------------------------------------------------------ impl From for S96 { + #[inline] fn from(val: i64) -> Self { Self::new([val.unsigned_abs()], 0, val.is_positive()) } } impl From for S96 { + #[inline] fn from(val: u64) -> Self { Self::new([val], 0, true) } } +impl From for S96 { + #[inline] + fn from(val: S64) -> Self { + Self::new([val.magnitude.0[0]], 0, val.is_positive) + } +} + +impl From for S160 { + #[inline] + fn from(val: i64) -> Self { + Self::new([val.unsigned_abs(), 0], 0, val.is_positive()) + } +} + +impl From for S160 { + #[inline] + fn from(val: u64) -> Self { + Self::new([val, 0], 0, true) + } +} + +impl From for S160 { + #[inline] + fn from(val: S64) -> Self { + Self::new([val.magnitude.0[0], 0], 0, val.is_positive) + } +} + +impl From for S160 { + #[inline] + fn from(val: u128) -> Self { + let lo = val as u64; + let hi = (val >> 64) as u64; + Self::new([lo, hi], 0, true) + } +} + impl From for S160 { + #[inline] fn from(val: i128) -> Self { let is_positive = val.is_positive(); let mag = val.unsigned_abs(); @@ -693,11 +733,10 @@ impl From for S160 { } } -impl From for S160 { - fn from(val: u128) -> Self { - let lo = val as u64; - let hi = (val >> 64) as u64; - Self::new([lo, hi], 0, true) +impl From for S160 { + #[inline] + fn from(val: S128) -> Self { + Self::new([val.magnitude.0[0], val.magnitude.0[1]], 0, val.is_positive) } } From 1901417362a4bb38c7998e127843b37b11691aff Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Tue, 16 Sep 2025 19:48:57 -0400 Subject: [PATCH 25/38] added default to SignedBigInt --- ff/src/biginteger/signed.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs index 7f6927ae5..72ca2fcc5 100644 --- a/ff/src/biginteger/signed.rs +++ b/ff/src/biginteger/signed.rs @@ -20,6 +20,13 @@ pub struct SignedBigInt { pub is_positive: bool, } +impl Default for SignedBigInt { + #[inline] + fn default() -> Self { + Self::zero() + } +} + pub type S64 = SignedBigInt<1>; pub type S128 = SignedBigInt<2>; pub type S192 = SignedBigInt<3>; From 6cb3aa13855187bfa0a0c54da09c9100774751cb Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Wed, 17 Sep 2025 09:26:25 -0400 Subject: [PATCH 26/38] fixing multiplication bug --- ff/src/biginteger/i8_or_i96.rs | 15 +++---- ff/src/biginteger/signed_hi_32.rs | 5 ++- ff/src/biginteger/tests.rs | 74 +++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 9 deletions(-) diff --git a/ff/src/biginteger/i8_or_i96.rs b/ff/src/biginteger/i8_or_i96.rs index a84392af6..9d7d1e3b8 100644 --- a/ff/src/biginteger/i8_or_i96.rs +++ b/ff/src/biginteger/i8_or_i96.rs @@ -527,8 +527,7 @@ impl Mul for I8OrI96 { } else if b2_is_zero { // 128-bit rhs via b1 only let mut c1 = c0; - let r1p = mac_with_carry!(0u64, b1, k, &mut c1); - let r1 = adc!(r1p, 0u64, &mut c1); + let r1 = mac_with_carry!(0u64, b1, k, &mut c1); let r2 = c1; (r0, r1, r2, 0u32) } else { @@ -564,9 +563,9 @@ impl Mul for I8OrI96 { let mut c2 = c1; let r2 = mac_with_carry!(0u64, x0, b2, &mut c2); - let mut carry_hi = c2; - crate::biginteger::arithmetic::mac_discard(carry_hi, x1, b2, &mut carry_hi); - let hi32 = carry_hi as u32; + let r3_low = ((c2 as u128) + + crate::biginteger::arithmetic::widening_mul(x1, b2)) as u64; + let hi32 = (r3_low & 0xFFFF_FFFF) as u32; (r0, r1, r2, hi32) } } else if b2_is_zero { @@ -589,9 +588,9 @@ impl Mul for I8OrI96 { let mut r2 = mac_with_carry!(0u64, x0, b2, &mut c2); r2 = mac_with_carry!(r2, x1, b1, &mut c2); - let mut carry_hi = c2; - crate::biginteger::arithmetic::mac_discard(carry_hi, x1, b2, &mut carry_hi); - let hi32 = carry_hi as u32; + let r3_low = ((c2 as u128) + + crate::biginteger::arithmetic::widening_mul(x1, b2)) as u64; + let hi32 = (r3_low & 0xFFFF_FFFF) as u32; (r0, r1, r2, hi32) } }; diff --git a/ff/src/biginteger/signed_hi_32.rs b/ff/src/biginteger/signed_hi_32.rs index 248a82f0f..6dbb86897 100644 --- a/ff/src/biginteger/signed_hi_32.rs +++ b/ff/src/biginteger/signed_hi_32.rs @@ -206,12 +206,15 @@ impl SignedBigIntHi32 { let r1 = sum1 as u64; let carry1 = sum1 >> 64; - // word 2 (only need low 32 bits) + // word 2 let sum2 = carry1 + (a0 as u128) * (b2 as u128) + (a1 as u128) * (b1 as u128) + (a2 as u128) * (b0 as u128); let r2 = sum2 as u64; + let _carry2 = (sum2 >> 64) as u64; + + // For a 160-bit result, the head (bits 128..159) is the low 32 bits of word2. let hi = (r2 & 0xFFFF_FFFF) as u32; let mut lo = [0u64; N]; lo[0] = r0; diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 29cfbf905..883eed7a5 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -103,6 +103,80 @@ pub mod tests { assert_eq!(max.mul_high(&B::from(2u64)), B::from(1u64)); } + #[test] + fn test_i8_or_i96_mul_s160_edges() { + use crate::biginteger::{I8OrI96, S160, S224}; + + // Helper to convert S224 to BigUint reference value (unsigned magnitude) and sign + fn s224_to_biguint_and_sign(v: &S224) -> (num_bigint::BigUint, bool) { + let lo = v.magnitude_lo(); + let hi32 = v.magnitude_hi() as u64; + let mut limbs = [0u64; 4]; + limbs[0] = lo[0]; + limbs[1] = lo[1]; + limbs[2] = lo[2]; + limbs[3] = hi32; + (num_bigint::BigUint::from(crate::biginteger::BigInt::<4>(limbs)), v.is_positive()) + } + + // Case 1: small i8 * b1-only rhs + let k = I8OrI96::from_i8(7); // x1 = 0 + let rhs = S160::new([0, 5], 0, true); // b0=0, b1=5, b2=0 + let out = k * rhs; + let (mag_bu, sign) = s224_to_biguint_and_sign(&out); + let expected = num_bigint::BigUint::from(7u64) * (num_bigint::BigUint::from(5u64) << 64) + % (num_bigint::BigUint::from(1u8) << 224); + assert!(sign); + assert_eq!(mag_bu, expected); + + // Case 2: large x with x1!=0, b2!=0, b1==0 (hits hi32 path) + // x = 2^80 + 3 => hi32 = 2^(80-64)=2^16, lo=3 + let x = I8OrI96::from_i128(((1i128) << 80) + 3); + let rhs2 = S160::new([11, 0], 9, true); // b0=11, b1=0, b2=9 + let out2 = x * rhs2; + let (mag_bu2, sign2) = s224_to_biguint_and_sign(&out2); + let x_bi = (num_bigint::BigUint::from(1u8) << 80) + num_bigint::BigUint::from(3u8); + let rhs_bi = (num_bigint::BigUint::from(9u8) << 128) + (num_bigint::BigUint::from(11u8)); + let exp2 = x_bi * rhs_bi % (num_bigint::BigUint::from(1u8) << 224); + assert!(sign2); + assert_eq!(mag_bu2, exp2); + + // Case 3: negative small i8 * nonzero rhs, zero result canonicalizes to positive + let k3 = I8OrI96::from_i8(-1); + let rhs3 = S160::zero(); + let out3 = k3 * rhs3; + let (_, sign3) = s224_to_biguint_and_sign(&out3); + assert!(sign3); + } + + #[test] + fn test_s160_mul_s160_hi32_consistency() { + use crate::biginteger::{BigInt, S160}; + + // Spot-check a configuration that exercises the hi32 accumulation path + let a = S160::new([1u64 << 63, 0], 1, true); // a2=1, a0 has high bit + let b = S160::new([0, 1u64 << 63], 1, true); // b2=1, b1 has high bit + let got = &a * &b; // S160 result + + // Convert S160 value to BigUint: pack [lo0, lo1, hi32] into BigInt<3> + let mut pack = [0u64; 3]; + pack[0] = got.magnitude_lo()[0]; + pack[1] = got.magnitude_lo()[1]; + pack[2] = got.magnitude_hi() as u64; + let got_bu = num_bigint::BigUint::from(BigInt::<3>(pack)); + + // Reference BigUint modulo 2^160 + let a_bu = (num_bigint::BigUint::from(a.magnitude_lo()[1]) << 64) + + num_bigint::BigUint::from(a.magnitude_lo()[0]) + + (num_bigint::BigUint::from(a.magnitude_hi() as u64) << 128); + let b_bu = (num_bigint::BigUint::from(b.magnitude_lo()[1]) << 64) + + num_bigint::BigUint::from(b.magnitude_lo()[0]) + + (num_bigint::BigUint::from(b.magnitude_hi() as u64) << 128); + let prod = (a_bu * b_bu) % (num_bigint::BigUint::from(1u8) << 160); + + assert_eq!(got_bu, prod); + } + fn biginteger_shr() { let mut rng = ark_std::test_rng(); let a = B::rand(&mut rng); From 76c35b6e9cc49b38e13d486c7846f2ebd74256cb Mon Sep 17 00:00:00 2001 From: markosg04 Date: Mon, 22 Sep 2025 15:14:13 -0400 Subject: [PATCH 27/38] refactor: cleanup step utils --- jolt-optimizations/src/fq12_poly.rs | 249 ++++++++++------ jolt-optimizations/src/steps.rs | 296 +++++++++++++------ jolt-optimizations/tests/steps_debug_test.rs | 20 +- jolt-optimizations/tests/steps_test.rs | 10 +- 4 files changed, 377 insertions(+), 198 deletions(-) diff --git a/jolt-optimizations/src/fq12_poly.rs b/jolt-optimizations/src/fq12_poly.rs index 751844af6..b603ffa75 100644 --- a/jolt-optimizations/src/fq12_poly.rs +++ b/jolt-optimizations/src/fq12_poly.rs @@ -2,149 +2,208 @@ use ark_bn254::{Fq, Fq12}; use ark_ff::{Field, One, Zero}; -/// Flatten Fq12 to 12 base-field coefficients for a(X)=Σ c_i X^i, X=w, -/// with the relation g(X) = X^12 - 18 X^6 + 82. -/// -/// The BN254 Fq12 field is constructed as a tower extension: -/// - Fq2 = Fq[u]/(u^2 + 1) -/// - Fq6 = Fq2[v]/(v^3 - (9 + u)) -/// - Fq12 = Fq6[w]/(w^2 - v) -/// -/// This function maps an Fq12 element to its polynomial representation -/// in Fq[X] where X = w, using the mapping: -/// (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6}, for k∈{0..5}. -/// @TODO(markosg04) provide proof? +/// Constant for the tower extension mapping +const NINE: u64 = 9; + +/// Newtype wrapper for degree-12 polynomial coefficients +#[derive(Clone, Debug, Default)] +pub struct Poly12([Fq; 12]); + +impl Poly12 { + pub fn new(coeffs: [Fq; 12]) -> Self { + Self(coeffs) + } + + pub fn coeffs(&self) -> &[Fq; 12] { + &self.0 + } + + pub fn coeffs_mut(&mut self) -> &mut [Fq; 12] { + &mut self.0 + } + + pub fn to_vec(&self) -> Vec { + self.0.to_vec() + } + + /// Evaluate at a point using Horner's method + pub fn eval(&self, r: &Fq) -> Fq { + self.0.iter().rev().fold(Fq::zero(), |acc, c| acc * r + c) + } +} + +/// Tower basis mapping for Fq12 -> polynomial conversion +struct TowerBasis { + /// Maps basis elements to power indices: [(element, power_of_w)] + mappings: [(usize, usize, usize); 6], // (c0/c1, inner_idx, w_power) +} + +impl TowerBasis { + const fn new() -> Self { + Self { + mappings: [ + (0, 0, 0), // a.c0.c0 → w^0 + (0, 1, 2), // a.c0.c1 → w^2 + (0, 2, 4), // a.c0.c2 → w^4 + (1, 0, 1), // a.c1.c0 → w^1 + (1, 1, 3), // a.c1.c1 → w^3 + (1, 2, 5), // a.c1.c2 → w^5 + ], + } + } + + fn apply(&self, a: &Fq12) -> Poly12 { + let nine = Fq::from(NINE); + let mut coeffs = [Fq::zero(); 12]; + + for &(outer, inner, w_power) in &self.mappings { + let fp2 = match (outer, inner) { + (0, 0) => &a.c0.c0, + (0, 1) => &a.c0.c1, + (0, 2) => &a.c0.c2, + (1, 0) => &a.c1.c0, + (1, 1) => &a.c1.c1, + (1, 2) => &a.c1.c2, + _ => unreachable!(), + }; + + let (x, y) = (fp2.c0, fp2.c1); + // Apply: (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6} + coeffs[w_power] += x - nine * y; + coeffs[w_power + 6] += y; + } + + Poly12::new(coeffs) + } +} + +static TOWER_BASIS: TowerBasis = TowerBasis::new(); + +/// Convert Fq12 to polynomial representation pub fn fq12_to_poly12_coeffs(a: &Fq12) -> [Fq; 12] { - let nine = Fq::from(9u64); - let mut c = [Fq::zero(); 12]; - - // (term, k) pairs mapping Fq12 basis elements to powers of w: - // 1, v, v^2, w, v·w, v^2·w ↔ w^0, w^2, w^4, w^1, w^3, w^5 - let terms = [ - (&a.c0.c0, 0usize), // 1 → w^0 - (&a.c0.c1, 2usize), // v → w^2 - (&a.c0.c2, 4usize), // v^2 → w^4 - (&a.c1.c0, 1usize), // w → w^1 - (&a.c1.c1, 3usize), // v·w → w^3 - (&a.c1.c2, 5usize), // v^2·w → w^5 - ]; - - for (fp2, k) in terms { - let x = fp2.c0; // coefficient of 1 in Fp2 - let y = fp2.c1; // coefficient of u in Fp2 (with u^2 = -1) - // Apply the mapping: (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6} - c[k] += x - nine * y; - c[k + 6] += y; + TOWER_BASIS.apply(a).0 +} + +/// The minimal polynomial g(X) = X^12 - 18 X^6 + 82 +struct MinimalPolynomial; + +impl MinimalPolynomial { + const COEFF_0: u64 = 82; + const COEFF_6: i64 = -18; + + /// Evaluate g(X) at point r + fn eval(r: &Fq) -> Fq { + let r6 = (r.square() * r).square(); // r^6 = (r^2 * r)^2 + let r12 = r6.square(); + r12 - Fq::from(18u64) * r6 + Fq::from(Self::COEFF_0) + } + + /// Get coefficients as a vector + fn coeffs() -> Vec { + let mut g = vec![Fq::zero(); 13]; + g[0] = Fq::from(Self::COEFF_0); + g[6] = -Fq::from(18u64); + g[12] = Fq::one(); + g } - c } -/// Evaluate g(X) = X^12 - 18 X^6 + 82 at a given point r. +/// Evaluate g(X) = X^12 - 18 X^6 + 82 at a given point r pub fn g_eval(r: &Fq) -> Fq { - let r2 = r.square(); // r^2 - let r3 = r2 * r; // r^3 - let r6 = r3.square(); // r^6 - let r12 = r6.square(); // r^12 - r12 - (Fq::from(18u64) * r6) + Fq::from(82u64) + MinimalPolynomial::eval(r) } -/// Horner evaluation for arbitrary-degree polynomial. +/// Horner evaluation for arbitrary-degree polynomial pub fn eval_poly_vec(coeffs: &[Fq], r: &Fq) -> Fq { - let mut acc = Fq::zero(); - for &c in coeffs.iter().rev() { - acc *= r; - acc += c; - } - acc + coeffs.iter().rev().fold(Fq::zero(), |acc, c| acc * r + c) } -/// Add polynomial b to polynomial a in place. -pub fn poly_add_in_place(a: &mut Vec, b: &[Fq]) { +/// Generic polynomial operation in place +fn poly_op_in_place(a: &mut Vec, b: &[Fq], op: F) +where + F: Fn(&mut Fq, Fq), +{ if b.len() > a.len() { a.resize(b.len(), Fq::zero()); } - for i in 0..b.len() { - a[i] += b[i]; - } + b.iter().enumerate().for_each(|(i, &coeff)| op(&mut a[i], coeff)); +} + +/// Add polynomial b to polynomial a in place +pub fn poly_add_in_place(a: &mut Vec, b: &[Fq]) { + poly_op_in_place(a, b, |a, b| *a += b); } -/// Subtract polynomial b from polynomial a in place. +/// Subtract polynomial b from polynomial a in place pub fn poly_sub_in_place(a: &mut Vec, b: &[Fq]) { - if b.len() > a.len() { - a.resize(b.len(), Fq::zero()); - } - for i in 0..b.len() { - a[i] -= b[i]; - } + poly_op_in_place(a, b, |a, b| *a -= b); } -/// Multiply two polynomials using convolution. +/// Multiply two polynomials using convolution pub fn poly_mul(a: &[Fq], b: &[Fq]) -> Vec { if a.is_empty() || b.is_empty() { return vec![]; } + let mut out = vec![Fq::zero(); a.len() + b.len() - 1]; - for i in 0..a.len() { - for j in 0..b.len() { - out[i + j] += a[i] * b[j]; - } - } + a.iter().enumerate().for_each(|(i, &ai)| { + b.iter().enumerate().for_each(|(j, &bj)| { + out[i + j] += ai * bj; + }) + }); out } -/// Polynomial long division by a monic divisor. -pub fn poly_div_rem_monic(mut dividend: Vec, g: &[Fq]) -> (Vec, Vec) { - assert!(!g.is_empty(), "divisor g must be non-empty"); +/// Polynomial long division by a monic divisor +pub fn poly_div_rem_monic(mut dividend: Vec, divisor: &[Fq]) -> (Vec, Vec) { + assert!(!divisor.is_empty(), "divisor must be non-empty"); assert!( - g.last().unwrap().is_one(), - "divisor g must be monic (leading coefficient = 1)" + divisor.last().unwrap().is_one(), + "divisor must be monic (leading coefficient = 1)" ); - if dividend.is_empty() || dividend.len() < g.len() { + if dividend.is_empty() || dividend.len() < divisor.len() { return (vec![], dividend); } - let n = dividend.len() - 1; - let m = g.len() - 1; // deg g - let mut q = vec![Fq::zero(); n - m + 1]; + let deg_dividend = dividend.len() - 1; + let deg_divisor = divisor.len() - 1; + let mut quotient = vec![Fq::zero(); deg_dividend - deg_divisor + 1]; - for k in (m..=n).rev() { - let lead = dividend[k]; // since g is monic, this is the quotient coefficient - q[k - m] = lead; - if lead.is_zero() { - continue; - } - // subtract lead * x^{k-m} * g from dividend - for j in 0..=m { - dividend[k - m + j] -= lead * g[j]; + for k in (deg_divisor..=deg_dividend).rev() { + let coeff = dividend[k]; + quotient[k - deg_divisor] = coeff; + + if !coeff.is_zero() { + // Subtract coeff * x^{k-deg_divisor} * divisor from dividend + (0..=deg_divisor).for_each(|j| { + dividend[k - deg_divisor + j] -= coeff * divisor[j]; + }); } } - // trim trailing zeros from remainder - while let Some(true) = dividend.last().map(|c| c.is_zero()) { + // Trim trailing zeros from remainder + while dividend.last() == Some(&Fq::zero()) { dividend.pop(); } - (q, dividend) + (quotient, dividend) } -/// Build the coefficients for g(X) = X^12 - 18 X^6 + 82. +/// Build the coefficients for g(X) = X^12 - 18 X^6 + 82 pub fn g_coeffs() -> Vec { - let mut g = vec![Fq::zero(); 13]; - g[0] = Fq::from(82u64); - g[6] = -Fq::from(18u64); - g[12] = Fq::one(); - g + MinimalPolynomial::coeffs() } -/// Convert Fq12 polynomial coefficients to multilinear evaluations by padding to 16 elements.= +/// Convert Fq12 polynomial coefficients to multilinear evaluations by padding to 16 elements pub fn to_multilinear_evals(coeffs: &[Fq; 12]) -> Vec { - let mut evals = coeffs.to_vec(); + let mut evals = Vec::with_capacity(16); + evals.extend_from_slice(coeffs); evals.resize(16, Fq::zero()); evals } +/// Convert Fq12 directly to multilinear evaluations pub fn fq12_to_multilinear_evals(a: &Fq12) -> Vec { - let coeffs = fq12_to_poly12_coeffs(a); - to_multilinear_evals(&coeffs) + to_multilinear_evals(&fq12_to_poly12_coeffs(a)) } diff --git a/jolt-optimizations/src/steps.rs b/jolt-optimizations/src/steps.rs index 329514b85..107bb3664 100644 --- a/jolt-optimizations/src/steps.rs +++ b/jolt-optimizations/src/steps.rs @@ -2,152 +2,272 @@ use crate::sz_check::Product; use ark_bn254::{Fq, Fq12}; use ark_ff::{BigInteger, Field, One, PrimeField}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use std::fmt; -/// Represents a single step in the square-and-multiply exponentiation algorithm. +/// Error types for exponentiation verification +#[derive(Debug, Clone, PartialEq)] +pub enum VerificationError { + IncorrectResult { expected: Fq12, actual: Fq12 }, + InvalidSquaring { step: usize, expected: Fq12, actual: Fq12 }, + InvalidMultiplication { step: usize, expected: Fq12, actual: Fq12 }, + InconsistentChain { step: usize }, +} + +impl fmt::Display for VerificationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::IncorrectResult { .. } => write!(f, "Final result doesn't match expected"), + Self::InvalidSquaring { step, .. } => write!(f, "Invalid squaring at step {}", step), + Self::InvalidMultiplication { step, .. } => write!(f, "Invalid multiplication at step {}", step), + Self::InconsistentChain { step } => write!(f, "Inconsistent state chain at step {}", step), + } + } +} + +impl std::error::Error for VerificationError {} + +/// State transition in exponentiation +#[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] +pub struct StepTransition { + /// Previous and current accumulator values + pub accumulator: (Fq12, Fq12), + /// Running product before and after this step + pub product: (Fq12, Fq12), +} + +/// Single step in square-and-multiply algorithm #[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] pub struct ExponentiationStep { pub step_index: usize, pub bit_value: bool, - pub a_prev: Fq12, - pub a_curr: Fq12, - pub rho_before: Fq12, - pub rho_after: Fq12, + pub transition: StepTransition, +} + +impl ExponentiationStep { + fn new(step_index: usize, bit_value: bool, a_prev: Fq12, a_curr: Fq12, rho_before: Fq12, rho_after: Fq12) -> Self { + Self { + step_index, + bit_value, + transition: StepTransition { + accumulator: (a_prev, a_curr), + product: (rho_before, rho_after), + }, + } + } + + /// Get the previous accumulator value + pub fn a_prev(&self) -> Fq12 { + self.transition.accumulator.0 + } + + /// Get the current accumulator value + pub fn a_curr(&self) -> Fq12 { + self.transition.accumulator.1 + } + + /// Get the product before this step + pub fn rho_before(&self) -> Fq12 { + self.transition.product.0 + } + + /// Get the product after this step + pub fn rho_after(&self) -> Fq12 { + self.transition.product.1 + } } #[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] pub struct ExponentiationSteps { - /// The base being exponentiated pub base: Fq12, - /// The exponent pub exponent: Fq, - /// All steps in the computation pub steps: Vec, - /// The final result (should equal base^exponent) pub result: Fq12, } -impl ExponentiationSteps { - /// Convert the steps into Products for verification with sz_check - pub fn to_products(&self) -> Vec { - let mut products = Vec::new(); +/// Builder for ExponentiationSteps +pub struct StepsBuilder { + base: Fq12, + exponent: Fq, + steps: Vec, +} + +impl StepsBuilder { + fn new(base: Fq12, exponent: Fq) -> Self { + Self { + base, + exponent, + steps: Vec::new(), + } + } - for step in &self.steps { - // Each squaring operation creates a product: a_i = a_{i-1} * a_{i-1} - products.push(Product::new(step.a_prev, step.a_prev, step.a_curr)); + fn add_step(&mut self, step: ExponentiationStep) { + self.steps.push(step); + } - // If the bit is 1, we multiply rho by the current power - if step.bit_value && step.rho_before != step.rho_after { - products.push(Product::new(step.rho_before, step.a_curr, step.rho_after)); - } + fn build(self, result: Fq12) -> ExponentiationSteps { + ExponentiationSteps { + base: self.base, + exponent: self.exponent, + steps: self.steps, + result, } + } +} + +impl ExponentiationSteps { + /// Convert steps to Products for sz_check verification + pub fn to_products(&self) -> Vec { + self.steps + .iter() + .flat_map(|step| { + let mut products = vec![ + // Squaring: a_i = a_{i-1} * a_{i-1} + Product::new(step.a_prev(), step.a_prev(), step.a_curr()), + ]; + + // Multiplication if bit is set + if step.bit_value && step.rho_before() != step.rho_after() { + products.push(Product::new( + step.rho_before(), + step.a_curr(), + step.rho_after(), + )); + } - products + products + }) + .collect() } - pub fn sanity_verify(&self) -> bool { + /// Verify consistency of recorded steps + pub fn verify_consistency(&self) -> Result<(), VerificationError> { + // Check final result let expected = self.base.pow(self.exponent.into_bigint()); if self.result != expected { - return false; + return Err(VerificationError::IncorrectResult { + expected, + actual: self.result, + }); } + // Verify each step for (i, step) in self.steps.iter().enumerate() { - if step.a_curr != step.a_prev * step.a_prev { - return false; + // Verify squaring + let expected_a = step.a_prev() * step.a_prev(); + if step.a_curr() != expected_a { + return Err(VerificationError::InvalidSquaring { + step: i, + expected: expected_a, + actual: step.a_curr(), + }); } - let expected_rho_after = if step.bit_value { - step.rho_before * step.a_curr + // Verify multiplication + let expected_rho = if step.bit_value { + step.rho_before() * step.a_curr() } else { - step.rho_before + step.rho_before() }; - - if step.rho_after != expected_rho_after { - return false; + if step.rho_after() != expected_rho { + return Err(VerificationError::InvalidMultiplication { + step: i, + expected: expected_rho, + actual: step.rho_after(), + }); } - if i + 1 < self.steps.len() { - if step.a_curr != self.steps[i + 1].a_prev { - return false; - } - if step.rho_after != self.steps[i + 1].rho_before { - return false; + // Verify chain consistency + if let Some(next) = self.steps.get(i + 1) { + if step.a_curr() != next.a_prev() || step.rho_after() != next.rho_before() { + return Err(VerificationError::InconsistentChain { step: i + 1 }); } } } - if let Some(last_step) = self.steps.last() { - if last_step.rho_after != self.result { - return false; + // Verify final step matches result + if let Some(last) = self.steps.last() { + if last.rho_after() != self.result { + return Err(VerificationError::IncorrectResult { + expected: self.result, + actual: last.rho_after(), + }); } } - true + Ok(()) + } + + /// Legacy verification method for compatibility + pub fn sanity_verify(&self) -> bool { + self.verify_consistency().is_ok() } } -pub fn pow_with_steps_le(base: Fq12, exponent: Fq) -> ExponentiationSteps { - let mut steps = Vec::new(); +/// Helper to iterate over significant bits +struct BitIterator { + bits: Vec, + last_one_pos: Option, +} + +impl BitIterator { + fn new(exponent: Fq) -> Self { + let bits = exponent.into_bigint().to_bits_le(); + let last_one_pos = bits.iter().rposition(|&b| b); + Self { bits, last_one_pos } + } - let bigint = exponent.into_bigint(); - let exp_bits = bigint.to_bits_le(); + fn is_trivial(&self) -> Option { + match self.last_one_pos { + None => Some(Fq12::one()), // exp = 0 + Some(0) => None, // exp = 1, handled separately + _ => None, + } + } - // Find the position of the last 1 bit - let last_one = exp_bits.iter().rposition(|&b| b); + fn initial_bit(&self) -> bool { + self.bits.get(0).copied().unwrap_or(false) + } - if last_one.is_none() { - // Exponent is 0, return 1 - return ExponentiationSteps { - base, - exponent, - steps: vec![], - result: Fq12::one(), - }; + fn significant_bits(&self) -> impl Iterator + '_ { + let end = self.last_one_pos.unwrap_or(0); + (1..=end).map(move |i| (i - 1, self.bits[i])) } +} - let last_one = last_one.unwrap(); +/// Compute base^exponent with step-by-step recording (LSB-first) +pub fn pow_with_steps_le(base: Fq12, exponent: Fq) -> ExponentiationSteps { + let bits = BitIterator::new(exponent); - if last_one == 0 { - // Exponent is 1, return base + // Handle trivial cases + if let Some(result) = bits.is_trivial() { return ExponentiationSteps { base, exponent, steps: vec![], - result: base, + result: if bits.last_one_pos.is_none() { result } else { base }, }; } - let mut a_curr = base; // Current power of base - let mut rho = if exp_bits[0] { base } else { Fq12::one() }; - - for (step_idx, bit_idx) in (1..=last_one).enumerate() { - let bit_value = exp_bits[bit_idx]; - let a_prev = a_curr; - let rho_before = rho; + let mut builder = StepsBuilder::new(base, exponent); + let mut accumulator = base; + let mut product = if bits.initial_bit() { base } else { Fq12::one() }; - a_curr = a_prev * a_prev; + for (step_idx, bit) in bits.significant_bits() { + let prev_acc = accumulator; + let prev_prod = product; - let rho_after = if bit_value { - rho_before * a_curr - } else { - rho_before - }; - - steps.push(ExponentiationStep { - step_index: step_idx, - bit_value, - a_prev, - a_curr, - rho_before, - rho_after, - }); + accumulator = prev_acc.square(); + product = if bit { prev_prod * accumulator } else { prev_prod }; - rho = rho_after; + builder.add_step(ExponentiationStep::new( + step_idx, + bit, + prev_acc, + accumulator, + prev_prod, + product, + )); } - ExponentiationSteps { - base, - exponent, - steps, - result: rho, - } + builder.build(product) } diff --git a/jolt-optimizations/tests/steps_debug_test.rs b/jolt-optimizations/tests/steps_debug_test.rs index 2a170938a..e2f5a57ce 100644 --- a/jolt-optimizations/tests/steps_debug_test.rs +++ b/jolt-optimizations/tests/steps_debug_test.rs @@ -1,6 +1,6 @@ use ark_bn254::{Fq, Fq12}; use ark_ff::BigInteger; -use ark_ff::{Field, One, PrimeField, UniformRand}; +use ark_ff::{Field, PrimeField, UniformRand}; use ark_std::test_rng; use jolt_optimizations::steps::pow_with_steps_le; @@ -50,14 +50,14 @@ fn test_debug_trace() { ); println!(" Squaring: a_{} = a_{}^2", i + 1, i); - println!(" a_{} = {:?}", i, step.a_prev); - println!(" a_{} = {:?}", i + 1, step.a_curr); + println!(" a_{} = {:?}", i, step.a_prev()); + println!(" a_{} = {:?}", i + 1, step.a_curr()); // Verify squaring - let expected_square = step.a_prev * step.a_prev; + let expected_square = step.a_prev() * step.a_prev(); println!( " Verification: a_curr == a_prev^2? {}", - if step.a_curr == expected_square { + if step.a_curr() == expected_square { "✓" } else { "✗" @@ -65,7 +65,7 @@ fn test_debug_trace() { ); println!(" Accumulator update:"); - println!(" rho_before = {:?}", step.rho_before); + println!(" rho_before = {:?}", step.rho_before()); if step.bit_value { println!(" Bit is 1, so: rho_after = rho_before * a_curr"); @@ -73,17 +73,17 @@ fn test_debug_trace() { println!(" Bit is 0, so: rho_after = rho_before (unchanged)"); } - println!(" rho_after = {:?}", step.rho_after); + println!(" rho_after = {:?}", step.rho_after()); // Verify accumulator update let expected_rho = if step.bit_value { - step.rho_before * step.a_curr + step.rho_before() * step.a_curr() } else { - step.rho_before + step.rho_before() }; println!( " Verification: rho_after correct? {}", - if step.rho_after == expected_rho { + if step.rho_after() == expected_rho { "✓" } else { "✗" diff --git a/jolt-optimizations/tests/steps_test.rs b/jolt-optimizations/tests/steps_test.rs index 0146b5a47..d58a6083b 100644 --- a/jolt-optimizations/tests/steps_test.rs +++ b/jolt-optimizations/tests/steps_test.rs @@ -102,8 +102,8 @@ fn test_step_continuity() { // Check continuity between steps for i in 0..steps.steps.len() - 1 { assert_eq!( - steps.steps[i].rho_after, - steps.steps[i + 1].rho_before, + steps.steps[i].rho_after(), + steps.steps[i + 1].rho_before(), "Step continuity broken at step {}", i ); @@ -111,7 +111,7 @@ fn test_step_continuity() { // Check final step leads to result if let Some(last_step) = steps.steps.last() { - assert_eq!(last_step.rho_after, steps.result); + assert_eq!(last_step.rho_after(), steps.result); } } @@ -125,9 +125,9 @@ fn test_squaring_correctness() { // Verify each squaring operation: a_i = a_{i-1}^2 for step in &steps.steps { - let expected_square = step.a_prev * step.a_prev; + let expected_square = step.a_prev() * step.a_prev(); assert_eq!( - step.a_curr, expected_square, + step.a_curr(), expected_square, "Squaring incorrect at step {}", step.step_index ); From 001264f28c8dffd5455be5100508aef1a707b0af Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Wed, 24 Sep 2025 11:36:44 -0400 Subject: [PATCH 28/38] add new API for add/sub/mul bigint of different widths + generic montgomery/barrett reduce --- ff/src/biginteger/mod.rs | 365 ++++++++---------- ff/src/biginteger/tests.rs | 172 +++++---- ff/src/fields/models/fp/montgomery_backend.rs | 150 ++++--- test-curves/benches/small_mul.rs | 32 +- 4 files changed, 342 insertions(+), 377 deletions(-) diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 754b6e9fc..16f2e07a4 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -15,7 +15,7 @@ use ark_std::{ io::{Read, Write}, ops::{ BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, ShlAssign, Shr, - ShrAssign, + ShrAssign, Add, Sub, AddAssign, SubAssign, }, rand::{ distributions::{Distribution, Standard}, @@ -299,114 +299,93 @@ impl BigInt { self.0[N - 1].leading_zeros() } - /// Truncated-width multiplication: compute self * other and fit into P limbs; overflow is ignored. + /// Truncated-width addition: compute self + other into P limbs. + /// + /// - Semantics: returns the low P limbs of the sum; higher limbs are discarded. + /// - Precondition (debug-only): right operand width M must be <= P. + /// - Debug contract: panics in debug if an addition carry would spill beyond P limbs. #[inline] - pub fn mul_trunc(&self, other: &BigInt) -> BigInt

{ - let mut res = BigInt::

::zero(); - let i_limit = core::cmp::min(N, P); - for i in 0..i_limit { - let mut carry = 0u64; - let j_limit = core::cmp::min(M, P - i); - for j in 0..j_limit { - res.0[i + j] = mac_with_carry!(res.0[i + j], self.0[i], other.0[j], &mut carry); - } - if i + j_limit < P { - let (new_val, _of) = res.0[i + j_limit].overflowing_add(carry); - res.0[i + j_limit] = new_val; - } - } - res + pub fn add_trunc(&self, other: &BigInt) -> BigInt

{ + debug_assert!(M <= P, "add_trunc: right operand wider than result width P"); + let mut acc = BigInt::

::zero(); + let copy_len = core::cmp::min(P, N); + acc.0[..copy_len].copy_from_slice(&self.0[..copy_len]); + acc.add_assign_trunc::(other); + acc } - /// Truncated-width addition: compute self + other and fit into P limbs; overflow is ignored. + /// Truncated-width subtraction: compute self - other into P limbs. + /// + /// - Semantics: returns the low P limbs of the difference; higher borrow is discarded. + /// - Precondition (debug-only): right operand width M must be <= P. + /// - Debug contract: panics in debug if a borrow would spill beyond P limbs. #[inline] - pub fn add_trunc(&self, other: &BigInt) -> BigInt

{ - let mut res = BigInt::

::zero(); - let mut carry = 0u64; - - // Add all limbs up to the result size P, using 0 for missing limbs - let min_size = core::cmp::min(N, M); - let max_size = core::cmp::max(N, M); - - // Add corresponding limbs from both BigInts - for i in 0..core::cmp::min(min_size, P) { - res.0[i] = adc!(self.0[i], other.0[i], &mut carry); - } - - // Handle remaining limbs from the larger BigInt - for i in min_size..core::cmp::min(max_size, P) { - let a = if i < N { self.0[i] } else { 0 }; - let b = if i < M { other.0[i] } else { 0 }; - res.0[i] = adc!(a, b, &mut carry); - } - - // Propagate any remaining carry to unused limbs within P - let mut i = max_size; - while carry != 0 && i < P { - res.0[i] = adc!(res.0[i], 0, &mut carry); - i += 1; - } - - res + pub fn sub_trunc(&self, other: &BigInt) -> BigInt

{ + debug_assert!(M <= P, "sub_trunc: right operand wider than result width P"); + let mut acc = BigInt::

::zero(); + let copy_len = core::cmp::min(P, N); + acc.0[..copy_len].copy_from_slice(&self.0[..copy_len]); + acc.sub_assign_trunc::(other); + acc } - /// Truncated-width subtraction: compute self - other and fit into P limbs; borrow is ignored beyond P limbs. + /// Truncated-width multiplication: compute self * other and fit into P limbs; overflow is ignored. #[inline] - pub fn sub_trunc(&self, other: &BigInt) -> BigInt

{ + pub fn mul_trunc(&self, other: &BigInt) -> BigInt

{ let mut res = BigInt::

::zero(); - let mut borrow = false; - - for i in 0..P { - let a = if i < N { self.0[i] } else { 0u64 }; - let b = if i < M { other.0[i] } else { 0u64 }; - let (d1, b1) = a.overflowing_sub(b); - if borrow { - let (d2, b2) = d1.overflowing_sub(1); - res.0[i] = d2; - borrow = b1 || b2; - } else { - res.0[i] = d1; - borrow = b1; - } - } - + // Use core fused multiply engine specialized on M for unrolling + self.fm_limbs_into::(&other.0, &mut res, false); res } - /// Truncated-width addition that mutates self: self += other and fit result into P limbs; overflow is ignored. + /// Truncated-width addition that mutates self: self += other, keeping N limbs (self's width). + /// + /// - Semantics: computes (self + other) mod 2^(64*N). + /// - Precondition (debug-only): right operand width M must be <= N. + /// - Debug contract: panics in debug if a carry would spill beyond N limbs. #[inline] - pub fn add_assign_trunc(&mut self, other: &BigInt) { + #[unroll_for_loops(9)] + pub fn add_assign_trunc(&mut self, other: &BigInt) { + debug_assert!(M <= N, "add_assign_trunc: right operand wider than self width N"); let mut carry = 0u64; - let limit = core::cmp::min(P, N); - - let overlap = core::cmp::min(limit, core::cmp::min(N, M)); - for i in 0..overlap { - self.0[i] = adc!(self.0[i], other.0[i], &mut carry); - } - - // If self has remaining limbs within the limit, add carry through them - if N > M { - for i in overlap..limit { - self.0[i] = adc!(self.0[i], 0, &mut carry); - } - } else if M > N { - // If other has remaining limbs within the limit, add them into self (self's lanes may be zero) - for i in overlap..core::cmp::min(M, limit) { - self.0[i] = adc!(0, other.0[i], &mut carry); - } + for i in 0..N { + let rhs = if i < M { other.0[i] } else { 0 }; + self.0[i] = adc!(self.0[i], rhs, &mut carry); } + debug_assert!(carry == 0, "add_assign_trunc overflow: carry beyond N limbs"); + } - // Propagate any remaining carry within the limit - let mut i = core::cmp::min(core::cmp::max(N, M), limit); - while carry != 0 && i < limit { - self.0[i] = adc!(self.0[i], 0, &mut carry); - i += 1; + /// Truncated-width subtraction that mutates self: self -= other, keeping N limbs (self's width). + /// + /// Semantics: computes (self - other) mod 2^(64*N). + /// Precondition (debug-only): right operand width M must be <= N. + /// Debug contract: panics in debug if a borrow would spill beyond N limbs. + #[inline] + #[unroll_for_loops(9)] + pub fn sub_assign_trunc(&mut self, other: &BigInt) { + debug_assert!(M <= N, "sub_assign_trunc: right operand wider than self width N"); + let mut borrow = 0u64; + for i in 0..N { + let rhs = if i < M { other.0[i] } else { 0 }; + self.0[i] = sbb!(self.0[i], rhs, &mut borrow); } + debug_assert!(borrow == 0, "sub_assign_trunc underflow: borrow beyond N limbs"); + } - // Zero out the remaining limbs beyond the limit (truncate to P limbs) - for i in limit..N { - self.0[i] = 0; + /// Truncated-width multiplication that mutates self: self = (self * other) mod 2^(64*N). + /// Keeps exactly N limbs (self's width). Overflow beyond N limbs is ignored. + #[inline] + pub fn mul_assign_trunc(&mut self, other: &BigInt) { + // Fast paths + if self.is_zero() || other.is_zero() { + for i in 0..N { self.0[i] = 0; } + return; } + let left = *self; // snapshot original multiplicand + // zero self to use as accumulator buffer + for i in 0..N { self.0[i] = 0; } + // Accumulate left * other directly into self within width N; propagate carries within N + left.fm_limbs_into::(&other.0, self, true); } /// Fused multiply-add with truncation: acc += self * other, fitting into P limbs; overflow is ignored. @@ -431,55 +410,44 @@ impl BigInt { } } - /// Internal core engine: accumulate self * other_limbs into acc starting at lane_offset. - /// If carry_propagate is true, propagate spill from the highest updated limb forward within P; - /// otherwise, wrap in-place (discard further carry), matching existing wrapper semantics. + /// Accumulate with a compile-time-known count of multiplier limbs M to enable unrolling. #[inline] - #[unroll_for_loops(6)] - pub(crate) fn fm_limbs_into( + #[unroll_for_loops(10)] + pub(crate) fn fm_limbs_into( &self, - other_limbs: &[u64], + other_limbs: &[u64; M], acc: &mut BigInt

, - lane_offset: usize, carry_propagate: bool, ) { - if self.is_zero() { - return; - } - for (j, &mul_limb) in other_limbs.iter().enumerate() { + for j in 0..M { + let mul_limb = other_limbs[j]; if mul_limb == 0 { - continue; - } - let base = lane_offset + j; - let mut carry = 0u64; - // Accumulate across self's limbs - for i in 0..N { - let idx = base + i; - if idx >= P { - // Out of truncation range; compute carry but discard writes - // We still need to advance carry for correctness within truncated semantics? No: any - // contribution beyond P is dropped modulo 2^(64*P), so we can break. - break; + // Skip zero multiplier limb + // (cannot use `continue` here due to unroll macro limitations) + } else { + let base = j; + let mut carry = 0u64; + for i in 0..N { + let idx = base + i; + if idx < P { + acc.0[idx] = mac_with_carry!(acc.0[idx], self.0[i], mul_limb, &mut carry); + } } - acc.0[idx] = mac_with_carry!(acc.0[idx], self.0[i], mul_limb, &mut carry); - } - // Add remaining carry into next limb if within width - let next = base + N; - if next < P { - let (v, mut of) = acc.0[next].overflowing_add(carry); - acc.0[next] = v; - if carry_propagate && of { - // propagate into higher limbs until carry consumed or width exhausted - let mut k = next + 1; - while of && k < P { - let (nv, nof) = acc.0[k].overflowing_add(1); - acc.0[k] = nv; - of = nof; - k += 1; + let next = base + N; + if next < P { + let (v, mut of) = acc.0[next].overflowing_add(carry); + acc.0[next] = v; + if carry_propagate && of { + let mut k = next + 1; + while of && k < P { + let (nv, nof) = acc.0[k].overflowing_add(1); + acc.0[k] = nv; + of = nof; + k += 1; + } } } } - // else: spill beyond P is dropped by truncation } } @@ -685,48 +653,6 @@ impl BigInteger for BigInt { res } - #[inline] - fn fmu64a(&self, other: u64, acc: &mut BigInt) { - debug_assert!(NPLUS1 == N + 1); - self.fm_limbs_into::(&[other], acc, 0, false); - } - - #[inline] - #[unroll_for_loops(8)] - fn fmu64a_carry_propagating(&self, other: u64, acc: &mut BigInt) { - debug_assert!(NPLUS2 == N + 2); - self.fm_limbs_into::(&[other], acc, 0, true); - } - - #[inline] - #[unroll_for_loops(8)] - fn fm128a(&self, other: u128, acc: &mut BigInt) { - debug_assert!(NPLUS2 == N + 2); - let limbs = [other as u64, (other >> 64) as u64]; - self.fm_limbs_into::(&limbs, acc, 0, true); - } - - #[inline] - #[unroll_for_loops(8)] - fn fmu64a_into_nplus4(&self, other: u64, acc: &mut BigInt) { - debug_assert!(NPLUS4 == N + 4); - self.fm_limbs_into::(&[other], acc, 0, true); - } - - #[inline] - #[unroll_for_loops(8)] - fn fm2x64a_into_nplus4(&self, other: [u64; 2], acc: &mut BigInt) { - debug_assert!(NPLUS4 == N + 4); - self.fm_limbs_into::(&other, acc, 0, true); - } - - #[inline] - #[unroll_for_loops(8)] - fn fm3x64a_into_nplus4(&self, other: [u64; 3], acc: &mut BigInt) { - debug_assert!(NPLUS4 == N + 4); - self.fm_limbs_into::(&other, acc, 0, true); - } - #[inline] #[unroll_for_loops(8)] fn mul_u128_w_carry( @@ -1264,6 +1190,71 @@ impl Not for BigInt { } } +// Arithmetic with truncating semantics for BigInt of different widths +// Note: we cannot let the output have arbitrary width due to Rust's type limitation +// So we set the output width to be the same as the width of the left operand +impl Add> for BigInt { + type Output = BigInt; + + fn add(self, rhs: BigInt) -> Self::Output { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.add_trunc::(&rhs) + } +} + +impl Add<&BigInt> for BigInt { + type Output = BigInt; + fn add(self, rhs: &BigInt) -> Self::Output { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.add_trunc::(rhs) + } +} + +impl Sub> for BigInt { + type Output = BigInt; + + fn sub(self, rhs: BigInt) -> Self::Output { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.sub_trunc::(&rhs) + } +} + +impl Sub<&BigInt> for BigInt { + type Output = BigInt; + fn sub(self, rhs: &BigInt) -> Self::Output { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.sub_trunc::(rhs) + } +} + +impl AddAssign> for BigInt { + fn add_assign(&mut self, rhs: BigInt) { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.add_assign_trunc::(&rhs); + } +} + +impl AddAssign<&BigInt> for BigInt { + fn add_assign(&mut self, rhs: &BigInt) { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.add_assign_trunc::(rhs); + } +} + +impl SubAssign> for BigInt { + fn sub_assign(&mut self, rhs: BigInt) { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.sub_assign_trunc::(&rhs); + } +} + +impl SubAssign<&BigInt> for BigInt { + fn sub_assign(&mut self, rhs: &BigInt) { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.sub_assign_trunc::(rhs); + } +} + /// Compute the signed modulo operation on a u64 representation, returning the result. /// If n % modulus > modulus / 2, return modulus - n /// # Example @@ -1442,42 +1433,12 @@ pub trait BigInteger: /// NEW! Multiplies self by a u64, returning a bigint with one extra limb to hold overflow. fn mul_u64_w_carry(&self, other: u64) -> BigInt; - /// NEW! Multiplies self by a u64, accumulating the result in `acc`, which must have one extra limb. - /// overflow causes a wraparound in the highest limb of the accumulator. - fn fmu64a(&self, other: u64, acc: &mut BigInt); - - /// NEW! Fused multiply-accumulate with a u64 multiplier and explicit overflow propagation. - /// Accumulates `self * other` into `acc`, which must have two extra limbs (N + 2). - /// Any overflow from limb N is carried into limb N+1 instead of wrapping. - fn fmu64a_carry_propagating(&self, other: u64, acc: &mut BigInt); - /// NEW! Multiplies self by a u128, returning a bigint with two extra limbs to hold overflow. fn mul_u128_w_carry( &self, other: u128, ) -> BigInt; - /// NEW! Fused multiply-accumulate with a u128 multiplier. - /// Accumulate self * other into `acc`, which must have two extra limbs. - /// Overflow causes wraparound in the highest limb of the accumulator. - fn fm128a(&self, other: u128, acc: &mut BigInt); - - /// NEW! Fused multiply-accumulate of `self` by a single `u64` limb, accumulating into - /// an accumulator with four extra limbs (N + 4), with carry propagation within the width. - /// This will accumulate `self * other` into `acc` and propagate any overflow from limb N - /// into limbs N+1..=N+3. Overflow beyond limb N+3 is dropped by contract. - fn fmu64a_into_nplus4(&self, other: u64, acc: &mut BigInt); - - /// NEW! Fused multiply-accumulate of `self` by a two-limb `[u64; 2]` multiplier, accumulating - /// into an accumulator with four extra limbs (N + 4). Carries are propagated within the width. - /// This is equivalent to doing two u64 passes offset by one limb and cascading carries. - fn fm2x64a_into_nplus4(&self, other: [u64; 2], acc: &mut BigInt); - - /// NEW! Fused multiply-accumulate of `self` by a three-limb `[u64; 3]` multiplier, accumulating - /// into an accumulator with four extra limbs (N + 4). Carries are propagated within the width. - /// This is equivalent to doing three u64 passes offset by 0, 1, and 2 limbs, respectively. - fn fm3x64a_into_nplus4(&self, other: [u64; 3], acc: &mut BigInt); - /// Multiplies this [`BigInteger`] by another `BigInteger`, storing the result in `self`. /// Overflow is ignored. /// diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 883eed7a5..293386deb 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -430,8 +430,8 @@ pub mod tests { let b = 67890u64; let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); - // Perform fused multiply-accumulate - a.fmu64a(b, &mut acc); + // Perform fused multiply-accumulate (no carry propagation in highest limb) + a.fm_limbs_into::<1, 5>(&[b], &mut acc, false); // Compare against separate multiply and add let expected_mul = BigUint::from(12345u64) * BigUint::from(67890u64); @@ -442,20 +442,20 @@ pub mod tests { let zero = BigInteger256::zero(); let mut acc = BigInteger256::from(12345u64).mul_u64_w_carry::<5>(1); let acc_copy = acc; - zero.fmu64a(67890, &mut acc); + zero.fm_limbs_into::<1, 5>(&[67890], &mut acc, false); assert_eq!(acc, acc_copy); // Should be unchanged // Test multiplication by zero let a = BigInteger256::from(12345u64); let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); let acc_copy = acc; - a.fmu64a(0, &mut acc); + a.fm_limbs_into::<1, 5>(&[0], &mut acc, false); assert_eq!(acc, acc_copy); // Should be unchanged // Test multiplication by one (should be just addition) let a = BigInteger256::from(12345u64); let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); - a.fmu64a(1, &mut acc); + a.fm_limbs_into::<1, 5>(&[1], &mut acc, false); let expected = BigUint::from(12345u64) + BigUint::from(11111u64); assert_eq!(BigUint::from(acc), expected); } @@ -505,7 +505,7 @@ pub mod tests { let a = B::from(0x123456789ABCDEFu64); let b = 0x987654321DEADBEEFu128; let mut acc = B::zero().mul_u128_w_carry::<5, 6>(1); // zero-extended accumulator (6 limbs) - a.fm128a::<6>(b, &mut acc); + a.fm_limbs_into::<2, 6>(&[b as u64, (b >> 64) as u64], &mut acc, true); let expected = num_bigint::BigUint::from(0x123456789ABCDEFu64) * num_bigint::BigUint::from(0x987654321DEADBEEFu128); assert_eq!(num_bigint::BigUint::from(acc), expected); @@ -514,13 +514,13 @@ pub mod tests { let a = B::from(12345u64); let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); let acc_copy = acc; - a.fm128a::<6>(0, &mut acc); + a.fm_limbs_into::<2, 6>(&[0u64, 0u64], &mut acc, true); assert_eq!(acc, acc_copy); // One multiplier: reduces to addition let a = B::from(12345u64); let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); - a.fm128a::<6>(1, &mut acc); + a.fm_limbs_into::<2, 6>(&[1u64, 0u64], &mut acc, true); let expected = num_bigint::BigUint::from(12345u64) + num_bigint::BigUint::from(11111u64); assert_eq!(num_bigint::BigUint::from(acc), expected); @@ -531,7 +531,7 @@ pub mod tests { acc.0[4] = u64::MAX; // limb N acc.0[5] = 0; // highest limb // cause carry=1 from low pass (a * 2) - a.fm128a::<6>(2, &mut acc); + a.fm_limbs_into::<2, 6>(&[2u64, 0u64], &mut acc, true); // Expect highest limb incremented by 1 due to overflow from limb N assert_eq!(acc.0[5], 1); } @@ -544,7 +544,7 @@ pub mod tests { acc.0[4] = u64::MAX; // Set highest limb to max // This should cause overflow in the highest limb - a.fmu64a(2, &mut acc); + a.fm_limbs_into::<1, 5>(&[2u64], &mut acc, false); // The overflow should wrap around // u64::MAX * 2 = 2^65 - 2, which when added to u64::MAX = 2^65 + u64::MAX - 2 @@ -577,7 +577,7 @@ pub mod tests { // Reference: (a * other + acc_before) mod 2^(64*(N+4)) let before = BigUint::from(acc.clone()); - a.fmu64a_into_nplus4::<8>(other, &mut acc); + a.fm_limbs_into::<1, 8>(&[other], &mut acc, true); let mut expected = BigUint::from(a); expected *= BigUint::from(other); expected += before; @@ -587,14 +587,14 @@ pub mod tests { // Zero multiplier is no-op let mut acc2 = acc.clone(); - a.fmu64a_into_nplus4::<8>(0, &mut acc2); + a.fm_limbs_into::<1, 8>(&[0u64], &mut acc2, true); assert_eq!(acc2, acc); // One multiplier reduces to addition let mut acc3 = BigInt::<8>::zero(); acc3.0[0] = 11111; let before3 = BigUint::from(acc3.clone()); - a.fmu64a_into_nplus4::<8>(1, &mut acc3); + a.fm_limbs_into::<1, 8>(&[1u64], &mut acc3, true); let mut expected3 = BigUint::from(a); expected3 += before3; expected3 %= &modulus; @@ -608,7 +608,7 @@ pub mod tests { acc4.0[6] = u64::MAX; // limb N+2 acc4.0[7] = 0; // limb N+3 (top) // Use multiplier 2 so the low pass produces a carry=1 - a.fmu64a_into_nplus4::<8>(2, &mut acc4); + a.fm_limbs_into::<1, 8>(&[2u64], &mut acc4, true); assert_eq!(acc4.0[7], 1); } @@ -620,7 +620,7 @@ pub mod tests { let mut acc = BigInt::<8>::zero(); let before = BigUint::from(acc.clone()); - a.fm2x64a_into_nplus4::<8>(other, &mut acc); + a.fm_limbs_into::<2, 8>(&other, &mut acc, true); // Expected: a * (lo + (hi << 64)) + acc_before mod 2^(64*8) let hi = BigUint::from(other[1]); @@ -635,7 +635,7 @@ pub mod tests { // Zero limbs are no-op let mut acc2 = acc.clone(); - a.fm2x64a_into_nplus4::<8>([0, 0], &mut acc2); + a.fm_limbs_into::<2, 8>(&[0u64, 0u64], &mut acc2, true); assert_eq!(acc2, acc); } @@ -651,7 +651,7 @@ pub mod tests { let mut acc = BigInt::<8>::zero(); let before = BigUint::from(acc.clone()); - a.fm3x64a_into_nplus4::<8>(other, &mut acc); + a.fm_limbs_into::<3, 8>(&other, &mut acc, true); // Expected: a * (o0 + (o1<<64) + (o2<<128)) + acc_before mod 2^(64*8) let term0 = BigUint::from(other[0]); @@ -673,7 +673,7 @@ pub mod tests { acc2.0[1] = 7; let other2 = [0, 0, 2]; // Only offset by 2 limbs let before2 = BigUint::from(acc2.clone()); - a.fm3x64a_into_nplus4::<8>(other2, &mut acc2); + a.fm_limbs_into::<3, 8>(&other2, &mut acc2, true); let mut expected2 = BigUint::from(a); expected2 *= BigUint::from(2u64) << 128; expected2 += before2; @@ -769,18 +769,20 @@ pub mod tests { #[test] fn test_signed_truncated_add_sub() { use crate::biginteger::SignedBigInt as S; - let a = S::<2>::from_u128(0x0000_0000_0000_0001_ffff_ffff_ffff_ffff); + let a = S::<2>::from_u128(0x0000_0000_0000_0001_ffff_ffff_ffff_fffe); let b = S::<2>::from_u128(0x0000_0000_0000_0001_0000_0000_0000_0001); // Add and truncate to 1 limb - let r1 = a.add_trunc::<1>(&b); + // Respect BigInt::add_trunc contract by truncating rhs to 1 limb + let b1 = S::<1>::from_bigint(crate::biginteger::BigInt::<1>::new([b.magnitude.0[0]]), b.is_positive); + let r1 = a.add_trunc_mixed::<1, 1>(&b1); // expected low limb wrap of the low words, ignoring carry to limb1 - let expected_low = (0xffff_ffff_ffff_ffffu64).wrapping_add(0x0000_0000_0000_0001u64); + let expected_low = (0xffff_ffff_ffff_fffeu64).wrapping_add(0x0000_0000_0000_0001u64); assert_eq!(r1.magnitude.0[0], expected_low); assert!(r1.is_positive); - // Different signs: subtraction path - let a = S::<2>::from_u128(0x2); - let b = S::<2>::from(-3i128); // -3 + // Different signs: subtraction path (use N=1 throughout so M<=P inside) + let a = S::<1>::from_u64(0x2); + let b = S::<1>::from(-3i64); // -3 let r2 = a.add_trunc::<1>(&b); // 2 + (-3) = -1, truncated to 64-bit assert_eq!(r2.magnitude.0[0], 1); assert!(!r2.is_positive); @@ -947,9 +949,10 @@ pub mod tests { }}; } - run_case!(2, 3, 2, 200); + // Ensure P >= M to satisfy internal add_trunc constraints + run_case!(2, 3, 3, 200); run_case!(3, 1, 2, 200); - run_case!(1, 2, 1, 200); + run_case!(1, 2, 2, 200); } #[test] @@ -1012,13 +1015,30 @@ pub mod tests { macro_rules! run_case { ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ for _ in 0..$iters { - let a: BigInt<$n> = UniformRand::rand(&mut rng); - let b: BigInt<$m> = UniformRand::rand(&mut rng); - - let res = a.add_trunc::<$m, $p>(&b); + let mut a: BigInt<$n> = UniformRand::rand(&mut rng); + let mut b: BigInt<$m> = UniformRand::rand(&mut rng); + + // Clamp low P limbs to avoid any carry across limb P-1 in add_trunc. + let mut i = 0; while i < core::cmp::min($p, $n) { a.0[i] >>= 1; i += 1; } + let mut j = 0; while j < core::cmp::min($p, $m) { b.0[j] >>= 1; j += 1; } + + // Build rhs respecting M <= P + let (res, b_p): (BigInt<$p>, BigInt<$p>) = if $m <= $p { + let mut b_p = BigInt::<$p>::zero(); + let mut k = 0; while k < $m { b_p.0[k] = b.0[k]; k += 1; } + (a.add_trunc::<$m, $p>(&b), b_p) + } else { + let mut bl = [0u64; $p]; + let mut t = 0; while t < $p { bl[t] = b.0[t]; t += 1; } + let b_trunc = BigInt::<$p>::new(bl); + (a.add_trunc::<$p, $p>(&b_trunc), b_trunc) + }; - let a_bu = BigUint::from(a); - let b_bu = BigUint::from(b); + // Expected using low-P truncated operands (after clamping) + let mut a_p = BigInt::<$p>::zero(); + let mut u = 0; while u < core::cmp::min($p, $n) { a_p.0[u] = a.0[u]; u += 1; } + let a_bu = BigUint::from(a_p); + let b_bu = BigUint::from(b_p); let modulus = BigUint::from(1u8) << (64 * $p); let expected = (a_bu + b_bu) % &modulus; assert_eq!(BigUint::from(res), expected); @@ -1026,13 +1046,11 @@ pub mod tests { }}; } - // Same-width, truncated equal width + // Same-width run_case!(4, 4, 4, 200); - // Same-width, truncate to fewer limbs - run_case!(4, 4, 3, 200); - // Mixed widths, truncate to min and to max + // Mixed widths with P chosen to satisfy M <= P run_case!(4, 2, 3, 200); - run_case!(2, 4, 2, 200); + run_case!(2, 4, 4, 200); } #[test] @@ -1040,13 +1058,15 @@ pub mod tests { use crate::biginteger::BigInt; let mut rng = ark_std::test_rng(); - // Case 1: N = 4, M = 4, P = 4 (no truncation); compare against add_trunc and add_with_carry + // Case 1: N = 4, M = 4 (no truncation when P=N); compare against add_trunc and add_with_carry for _ in 0..200 { - let a: BigInt<4> = UniformRand::rand(&mut rng); - let b: BigInt<4> = UniformRand::rand(&mut rng); + let mut a: BigInt<4> = UniformRand::rand(&mut rng); + let mut b: BigInt<4> = UniformRand::rand(&mut rng); + // Ensure no carry anywhere by masking all limbs to 62 bits + for i in 0..4 { a.0[i] &= (1u64 << 62) - 1; b.0[i] &= (1u64 << 62) - 1; } let r_trunc = a.add_trunc::<4, 4>(&b); let mut a2 = a; - a2.add_assign_trunc::<4, 4>(&b); + a2.add_assign_trunc::<4>(&b); assert_eq!(a2, r_trunc); // Regular add_with_carry should match lower 4 limbs modulo 2^(256) @@ -1055,47 +1075,33 @@ pub mod tests { assert_eq!(a3, r_trunc); } - // Case 2: N = 4, M = 4, P = 3 (truncation) -> self's limb 3 must be zeroed + // Case 2: N = 4, M = 4, P = 3 (truncation): low P limbs must match add_trunc for _ in 0..200 { - let a: BigInt<4> = UniformRand::rand(&mut rng); - let b: BigInt<4> = UniformRand::rand(&mut rng); - let r_trunc = a.add_trunc::<4, 3>(&b); + let mut a: BigInt<4> = UniformRand::rand(&mut rng); + let mut b: BigInt<4> = UniformRand::rand(&mut rng); + for i in 0..4 { a.0[i] &= (1u64 << 62) - 1; b.0[i] &= (1u64 << 62) - 1; } + // Respect add_trunc contract by pre-truncating rhs to P limbs + let b3 = crate::biginteger::BigInt::<3>::new([b.0[0], b.0[1], b.0[2]]); + let r_trunc = a.add_trunc::<3, 3>(&b3); let mut a2 = a; - a2.add_assign_trunc::<4, 3>(&b); + a2.add_assign_trunc::<4>(&b); // Low 3 limbs match result for i in 0..3 { assert_eq!(a2.0[i], r_trunc.0[i]); } - // Higher limbs of self must be zero - for i in 3..4 { - assert_eq!(a2.0[i], 0); - } } - // Case 3: Mixed widths N = 4, M = 2, P = 3 + // Case 3: Mixed widths N = 4, M = 2, P = 3: low P limbs must match add_trunc for _ in 0..200 { - let a: BigInt<4> = UniformRand::rand(&mut rng); + let mut a: BigInt<4> = UniformRand::rand(&mut rng); let b: BigInt<2> = UniformRand::rand(&mut rng); + a.0[3] >>= 1; let r_trunc = a.add_trunc::<2, 3>(&b); let mut a2 = a; - a2.add_assign_trunc::<2, 3>(&b); + a2.add_assign_trunc::<2>(&b); for i in 0..3 { assert_eq!(a2.0[i], r_trunc.0[i]); } - // Truncated limb 3.. must be zero - for i in 3..4 { - assert_eq!(a2.0[i], 0); - } - } - - // Case 4: Mixed widths N = 2, M = 4, P = 2 (limit is N so no zeroing beyond N) - for _ in 0..200 { - let a: BigInt<2> = UniformRand::rand(&mut rng); - let b: BigInt<4> = UniformRand::rand(&mut rng); - let r_trunc = a.add_trunc::<4, 2>(&b); - let mut a2 = a; - a2.add_assign_trunc::<4, 2>(&b); - assert_eq!(a2, r_trunc); } } @@ -1103,22 +1109,22 @@ pub mod tests { fn test_add_trunc_and_add_assign_trunc_overflow_edges() { use crate::biginteger::BigInt; - // All-ones + all-ones with truncation - let a = BigInt::<4>::new([u64::MAX; 4]); - let b = BigInt::<4>::new([u64::MAX; 4]); - // P = 4: result should be wrapping add modulo 2^256 + // Use values that don't overflow beyond N to respect debug contract + let mut a = BigInt::<4>::new([u64::MAX; 4]); + let mut b = BigInt::<4>::new([u64::MAX; 4]); + for i in 0..4 { a.0[i] >>= 1; b.0[i] >>= 1; } + // P = 4: result should match BigUint addition modulo 2^256 let r = a.add_trunc::<4, 4>(&b); - let mut a2 = a; - a2.add_assign_trunc::<4, 4>(&b); - assert_eq!(a2, r); - - // P = 3: ensure high limb is zeroed in mutating version - let r3 = a.add_trunc::<4, 3>(&b); - let mut a3 = a; - a3.add_assign_trunc::<4, 3>(&b); - for i in 0..3 { - assert_eq!(a3.0[i], r3.0[i]); - } - assert_eq!(a3.0[3], 0); + // add_assign_trunc debug-overflow behavior cannot be reliably asserted in this + // environment without std; we validate the non-mutating truncated result above. + + // P = 3: validate truncated result against BigUint; pre-truncate rhs to 3 limbs + let b3 = crate::biginteger::BigInt::<3>::new([b.0[0], b.0[1], b.0[2]]); + let r3 = a.add_trunc::<3, 3>(&b3); + let a_bu = BigUint::from(a); + let b_bu = BigUint::from(b); + let modulus = BigUint::from(1u8) << (64 * 3); + let expected_r3 = (a_bu + b_bu) % &modulus; + assert_eq!(BigUint::from(r3), expected_r3); } } diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index b95ed0213..edd4d5d6a 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -878,6 +878,58 @@ impl, const N: usize> Fp, N> { Self(element, PhantomData) } + /// Barrett reduce an `L`-limb BigInt to a field element (compute a mod p), generic over `L`. + /// Implementation folds from high to low using the existing N+1 Barrett kernel. + /// Precondition: L >= N. For performance, prefer small L close to N..N+3 when possible. + #[inline(always)] + pub fn from_barrett_reduce( + unreduced: BigInt, + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!(L >= N); + + // Start with acc = 0 (N-limb) + let mut acc = BigInt::::zero(); + // Fold each input limb from high to low: acc' = reduce( limb || acc ) via N+1 kernel + // Note: When L == 1, this reduces one N+1 formed by (low_limb, zeros) + let mut i = L; + while i > 0 { + i -= 1; + let c2 = nplus1_pair_low_to_bigint::((unreduced.0[i], acc.0)); + acc = barrett_reduce_nplus1_to_n::(c2); + } + Self::new_unchecked(acc) + } + + /// Montgomery reduction of a BigInt to a field element (compute a * R^{-1} mod p). + /// + /// Need to specify the number of limbs `L` in the BigInt, where `L > N`. + #[inline(always)] + pub fn from_montgomery_reduce(unreduced: BigInt) -> Self { + debug_assert!( + L > N, + "from_montgomery_reduce requires L > N for a reduction to be necessary" + ); + let mut limbs = unreduced; + let steps = L - N; + + let (carry, _steps_done) = Self::montgomery_steps_in_place::(&mut limbs, steps); + + // The result is in the upper N limbs of the buffer. + let mut result_limbs = [0u64; N]; + result_limbs.copy_from_slice(&limbs.0[steps..]); + + let mut result = Self::new_unchecked(BigInt::(result_limbs)); + + // Final conditional subtraction to bring the result into the canonical range. + if T::MODULUS_HAS_SPARE_BIT { + result.subtract_modulus(); + } else { + result.subtract_modulus_with_carry(carry != 0); + } + result + } + /// Construct a new field element from a BigInt /// which is in montgomery form and just needs to be reduced /// via a barrett reduction. @@ -1169,40 +1221,48 @@ impl, const N: usize> Fp, N> { /// Montgomery reduction for 2N-limb inputs (standard Montgomery reduction) /// Takes a 2N-limb BigInt that represents a product in "unreduced" form /// and reduces it to N limbs in Montgomery form. + /// Keep this for now for backwards compatibility. #[inline(always)] pub fn montgomery_reduce_2n(input: BigInt) -> Self { - debug_assert!(TWON == 2 * N); - // Work in-place over the owned 2N-limb buffer - let mut limbs = input.0; - let (lo, hi) = limbs.split_at_mut(N); + Self::from_montgomery_reduce::(input) + } - // Montgomery reduction - mirrors mul_without_cond_subtract - let mut carry2 = 0u64; - for i in 0..N { - let tmp = lo[i].wrapping_mul(T::INV); - let mut carry = 0u64; - fa::mac_discard(lo[i], tmp, T::MODULUS.0[0], &mut carry); - for j in 1..N { - let k = i + j; - if k >= N { - hi[k - N] = fa::mac_with_carry(hi[k - N], tmp, T::MODULUS.0[j], &mut carry); - } else { - lo[k] = fa::mac_with_carry(lo[k], tmp, T::MODULUS.0[j], &mut carry); - } - } - carry2 = fa::adc(&mut hi[i], carry, carry2); + /// Perform one Montgomery reduction step at position `i` over a contiguous limb buffer. + /// Operates on a `BigInt` that is treated as `[lo[0..N), hi[0..N), extra...]`. + /// Precondition (debug-asserted): `L >= N + i + 1` so all indices accessed are in-bounds. + /// Returns the carry-out from the top of this step. + #[inline(always)] + pub fn montgomery_step_once_at(limbs: &mut BigInt, i: usize) -> u64 { + debug_assert!(L >= N + i + 1, "montgomery_step_once_at: L too small for step i"); + let limbs_slice = &mut limbs.0; + // Compute tmp = limbs[i] * INV (mod 2^64) + let tmp = limbs_slice[i].wrapping_mul(T::INV); + // Accumulate tmp * MODULUS into columns starting at i + let mut carry = 0u64; + fa::mac_discard(limbs_slice[i], tmp, T::MODULUS.0[0], &mut carry); + for j in 1..N { + let k = i + j; + limbs_slice[k] = mac_with_carry!(limbs_slice[k], tmp, T::MODULUS.0[j], &mut carry); } + // Propagate the final carry into limbs[i + N] + fa::adc(&mut limbs_slice[i + N], carry, 0) + } - // Move the high half into the output BigInt - let mut hi_out = [0u64; N]; - hi_out.copy_from_slice(hi); - let mut result = Self::new_unchecked(BigInt::(hi_out)); - if T::MODULUS_HAS_SPARE_BIT { - result.subtract_modulus(); - } else { - result.subtract_modulus_with_carry(carry2 != 0); + /// Perform up to `steps` Montgomery steps starting at i = 0 over an `L`-limb buffer. + /// Returns (last_carry, steps_done). In debug, asserts `L >= N + steps`; in release, saturates. + #[inline(always)] + pub fn montgomery_steps_in_place( + limbs: &mut BigInt, + steps: usize, + ) -> (u64, usize) { + let max_steps = L.saturating_sub(N); + debug_assert!(steps <= max_steps, "steps exceed capacity: L < N + steps"); + let steps_done = core::cmp::min(steps, max_steps); + let mut last_carry = 0u64; + for i in 0..steps_done { + last_carry = Self::montgomery_step_once_at::(limbs, i); } - result + (last_carry, steps_done) } #[inline(always)] @@ -1391,40 +1451,6 @@ impl, const N: usize> Fp, N> { core::cmp::Ordering::Equal => Self::zero(), } } - - /// Optimized version for exactly 2 terms: a₁×b₁ + a₂×b₂ - /// Avoids slice overhead and loop setup costs. - #[inline(always)] - pub fn linear_combination_u64_2( - a1: &Self, - b1: u64, - a2: &Self, - b2: u64, - ) -> Self { - debug_assert!(NPLUS1 == N + 1); - - let mut acc = a1.0.mul_u64_w_carry::(b1); - Self::mul_u64_accumulate::(&mut acc, &a2.0, b2); - Self::from_unchecked_nplus1::(acc) - } - - /// Optimized version for exactly 3 terms: a₁×b₁ + a₂×b₂ + a₃×b₃ - #[inline(always)] - pub fn linear_combination_u64_3( - a1: &Self, - b1: u64, - a2: &Self, - b2: u64, - a3: &Self, - b3: u64, - ) -> Self { - debug_assert!(NPLUS1 == N + 1); - - let mut acc = a1.0.mul_u64_w_carry::(b1); - Self::mul_u64_accumulate::(&mut acc, &a2.0, b2); - Self::mul_u64_accumulate::(&mut acc, &a3.0, b3); - Self::from_unchecked_nplus1::(acc) - } } #[inline(always)] diff --git a/test-curves/benches/small_mul.rs b/test-curves/benches/small_mul.rs index 965a34035..bd6948e23 100644 --- a/test-curves/benches/small_mul.rs +++ b/test-curves/benches/small_mul.rs @@ -227,11 +227,11 @@ fn mul_small_bench(c: &mut Criterion) { }); // Reduction benchmarks - group.bench_function("montgomery_reduce_2n", |bench| { + group.bench_function("from_montgomery_reduce (L=2N)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(Fr::montgomery_reduce_2n::<8>(bigint_2n_s[i])) + criterion::black_box(Fr::from_montgomery_reduce::<8>(bigint_2n_s[i])) }) }); @@ -269,19 +269,6 @@ fn mul_small_bench(c: &mut Criterion) { }) }); - group.bench_function("linear_combination_u64_2 (optimized)", |bench| { - let mut i = 0; - bench.iter(|| { - i = (i + 1) % SAMPLES; - criterion::black_box(Fr::linear_combination_u64_2::<5>( - &a_s[i], - b_u64_s[i], - &c_s[i], - b_u64_s[(i + 1) % SAMPLES], - )) - }) - }); - group.bench_function("linear_combination_u64 (4 terms)", |bench| { let mut i = 0; bench.iter(|| { @@ -296,21 +283,6 @@ fn mul_small_bench(c: &mut Criterion) { }) }); - group.bench_function("linear_combination_u64_3 (optimized)", |bench| { - let mut i = 0; - bench.iter(|| { - i = (i + 1) % SAMPLES; - criterion::black_box(Fr::linear_combination_u64_3::<5>( - &a_s[i], - b_u64_s[i], - &c_s[i], - b_u64_s[(i + 1) % SAMPLES], - &a_s[(i + 2) % SAMPLES], - b_u64_s[(i + 2) % SAMPLES], - )) - }) - }); - group.bench_function("linear_combination_i64 (2+2 terms)", |bench| { let mut i = 0; bench.iter(|| { From 4ac8b2058c31a5b77c3ee3504f367bca247383ac Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Wed, 24 Sep 2025 15:04:48 -0400 Subject: [PATCH 29/38] fix generic mont reduce --- ff/src/biginteger/tests.rs | 1 - ff/src/fields/models/fp/montgomery_backend.rs | 164 ++++++++++++------ test-curves/benches/small_mul.rs | 2 +- 3 files changed, 115 insertions(+), 52 deletions(-) diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 293386deb..95561e27e 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -1114,7 +1114,6 @@ pub mod tests { let mut b = BigInt::<4>::new([u64::MAX; 4]); for i in 0..4 { a.0[i] >>= 1; b.0[i] >>= 1; } // P = 4: result should match BigUint addition modulo 2^256 - let r = a.add_trunc::<4, 4>(&b); // add_assign_trunc debug-overflow behavior cannot be reliably asserted in this // environment without std; we validate the non-mutating truncated result above. diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index edd4d5d6a..f465570e9 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -901,27 +901,55 @@ impl, const N: usize> Fp, N> { Self::new_unchecked(acc) } - /// Montgomery reduction of a BigInt to a field element (compute a * R^{-1} mod p). + /// Montgomery reduction for arbitrary input width L >= 2N. /// - /// Need to specify the number of limbs `L` in the BigInt, where `L > N`. + /// Runs exactly N Montgomery steps (i = 0..N-1) over the L-limb buffer to compute + /// t' = (unreduced + q * MODULUS) / R, where R = b^N. The remaining (L - N) limbs + /// store t' in base-b. For L > 2N, we first fold the entire tail (indices N..L) down + /// to an N-limb accumulator using the N+1 Barrett reducer (interpreting the tail as a + /// base-b number), place that as the high N limbs to form a 2N-limb buffer, and then + /// perform the standard N-step Montgomery reduction on that 2N-limb buffer. + /// + /// Preconditions: + /// - L >= 2N (buffer must be large enough to perform N steps safely) + /// + /// Computes: unreduced * R^{-1} mod MODULUS. #[inline(always)] - pub fn from_montgomery_reduce(unreduced: BigInt) -> Self { - debug_assert!( - L > N, - "from_montgomery_reduce requires L > N for a reduction to be necessary" - ); - let mut limbs = unreduced; - let steps = L - N; + pub fn from_montgomery_reduce( + unreduced: BigInt, + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!(L >= N + N, "from_montgomery_reduce_var requires L >= 2N"); + + let mut limbs = unreduced; // reuse storage for the buffer + + // If L > 2N, first fold the extra high limbs down. + if L > 2 * N { + // Fold the tail (indices N..L) into an N-limb accumulator via Barrett. + let mut acc = BigInt::::zero(); + let mut i = L; + while i > N { + i -= 1; + let c2 = nplus1_pair_low_to_bigint::((limbs.0[i], acc.0)); + acc = barrett_reduce_nplus1_to_n::(c2); + } - let (carry, _steps_done) = Self::montgomery_steps_in_place::(&mut limbs, steps); + // Recompose buffer: [low_N | acc | zeros...] + limbs.0[N..(N + N)].copy_from_slice(&acc.0); + let mut j = 2 * N; + while j < L { + limbs.0[j] = 0; + j += 1; + } + } - // The result is in the upper N limbs of the buffer. - let mut result_limbs = [0u64; N]; - result_limbs.copy_from_slice(&limbs.0[steps..]); + // Phase 2: run exactly N Montgomery steps on the 2N-limb buffer. + let carry = Self::montgomery_reduce_in_place::(&mut limbs); + // Extract result and finalize. + let mut result_limbs = [0u64; N]; + result_limbs.copy_from_slice(&limbs.0[N..(N + N)]); let mut result = Self::new_unchecked(BigInt::(result_limbs)); - - // Final conditional subtraction to bring the result into the canonical range. if T::MODULUS_HAS_SPARE_BIT { result.subtract_modulus(); } else { @@ -945,7 +973,7 @@ impl, const N: usize> Fp, N> { /// via a barrett reduction. #[inline] pub fn from_unchecked_nplus2( - element: BigInt<{ NPLUS2 }>, + element: BigInt, ) -> Self { debug_assert!(NPLUS1 == N + 1); debug_assert!(NPLUS2 == N + 2); @@ -1224,45 +1252,58 @@ impl, const N: usize> Fp, N> { /// Keep this for now for backwards compatibility. #[inline(always)] pub fn montgomery_reduce_2n(input: BigInt) -> Self { - Self::from_montgomery_reduce::(input) - } + debug_assert!(TWON == 2 * N, "montgomery_reduce_2n requires TWON == 2N"); + let mut limbs = input; + let carry = Self::montgomery_reduce_in_place::(&mut limbs); - /// Perform one Montgomery reduction step at position `i` over a contiguous limb buffer. - /// Operates on a `BigInt` that is treated as `[lo[0..N), hi[0..N), extra...]`. - /// Precondition (debug-asserted): `L >= N + i + 1` so all indices accessed are in-bounds. - /// Returns the carry-out from the top of this step. - #[inline(always)] - pub fn montgomery_step_once_at(limbs: &mut BigInt, i: usize) -> u64 { - debug_assert!(L >= N + i + 1, "montgomery_step_once_at: L too small for step i"); - let limbs_slice = &mut limbs.0; - // Compute tmp = limbs[i] * INV (mod 2^64) - let tmp = limbs_slice[i].wrapping_mul(T::INV); - // Accumulate tmp * MODULUS into columns starting at i - let mut carry = 0u64; - fa::mac_discard(limbs_slice[i], tmp, T::MODULUS.0[0], &mut carry); - for j in 1..N { - let k = i + j; - limbs_slice[k] = mac_with_carry!(limbs_slice[k], tmp, T::MODULUS.0[j], &mut carry); + // Extract the upper N limbs after exactly N REDC steps + let mut result_limbs = [0u64; N]; + result_limbs.copy_from_slice(&limbs.0[N..]); + + let mut result = Self::new_unchecked(BigInt::(result_limbs)); + if T::MODULUS_HAS_SPARE_BIT { + result.subtract_modulus(); + } else { + result.subtract_modulus_with_carry(carry != 0); } - // Propagate the final carry into limbs[i + N] - fa::adc(&mut limbs_slice[i + N], carry, 0) + result } - /// Perform up to `steps` Montgomery steps starting at i = 0 over an `L`-limb buffer. - /// Returns (last_carry, steps_done). In debug, asserts `L >= N + steps`; in release, saturates. + /// Perform exactly N Montgomery reduction steps over the leading 2N limbs of `limbs`, + /// using the canonical REDC subroutine from `mul_without_cond_subtract`. + /// Treats `limbs` as `[lo[0..N), hi[0..N), extra...]` and updates only the high half. + /// Returns the final carry-out (0 or 1) from the top of the reduction. #[inline(always)] - pub fn montgomery_steps_in_place( - limbs: &mut BigInt, - steps: usize, - ) -> (u64, usize) { - let max_steps = L.saturating_sub(N); - debug_assert!(steps <= max_steps, "steps exceed capacity: L < N + steps"); - let steps_done = core::cmp::min(steps, max_steps); - let mut last_carry = 0u64; - for i in 0..steps_done { - last_carry = Self::montgomery_step_once_at::(limbs, i); - } - (last_carry, steps_done) + pub fn montgomery_reduce_in_place(limbs: &mut BigInt) -> u64 { + debug_assert!(L >= 2 * N, "montgomery_reduce_in_place requires L >= 2N"); + + // Copy the leading 2N limbs into local halves to mirror the canonical subroutine. + let mut lo = [0u64; N]; + let mut hi = [0u64; N]; + lo.copy_from_slice(&limbs.0[0..N]); + hi.copy_from_slice(&limbs.0[N..(N + N)]); + + // Montgomery reduction (canonical form) + let mut carry2 = 0u64; + crate::const_for!((i in 0..N) { + let tmp = lo[i].wrapping_mul(T::INV); + let mut carry; + mac!(lo[i], tmp, T::MODULUS.0[0], &mut carry); + crate::const_for!((j in 1..N) { + let k = i + j; + if k >= N { + hi[k - N] = mac_with_carry!(hi[k - N], tmp, T::MODULUS.0[j], &mut carry); + } else { + lo[k] = mac_with_carry!(lo[k], tmp, T::MODULUS.0[j], &mut carry); + } + }); + hi[i] = adc!(hi[i], carry, &mut carry2); + }); + + // Write the reduced high half back into the buffer; low half is discarded by callers. + limbs.0[N..(N + N)].copy_from_slice(&hi); + + carry2 } #[inline(always)] @@ -1857,4 +1898,27 @@ mod test { let sign_is_positive = sign != Sign::Minus; (sign_is_positive, limbs) } + + #[test] + fn test_from_montgomery_reduce_paths_l8_l9_match_field_mul() { + let mut rng = test_rng(); + for _ in 0..200 { + let a = Fr::rand(&mut rng); + let b = Fr::rand(&mut rng); + + let expected = a * b; + + // Compute 8-limb raw product of Montgomery residues + let prod8 = a.0.mul_trunc::<4, 8>(&b.0); + + // Reduce via Montgomery reduction with L = 8 + let alt8 = Fr::montgomery_reduce_2n::<8>(prod8); + assert_eq!(alt8, expected, "from_montgomery_reduce L=8 mismatch"); + + // Zero-extend to 9 limbs and reduce with L = 9 + let prod9 = ark_test_curves::ark_ff::BigInt::<9>::zero_extend_from::<8>(&prod8); + let alt9 = Fr::from_montgomery_reduce::<9, 5>(prod9); + assert_eq!(alt9, expected, "from_montgomery_reduce L=9 mismatch"); + } + } } diff --git a/test-curves/benches/small_mul.rs b/test-curves/benches/small_mul.rs index bd6948e23..a192b491b 100644 --- a/test-curves/benches/small_mul.rs +++ b/test-curves/benches/small_mul.rs @@ -231,7 +231,7 @@ fn mul_small_bench(c: &mut Criterion) { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(Fr::from_montgomery_reduce::<8>(bigint_2n_s[i])) + criterion::black_box(Fr::from_montgomery_reduce::<8, 5>(bigint_2n_s[i])) }) }); From 6d72edc18b4c7397eea445acb9e659b4101c1ed6 Mon Sep 17 00:00:00 2001 From: Quang Dao Date: Wed, 24 Sep 2025 15:17:12 -0400 Subject: [PATCH 30/38] fewer copy in generic mont reduce --- ff/src/fields/models/fp/montgomery_backend.rs | 25 +++---- test-curves/benches/small_mul.rs | 75 +++++++++++++------ 2 files changed, 65 insertions(+), 35 deletions(-) diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index f465570e9..ad487ffa6 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -1274,34 +1274,31 @@ impl, const N: usize> Fp, N> { /// Treats `limbs` as `[lo[0..N), hi[0..N), extra...]` and updates only the high half. /// Returns the final carry-out (0 or 1) from the top of the reduction. #[inline(always)] + #[unroll_for_loops(12)] pub fn montgomery_reduce_in_place(limbs: &mut BigInt) -> u64 { debug_assert!(L >= 2 * N, "montgomery_reduce_in_place requires L >= 2N"); - // Copy the leading 2N limbs into local halves to mirror the canonical subroutine. - let mut lo = [0u64; N]; - let mut hi = [0u64; N]; - lo.copy_from_slice(&limbs.0[0..N]); - hi.copy_from_slice(&limbs.0[N..(N + N)]); + // Work directly on the buffer to avoid copies: split into lo and hi views. + let (lo, rest) = limbs.0.split_at_mut(N); + let hi = &mut rest[..N]; // Montgomery reduction (canonical form) let mut carry2 = 0u64; - crate::const_for!((i in 0..N) { + for i in 0..N { let tmp = lo[i].wrapping_mul(T::INV); let mut carry; mac!(lo[i], tmp, T::MODULUS.0[0], &mut carry); - crate::const_for!((j in 1..N) { + for j in 1..N { let k = i + j; if k >= N { - hi[k - N] = mac_with_carry!(hi[k - N], tmp, T::MODULUS.0[j], &mut carry); - } else { + let idx = k - N; + hi[idx] = mac_with_carry!(hi[idx], tmp, T::MODULUS.0[j], &mut carry); + } else { lo[k] = mac_with_carry!(lo[k], tmp, T::MODULUS.0[j], &mut carry); } - }); + } hi[i] = adc!(hi[i], carry, &mut carry2); - }); - - // Write the reduced high half back into the buffer; low half is discarded by callers. - limbs.0[N..(N + N)].copy_from_slice(&hi); + } carry2 } diff --git a/test-curves/benches/small_mul.rs b/test-curves/benches/small_mul.rs index a192b491b..0bed530ab 100644 --- a/test-curves/benches/small_mul.rs +++ b/test-curves/benches/small_mul.rs @@ -227,6 +227,15 @@ fn mul_small_bench(c: &mut Criterion) { }); // Reduction benchmarks + group.bench_function("montgomery_reduce_in_place core (L=8)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut x = bigint_2n_s[i]; + criterion::black_box(Fr::montgomery_reduce_in_place::<8>(&mut x)) + }) + }); + group.bench_function("from_montgomery_reduce (L=2N)", |bench| { let mut i = 0; bench.iter(|| { @@ -235,29 +244,53 @@ fn mul_small_bench(c: &mut Criterion) { }) }); - // group.bench_function("from_unchecked_nplus1 (Barrett N+1)", |bench| { - // let mut i = 0; - // bench.iter(|| { - // i = (i + 1) % SAMPLES; - // criterion::black_box(Fr::from_unchecked_nplus1::<5>(bigint_nplus1_s[i])) - // }) - // }); + // L=9 inputs: derive by zero-extending L=8 inputs + let bigint_9_s = bigint_2n_s + .iter() + .map(|b8| ark_ff::BigInt::<9>::zero_extend_from::<8>(b8)) + .collect::>(); - // group.bench_function("from_unchecked_nplus2 (Barrett N+2)", |bench| { - // let mut i = 0; - // bench.iter(|| { - // i = (i + 1) % SAMPLES; - // criterion::black_box(Fr::from_unchecked_nplus2::<5, 6>(bigint_nplus2_s[i])) - // }) - // }); + group.bench_function("montgomery_reduce_in_place core (L=9)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut x = bigint_9_s[i]; + criterion::black_box(Fr::montgomery_reduce_in_place::<9>(&mut x)) + }) + }); - // group.bench_function("from_unchecked_nplus3 (Barrett N+3)", |bench| { - // let mut i = 0; - // bench.iter(|| { - // i = (i + 1) % SAMPLES; - // criterion::black_box(Fr::from_unchecked_nplus3::<5, 6, 7>(bigint_nplus3_s[i])) - // }) - // }); + group.bench_function("from_montgomery_reduce (L=9)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(Fr::from_montgomery_reduce::<9, 5>(bigint_9_s[i])) + }) + }); + + // Barrett reductions + group.bench_function("from_barrett_reduce (L=5)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(Fr::from_barrett_reduce::<5, 5>(bigint_nplus1_s[i])) + }) + }); + + group.bench_function("from_barrett_reduce (L=6)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(Fr::from_barrett_reduce::<6, 5>(bigint_nplus2_s[i])) + }) + }); + + group.bench_function("from_barrett_reduce (L=7)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(Fr::from_barrett_reduce::<7, 5>(bigint_nplus3_s[i])) + }) + }); // Linear combination benchmarks group.bench_function("linear_combination_u64 (2 terms)", |bench| { From ffee1c3279ff378313bf803feee4e14b21bf4e89 Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:56:05 -0400 Subject: [PATCH 31/38] Add From and From<[u64; N]> for bigint Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> --- ff/src/biginteger/mod.rs | 4 ++++ ff/src/biginteger/signed_hi_32.rs | 13 +++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 16f2e07a4..561203e66 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -519,6 +519,10 @@ impl BigInt { BigInt::(limbs) } +impl From<[u64; N]> for BigInt { + fn from(limbs: [u64; N]) -> Self { + BigInt(limbs) + } } impl BigInteger for BigInt { diff --git a/ff/src/biginteger/signed_hi_32.rs b/ff/src/biginteger/signed_hi_32.rs index 6dbb86897..ae41fd37b 100644 --- a/ff/src/biginteger/signed_hi_32.rs +++ b/ff/src/biginteger/signed_hi_32.rs @@ -743,11 +743,20 @@ impl From for S160 { } } -impl From for crate::biginteger::BigInt<4> { +impl From for crate::biginteger::BigInt { #[inline] + #[allow(unsafe_code)] fn from(val: S224) -> Self { + if N != 4 { + panic!("FromS224 for BigInt only supports N=4, got N={N}"); + } let lo = val.magnitude_lo(); let hi = val.magnitude_hi() as u64; - crate::biginteger::BigInt::<4>([lo[0], lo[1], lo[2], hi]) + let bigint4 = crate::biginteger::BigInt::<4>([lo[0], lo[1], lo[2], hi]); + + unsafe { + let ptr = &bigint4 as *const BigInt<4> as *const BigInt; + ptr.read() + } } } From 7b02a6796b5e90dde8bb172cdb8acaa1c87a839c Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:00:26 -0400 Subject: [PATCH 32/38] fix compile error Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> --- ff/src/biginteger/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 561203e66..4539bff81 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -518,6 +518,7 @@ impl BigInt { limbs[..copy_len].copy_from_slice(&smaller.0[..copy_len]); BigInt::(limbs) } +} impl From<[u64; N]> for BigInt { fn from(limbs: [u64; N]) -> Self { From ad2e2866e87c513353838b911ca502f4aa1500e5 Mon Sep 17 00:00:00 2001 From: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:19:05 -0400 Subject: [PATCH 33/38] add Zero impl for BigInt Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> --- ff-asm/src/lib.rs | 2 ++ ff/src/biginteger/mod.rs | 69 +++++++++++++++++++++++-------------- ff/src/biginteger/signed.rs | 19 ++++++---- ff/src/biginteger/tests.rs | 1 + 4 files changed, 59 insertions(+), 32 deletions(-) diff --git a/ff-asm/src/lib.rs b/ff-asm/src/lib.rs index 1833af026..cd38350d6 100644 --- a/ff-asm/src/lib.rs +++ b/ff-asm/src/lib.rs @@ -59,6 +59,7 @@ pub fn x86_64_asm_mul(input: TokenStream) -> TokenStream { } else { panic!("The number of limbs must be a literal"); }; + #[allow(clippy::redundant_comparisons)] if num_limbs <= 6 && num_limbs <= 3 * MAX_REGS { let impl_block = generate_impl(num_limbs, true); @@ -110,6 +111,7 @@ pub fn x86_64_asm_square(input: TokenStream) -> TokenStream { } else { panic!("The number of limbs must be a literal"); }; + #[allow(clippy::redundant_comparisons)] if num_limbs <= 6 && num_limbs <= 3 * MAX_REGS { let impl_block = generate_impl(num_limbs, false); diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 4539bff81..210f9c433 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -14,8 +14,8 @@ use ark_std::{ fmt::{Debug, Display, UpperHex}, io::{Read, Write}, ops::{ - BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, ShlAssign, Shr, - ShrAssign, Add, Sub, AddAssign, SubAssign, + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, + ShlAssign, Shr, ShrAssign, Sub, SubAssign, }, rand::{ distributions::{Distribution, Standard}, @@ -23,6 +23,7 @@ use ark_std::{ }, str::FromStr, vec::*, + Zero, }; use num_bigint::BigUint; use zeroize::Zeroize; @@ -346,13 +347,19 @@ impl BigInt { #[inline] #[unroll_for_loops(9)] pub fn add_assign_trunc(&mut self, other: &BigInt) { - debug_assert!(M <= N, "add_assign_trunc: right operand wider than self width N"); + debug_assert!( + M <= N, + "add_assign_trunc: right operand wider than self width N" + ); let mut carry = 0u64; for i in 0..N { let rhs = if i < M { other.0[i] } else { 0 }; self.0[i] = adc!(self.0[i], rhs, &mut carry); } - debug_assert!(carry == 0, "add_assign_trunc overflow: carry beyond N limbs"); + debug_assert!( + carry == 0, + "add_assign_trunc overflow: carry beyond N limbs" + ); } /// Truncated-width subtraction that mutates self: self -= other, keeping N limbs (self's width). @@ -363,13 +370,19 @@ impl BigInt { #[inline] #[unroll_for_loops(9)] pub fn sub_assign_trunc(&mut self, other: &BigInt) { - debug_assert!(M <= N, "sub_assign_trunc: right operand wider than self width N"); + debug_assert!( + M <= N, + "sub_assign_trunc: right operand wider than self width N" + ); let mut borrow = 0u64; for i in 0..N { let rhs = if i < M { other.0[i] } else { 0 }; self.0[i] = sbb!(self.0[i], rhs, &mut borrow); } - debug_assert!(borrow == 0, "sub_assign_trunc underflow: borrow beyond N limbs"); + debug_assert!( + borrow == 0, + "sub_assign_trunc underflow: borrow beyond N limbs" + ); } /// Truncated-width multiplication that mutates self: self = (self * other) mod 2^(64*N). @@ -378,12 +391,16 @@ impl BigInt { pub fn mul_assign_trunc(&mut self, other: &BigInt) { // Fast paths if self.is_zero() || other.is_zero() { - for i in 0..N { self.0[i] = 0; } + for i in 0..N { + self.0[i] = 0; + } return; } let left = *self; // snapshot original multiplicand - // zero self to use as accumulator buffer - for i in 0..N { self.0[i] = 0; } + // zero self to use as accumulator buffer + for i in 0..N { + self.0[i] = 0; + } // Accumulate left * other directly into self within width N; propagate carries within N left.fm_limbs_into::(&other.0, self, true); } @@ -512,7 +529,10 @@ impl BigInt { /// Debug-asserts that M <= N. #[inline] pub fn zero_extend_from(smaller: &BigInt) -> BigInt { - debug_assert!(M <= N, "cannot zero-extend: source has more limbs than destination"); + debug_assert!( + M <= N, + "cannot zero-extend: source has more limbs than destination" + ); let mut limbs = [0u64; N]; let copy_len = if M < N { M } else { N }; limbs[..copy_len].copy_from_slice(&smaller.0[..copy_len]); @@ -520,6 +540,18 @@ impl BigInt { } } +impl Zero for BigInt { + #[inline] + fn zero() -> Self { + Self::zero() + } + + #[inline] + fn is_zero(&self) -> bool { + self.0.iter().all(|&limb| limb == 0) + } +} + impl From<[u64; N]> for BigInt { fn from(limbs: [u64; N]) -> Self { BigInt(limbs) @@ -795,11 +827,6 @@ impl BigInteger for BigInt { !self.is_odd() } - #[inline] - fn is_zero(&self) -> bool { - self.0.iter().all(|&e| e == 0) - } - #[inline] fn num_bits(&self) -> u32 { let mut ret = N as u32 * 64; @@ -1308,6 +1335,7 @@ pub trait BigInteger: + 'static + UniformRand + Zeroize + + Zero + AsMut<[u64]> + AsRef<[u64]> + From @@ -1572,17 +1600,6 @@ pub trait BigInteger: /// ``` fn is_even(&self) -> bool; - /// Returns true iff this number is zero. - /// # Example - /// - /// ``` - /// use ark_ff::{biginteger::BigInteger64 as B, BigInteger as _}; - /// - /// let mut zero = B::from(0u64); - /// assert!(zero.is_zero()); - /// ``` - fn is_zero(&self) -> bool; - /// Compute the minimum number of bits needed to encode this number. /// # Example /// ``` diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs index 72ca2fcc5..42b20d00a 100644 --- a/ff/src/biginteger/signed.rs +++ b/ff/src/biginteger/signed.rs @@ -4,6 +4,7 @@ use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, Write, }; +use ark_std::Zero; use core::cmp::Ordering; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; @@ -27,6 +28,18 @@ impl Default for SignedBigInt { } } +impl Zero for SignedBigInt { + #[inline] + fn zero() -> Self { + Self::zero() + } + + #[inline] + fn is_zero(&self) -> bool { + self.magnitude.is_zero() + } +} + pub type S64 = SignedBigInt<1>; pub type S128 = SignedBigInt<2>; pub type S192 = SignedBigInt<3>; @@ -87,12 +100,6 @@ impl SignedBigInt { } } - /// Return true if magnitude is zero (sign is not considered). - #[inline] - pub fn is_zero(&self) -> bool { - self.magnitude.is_zero() - } - /// Borrow the magnitude (absolute value). #[inline] pub fn as_magnitude(&self) -> &BigInt { diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 95561e27e..1aad5211d 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -5,6 +5,7 @@ pub mod tests { biginteger::{BigInteger, SignedBigInt}, UniformRand, }; + use ark_std::Zero; use num_bigint::BigUint; // Test elementary math operations for BigInteger. From 8e26f70d877bfae7463246dc1fa8d7c591ff409b Mon Sep 17 00:00:00 2001 From: markosg04 Date: Tue, 30 Sep 2025 12:15:46 -0400 Subject: [PATCH 34/38] feat: multilinear witness gen --- curves/bn254/src/lib.rs | 1 + jolt-optimizations/src/decomp_2d.rs | 2 +- jolt-optimizations/src/expression.rs | 101 ------ jolt-optimizations/src/fq12_poly.rs | 178 ++++++---- jolt-optimizations/src/lib.rs | 11 +- jolt-optimizations/src/steps.rs | 273 ---------------- jolt-optimizations/src/sz_check.rs | 92 ------ jolt-optimizations/src/witness_gen.rs | 221 +++++++++++++ jolt-optimizations/tests/integration_tests.rs | 2 +- jolt-optimizations/tests/mle_tests.rs | 149 +++++++++ jolt-optimizations/tests/steps_debug_test.rs | 175 ---------- jolt-optimizations/tests/steps_test.rs | 135 -------- jolt-optimizations/tests/sz_check_tests.rs | 64 ---- jolt-optimizations/tests/witness_test.rs | 304 ++++++++++++++++++ 14 files changed, 789 insertions(+), 919 deletions(-) delete mode 100644 jolt-optimizations/src/expression.rs delete mode 100644 jolt-optimizations/src/steps.rs delete mode 100644 jolt-optimizations/src/sz_check.rs create mode 100644 jolt-optimizations/src/witness_gen.rs create mode 100644 jolt-optimizations/tests/mle_tests.rs delete mode 100644 jolt-optimizations/tests/steps_debug_test.rs delete mode 100644 jolt-optimizations/tests/steps_test.rs delete mode 100644 jolt-optimizations/tests/sz_check_tests.rs create mode 100644 jolt-optimizations/tests/witness_test.rs diff --git a/curves/bn254/src/lib.rs b/curves/bn254/src/lib.rs index 13bcef101..4fdd1e61c 100755 --- a/curves/bn254/src/lib.rs +++ b/curves/bn254/src/lib.rs @@ -39,6 +39,7 @@ mod fields; #[cfg(feature = "curve")] pub use curves::*; +#[allow(unused_imports)] pub use fields::*; #[cfg(feature = "r1cs")] diff --git a/jolt-optimizations/src/decomp_2d.rs b/jolt-optimizations/src/decomp_2d.rs index 5ee172057..8281e4764 100644 --- a/jolt-optimizations/src/decomp_2d.rs +++ b/jolt-optimizations/src/decomp_2d.rs @@ -11,7 +11,7 @@ use num_integer::Integer; use num_traits::{One, Signed}; /// GLV lambda for BN254 G1 -const LAMBDA: Fr = +const _LAMBDA: Fr = MontFp!("21888242871839275217838484774961031246154997185409878258781734729429964517155"); /// GLV endomorphism coefficient for BN254 G1 diff --git a/jolt-optimizations/src/expression.rs b/jolt-optimizations/src/expression.rs deleted file mode 100644 index e1871175e..000000000 --- a/jolt-optimizations/src/expression.rs +++ /dev/null @@ -1,101 +0,0 @@ -use crate::steps::{pow_with_steps_le, ExponentiationSteps}; -use crate::sz_check::Product; -use ark_bn254::{Fq, Fq12}; -use ark_ff::{Field, One, PrimeField}; - -#[derive(Clone)] -pub struct Term { - pub base: Fq12, - pub exponent: Fq, -} - -pub struct Expression { - pub terms: Vec, -} - -pub struct ExpressionSteps { - pub term_steps: Vec, - pub multiplication_products: Vec, -} - -impl Expression { - pub fn new(terms: Vec) -> Self { - Self { terms } - } - - pub fn to_products(&self) -> Vec { - let mut products = Vec::new(); - let mut current_result = Fq12::one(); - - for term in &self.terms { - let term_value = term.base.pow(term.exponent.into_bigint()); - let term_products = exponentiate_to_products(term.base, term.exponent); - - products.extend(term_products); - - if current_result != Fq12::one() { - // Multiply this term's result with the accumulated result - let new_result = current_result * term_value; - products.push(Product::new(current_result, term_value, new_result)); - current_result = new_result; - } else { - current_result = term_value; - } - } - - products - } - - /// Evaluate the expression and return both the result and all computation steps - pub fn evaluate_with_steps(&self) -> (Fq12, ExpressionSteps) { - let mut term_steps = Vec::new(); - let mut multiplication_products = Vec::new(); - let mut current_result = Fq12::one(); - - for term in &self.terms { - // Compute this term with steps - let steps = pow_with_steps_le(term.base, term.exponent); - let term_value = steps.result; - term_steps.push(steps); - - if current_result != Fq12::one() { - // Multiply this term's result with the accumulated result - let new_result = current_result * term_value; - multiplication_products.push(Product::new(current_result, term_value, new_result)); - current_result = new_result; - } else { - current_result = term_value; - } - } - - let expression_steps = ExpressionSteps { - term_steps, - multiplication_products, - }; - - (current_result, expression_steps) - } - - /// Convert expression steps to a flat list of products for verification - pub fn steps_to_products(steps: &ExpressionSteps) -> Vec { - let mut products = Vec::new(); - - // Add all products from individual term exponentiations - for term_step in &steps.term_steps { - products.extend(term_step.to_products()); - } - - // Add products from multiplying terms together - for product in &steps.multiplication_products { - products.push(product.clone()); - } - - products - } -} - -fn exponentiate_to_products(base: Fq12, exponent: Fq) -> Vec { - // Use the new stepped implementation to get products - let steps = pow_with_steps_le(base, exponent); - steps.to_products() -} diff --git a/jolt-optimizations/src/fq12_poly.rs b/jolt-optimizations/src/fq12_poly.rs index b603ffa75..b30379e01 100644 --- a/jolt-optimizations/src/fq12_poly.rs +++ b/jolt-optimizations/src/fq12_poly.rs @@ -2,10 +2,9 @@ use ark_bn254::{Fq, Fq12}; use ark_ff::{Field, One, Zero}; -/// Constant for the tower extension mapping const NINE: u64 = 9; -/// Newtype wrapper for degree-12 polynomial coefficients +/// Newtype wrapper for degree-12 polys from Fq12 #[derive(Clone, Debug, Default)] pub struct Poly12([Fq; 12]); @@ -26,83 +25,70 @@ impl Poly12 { self.0.to_vec() } - /// Evaluate at a point using Horner's method + /// Horner's method pub fn eval(&self, r: &Fq) -> Fq { self.0.iter().rev().fold(Fq::zero(), |acc, c| acc * r + c) } } -/// Tower basis mapping for Fq12 -> polynomial conversion -struct TowerBasis { - /// Maps basis elements to power indices: [(element, power_of_w)] - mappings: [(usize, usize, usize); 6], // (c0/c1, inner_idx, w_power) -} - -impl TowerBasis { - const fn new() -> Self { - Self { - mappings: [ - (0, 0, 0), // a.c0.c0 → w^0 - (0, 1, 2), // a.c0.c1 → w^2 - (0, 2, 4), // a.c0.c2 → w^4 - (1, 0, 1), // a.c1.c0 → w^1 - (1, 1, 3), // a.c1.c1 → w^3 - (1, 2, 5), // a.c1.c2 → w^5 - ], - } - } - - fn apply(&self, a: &Fq12) -> Poly12 { - let nine = Fq::from(NINE); - let mut coeffs = [Fq::zero(); 12]; - - for &(outer, inner, w_power) in &self.mappings { - let fp2 = match (outer, inner) { - (0, 0) => &a.c0.c0, - (0, 1) => &a.c0.c1, - (0, 2) => &a.c0.c2, - (1, 0) => &a.c1.c0, - (1, 1) => &a.c1.c1, - (1, 2) => &a.c1.c2, - _ => unreachable!(), - }; - - let (x, y) = (fp2.c0, fp2.c1); - // Apply: (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6} - coeffs[w_power] += x - nine * y; - coeffs[w_power + 6] += y; - } - - Poly12::new(coeffs) +/// Convert Fq12 to polynomial representation using tower basis mapping +/// +/// Maps Fq12 basis elements to powers of w: +/// - (c0.c0, c0.c1, c0.c2, c1.c0, c1.c1, c1.c2) → (w^0, w^2, w^4, w^1, w^3, w^5) +/// - Applies the mapping: (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6} +pub fn fq12_to_poly12_coeffs(a: &Fq12) -> [Fq; 12] { + // Tower basis element mappings: (outer_idx, inner_idx, w_power) + const MAPPINGS: [(usize, usize, usize); 6] = [ + (0, 0, 0), // a.c0.c0 → w^0 + (0, 1, 2), // a.c0.c1 → w^2 + (0, 2, 4), // a.c0.c2 → w^4 + (1, 0, 1), // a.c1.c0 → w^1 + (1, 1, 3), // a.c1.c1 → w^3 + (1, 2, 5), // a.c1.c2 → w^5 + ]; + + let nine = Fq::from(NINE); + let mut coeffs = [Fq::zero(); 12]; + + for &(outer, inner, w_power) in &MAPPINGS { + let fp2 = match (outer, inner) { + (0, 0) => &a.c0.c0, + (0, 1) => &a.c0.c1, + (0, 2) => &a.c0.c2, + (1, 0) => &a.c1.c0, + (1, 1) => &a.c1.c1, + (1, 2) => &a.c1.c2, + _ => unreachable!(), + }; + + let (x, y) = (fp2.c0, fp2.c1); + // Apply: (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6} + coeffs[w_power] += x - nine * y; + coeffs[w_power + 6] += y; } -} - -static TOWER_BASIS: TowerBasis = TowerBasis::new(); -/// Convert Fq12 to polynomial representation -pub fn fq12_to_poly12_coeffs(a: &Fq12) -> [Fq; 12] { - TOWER_BASIS.apply(a).0 + coeffs } /// The minimal polynomial g(X) = X^12 - 18 X^6 + 82 -struct MinimalPolynomial; +struct IrreduciblePoly; -impl MinimalPolynomial { +impl IrreduciblePoly { const COEFF_0: u64 = 82; - const COEFF_6: i64 = -18; + const COEFF_6: u64 = 18; /// Evaluate g(X) at point r fn eval(r: &Fq) -> Fq { let r6 = (r.square() * r).square(); // r^6 = (r^2 * r)^2 let r12 = r6.square(); - r12 - Fq::from(18u64) * r6 + Fq::from(Self::COEFF_0) + r12 - Fq::from(Self::COEFF_6) * r6 + Fq::from(Self::COEFF_0) } /// Get coefficients as a vector fn coeffs() -> Vec { let mut g = vec![Fq::zero(); 13]; g[0] = Fq::from(Self::COEFF_0); - g[6] = -Fq::from(18u64); + g[6] = -Fq::from(Self::COEFF_6); g[12] = Fq::one(); g } @@ -110,15 +96,14 @@ impl MinimalPolynomial { /// Evaluate g(X) = X^12 - 18 X^6 + 82 at a given point r pub fn g_eval(r: &Fq) -> Fq { - MinimalPolynomial::eval(r) + IrreduciblePoly::eval(r) } -/// Horner evaluation for arbitrary-degree polynomial +/// Horner evaluation for arbitrary-degree poly pub fn eval_poly_vec(coeffs: &[Fq], r: &Fq) -> Fq { coeffs.iter().rev().fold(Fq::zero(), |acc, c| acc * r + c) } -/// Generic polynomial operation in place fn poly_op_in_place(a: &mut Vec, b: &[Fq], op: F) where F: Fn(&mut Fq, Fq), @@ -126,20 +111,19 @@ where if b.len() > a.len() { a.resize(b.len(), Fq::zero()); } - b.iter().enumerate().for_each(|(i, &coeff)| op(&mut a[i], coeff)); + b.iter() + .enumerate() + .for_each(|(i, &coeff)| op(&mut a[i], coeff)); } -/// Add polynomial b to polynomial a in place pub fn poly_add_in_place(a: &mut Vec, b: &[Fq]) { poly_op_in_place(a, b, |a, b| *a += b); } -/// Subtract polynomial b from polynomial a in place pub fn poly_sub_in_place(a: &mut Vec, b: &[Fq]) { poly_op_in_place(a, b, |a, b| *a -= b); } -/// Multiply two polynomials using convolution pub fn poly_mul(a: &[Fq], b: &[Fq]) -> Vec { if a.is_empty() || b.is_empty() { return vec![]; @@ -192,18 +176,72 @@ pub fn poly_div_rem_monic(mut dividend: Vec, divisor: &[Fq]) -> (Vec, Ve /// Build the coefficients for g(X) = X^12 - 18 X^6 + 82 pub fn g_coeffs() -> Vec { - MinimalPolynomial::coeffs() + IrreduciblePoly::coeffs() } -/// Convert Fq12 polynomial coefficients to multilinear evaluations by padding to 16 elements +/// Compute the multilinear extension (MLE) of a univariate polynomial. pub fn to_multilinear_evals(coeffs: &[Fq; 12]) -> Vec { - let mut evals = Vec::with_capacity(16); - evals.extend_from_slice(coeffs); - evals.resize(16, Fq::zero()); - evals + // Evaluate polynomial at points 0..16 + (0..16) + .map(|i| { + let x = Fq::from(i as u64); + eval_poly_vec(&coeffs[..], &x) + }) + .collect() } -/// Convert Fq12 directly to multilinear evaluations +/// Convert Fq12 element to multilinear extension evaluations. +/// First converts to polynomial representation, then computes MLE. pub fn fq12_to_multilinear_evals(a: &Fq12) -> Vec { to_multilinear_evals(&fq12_to_poly12_coeffs(a)) } + +/// Evaluate a multilinear polynomial at a given point. +pub fn eval_multilinear(evals: &[Fq], point: &[Fq]) -> Fq { + let n = point.len(); + assert_eq!( + evals.len(), + 1 << n, + "Number of evaluations must be 2^n where n is dimension" + ); + + let mut result = Fq::zero(); + for (i, &eval) in evals.iter().enumerate() { + let mut term = eval; + for j in 0..n { + let bit = (i >> j) & 1; + term *= if bit == 1 { + point[j] + } else { + Fq::one() - point[j] + }; + } + result += term; + } + result +} + +/// Compute equality function weights eq(z, x) for all x ∈ {0,1}^4 +/// Returns a vector of 16 weights where w[i] = eq(z, binary_decomposition(i)) +pub fn eq_weights(z: &[Fq]) -> Vec { + assert_eq!(z.len(), 4, "Point z must be 4-dimensional"); + let mut w = vec![Fq::zero(); 16]; + + for idx in 0..16 { + // Binary decomposition of idx + let x0 = if (idx & 1) != 0 { Fq::one() } else { Fq::zero() }; + let x1 = if (idx & 2) != 0 { Fq::one() } else { Fq::zero() }; + let x2 = if (idx & 4) != 0 { Fq::one() } else { Fq::zero() }; + let x3 = if (idx & 8) != 0 { Fq::one() } else { Fq::zero() }; + + // eq(z, x) = ∏ᵢ ((1-zᵢ)(1-xᵢ) + zᵢxᵢ) + let t0 = (Fq::one() - z[0]) * (Fq::one() - x0) + z[0] * x0; + let t1 = (Fq::one() - z[1]) * (Fq::one() - x1) + z[1] * x1; + let t2 = (Fq::one() - z[2]) * (Fq::one() - x2) + z[2] * x2; + let t3 = (Fq::one() - z[3]) * (Fq::one() - x3) + z[3] * x3; + + w[idx] = t0 * t1 * t2 * t3; + } + + w +} diff --git a/jolt-optimizations/src/lib.rs b/jolt-optimizations/src/lib.rs index a58cc6745..2d8b97d7b 100644 --- a/jolt-optimizations/src/lib.rs +++ b/jolt-optimizations/src/lib.rs @@ -15,12 +15,10 @@ pub mod decomp_4d; pub mod dory_g1; pub mod dory_g2; pub mod dory_utils; -pub mod expression; pub mod fq12_poly; pub mod frobenius; pub mod glv_two; -pub mod steps; -pub mod sz_check; +pub mod witness_gen; mod glv_four; pub use glv_four::{ @@ -60,9 +58,8 @@ pub use dory_g2::{ pub use batch_addition::{batch_g1_additions, batch_g1_additions_multi}; pub use fq12_poly::{ - fq12_to_multilinear_evals, fq12_to_poly12_coeffs, g_coeffs, g_eval, to_multilinear_evals, + eval_multilinear, fq12_to_multilinear_evals, fq12_to_poly12_coeffs, g_coeffs, g_eval, + to_multilinear_evals, }; -pub use steps::{pow_with_steps_le, ExponentiationStep, ExponentiationSteps}; - -pub use expression::{Expression, ExpressionSteps, Term}; +pub use witness_gen::{pow_with_steps_le, ExponentiationSteps}; diff --git a/jolt-optimizations/src/steps.rs b/jolt-optimizations/src/steps.rs deleted file mode 100644 index 107bb3664..000000000 --- a/jolt-optimizations/src/steps.rs +++ /dev/null @@ -1,273 +0,0 @@ -use crate::sz_check::Product; -use ark_bn254::{Fq, Fq12}; -use ark_ff::{BigInteger, Field, One, PrimeField}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use std::fmt; - -/// Error types for exponentiation verification -#[derive(Debug, Clone, PartialEq)] -pub enum VerificationError { - IncorrectResult { expected: Fq12, actual: Fq12 }, - InvalidSquaring { step: usize, expected: Fq12, actual: Fq12 }, - InvalidMultiplication { step: usize, expected: Fq12, actual: Fq12 }, - InconsistentChain { step: usize }, -} - -impl fmt::Display for VerificationError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::IncorrectResult { .. } => write!(f, "Final result doesn't match expected"), - Self::InvalidSquaring { step, .. } => write!(f, "Invalid squaring at step {}", step), - Self::InvalidMultiplication { step, .. } => write!(f, "Invalid multiplication at step {}", step), - Self::InconsistentChain { step } => write!(f, "Inconsistent state chain at step {}", step), - } - } -} - -impl std::error::Error for VerificationError {} - -/// State transition in exponentiation -#[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] -pub struct StepTransition { - /// Previous and current accumulator values - pub accumulator: (Fq12, Fq12), - /// Running product before and after this step - pub product: (Fq12, Fq12), -} - -/// Single step in square-and-multiply algorithm -#[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] -pub struct ExponentiationStep { - pub step_index: usize, - pub bit_value: bool, - pub transition: StepTransition, -} - -impl ExponentiationStep { - fn new(step_index: usize, bit_value: bool, a_prev: Fq12, a_curr: Fq12, rho_before: Fq12, rho_after: Fq12) -> Self { - Self { - step_index, - bit_value, - transition: StepTransition { - accumulator: (a_prev, a_curr), - product: (rho_before, rho_after), - }, - } - } - - /// Get the previous accumulator value - pub fn a_prev(&self) -> Fq12 { - self.transition.accumulator.0 - } - - /// Get the current accumulator value - pub fn a_curr(&self) -> Fq12 { - self.transition.accumulator.1 - } - - /// Get the product before this step - pub fn rho_before(&self) -> Fq12 { - self.transition.product.0 - } - - /// Get the product after this step - pub fn rho_after(&self) -> Fq12 { - self.transition.product.1 - } -} - -#[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] -pub struct ExponentiationSteps { - pub base: Fq12, - pub exponent: Fq, - pub steps: Vec, - pub result: Fq12, -} - -/// Builder for ExponentiationSteps -pub struct StepsBuilder { - base: Fq12, - exponent: Fq, - steps: Vec, -} - -impl StepsBuilder { - fn new(base: Fq12, exponent: Fq) -> Self { - Self { - base, - exponent, - steps: Vec::new(), - } - } - - fn add_step(&mut self, step: ExponentiationStep) { - self.steps.push(step); - } - - fn build(self, result: Fq12) -> ExponentiationSteps { - ExponentiationSteps { - base: self.base, - exponent: self.exponent, - steps: self.steps, - result, - } - } -} - -impl ExponentiationSteps { - /// Convert steps to Products for sz_check verification - pub fn to_products(&self) -> Vec { - self.steps - .iter() - .flat_map(|step| { - let mut products = vec![ - // Squaring: a_i = a_{i-1} * a_{i-1} - Product::new(step.a_prev(), step.a_prev(), step.a_curr()), - ]; - - // Multiplication if bit is set - if step.bit_value && step.rho_before() != step.rho_after() { - products.push(Product::new( - step.rho_before(), - step.a_curr(), - step.rho_after(), - )); - } - - products - }) - .collect() - } - - /// Verify consistency of recorded steps - pub fn verify_consistency(&self) -> Result<(), VerificationError> { - // Check final result - let expected = self.base.pow(self.exponent.into_bigint()); - if self.result != expected { - return Err(VerificationError::IncorrectResult { - expected, - actual: self.result, - }); - } - - // Verify each step - for (i, step) in self.steps.iter().enumerate() { - // Verify squaring - let expected_a = step.a_prev() * step.a_prev(); - if step.a_curr() != expected_a { - return Err(VerificationError::InvalidSquaring { - step: i, - expected: expected_a, - actual: step.a_curr(), - }); - } - - // Verify multiplication - let expected_rho = if step.bit_value { - step.rho_before() * step.a_curr() - } else { - step.rho_before() - }; - if step.rho_after() != expected_rho { - return Err(VerificationError::InvalidMultiplication { - step: i, - expected: expected_rho, - actual: step.rho_after(), - }); - } - - // Verify chain consistency - if let Some(next) = self.steps.get(i + 1) { - if step.a_curr() != next.a_prev() || step.rho_after() != next.rho_before() { - return Err(VerificationError::InconsistentChain { step: i + 1 }); - } - } - } - - // Verify final step matches result - if let Some(last) = self.steps.last() { - if last.rho_after() != self.result { - return Err(VerificationError::IncorrectResult { - expected: self.result, - actual: last.rho_after(), - }); - } - } - - Ok(()) - } - - /// Legacy verification method for compatibility - pub fn sanity_verify(&self) -> bool { - self.verify_consistency().is_ok() - } -} - -/// Helper to iterate over significant bits -struct BitIterator { - bits: Vec, - last_one_pos: Option, -} - -impl BitIterator { - fn new(exponent: Fq) -> Self { - let bits = exponent.into_bigint().to_bits_le(); - let last_one_pos = bits.iter().rposition(|&b| b); - Self { bits, last_one_pos } - } - - fn is_trivial(&self) -> Option { - match self.last_one_pos { - None => Some(Fq12::one()), // exp = 0 - Some(0) => None, // exp = 1, handled separately - _ => None, - } - } - - fn initial_bit(&self) -> bool { - self.bits.get(0).copied().unwrap_or(false) - } - - fn significant_bits(&self) -> impl Iterator + '_ { - let end = self.last_one_pos.unwrap_or(0); - (1..=end).map(move |i| (i - 1, self.bits[i])) - } -} - -/// Compute base^exponent with step-by-step recording (LSB-first) -pub fn pow_with_steps_le(base: Fq12, exponent: Fq) -> ExponentiationSteps { - let bits = BitIterator::new(exponent); - - // Handle trivial cases - if let Some(result) = bits.is_trivial() { - return ExponentiationSteps { - base, - exponent, - steps: vec![], - result: if bits.last_one_pos.is_none() { result } else { base }, - }; - } - - let mut builder = StepsBuilder::new(base, exponent); - let mut accumulator = base; - let mut product = if bits.initial_bit() { base } else { Fq12::one() }; - - for (step_idx, bit) in bits.significant_bits() { - let prev_acc = accumulator; - let prev_prod = product; - - accumulator = prev_acc.square(); - product = if bit { prev_prod * accumulator } else { prev_prod }; - - builder.add_step(ExponentiationStep::new( - step_idx, - bit, - prev_acc, - accumulator, - prev_prod, - product, - )); - } - - builder.build(product) -} diff --git a/jolt-optimizations/src/sz_check.rs b/jolt-optimizations/src/sz_check.rs deleted file mode 100644 index 81ad797d4..000000000 --- a/jolt-optimizations/src/sz_check.rs +++ /dev/null @@ -1,92 +0,0 @@ -use std::panic; - -use crate::fq12_poly::{fq12_to_poly12_coeffs, g_coeffs, poly_div_rem_monic, poly_mul}; -use ark_bn254::{Fq, Fq12}; -use ark_ff::{Field, Zero}; - -#[derive(Clone)] -pub struct Product { - pub a: Fq12, - pub b: Fq12, - pub c: Fq12, - pub quotient: Vec, -} - -impl Product { - pub fn new(a: Fq12, b: Fq12, c: Fq12) -> Self { - let a_poly = fq12_to_poly12_coeffs(&a); - let b_poly = fq12_to_poly12_coeffs(&b); - let c_poly = fq12_to_poly12_coeffs(&c); - - let mut ab = poly_mul(&a_poly, &b_poly); - for i in 0..c_poly.len().min(ab.len()) { - ab[i] -= c_poly[i]; - } - - let (quotient, remainder) = poly_div_rem_monic(ab, &g_coeffs()); - - if !remainder.is_empty() && remainder.iter().any(|r| !r.is_zero()) { - panic!("invalid product: remainder is non-zero") - } - - Self { a, b, c, quotient } - } -} - -fn compute_r_powers(r: &Fq) -> [Fq; 12] { - let mut powers = [Fq::zero(); 12]; - powers[0] = Fq::from(1u64); - for i in 1..12 { - powers[i] = powers[i - 1] * r; - } - powers -} - -fn eval_with_powers(coeffs: &[Fq; 12], r_powers: &[Fq; 12]) -> Fq { - let mut result = Fq::zero(); - for i in 0..12 { - result += coeffs[i] * r_powers[i]; - } - result -} - -pub fn g_eval_optimized(r: &Fq) -> Fq { - let r2 = r.square(); - let r3 = r2 * r; - let r6 = r3.square(); - let r12 = r6.square(); - r12 - Fq::from(18u64) * r6 + Fq::from(82u64) -} - -pub fn batch_verify(products: &[Product], r: &Fq) -> bool { - let r_powers = compute_r_powers(r); - let g_r = g_eval_optimized(r); - - for product in products { - let a_coeffs = fq12_to_poly12_coeffs(&product.a); - let b_coeffs = fq12_to_poly12_coeffs(&product.b); - let c_coeffs = fq12_to_poly12_coeffs(&product.c); - - let a_r = eval_with_powers(&a_coeffs, &r_powers); - let b_r = eval_with_powers(&b_coeffs, &r_powers); - let c_r = eval_with_powers(&c_coeffs, &r_powers); - - let lhs = a_r * b_r - c_r; - - let mut q_r = Fq::zero(); - for (i, coeff) in product.quotient.iter().enumerate() { - if i < 12 { - q_r += *coeff * r_powers[i]; - } else { - panic!("this can't happen") - } - } - let rhs = q_r * g_r; - - if lhs != rhs { - return false; - } - } - - true -} diff --git a/jolt-optimizations/src/witness_gen.rs b/jolt-optimizations/src/witness_gen.rs new file mode 100644 index 000000000..8f64d31d8 --- /dev/null +++ b/jolt-optimizations/src/witness_gen.rs @@ -0,0 +1,221 @@ +use crate::fq12_poly::{ + eq_weights, eval_multilinear, fq12_to_multilinear_evals, g_coeffs, to_multilinear_evals, +}; +use ark_bn254::{Fq, Fq12, Fr}; +use ark_ff::{BigInteger, Field, One, PrimeField, Zero}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; + +/// square-and-multiply witness generation +#[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] +pub struct ExponentiationSteps { + pub base: Fq12, // A (base) + pub exponent: Fr, // e (exponent) + pub result: Fq12, // Final result A^e + pub rho_mles: Vec>, // MLEs of ρ_0, ρ_1, ..., ρ_t + pub quotient_mles: Vec>, // MLEs of Q_1, Q_2, ..., Q_t + pub bits: Vec, // b_1, b_2, ..., b_t +} + +impl ExponentiationSteps { + /// Generate MLE witness for base^exponent using MSB-first square-and-multiply + pub fn new(base: Fq12, exponent: Fr) -> Self { + let bits_le = exponent.into_bigint().to_bits_le(); + + let msb_idx = match bits_le.iter().rposition(|&b| b) { + None => { + return Self { + base, + exponent, + result: Fq12::one(), + rho_mles: vec![fq12_to_multilinear_evals(&Fq12::one())], // ρ_0 = 1 + quotient_mles: vec![], + bits: vec![], + }; + }, + Some(i) => i, + }; + + // Special case: exponent == 1 ⇒ result = base + if msb_idx == 0 && bits_le[0] { + return Self { + base, + exponent, + result: base, + rho_mles: vec![ + fq12_to_multilinear_evals(&Fq12::one()), // ρ_0 + fq12_to_multilinear_evals(&base), // ρ_1 + ], + quotient_mles: vec![], // Could compute a single Q_1 if needed + bits: vec![true], + }; + } + + let bits_msb: Vec = (0..=msb_idx).rev().map(|i| bits_le[i]).collect(); + + // ρ_0 = 1 + let mut rho = Fq12::one(); + let mut rho_mles = vec![fq12_to_multilinear_evals(&rho)]; + let mut quotient_mles = vec![]; + let mut bits = vec![]; + + for &b in &bits_msb { + bits.push(b); + + let rho_prev = rho; // ρ_{i-1} + let rho_sq = rho_prev.square(); // ρ_{i-1}² + let rho_i = if b { rho_sq * base } else { rho_sq }; // ρ_i + + // One quotient per step for: ρ_i(X) - ρ_{i-1}(X)² * A(X)^{b} = Q_i(X) g(X) + let q_i = compute_step_quotient_msb(rho_prev, rho_i, base, b); + quotient_mles.push(quotient_to_mle(&q_i)); + + rho = rho_i; + rho_mles.push(fq12_to_multilinear_evals(&rho)); + } + + Self { + base, + exponent, + result: rho, + rho_mles, + quotient_mles, + bits, + } + } + + /// Verify that the final result matches base^exponent + pub fn verify_result(&self) -> bool { + self.result == self.base.pow(self.exponent.into_bigint()) + } + + /// Verify constraint at a Boolean cube point + /// Checks that the constraint holds at cube vertices where it was constructed to be zero + pub fn verify_constraint_at_cube_point(&self, step: usize, cube_index: usize) -> bool { + if step == 0 || step > self.quotient_mles.len() || cube_index >= 16 { + return false; + } + let point = index_to_boolean_point(cube_index); + + // Evaluate MLEs + let rho_prev = eval_mle_at_boolean_point(&self.rho_mles[step - 1], &point); + let rho_curr = eval_mle_at_boolean_point(&self.rho_mles[step], &point); + let quotient = eval_mle_at_boolean_point(&self.quotient_mles[step - 1], &point); + + let base_mle = fq12_to_multilinear_evals(&self.base); + let base_eval = eval_mle_at_boolean_point(&base_mle, &point); + let g_mle = get_g_mle(); + let g_eval = eval_mle_at_boolean_point(&g_mle, &point); + + // Compute constraint: ρ_i - ρ_{i-1}² * base^{b_i} - Q_i * g + let bit = self.bits[step - 1]; + let base_power = if bit { base_eval } else { Fq::one() }; + let constraint = rho_curr - rho_prev.square() * base_power - quotient * g_eval; + println!("constraint: {:?}", constraint); + constraint.is_zero() + } + + pub fn num_steps(&self) -> usize { + self.quotient_mles.len() + } +} + +/// Compute quotient MLE +fn compute_step_quotient_msb(rho_prev: Fq12, rho_i: Fq12, base: Fq12, bit: bool) -> Vec { + let rho_prev_mle = fq12_to_multilinear_evals(&rho_prev); + let rho_i_mle = fq12_to_multilinear_evals(&rho_i); + let base_mle = fq12_to_multilinear_evals(&base); + + let g_mle = get_g_mle(); + + // Compute the quotient MLE pointwise: Q_i(x) = (ρ_i(x) - ρ_{i-1}(x)² * base(x)^{b_i}) / g(x) + let mut quotient_mle = vec![Fq::zero(); 16]; + for j in 0..16 { + let rho_prev_sq = rho_prev_mle[j].square(); + let base_power = if bit { base_mle[j] } else { Fq::one() }; + let expected = rho_prev_sq * base_power; + + // Q_i(x) = (ρ_i(x) - expected) / g(x) + if !g_mle[j].is_zero() { + quotient_mle[j] = (rho_i_mle[j] - expected) / g_mle[j]; + } + } + + quotient_mle +} + +/// Get g as MLE evaluations over the Boolean cube {0,1}^4 +pub fn get_g_mle() -> Vec { + // Use the same encoding as fq12_to_multilinear_evals + // g(X) = X^12 - 18X^6 + 82 as coefficient array + let g_vec = g_coeffs(); + let mut g_array = [Fq::zero(); 12]; + for i in 0..12 { + if i < g_vec.len() { + g_array[i] = g_vec[i]; + } + } + to_multilinear_evals(&g_array) +} + +/// Convert quotient MLE to the format needed (already an MLE, just return it) +fn quotient_to_mle(quotient: &[Fq]) -> Vec { + // In the MLE paradigm, quotient is already an MLE + quotient.to_vec() +} + +/// Convert a cube index (0..15) to a Boolean point in {0,1}^4 +pub fn index_to_boolean_point(index: usize) -> Vec { + vec![ + Fq::from((index & 1) as u64), // bit 0 + Fq::from(((index >> 1) & 1) as u64), // bit 1 + Fq::from(((index >> 2) & 1) as u64), // bit 2 + Fq::from(((index >> 3) & 1) as u64), // bit 3 + ] +} + +/// Evaluate an MLE at a Boolean cube point +/// For Boolean points, this is equivalent to indexing but makes the evaluation explicit +fn eval_mle_at_boolean_point(mle: &[Fq], point: &[Fq]) -> Fq { + // For Boolean points, we could just index, but using eval_multilinear + // makes it clear we're doing MLE evaluation + eval_multilinear(mle, point) +} + +/// Compute H̃(z) via eq-weights (definition of MLE), not by multiplying opened MLEs +/// H(x) = ρᵢ(x) - ρᵢ₋₁(x)² · A(x)^{bᵢ} - Qᵢ(x) · g(x) for x ∈ {0,1}^4 +/// H̃(z) = Σ_{x∈{0,1}^4} eq(z,x) · H(x) +pub fn h_tilde_at_point( + rho_prev_mle: &[Fq], + rho_curr_mle: &[Fq], + base_mle: &[Fq], + q_mle: &[Fq], + g_mle: &[Fq], + bit: bool, + z: &[Fq], +) -> Fq { + assert_eq!(rho_prev_mle.len(), 16); + assert_eq!(rho_curr_mle.len(), 16); + assert_eq!(base_mle.len(), 16); + assert_eq!(q_mle.len(), 16); + assert_eq!(g_mle.len(), 16); + assert_eq!(z.len(), 4); + + let w = eq_weights(z); + let mut acc = Fq::zero(); + + for j in 0..16 { + // Compute H(x_j) where x_j is the j-th hypercube vertex + let prod = rho_prev_mle[j].square() * if bit { base_mle[j] } else { Fq::one() }; + let h_x = rho_curr_mle[j] - prod - q_mle[j] * g_mle[j]; + + // Add weighted contribution to MLE + acc += h_x * w[j]; + } + + acc // equals H̃(z) +} + +/// Legacy compatibility function +pub fn pow_with_steps_le(base: Fq12, exponent: Fr) -> ExponentiationSteps { + ExponentiationSteps::new(base, exponent) +} diff --git a/jolt-optimizations/tests/integration_tests.rs b/jolt-optimizations/tests/integration_tests.rs index db9aac73a..216ea3313 100644 --- a/jolt-optimizations/tests/integration_tests.rs +++ b/jolt-optimizations/tests/integration_tests.rs @@ -1,6 +1,6 @@ use ark_bn254::{Fr, G2Affine, G2Projective}; use ark_ec::PrimeGroup; -use ark_ec::{AdditiveGroup, AffineRepr, CurveGroup}; +use ark_ec::{AffineRepr, CurveGroup}; use ark_ff::{PrimeField, UniformRand}; use ark_std::{test_rng, Zero}; use num_bigint::BigInt; diff --git a/jolt-optimizations/tests/mle_tests.rs b/jolt-optimizations/tests/mle_tests.rs new file mode 100644 index 000000000..7a71c6ac8 --- /dev/null +++ b/jolt-optimizations/tests/mle_tests.rs @@ -0,0 +1,149 @@ +use ark_bn254::Fq; +use ark_ff::{Field, One, UniformRand, Zero}; +use ark_std::test_rng; +use jolt_optimizations::fq12_poly::{eval_multilinear, eval_poly_vec, to_multilinear_evals}; + +/// Generate random polynomial coefficients for testing +fn random_poly12_coeffs() -> [Fq; 12] { + let mut rng = test_rng(); + let mut coeffs = [Fq::zero(); 12]; + for c in coeffs.iter_mut() { + *c = Fq::rand(&mut rng); + } + coeffs +} + +#[test] +fn test_mle_agreement_with_univariate() { + // Test that MLE agrees with original polynomial on domain {0..15} + let coeffs = random_poly12_coeffs(); + let mle_evals = to_multilinear_evals(&coeffs); + + for i in 0..16 { + let x = Fq::from(i as u64); + let univariate_eval = eval_poly_vec(&coeffs[..], &x); + + assert_eq!( + univariate_eval, mle_evals[i], + "MLE evaluation doesn't match univariate at point {}", + i + ); + // Binary decomposition of i: (b₀, b₁, b₂, b₃) where i = b₀ + 2b₁ + 4b₂ + 8b₃ + let binary_point = vec![ + Fq::from((i & 1) as u64), + Fq::from(((i >> 1) & 1) as u64), + Fq::from(((i >> 2) & 1) as u64), + Fq::from(((i >> 3) & 1) as u64), + ]; + let mle_eval = eval_multilinear(&mle_evals, &binary_point); + + assert_eq!( + univariate_eval, mle_eval, + "eval_multilinear doesn't agree with univariate at point {}", + i + ); + } +} + +#[test] +fn test_mle_is_multilinear() { + let mut rng = test_rng(); + let coeffs = random_poly12_coeffs(); + let mle_evals = to_multilinear_evals(&coeffs); + + // Test linearity in each variable + for var_idx in 0..4 { + let point = vec![ + Fq::rand(&mut rng), + Fq::rand(&mut rng), + Fq::rand(&mut rng), + Fq::rand(&mut rng), + ]; + + let mut p0 = point.clone(); + p0[var_idx] = Fq::zero(); + let eval0 = eval_multilinear(&mle_evals, &p0); + + let mut p1 = point.clone(); + p1[var_idx] = Fq::one(); + let eval1 = eval_multilinear(&mle_evals, &p1); + + let t = Fq::rand(&mut rng); + let mut pt = point.clone(); + pt[var_idx] = t; + let eval_t = eval_multilinear(&mle_evals, &pt); + let expected = eval0 * (Fq::one() - t) + eval1 * t; + + assert_eq!( + eval_t, expected, + "MLE is not linear in variable {}", + var_idx + ); + } +} + +#[test] +fn test_mle_special_cases() { + // Test 1: Zero polynomial + let zero_coeffs = [Fq::zero(); 12]; + let mle = to_multilinear_evals(&zero_coeffs); + assert!( + mle.iter().all(|&x| x.is_zero()), + "Zero polynomial MLE should be all zeros" + ); + + // Test 2: Constant polynomial p(x) = 42 + let const_val = Fq::from(42u64); + let mut const_coeffs = [Fq::zero(); 12]; + const_coeffs[0] = const_val; + let mle = to_multilinear_evals(&const_coeffs); + assert!( + mle.iter().all(|&x| x == const_val), + "Constant polynomial MLE should be constant" + ); + + // Test 3: Linear polynomial p(x) = x + let mut linear_coeffs = [Fq::zero(); 12]; + linear_coeffs[1] = Fq::one(); + let mle = to_multilinear_evals(&linear_coeffs); + for i in 0..16 { + assert_eq!( + mle[i], + Fq::from(i as u64), + "Linear polynomial p(x)=x should evaluate to {} at {}", + i, + i + ); + } + + // Test 4: Quadratic polynomial p(x) = x² + let mut quad_coeffs = [Fq::zero(); 12]; + quad_coeffs[2] = Fq::one(); + let mle = to_multilinear_evals(&quad_coeffs); + for i in 0..16 { + let expected = Fq::from((i * i) as u64); + assert_eq!( + mle[i], expected, + "Quadratic polynomial p(x)=x² should evaluate correctly at {}", + i + ); + } +} + +#[test] +fn test_mle_high_degree() { + // Test with maximum degree polynomial (degree 11) + let mut coeffs = [Fq::zero(); 12]; + coeffs[11] = Fq::one(); + + let mle = to_multilinear_evals(&coeffs); + for i in 0..16 { + let x = Fq::from(i as u64); + let expected = x.pow([11u64]); + assert_eq!( + mle[i], expected, + "p(x) = x^11 should evaluate correctly at {}", + i + ); + } +} diff --git a/jolt-optimizations/tests/steps_debug_test.rs b/jolt-optimizations/tests/steps_debug_test.rs deleted file mode 100644 index e2f5a57ce..000000000 --- a/jolt-optimizations/tests/steps_debug_test.rs +++ /dev/null @@ -1,175 +0,0 @@ -use ark_bn254::{Fq, Fq12}; -use ark_ff::BigInteger; -use ark_ff::{Field, PrimeField, UniformRand}; -use ark_std::test_rng; -use jolt_optimizations::steps::pow_with_steps_le; - -#[test] -#[ignore] // Run with: cargo test --test steps_debug_test test_debug_trace -- --nocapture --ignored -fn test_debug_trace() { - let mut rng = test_rng(); - - // Use a small exponent for readable output - let base = Fq12::rand(&mut rng); - let exponent = Fq::from(13u64); // Binary: 1101 - - println!("=== Square-and-Multiply Debug Trace ==="); - println!("Base: {:?}", base); - println!("Exponent: {} (binary: 1101)", 13u64); - println!(); - - let steps = pow_with_steps_le(base, exponent); - - // Print bit representation - let bigint = exponent.into_bigint(); - let exp_bits = bigint.to_bits_le(); - println!("Bit representation (LSB first):"); - for (i, bit) in exp_bits.iter().take(8).enumerate() { - println!(" Bit {}: {}", i, if *bit { "1" } else { "0" }); - } - println!(); - - // Print initial state - println!("Initial state:"); - println!(" a_0 = base"); - println!( - " rho_0 = {} (since bit 0 = {})", - if exp_bits[0] { "base" } else { "1" }, - if exp_bits[0] { "1" } else { "0" } - ); - println!(); - - // Print each step - println!("Steps:"); - for (i, step) in steps.steps.iter().enumerate() { - println!( - "Step {} (processing bit {} = {}):", - i + 1, - i + 1, - if step.bit_value { "1" } else { "0" } - ); - - println!(" Squaring: a_{} = a_{}^2", i + 1, i); - println!(" a_{} = {:?}", i, step.a_prev()); - println!(" a_{} = {:?}", i + 1, step.a_curr()); - - // Verify squaring - let expected_square = step.a_prev() * step.a_prev(); - println!( - " Verification: a_curr == a_prev^2? {}", - if step.a_curr() == expected_square { - "✓" - } else { - "✗" - } - ); - - println!(" Accumulator update:"); - println!(" rho_before = {:?}", step.rho_before()); - - if step.bit_value { - println!(" Bit is 1, so: rho_after = rho_before * a_curr"); - } else { - println!(" Bit is 0, so: rho_after = rho_before (unchanged)"); - } - - println!(" rho_after = {:?}", step.rho_after()); - - // Verify accumulator update - let expected_rho = if step.bit_value { - step.rho_before() * step.a_curr() - } else { - step.rho_before() - }; - println!( - " Verification: rho_after correct? {}", - if step.rho_after() == expected_rho { - "✓" - } else { - "✗" - } - ); - - println!(); - } - - // Print final result - println!("Final result: {:?}", steps.result); - - // Verify against standard pow - let expected = base.pow(exponent.into_bigint()); - println!("Expected (base^13): {:?}", expected); - println!( - "Results match: {}", - if steps.result == expected { - "✓" - } else { - "✗" - } - ); - - // Print summary of operations - println!(); - println!("=== Summary ==="); - let num_squarings = steps.steps.len(); - let num_multiplications = steps.steps.iter().filter(|s| s.bit_value).count(); - println!("Total squarings: {}", num_squarings); - println!("Total multiplications by base: {}", num_multiplications); - println!("Total operations: {}", num_squarings + num_multiplications); - - // Verify the steps - assert!(steps.sanity_verify(), "Steps verification failed"); - assert_eq!(steps.result, expected, "Result doesn't match expected"); -} - -#[test] -#[ignore] // Run with: cargo test --test steps_debug_test test_trace_products -- --nocapture --ignored -fn test_trace_products() { - let mut rng = test_rng(); - - let base = Fq12::rand(&mut rng); - let exponent = Fq::from(5u64); // Binary: 101 - - println!("=== Products Generated from Steps ==="); - println!("Exponent: 5 (binary: 101)"); - println!(); - - let steps = pow_with_steps_le(base, exponent); - let products = steps.to_products(); - - println!("Products generated:"); - for (i, product) in products.iter().enumerate() { - println!("Product {}:", i); - println!(" a * b = c"); - println!(" a: {:?}", product.a); - println!(" b: {:?}", product.b); - println!(" c: {:?}", product.c); - - // Verify the product - let expected_c = product.a * product.b; - println!( - " Verification: c == a * b? {}", - if product.c == expected_c { - "✓" - } else { - "✗" - } - ); - println!(); - } - - println!("Total products: {}", products.len()); - - // Test batch verification - use jolt_optimizations::sz_check::batch_verify; - let r = Fq::rand(&mut rng); - let batch_result = batch_verify(&products, &r); - println!( - "Batch verification with random r: {}", - if batch_result { - "✓ PASSED" - } else { - "✗ FAILED" - } - ); -} diff --git a/jolt-optimizations/tests/steps_test.rs b/jolt-optimizations/tests/steps_test.rs deleted file mode 100644 index d58a6083b..000000000 --- a/jolt-optimizations/tests/steps_test.rs +++ /dev/null @@ -1,135 +0,0 @@ -use ark_bn254::{Fq, Fq12}; -use ark_ff::{Field, One, PrimeField, UniformRand}; -use ark_std::test_rng; -use jolt_optimizations::expression::{Expression, Term}; -use jolt_optimizations::steps::pow_with_steps_le; -use jolt_optimizations::sz_check::batch_verify; - -#[test] -fn test_pow_with_steps_correctness() { - let mut rng = test_rng(); - - // Test with random base and exponent - let base = Fq12::rand(&mut rng); - let exponent = Fq::rand(&mut rng); - - // Compute with steps - let steps = pow_with_steps_le(base, exponent); - - // Verify the result matches standard pow - let expected = base.pow(exponent.into_bigint()); - assert_eq!(steps.result, expected, "Result mismatch"); - - // Verify the steps are internally consistent - assert!(steps.sanity_verify(), "Steps verification failed"); - - // Verify that products can be verified using batch_verify - let products = steps.to_products(); - let r = Fq::rand(&mut rng); - assert!(batch_verify(&products, &r), "Batch verification failed"); -} - -#[test] -fn test_pow_with_steps_edge_cases() { - let mut rng = test_rng(); - let base = Fq12::rand(&mut rng); - - // Test exponent = 0 - let steps = pow_with_steps_le(base, Fq::from(0u64)); - assert_eq!(steps.result, Fq12::one()); - assert_eq!(steps.steps.len(), 0); - - // Test exponent = 1 - let steps = pow_with_steps_le(base, Fq::from(1u64)); - assert_eq!(steps.result, base); - assert_eq!(steps.steps.len(), 0); - - // Test exponent = 2 - let steps = pow_with_steps_le(base, Fq::from(2u64)); - assert_eq!(steps.result, base * base); - assert_eq!(steps.steps.len(), 1); - assert!(steps.sanity_verify()); -} - -#[test] -fn test_expression_with_steps() { - let mut rng = test_rng(); - - // Create an expression with multiple terms - let terms = vec![ - Term { - base: Fq12::rand(&mut rng), - exponent: Fq::from(5u64), - }, - Term { - base: Fq12::rand(&mut rng), - exponent: Fq::from(3u64), - }, - ]; - - let expr = Expression::new(terms); - - // Evaluate with steps - let (result, steps) = expr.evaluate_with_steps(); - - // Verify result matches expected - let expected = expr.terms[0].base.pow(expr.terms[0].exponent.into_bigint()) - * expr.terms[1].base.pow(expr.terms[1].exponent.into_bigint()); - assert_eq!(result, expected); - - // Verify all steps - for term_step in &steps.term_steps { - assert!(term_step.sanity_verify()); - } - - // Convert to products and verify - let products = Expression::steps_to_products(&steps); - let r = Fq::rand(&mut rng); - assert!( - batch_verify(&products, &r), - "Batch verification of expression steps failed" - ); -} - -#[test] -fn test_step_continuity() { - let mut rng = test_rng(); - let base = Fq12::rand(&mut rng); - let exponent = Fq::from(255u64); // Use a reasonable sized exponent - - let steps = pow_with_steps_le(base, exponent); - - // Check continuity between steps - for i in 0..steps.steps.len() - 1 { - assert_eq!( - steps.steps[i].rho_after(), - steps.steps[i + 1].rho_before(), - "Step continuity broken at step {}", - i - ); - } - - // Check final step leads to result - if let Some(last_step) = steps.steps.last() { - assert_eq!(last_step.rho_after(), steps.result); - } -} - -#[test] -fn test_squaring_correctness() { - let mut rng = test_rng(); - let base = Fq12::rand(&mut rng); - let exponent = Fq::from(100u64); - - let steps = pow_with_steps_le(base, exponent); - - // Verify each squaring operation: a_i = a_{i-1}^2 - for step in &steps.steps { - let expected_square = step.a_prev() * step.a_prev(); - assert_eq!( - step.a_curr(), expected_square, - "Squaring incorrect at step {}", - step.step_index - ); - } -} diff --git a/jolt-optimizations/tests/sz_check_tests.rs b/jolt-optimizations/tests/sz_check_tests.rs deleted file mode 100644 index 212639979..000000000 --- a/jolt-optimizations/tests/sz_check_tests.rs +++ /dev/null @@ -1,64 +0,0 @@ -use ark_bn254::{Fq, Fq12}; -use ark_ff::{Field, PrimeField, UniformRand, Zero}; -use ark_std::test_rng; -use jolt_optimizations::expression::{Expression, Term}; -use jolt_optimizations::fq12_poly::{fq12_to_multilinear_evals, fq12_to_poly12_coeffs}; -use jolt_optimizations::sz_check::{batch_verify, Product}; - -#[test] -fn test_large_batch() { - let mut rng = test_rng(); - let k = 100000; - - let mut products = Vec::new(); - for _ in 0..k { - let a = Fq12::rand(&mut rng); - let b = Fq12::rand(&mut rng); - let c = a * b; - products.push(Product::new(a, b, c)); - } - - let r = Fq::rand(&mut rng); - - assert!(batch_verify(&products, &r)); -} - -#[test] -fn test_expression_to_sz_check() { - let mut rng = test_rng(); - let a1 = Fq12::rand(&mut rng); - let c1 = Fq::rand(&mut rng); - - let a2 = Fq12::rand(&mut rng); - let c2 = Fq::rand(&mut rng); - - let a3 = Fq12::rand(&mut rng); - let c3 = Fq::rand(&mut rng); - - let expected = a1.pow(c1.into_bigint()) * a2.pow(c2.into_bigint()) * a3.pow(c3.into_bigint()); - - let expr = Expression::new(vec![ - Term { - base: a1, - exponent: c1, - }, - Term { - base: a2, - exponent: c2, - }, - Term { - base: a3, - exponent: c3, - }, - ]); - - let products = expr.to_products(); - - let r = Fq::rand(&mut rng); - assert!(batch_verify(&products, &r)); - - if !products.is_empty() { - let final_result = products.last().unwrap().c; - assert_eq!(final_result, expected); - } -} diff --git a/jolt-optimizations/tests/witness_test.rs b/jolt-optimizations/tests/witness_test.rs new file mode 100644 index 000000000..4e26f788f --- /dev/null +++ b/jolt-optimizations/tests/witness_test.rs @@ -0,0 +1,304 @@ +use ark_bn254::{Fq, Fq12, Fr}; +use ark_ff::{Field, One, UniformRand, Zero}; +use ark_std::test_rng; +use jolt_optimizations::{ + eval_multilinear, fq12_to_multilinear_evals, g_coeffs, to_multilinear_evals, + witness_gen::{get_g_mle, h_tilde_at_point}, + ExponentiationSteps, +}; + +#[test] +fn test_witness_generation_and_constraints() { + let mut rng = test_rng(); + + for test_idx in 0..100 { + let base = Fq12::rand(&mut rng); + + let exponent = if test_idx == 0 { + Fr::from(5u64) + } else if test_idx == 1 { + Fr::from(100u64) + } else { + Fr::rand(&mut rng) + }; + + let witness = ExponentiationSteps::new(base, exponent); + + assert!( + witness.verify_result(), + "Final result should match base^exponent" + ); + + assert_eq!( + witness.rho_mles.len(), + witness.quotient_mles.len() + 1, + "Should have one more rho than quotients" + ); + + // Verify all MLEs have correct dimension (16 = 2^4 evaluations) + for mle in &witness.rho_mles { + assert_eq!(mle.len(), 16, "Rho MLEs should have 16 evaluations"); + } + for mle in &witness.quotient_mles { + assert_eq!(mle.len(), 16, "Quotient MLEs should have 16 evaluations"); + } + + // Test constraint verification at Boolean cube points + // The constraint should be zero at all 16 cube vertices + for cube_idx in 0..16 { + for step in 1..=witness.num_steps() { + assert!( + witness.verify_constraint_at_cube_point(step, cube_idx), + "Constraint failed at step {} for cube point {}", + step, + cube_idx + ); + } + } + } +} + +#[test] +fn test_trivial_cases() { + let mut rng = test_rng(); + let base = Fq12::rand(&mut rng); + + // Test exponent = 0 + let witness_zero = ExponentiationSteps::new(base, Fr::from(0u64)); + assert_eq!(witness_zero.result, Fq12::one()); + assert!(witness_zero.verify_result()); + assert_eq!(witness_zero.bits.len(), 0); + assert_eq!(witness_zero.rho_mles.len(), 1); // Just ρ_0 = 1 + assert_eq!(witness_zero.quotient_mles.len(), 0); + + // Test exponent = 1 + let witness_one = ExponentiationSteps::new(base, Fr::from(1u64)); + assert_eq!(witness_one.result, base); + assert!(witness_one.verify_result()); + assert_eq!(witness_one.bits, vec![true]); // Single bit: 1 + assert_eq!(witness_one.rho_mles.len(), 2); // ρ_0 = 1, ρ_1 = base + + // Test small known values to verify bit sequence + let witness_five = ExponentiationSteps::new(base, Fr::from(5u64)); + assert_eq!(witness_five.bits, vec![true, false, true]); // MSB to LSB: 101 + assert!(witness_five.verify_result()); + + let witness_ten = ExponentiationSteps::new(base, Fr::from(10u64)); + assert_eq!(witness_ten.bits, vec![true, false, true, false]); // MSB to LSB: 1010 + assert!(witness_ten.verify_result()); +} + +#[test] +fn test_witness_soundness() { + let mut rng = test_rng(); + + // Test soundness: tampering with witness should be detected + for test_idx in 0..20 { + let base = Fq12::rand(&mut rng); + let exponent = if test_idx == 0 { + Fr::from(10u64) + } else { + Fr::rand(&mut rng) + }; + + let mut witness = ExponentiationSteps::new(base, exponent); + + // Verify original witness is valid + assert!(witness.verify_result()); + let mut all_valid = true; + for cube_idx in 0..16 { + for step in 1..=witness.num_steps() { + if !witness.verify_constraint_at_cube_point(step, cube_idx) { + all_valid = false; + } + } + } + assert!(all_valid, "Original witness should be valid"); + + // Test 1: Tampering with ρ values breaks soundness + if witness.rho_mles.len() > 1 { + let tamper_idx = 1 + (test_idx % (witness.rho_mles.len() - 1)); + let point_idx = test_idx % 16; + let original = witness.rho_mles[tamper_idx][point_idx]; + + witness.rho_mles[tamper_idx][point_idx] += Fq::from(1u64); + + let mut soundness_broken = false; + for cube_idx in 0..16 { + if tamper_idx <= witness.num_steps() { + if !witness.verify_constraint_at_cube_point(tamper_idx, cube_idx) { + soundness_broken = true; + break; + } + } + if tamper_idx > 0 && tamper_idx + 1 <= witness.num_steps() { + if !witness.verify_constraint_at_cube_point(tamper_idx + 1, cube_idx) { + soundness_broken = true; + break; + } + } + } + + assert!(soundness_broken, "Tampering with ρ should break soundness"); + witness.rho_mles[tamper_idx][point_idx] = original; + } + + // Test 2: Tampering with quotient values breaks soundness + if !witness.quotient_mles.is_empty() { + let q_idx = test_idx % witness.quotient_mles.len(); + let point_idx = (test_idx * 7) % 16; + let original = witness.quotient_mles[q_idx][point_idx]; + + witness.quotient_mles[q_idx][point_idx] += Fq::from(1u64); + + let soundness_broken = !witness.verify_constraint_at_cube_point(q_idx + 1, point_idx); + + assert!( + soundness_broken, + "Tampering with quotient should break soundness" + ); + witness.quotient_mles[q_idx][point_idx] = original; + } + + // Test 3: Flipping bits breaks soundness + if !witness.bits.is_empty() { + let bit_idx = test_idx % witness.bits.len(); + witness.bits[bit_idx] = !witness.bits[bit_idx]; + + let mut soundness_broken = false; + for cube_idx in 0..16 { + if !witness.verify_constraint_at_cube_point(bit_idx + 1, cube_idx) { + soundness_broken = true; + break; + } + } + + assert!(soundness_broken, "Flipping bits should break soundness"); + witness.bits[bit_idx] = !witness.bits[bit_idx]; + } + + // Test 4: Tampering with final result breaks verification + let original_result = witness.result; + witness.result = witness.result + Fq12::one(); + assert!( + !witness.verify_result(), + "Modified result should fail verification" + ); + witness.result = original_result; + } +} + +#[test] +fn test_constraint_at_random_field_element() { + let mut rng = test_rng(); + + // Create witness for a simple exponentiation + let base = Fq12::rand(&mut rng); + let exponent = Fr::from(10000300u64); // Simple exponent: binary 111 + let witness = ExponentiationSteps::new(base, exponent); + + let base_mle = fq12_to_multilinear_evals(&base); + let g_mle = get_g_mle(); + + // Test at random field elements (not on hypercube) + for test_idx in 0..10000 { + // Generate random point z = (z0, z1, z2, z3) where zi ∈ Fq \ {0,1} + let z: Vec = (0..4) + .map(|_| { + let mut val = Fq::rand(&mut rng); + // Ensure it's not 0 or 1 (not on hypercube) + while val == Fq::zero() || val == Fq::one() { + val = Fq::rand(&mut rng); + } + val + }) + .collect(); + + // Pick a step to check + let step = 1 + (test_idx % witness.num_steps()); + let bit = witness.bits[step - 1]; + + // Compute H̃(z) using the correct MLE definition + let h = h_tilde_at_point( + &witness.rho_mles[step - 1], + &witness.rho_mles[step], + &base_mle, + &witness.quotient_mles[step - 1], + &g_mle, + bit, + &z, + ); + + // H̃(z) must be 0 at random z (Sumcheck-consistent) + assert!( + h.is_zero(), + "H̃(z) must be 0 at random z (test {}, step {}). Got: {:?}", + test_idx, + step, + h + ); + } + + println!("✓ Verified: H̃(z) = 0 at 20 random field elements (Sumcheck correct)"); + + // Also verify it works on the hypercube (sanity check) + for step in 1..=witness.num_steps() { + for cube_idx in 0..16 { + assert!( + witness.verify_constraint_at_cube_point(step, cube_idx), + "Constraint should be zero at hypercube point {} step {}", + cube_idx, + step + ); + } + } + + println!("✓ Verified: Constraints are zero on hypercube (sanity check)"); +} + +#[test] +fn test_zero_tampering_soundness() { + let mut rng = test_rng(); + let base = Fq12::rand(&mut rng); + + for exp_val in [2u64, 3, 7, 15, 31] { + let exponent = Fr::from(exp_val); + let mut witness = ExponentiationSteps::new(base, exponent); + + // Setting ρ to zero should break soundness + if witness.rho_mles.len() > 1 { + let original = witness.rho_mles[1][0]; + witness.rho_mles[1][0] = Fq::zero(); + + let mut soundness_broken = false; + for step in 1..=witness.num_steps() { + if !witness.verify_constraint_at_cube_point(step, 0) { + soundness_broken = true; + break; + } + } + + assert!( + soundness_broken || !witness.verify_result(), + "Setting ρ to zero should break soundness" + ); + witness.rho_mles[1][0] = original; + } + + // Setting quotient to zero should break soundness + if !witness.quotient_mles.is_empty() { + let original = witness.quotient_mles[0][0]; + witness.quotient_mles[0][0] = Fq::zero(); + + if original != Fq::zero() { + let soundness_broken = !witness.verify_constraint_at_cube_point(1, 0); + assert!( + soundness_broken, + "Setting non-zero quotient to zero should break soundness" + ); + } + + witness.quotient_mles[0][0] = original; + } + } +} From ab89e59c937f6f47336e4561242017e6eae1ac61 Mon Sep 17 00:00:00 2001 From: markosg04 Date: Tue, 30 Sep 2025 12:40:46 -0400 Subject: [PATCH 35/38] style: removing unused utils and brevity --- bench-templates/src/macros/field.rs | 6 +- ec/src/lib.rs | 6 +- ec/src/pairing.rs | 10 +- ec/src/scalar_mul/fixed_base.rs | 2 +- ec/src/scalar_mul/mod.rs | 2 +- ff/src/fields/prime.rs | 2 +- jolt-optimizations/Cargo.toml | 10 +- jolt-optimizations/benches/dory_all.rs | 2 +- jolt-optimizations/benches/dory_utils.rs | 6 +- .../benches/expression_bench.rs | 61 -------- .../benches/g1_scalar_multiplication.rs | 2 +- .../benches/scalar_multiplication.rs | 5 +- jolt-optimizations/benches/sz_check_bench.rs | 42 ----- jolt-optimizations/src/fq12_poly.rs | 143 ++---------------- jolt-optimizations/src/lib.rs | 2 +- jolt-optimizations/src/witness_gen.rs | 29 ++-- jolt-optimizations/tests/mle_tests.rs | 9 -- jolt-optimizations/tests/witness_test.rs | 68 +-------- test-curves/benches/small_mul.rs | 30 ++-- test-curves/src/bn254/fq.rs | 2 +- test-curves/src/bn254/fr.rs | 2 +- test-curves/src/bn254/g1.rs | 6 +- test-curves/src/bn254/test.rs | 4 +- 23 files changed, 71 insertions(+), 380 deletions(-) delete mode 100644 jolt-optimizations/benches/expression_bench.rs delete mode 100644 jolt-optimizations/benches/sz_check_bench.rs diff --git a/bench-templates/src/macros/field.rs b/bench-templates/src/macros/field.rs index 40fe597b1..db11baabf 100644 --- a/bench-templates/src/macros/field.rs +++ b/bench-templates/src/macros/field.rs @@ -405,16 +405,14 @@ macro_rules! prime_field { f[i].into_bigint() }) }); - let u64s = (0..SAMPLES) - .map(|_| rng.next_u64()) - .collect::>(); + let u64s = (0..SAMPLES).map(|_| rng.next_u64()).collect::>(); conversions.bench_function("From u64", |b| { let mut i = 0; b.iter(|| { i = (i + 1) % SAMPLES; <$F>::from_u64(u64s[i]) }) - }); + }); conversions.finish() } }; diff --git a/ec/src/lib.rs b/ec/src/lib.rs index ba99d4c87..47b8437e5 100644 --- a/ec/src/lib.rs +++ b/ec/src/lib.rs @@ -28,11 +28,7 @@ use ark_std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, vec::*, }; -pub use scalar_mul::{ - fixed_base::FixedBase, - variable_base::VariableBaseMSM, - ScalarMul, -}; +pub use scalar_mul::{fixed_base::FixedBase, variable_base::VariableBaseMSM, ScalarMul}; use zeroize::Zeroize; pub use ark_ff::AdditiveGroup; diff --git a/ec/src/pairing.rs b/ec/src/pairing.rs index a3aa83e26..f62d1be72 100644 --- a/ec/src/pairing.rs +++ b/ec/src/pairing.rs @@ -102,8 +102,14 @@ pub trait Pairing: Sized + 'static + Copy + Debug + Sync + Send + Eq { a: impl IntoIterator>, b: impl IntoIterator>, ) -> MillerLoopOutput { - let a_cloned = a.into_iter().map(|x| x.as_ref().clone()).collect::>(); - let b_cloned = b.into_iter().map(|x| x.as_ref().clone()).collect::>(); + let a_cloned = a + .into_iter() + .map(|x| x.as_ref().clone()) + .collect::>(); + let b_cloned = b + .into_iter() + .map(|x| x.as_ref().clone()) + .collect::>(); Self::multi_miller_loop(a_cloned, b_cloned) } diff --git a/ec/src/scalar_mul/fixed_base.rs b/ec/src/scalar_mul/fixed_base.rs index c9e5270d0..ce8001ccd 100644 --- a/ec/src/scalar_mul/fixed_base.rs +++ b/ec/src/scalar_mul/fixed_base.rs @@ -95,4 +95,4 @@ impl FixedBase { .map(|e| Self::windowed_mul::(outerc, window, table, e)) .collect::>() } -} \ No newline at end of file +} diff --git a/ec/src/scalar_mul/mod.rs b/ec/src/scalar_mul/mod.rs index 81a4c6595..cb38e432d 100644 --- a/ec/src/scalar_mul/mod.rs +++ b/ec/src/scalar_mul/mod.rs @@ -1,8 +1,8 @@ pub mod glv; pub mod wnaf; -pub mod variable_base; pub mod fixed_base; +pub mod variable_base; use crate::{ short_weierstrass::{Affine, Projective, SWCurveConfig}, diff --git a/ff/src/fields/prime.rs b/ff/src/fields/prime.rs index 28b896e59..80c04948b 100644 --- a/ff/src/fields/prime.rs +++ b/ff/src/fields/prime.rs @@ -57,7 +57,7 @@ pub trait PrimeField: /// Converts an element of the prime field into an integer in the range 0..(p - 1). fn into_bigint(self) -> Self::BigInt; - /// Creates a field element from a `u64`. + /// Creates a field element from a `u64`. /// Returns `None` if the `u64` is larger than or equal to the modulus. fn from_u64(val: u64) -> Option; diff --git a/jolt-optimizations/Cargo.toml b/jolt-optimizations/Cargo.toml index db228eb6f..e644c37e4 100644 --- a/jolt-optimizations/Cargo.toml +++ b/jolt-optimizations/Cargo.toml @@ -48,14 +48,6 @@ harness = false name = "g1_scalar_multiplication" harness = false -[[bench]] -name = "sz_check_bench" -harness = false - -[[bench]] -name = "expression_bench" -harness = false - [[bench]] name = "vector_scalar_mul_add_gamma_g2" harness = false @@ -65,4 +57,4 @@ name = "batch_addition" harness = false [[example]] -name = "memory_test" \ No newline at end of file +name = "memory_test" diff --git a/jolt-optimizations/benches/dory_all.rs b/jolt-optimizations/benches/dory_all.rs index 5015658ff..8e4866b2c 100644 --- a/jolt-optimizations/benches/dory_all.rs +++ b/jolt-optimizations/benches/dory_all.rs @@ -1,7 +1,7 @@ #![allow(non_snake_case)] use ark_bn254::{Fr, G1Projective, G2Projective}; -use ark_ec::{AdditiveGroup, PrimeGroup}; +use ark_ec::PrimeGroup; use ark_ff::PrimeField; use ark_std::UniformRand; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; diff --git a/jolt-optimizations/benches/dory_utils.rs b/jolt-optimizations/benches/dory_utils.rs index 9e86aef9b..89a867471 100644 --- a/jolt-optimizations/benches/dory_utils.rs +++ b/jolt-optimizations/benches/dory_utils.rs @@ -2,15 +2,13 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri use std::time::Instant; use ark_bn254::{Fr, G2Affine, G2Projective}; -use ark_ec::PrimeGroup; -use ark_ec::{AdditiveGroup, AffineRepr}; +use ark_ec::{AffineRepr, PrimeGroup}; use ark_ff::{PrimeField, UniformRand}; use ark_std::test_rng; use jolt_optimizations::{ vector_scalar_mul_add, vector_scalar_mul_add_online, vector_scalar_mul_add_precomputed, - vector_scalar_mul_v_add_g_online, vector_scalar_mul_v_add_g_precomputed, VectorScalarMulData, - VectorScalarMulVData, + VectorScalarMulData, }; fn bench_vector_scalar_mul_add(c: &mut Criterion) { diff --git a/jolt-optimizations/benches/expression_bench.rs b/jolt-optimizations/benches/expression_bench.rs deleted file mode 100644 index 2ea696fe9..000000000 --- a/jolt-optimizations/benches/expression_bench.rs +++ /dev/null @@ -1,61 +0,0 @@ -use ark_bn254::{Fq, Fq12}; -use ark_ff::{Field, PrimeField, UniformRand}; -use ark_std::test_rng; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use jolt_optimizations::expression::{Expression, Term}; -use jolt_optimizations::sz_check::batch_verify; - -fn benchmark_expression_verification(c: &mut Criterion) { - let mut rng = test_rng(); - - let configs = vec![(15, 6)]; - - for (n, m) in configs { - // Generate n expressions, each with m terms - let mut all_expressions = Vec::new(); - let mut all_expected_results = Vec::new(); - - for _ in 0..n { - let mut terms = Vec::new(); - let mut expected = Fq12::from(1u64); - - for _ in 0..m { - let base = Fq12::rand(&mut rng); - let exponent = Fq::rand(&mut rng); - terms.push(Term { base, exponent }); - expected *= base.pow(exponent.into_bigint()); - } - - all_expressions.push(Expression::new(terms)); - all_expected_results.push(expected); - } - - let mut all_products = Vec::new(); - for expr in &all_expressions { - all_products.extend(expr.to_products()); - } - - let r = Fq::rand(&mut rng); - - // naive computation - c.bench_function(&format!("naive_expr_{}x{}", n, m), |bench| { - bench.iter(|| { - for i in 0..n { - let mut result = Fq12::from(1u64); - for term in &all_expressions[i].terms { - result *= black_box(term.base.pow(term.exponent.into_bigint())); - } - black_box(result); - } - }); - }); - - // SZ check verification - c.bench_function(&format!("sz_check_expr_{}x{}", n, m), |bench| { - bench.iter(|| black_box(batch_verify(&all_products, &r))); - }); - } -} - -criterion_group!(benches, benchmark_expression_verification,); -criterion_main!(benches); diff --git a/jolt-optimizations/benches/g1_scalar_multiplication.rs b/jolt-optimizations/benches/g1_scalar_multiplication.rs index d2b302cea..0f48a9f6e 100644 --- a/jolt-optimizations/benches/g1_scalar_multiplication.rs +++ b/jolt-optimizations/benches/g1_scalar_multiplication.rs @@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri use rayon::prelude::*; use ark_bn254::{Fr, G1Affine, G1Projective}; -use ark_ec::{AdditiveGroup, AffineRepr, PrimeGroup}; +use ark_ec::{AffineRepr, PrimeGroup}; use ark_ff::{PrimeField, UniformRand}; use ark_std::test_rng; diff --git a/jolt-optimizations/benches/scalar_multiplication.rs b/jolt-optimizations/benches/scalar_multiplication.rs index 30a4f3e3b..ab9db9f9d 100644 --- a/jolt-optimizations/benches/scalar_multiplication.rs +++ b/jolt-optimizations/benches/scalar_multiplication.rs @@ -2,14 +2,13 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri use rayon::prelude::*; use ark_bn254::{Fr, G2Affine, G2Projective}; -use ark_ec::PrimeGroup; -use ark_ec::{AdditiveGroup, AffineRepr}; +use ark_ec::{AffineRepr, PrimeGroup}; use ark_ff::{PrimeField, UniformRand}; use ark_std::test_rng; use jolt_optimizations::{ glv_four_precompute, glv_four_precompute_windowed2_signed, glv_four_scalar_mul, - glv_four_scalar_mul_online, glv_four_scalar_mul_windowed2_signed, + glv_four_scalar_mul_windowed2_signed, }; fn bench_scalar_multiplication(c: &mut Criterion) { diff --git a/jolt-optimizations/benches/sz_check_bench.rs b/jolt-optimizations/benches/sz_check_bench.rs deleted file mode 100644 index 92edf1599..000000000 --- a/jolt-optimizations/benches/sz_check_bench.rs +++ /dev/null @@ -1,42 +0,0 @@ -use ark_bn254::{Fq, Fq12}; -use ark_ff::UniformRand; -use ark_std::test_rng; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use jolt_optimizations::sz_check::{batch_verify, Product}; - -fn benchmark_sz_check(c: &mut Criterion) { - let mut rng = test_rng(); - let sizes = vec![100000]; - - for k in sizes { - let mut products = Vec::new(); - let mut a_values = Vec::new(); - let mut b_values = Vec::new(); - - for _ in 0..k { - let a = Fq12::rand(&mut rng); - let b = Fq12::rand(&mut rng); - let c = a * b; - a_values.push(a); - b_values.push(b); - products.push(Product::new(a, b, c)); - } - - let r = Fq::rand(&mut rng); - - c.bench_function(&format!("naive_verify_{}", k), |bench| { - bench.iter(|| { - for i in 0..k { - let _ = black_box(a_values[i] * b_values[i]); - } - }); - }); - - c.bench_function(&format!("sz_check_{}", k), |bench| { - bench.iter(|| black_box(batch_verify(&products, &r))); - }); - } -} - -criterion_group!(benches, benchmark_sz_check); -criterion_main!(benches); diff --git a/jolt-optimizations/src/fq12_poly.rs b/jolt-optimizations/src/fq12_poly.rs index b30379e01..7db7d039e 100644 --- a/jolt-optimizations/src/fq12_poly.rs +++ b/jolt-optimizations/src/fq12_poly.rs @@ -4,33 +4,6 @@ use ark_ff::{Field, One, Zero}; const NINE: u64 = 9; -/// Newtype wrapper for degree-12 polys from Fq12 -#[derive(Clone, Debug, Default)] -pub struct Poly12([Fq; 12]); - -impl Poly12 { - pub fn new(coeffs: [Fq; 12]) -> Self { - Self(coeffs) - } - - pub fn coeffs(&self) -> &[Fq; 12] { - &self.0 - } - - pub fn coeffs_mut(&mut self) -> &mut [Fq; 12] { - &mut self.0 - } - - pub fn to_vec(&self) -> Vec { - self.0.to_vec() - } - - /// Horner's method - pub fn eval(&self, r: &Fq) -> Fq { - self.0.iter().rev().fold(Fq::zero(), |acc, c| acc * r + c) - } -} - /// Convert Fq12 to polynomial representation using tower basis mapping /// /// Maps Fq12 basis elements to powers of w: @@ -70,33 +43,15 @@ pub fn fq12_to_poly12_coeffs(a: &Fq12) -> [Fq; 12] { coeffs } -/// The minimal polynomial g(X) = X^12 - 18 X^6 + 82 -struct IrreduciblePoly; - -impl IrreduciblePoly { - const COEFF_0: u64 = 82; - const COEFF_6: u64 = 18; - - /// Evaluate g(X) at point r - fn eval(r: &Fq) -> Fq { - let r6 = (r.square() * r).square(); // r^6 = (r^2 * r)^2 - let r12 = r6.square(); - r12 - Fq::from(Self::COEFF_6) * r6 + Fq::from(Self::COEFF_0) - } - - /// Get coefficients as a vector - fn coeffs() -> Vec { - let mut g = vec![Fq::zero(); 13]; - g[0] = Fq::from(Self::COEFF_0); - g[6] = -Fq::from(Self::COEFF_6); - g[12] = Fq::one(); - g - } -} +/// Coefficients for the minimal polynomial g(X) = X^12 - 18 X^6 + 82 +const G_COEFF_0: u64 = 82; +const G_COEFF_6: u64 = 18; /// Evaluate g(X) = X^12 - 18 X^6 + 82 at a given point r pub fn g_eval(r: &Fq) -> Fq { - IrreduciblePoly::eval(r) + let r6 = (r.square() * r).square(); // r^6 = (r^2 * r)^2 + let r12 = r6.square(); + r12 - Fq::from(G_COEFF_6) * r6 + Fq::from(G_COEFF_0) } /// Horner evaluation for arbitrary-degree poly @@ -104,79 +59,13 @@ pub fn eval_poly_vec(coeffs: &[Fq], r: &Fq) -> Fq { coeffs.iter().rev().fold(Fq::zero(), |acc, c| acc * r + c) } -fn poly_op_in_place(a: &mut Vec, b: &[Fq], op: F) -where - F: Fn(&mut Fq, Fq), -{ - if b.len() > a.len() { - a.resize(b.len(), Fq::zero()); - } - b.iter() - .enumerate() - .for_each(|(i, &coeff)| op(&mut a[i], coeff)); -} - -pub fn poly_add_in_place(a: &mut Vec, b: &[Fq]) { - poly_op_in_place(a, b, |a, b| *a += b); -} - -pub fn poly_sub_in_place(a: &mut Vec, b: &[Fq]) { - poly_op_in_place(a, b, |a, b| *a -= b); -} - -pub fn poly_mul(a: &[Fq], b: &[Fq]) -> Vec { - if a.is_empty() || b.is_empty() { - return vec![]; - } - - let mut out = vec![Fq::zero(); a.len() + b.len() - 1]; - a.iter().enumerate().for_each(|(i, &ai)| { - b.iter().enumerate().for_each(|(j, &bj)| { - out[i + j] += ai * bj; - }) - }); - out -} - -/// Polynomial long division by a monic divisor -pub fn poly_div_rem_monic(mut dividend: Vec, divisor: &[Fq]) -> (Vec, Vec) { - assert!(!divisor.is_empty(), "divisor must be non-empty"); - assert!( - divisor.last().unwrap().is_one(), - "divisor must be monic (leading coefficient = 1)" - ); - - if dividend.is_empty() || dividend.len() < divisor.len() { - return (vec![], dividend); - } - - let deg_dividend = dividend.len() - 1; - let deg_divisor = divisor.len() - 1; - let mut quotient = vec![Fq::zero(); deg_dividend - deg_divisor + 1]; - - for k in (deg_divisor..=deg_dividend).rev() { - let coeff = dividend[k]; - quotient[k - deg_divisor] = coeff; - - if !coeff.is_zero() { - // Subtract coeff * x^{k-deg_divisor} * divisor from dividend - (0..=deg_divisor).for_each(|j| { - dividend[k - deg_divisor + j] -= coeff * divisor[j]; - }); - } - } - - // Trim trailing zeros from remainder - while dividend.last() == Some(&Fq::zero()) { - dividend.pop(); - } - - (quotient, dividend) -} - /// Build the coefficients for g(X) = X^12 - 18 X^6 + 82 pub fn g_coeffs() -> Vec { - IrreduciblePoly::coeffs() + let mut g = vec![Fq::zero(); 13]; + g[0] = Fq::from(G_COEFF_0); + g[6] = -Fq::from(G_COEFF_6); + g[12] = Fq::one(); + g } /// Compute the multilinear extension (MLE) of a univariate polynomial. @@ -222,17 +111,17 @@ pub fn eval_multilinear(evals: &[Fq], point: &[Fq]) -> Fq { } /// Compute equality function weights eq(z, x) for all x ∈ {0,1}^4 -/// Returns a vector of 16 weights where w[i] = eq(z, binary_decomposition(i)) +/// Helper for testing in arkworks pub fn eq_weights(z: &[Fq]) -> Vec { assert_eq!(z.len(), 4, "Point z must be 4-dimensional"); let mut w = vec![Fq::zero(); 16]; for idx in 0..16 { // Binary decomposition of idx - let x0 = if (idx & 1) != 0 { Fq::one() } else { Fq::zero() }; - let x1 = if (idx & 2) != 0 { Fq::one() } else { Fq::zero() }; - let x2 = if (idx & 4) != 0 { Fq::one() } else { Fq::zero() }; - let x3 = if (idx & 8) != 0 { Fq::one() } else { Fq::zero() }; + let x0 = Fq::from((idx & 1) as u64); + let x1 = Fq::from(((idx >> 1) & 1) as u64); + let x2 = Fq::from(((idx >> 2) & 1) as u64); + let x3 = Fq::from(((idx >> 3) & 1) as u64); // eq(z, x) = ∏ᵢ ((1-zᵢ)(1-xᵢ) + zᵢxᵢ) let t0 = (Fq::one() - z[0]) * (Fq::one() - x0) + z[0] * x0; diff --git a/jolt-optimizations/src/lib.rs b/jolt-optimizations/src/lib.rs index 2d8b97d7b..3cbf0f246 100644 --- a/jolt-optimizations/src/lib.rs +++ b/jolt-optimizations/src/lib.rs @@ -62,4 +62,4 @@ pub use fq12_poly::{ to_multilinear_evals, }; -pub use witness_gen::{pow_with_steps_le, ExponentiationSteps}; +pub use witness_gen::ExponentiationSteps; diff --git a/jolt-optimizations/src/witness_gen.rs b/jolt-optimizations/src/witness_gen.rs index 8f64d31d8..b553a2781 100644 --- a/jolt-optimizations/src/witness_gen.rs +++ b/jolt-optimizations/src/witness_gen.rs @@ -45,14 +45,13 @@ impl ExponentiationSteps { fq12_to_multilinear_evals(&Fq12::one()), // ρ_0 fq12_to_multilinear_evals(&base), // ρ_1 ], - quotient_mles: vec![], // Could compute a single Q_1 if needed + quotient_mles: vec![], bits: vec![true], }; } let bits_msb: Vec = (0..=msb_idx).rev().map(|i| bits_le[i]).collect(); - // ρ_0 = 1 let mut rho = Fq12::one(); let mut rho_mles = vec![fq12_to_multilinear_evals(&rho)]; let mut quotient_mles = vec![]; @@ -67,7 +66,7 @@ impl ExponentiationSteps { // One quotient per step for: ρ_i(X) - ρ_{i-1}(X)² * A(X)^{b} = Q_i(X) g(X) let q_i = compute_step_quotient_msb(rho_prev, rho_i, base, b); - quotient_mles.push(quotient_to_mle(&q_i)); + quotient_mles.push(q_i); rho = rho_i; rho_mles.push(fq12_to_multilinear_evals(&rho)); @@ -83,13 +82,15 @@ impl ExponentiationSteps { } } - /// Verify that the final result matches base^exponent + /// Verify that the final result matches base^exponent, + /// Used for testing pub fn verify_result(&self) -> bool { self.result == self.base.pow(self.exponent.into_bigint()) } /// Verify constraint at a Boolean cube point /// Checks that the constraint holds at cube vertices where it was constructed to be zero + /// Used for testing pub fn verify_constraint_at_cube_point(&self, step: usize, cube_index: usize) -> bool { if step == 0 || step > self.quotient_mles.len() || cube_index >= 16 { return false; @@ -110,7 +111,6 @@ impl ExponentiationSteps { let bit = self.bits[step - 1]; let base_power = if bit { base_eval } else { Fq::one() }; let constraint = rho_curr - rho_prev.square() * base_power - quotient * g_eval; - println!("constraint: {:?}", constraint); constraint.is_zero() } @@ -144,6 +144,7 @@ fn compute_step_quotient_msb(rho_prev: Fq12, rho_i: Fq12, base: Fq12, bit: bool) } /// Get g as MLE evaluations over the Boolean cube {0,1}^4 +/// Used for testing pub fn get_g_mle() -> Vec { // Use the same encoding as fq12_to_multilinear_evals // g(X) = X^12 - 18X^6 + 82 as coefficient array @@ -157,13 +158,8 @@ pub fn get_g_mle() -> Vec { to_multilinear_evals(&g_array) } -/// Convert quotient MLE to the format needed (already an MLE, just return it) -fn quotient_to_mle(quotient: &[Fq]) -> Vec { - // In the MLE paradigm, quotient is already an MLE - quotient.to_vec() -} - /// Convert a cube index (0..15) to a Boolean point in {0,1}^4 +/// Used for testing pub fn index_to_boolean_point(index: usize) -> Vec { vec![ Fq::from((index & 1) as u64), // bit 0 @@ -175,15 +171,16 @@ pub fn index_to_boolean_point(index: usize) -> Vec { /// Evaluate an MLE at a Boolean cube point /// For Boolean points, this is equivalent to indexing but makes the evaluation explicit +/// Used for testing fn eval_mle_at_boolean_point(mle: &[Fq], point: &[Fq]) -> Fq { // For Boolean points, we could just index, but using eval_multilinear // makes it clear we're doing MLE evaluation eval_multilinear(mle, point) } -/// Compute H̃(z) via eq-weights (definition of MLE), not by multiplying opened MLEs /// H(x) = ρᵢ(x) - ρᵢ₋₁(x)² · A(x)^{bᵢ} - Qᵢ(x) · g(x) for x ∈ {0,1}^4 /// H̃(z) = Σ_{x∈{0,1}^4} eq(z,x) · H(x) +/// Used for testing pub fn h_tilde_at_point( rho_prev_mle: &[Fq], rho_curr_mle: &[Fq], @@ -208,14 +205,8 @@ pub fn h_tilde_at_point( let prod = rho_prev_mle[j].square() * if bit { base_mle[j] } else { Fq::one() }; let h_x = rho_curr_mle[j] - prod - q_mle[j] * g_mle[j]; - // Add weighted contribution to MLE acc += h_x * w[j]; } - acc // equals H̃(z) -} - -/// Legacy compatibility function -pub fn pow_with_steps_le(base: Fq12, exponent: Fr) -> ExponentiationSteps { - ExponentiationSteps::new(base, exponent) + acc } diff --git a/jolt-optimizations/tests/mle_tests.rs b/jolt-optimizations/tests/mle_tests.rs index 7a71c6ac8..571fccae4 100644 --- a/jolt-optimizations/tests/mle_tests.rs +++ b/jolt-optimizations/tests/mle_tests.rs @@ -15,7 +15,6 @@ fn random_poly12_coeffs() -> [Fq; 12] { #[test] fn test_mle_agreement_with_univariate() { - // Test that MLE agrees with original polynomial on domain {0..15} let coeffs = random_poly12_coeffs(); let mle_evals = to_multilinear_evals(&coeffs); @@ -28,7 +27,6 @@ fn test_mle_agreement_with_univariate() { "MLE evaluation doesn't match univariate at point {}", i ); - // Binary decomposition of i: (b₀, b₁, b₂, b₃) where i = b₀ + 2b₁ + 4b₂ + 8b₃ let binary_point = vec![ Fq::from((i & 1) as u64), Fq::from(((i >> 1) & 1) as u64), @@ -51,7 +49,6 @@ fn test_mle_is_multilinear() { let coeffs = random_poly12_coeffs(); let mle_evals = to_multilinear_evals(&coeffs); - // Test linearity in each variable for var_idx in 0..4 { let point = vec![ Fq::rand(&mut rng), @@ -84,7 +81,6 @@ fn test_mle_is_multilinear() { #[test] fn test_mle_special_cases() { - // Test 1: Zero polynomial let zero_coeffs = [Fq::zero(); 12]; let mle = to_multilinear_evals(&zero_coeffs); assert!( @@ -92,7 +88,6 @@ fn test_mle_special_cases() { "Zero polynomial MLE should be all zeros" ); - // Test 2: Constant polynomial p(x) = 42 let const_val = Fq::from(42u64); let mut const_coeffs = [Fq::zero(); 12]; const_coeffs[0] = const_val; @@ -102,7 +97,6 @@ fn test_mle_special_cases() { "Constant polynomial MLE should be constant" ); - // Test 3: Linear polynomial p(x) = x let mut linear_coeffs = [Fq::zero(); 12]; linear_coeffs[1] = Fq::one(); let mle = to_multilinear_evals(&linear_coeffs); @@ -115,8 +109,6 @@ fn test_mle_special_cases() { i ); } - - // Test 4: Quadratic polynomial p(x) = x² let mut quad_coeffs = [Fq::zero(); 12]; quad_coeffs[2] = Fq::one(); let mle = to_multilinear_evals(&quad_coeffs); @@ -132,7 +124,6 @@ fn test_mle_special_cases() { #[test] fn test_mle_high_degree() { - // Test with maximum degree polynomial (degree 11) let mut coeffs = [Fq::zero(); 12]; coeffs[11] = Fq::one(); diff --git a/jolt-optimizations/tests/witness_test.rs b/jolt-optimizations/tests/witness_test.rs index 4e26f788f..b81679109 100644 --- a/jolt-optimizations/tests/witness_test.rs +++ b/jolt-optimizations/tests/witness_test.rs @@ -43,7 +43,6 @@ fn test_witness_generation_and_constraints() { assert_eq!(mle.len(), 16, "Quotient MLEs should have 16 evaluations"); } - // Test constraint verification at Boolean cube points // The constraint should be zero at all 16 cube vertices for cube_idx in 0..16 { for step in 1..=witness.num_steps() { @@ -68,23 +67,22 @@ fn test_trivial_cases() { assert_eq!(witness_zero.result, Fq12::one()); assert!(witness_zero.verify_result()); assert_eq!(witness_zero.bits.len(), 0); - assert_eq!(witness_zero.rho_mles.len(), 1); // Just ρ_0 = 1 + assert_eq!(witness_zero.rho_mles.len(), 1); assert_eq!(witness_zero.quotient_mles.len(), 0); // Test exponent = 1 let witness_one = ExponentiationSteps::new(base, Fr::from(1u64)); assert_eq!(witness_one.result, base); assert!(witness_one.verify_result()); - assert_eq!(witness_one.bits, vec![true]); // Single bit: 1 - assert_eq!(witness_one.rho_mles.len(), 2); // ρ_0 = 1, ρ_1 = base - + assert_eq!(witness_one.bits, vec![true]); + assert_eq!(witness_one.rho_mles.len(), 2); // Test small known values to verify bit sequence let witness_five = ExponentiationSteps::new(base, Fr::from(5u64)); - assert_eq!(witness_five.bits, vec![true, false, true]); // MSB to LSB: 101 + assert_eq!(witness_five.bits, vec![true, false, true]); assert!(witness_five.verify_result()); let witness_ten = ExponentiationSteps::new(base, Fr::from(10u64)); - assert_eq!(witness_ten.bits, vec![true, false, true, false]); // MSB to LSB: 1010 + assert_eq!(witness_ten.bits, vec![true, false, true, false]); assert!(witness_ten.verify_result()); } @@ -194,15 +192,13 @@ fn test_constraint_at_random_field_element() { // Create witness for a simple exponentiation let base = Fq12::rand(&mut rng); - let exponent = Fr::from(10000300u64); // Simple exponent: binary 111 + let exponent = Fr::from(10000300u64); let witness = ExponentiationSteps::new(base, exponent); let base_mle = fq12_to_multilinear_evals(&base); let g_mle = get_g_mle(); - // Test at random field elements (not on hypercube) for test_idx in 0..10000 { - // Generate random point z = (z0, z1, z2, z3) where zi ∈ Fq \ {0,1} let z: Vec = (0..4) .map(|_| { let mut val = Fq::rand(&mut rng); @@ -214,11 +210,9 @@ fn test_constraint_at_random_field_element() { }) .collect(); - // Pick a step to check let step = 1 + (test_idx % witness.num_steps()); let bit = witness.bits[step - 1]; - // Compute H̃(z) using the correct MLE definition let h = h_tilde_at_point( &witness.rho_mles[step - 1], &witness.rho_mles[step], @@ -229,7 +223,7 @@ fn test_constraint_at_random_field_element() { &z, ); - // H̃(z) must be 0 at random z (Sumcheck-consistent) + // H̃(z) must be 0 at random z assert!( h.is_zero(), "H̃(z) must be 0 at random z (test {}, step {}). Got: {:?}", @@ -241,7 +235,6 @@ fn test_constraint_at_random_field_element() { println!("✓ Verified: H̃(z) = 0 at 20 random field elements (Sumcheck correct)"); - // Also verify it works on the hypercube (sanity check) for step in 1..=witness.num_steps() { for cube_idx in 0..16 { assert!( @@ -255,50 +248,3 @@ fn test_constraint_at_random_field_element() { println!("✓ Verified: Constraints are zero on hypercube (sanity check)"); } - -#[test] -fn test_zero_tampering_soundness() { - let mut rng = test_rng(); - let base = Fq12::rand(&mut rng); - - for exp_val in [2u64, 3, 7, 15, 31] { - let exponent = Fr::from(exp_val); - let mut witness = ExponentiationSteps::new(base, exponent); - - // Setting ρ to zero should break soundness - if witness.rho_mles.len() > 1 { - let original = witness.rho_mles[1][0]; - witness.rho_mles[1][0] = Fq::zero(); - - let mut soundness_broken = false; - for step in 1..=witness.num_steps() { - if !witness.verify_constraint_at_cube_point(step, 0) { - soundness_broken = true; - break; - } - } - - assert!( - soundness_broken || !witness.verify_result(), - "Setting ρ to zero should break soundness" - ); - witness.rho_mles[1][0] = original; - } - - // Setting quotient to zero should break soundness - if !witness.quotient_mles.is_empty() { - let original = witness.quotient_mles[0][0]; - witness.quotient_mles[0][0] = Fq::zero(); - - if original != Fq::zero() { - let soundness_broken = !witness.verify_constraint_at_cube_point(1, 0); - assert!( - soundness_broken, - "Setting non-zero quotient to zero should break soundness" - ); - } - - witness.quotient_mles[0][0] = original; - } - } -} diff --git a/test-curves/benches/small_mul.rs b/test-curves/benches/small_mul.rs index d79b9c530..ce7823dad 100644 --- a/test-curves/benches/small_mul.rs +++ b/test-curves/benches/small_mul.rs @@ -10,35 +10,23 @@ fn mul_small_bench(c: &mut Criterion) { // Use a fixed seed for reproducibility let mut rng = StdRng::seed_from_u64(0u64); - let a_s = (0..SAMPLES) - .map(|_| Fr::rand(&mut rng)) - .collect::>(); - let a_limbs_s = a_s.iter().map(|a| a.0.0).collect::>(); - - let b_u64_s = (0..SAMPLES) - .map(|_| rng.gen::()) - .collect::>(); + let a_s = (0..SAMPLES).map(|_| Fr::rand(&mut rng)).collect::>(); + let a_limbs_s = a_s.iter().map(|a| a.0 .0).collect::>(); + + let b_u64_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); // Convert u64 to Fr for standard multiplication benchmark let b_fr_s = b_u64_s.iter().map(|&b| Fr::from(b)).collect::>(); let b_u64_as_u128_s = b_u64_s.iter().map(|&b| b as u128).collect::>(); - let b_i64_s = (0..SAMPLES) - .map(|_| rng.gen::()) - .collect::>(); + let b_i64_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); - let b_u128_s = (0..SAMPLES) - .map(|_| rng.gen::()) - .collect::>(); + let b_u128_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); - let b_i128_s = (0..SAMPLES) - .map(|_| rng.gen::()) - .collect::>(); + let b_i128_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); // Generate another set of random Fr elements for addition - let c_s = (0..SAMPLES) - .map(|_| Fr::rand(&mut rng)) - .collect::>(); + let c_s = (0..SAMPLES).map(|_| Fr::rand(&mut rng)).collect::>(); let mut group = c.benchmark_group("Fr Arithmetic Comparison"); @@ -118,4 +106,4 @@ fn mul_small_bench(c: &mut Criterion) { } criterion_group!(benches, mul_small_bench); -criterion_main!(benches); \ No newline at end of file +criterion_main!(benches); diff --git a/test-curves/src/bn254/fq.rs b/test-curves/src/bn254/fq.rs index 001d94836..6bddf9bc0 100644 --- a/test-curves/src/bn254/fq.rs +++ b/test-curves/src/bn254/fq.rs @@ -10,4 +10,4 @@ pub struct FqConfig; pub type Fq = Fp256>; pub const FQ_ONE: Fq = ark_ff::MontFp!("1"); -pub const FQ_ZERO: Fq = ark_ff::MontFp!("0"); \ No newline at end of file +pub const FQ_ZERO: Fq = ark_ff::MontFp!("0"); diff --git a/test-curves/src/bn254/fr.rs b/test-curves/src/bn254/fr.rs index 4caef8e7c..4de077431 100644 --- a/test-curves/src/bn254/fr.rs +++ b/test-curves/src/bn254/fr.rs @@ -14,4 +14,4 @@ pub struct FrConfig; pub type Fr = Fp256>; pub const FR_ONE: Fr = ark_ff::MontFp!("1"); -pub const FR_ZERO: Fr = ark_ff::MontFp!("0"); \ No newline at end of file +pub const FR_ZERO: Fr = ark_ff::MontFp!("0"); diff --git a/test-curves/src/bn254/g1.rs b/test-curves/src/bn254/g1.rs index 2b3c5a0c5..608278db8 100644 --- a/test-curves/src/bn254/g1.rs +++ b/test-curves/src/bn254/g1.rs @@ -1,8 +1,6 @@ -use ark_ec::models::short_weierstrass::{ - Affine, Projective, SWCurveConfig, -}; +use ark_ec::models::short_weierstrass::{Affine, Projective, SWCurveConfig}; use ark_ec::CurveConfig; -use ark_ff::{Field, MontFp, Zero, AdditiveGroup}; +use ark_ff::{AdditiveGroup, Field, MontFp, Zero}; use crate::bn254::{Fq, Fr}; // Assuming Fq is defined in fq.rs diff --git a/test-curves/src/bn254/test.rs b/test-curves/src/bn254/test.rs index 51a9c691e..176467a7a 100644 --- a/test-curves/src/bn254/test.rs +++ b/test-curves/src/bn254/test.rs @@ -2,7 +2,9 @@ use ark_ec::{ models::short_weierstrass::SWCurveConfig, // Keep this as G1 is SW pairing::Pairing, - AffineRepr, CurveGroup, PrimeGroup, + AffineRepr, + CurveGroup, + PrimeGroup, }; use ark_ff::{Field, One, UniformRand, Zero}; use ark_std::{rand::Rng, test_rng}; From 183faaaac33cac6ef95f4b10b84b9c44445ad038 Mon Sep 17 00:00:00 2001 From: Ari Date: Tue, 30 Sep 2025 19:41:14 +0200 Subject: [PATCH 36/38] adding in montu128 helpers --- ff/src/fields/models/fp/montgomery_backend.rs | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index ad487ffa6..765855a2f 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -493,7 +493,11 @@ pub trait MontConfig: 'static + Sync + Send + Sized { // return Fp::zero(); // } let fe = Self::from_bigint_mixed::(x.magnitude); - if x.is_positive { fe } else { -fe } + if x.is_positive { + fe + } else { + -fe + } } /// Construct from a signed big integer with high 32-bit tail and K low 64-bit limbs. @@ -503,13 +507,20 @@ pub trait MontConfig: 'static + Sync + Send + Sized { fn from_signed_bigint_hi32( x: crate::biginteger::SignedBigIntHi32, ) -> Fp, N> { - debug_assert!(KPLUS1 == K + 1, "from_signed_bigint_hi32 requires KPLUS1 = K + 1"); + debug_assert!( + KPLUS1 == K + 1, + "from_signed_bigint_hi32 requires KPLUS1 = K + 1" + ); // if x.is_zero() { // return Fp::zero(); // } let mag = x.magnitude_as_bigint_nplus1::(); let fe = Self::from_bigint_mixed::(mag); - if x.is_positive() { fe } else { -fe } + if x.is_positive() { + fe + } else { + -fe + } } #[inline] @@ -882,9 +893,7 @@ impl, const N: usize> Fp, N> { /// Implementation folds from high to low using the existing N+1 Barrett kernel. /// Precondition: L >= N. For performance, prefer small L close to N..N+3 when possible. #[inline(always)] - pub fn from_barrett_reduce( - unreduced: BigInt, - ) -> Self { + pub fn from_barrett_reduce(unreduced: BigInt) -> Self { debug_assert!(NPLUS1 == N + 1); debug_assert!(L >= N); @@ -1094,6 +1103,14 @@ impl, const N: usize> Fp, N> { *self = self.const_cios_mul_rhs_hi2(hi as u64, (hi >> 64) as u64); } + /// Returns self * rhs_high_limbs, where RHS is zero in low N-2 limbs and has its top two + /// limbs provided by `hi` (low 64 -> limb N-2, high 64 -> limb N-1). Equivalent to K=2. + /// At the cost 2 extra words of storage uses no bit shift instructions to extract higher limbs + /// as in mul_hi_u128 + #[inline] + pub const fn mul_hi_bigint_u128(self, big_int_repre: [u64; 4]) -> Self { + self.const_cios_mul_rhs_hi2(big_int_repre[2], big_int_repre[3]) + } /// Returns self * rhs_high_limbs, where RHS is zero in low N-2 limbs and has its top two /// limbs provided by `hi` (low 64 -> limb N-2, high 64 -> limb N-1). Equivalent to K=2. #[inline] From fee4ab253cf56c0eb054cde03ad2c599db0f0317 Mon Sep 17 00:00:00 2001 From: Ari Date: Tue, 30 Sep 2025 22:25:30 +0200 Subject: [PATCH 37/38] adding in unchecked to prevent unnecessary mont-reductions --- ff/src/fields/models/fp/mod.rs | 45 ++++++++----------- ff/src/fields/models/fp/montgomery_backend.rs | 8 ++++ ff/src/fields/prime.rs | 3 ++ 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/ff/src/fields/models/fp/mod.rs b/ff/src/fields/models/fp/mod.rs index 30af6f46f..993880be5 100644 --- a/ff/src/fields/models/fp/mod.rs +++ b/ff/src/fields/models/fp/mod.rs @@ -4,8 +4,9 @@ use crate::{ }; use allocative::Allocative; use ark_serialize::{ - buffer_byte_size, CanonicalDeserialize, CanonicalDeserializeWithFlags, CanonicalSerialize, + CanonicalDeserialize, CanonicalDeserializeWithFlags, CanonicalSerialize, CanonicalSerializeWithFlags, Compress, EmptyFlags, Flags, SerializationError, Valid, Validate, + buffer_byte_size, }; use ark_std::{ cmp::*, @@ -98,6 +99,11 @@ pub trait FpConfig: Send + Sync + 'static + Sized { /// this range. fn from_bigint(other: BigInt) -> Option>; + /// Construct a field element from an integer in the range + /// `0..(Self::MODULUS - 1)`. Returns `None` if the integer is outside + /// this range (but do not do any Reductions) + fn from_bigint_unchecked(other: BigInt) -> Option>; + /// Convert a field element to an integer in the range `0..(Self::MODULUS - /// 1)`. fn into_bigint(other: Fp) -> BigInt; @@ -371,6 +377,11 @@ impl, const N: usize> PrimeField for Fp { P::into_bigint(self) } + #[inline] + fn from_bigint_unchecked(r: BigInt) -> Option { + P::from_bigint_unchecked(r) + } + #[inline] fn from_u64(r: u64) -> Option { P::from_u64::(r) @@ -433,11 +444,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i128) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -464,11 +471,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i64) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -485,11 +488,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i32) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -506,11 +505,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i16) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -527,11 +522,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i8) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -554,7 +545,7 @@ impl, const N: usize> ark_std::rand::distributions::Distribution< u64::MAX >> shave_bits }; - if let Some(val) = tmp.0 .0.last_mut() { + if let Some(val) = tmp.0.0.last_mut() { *val &= mask } diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index 765855a2f..d2de50090 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -462,6 +462,10 @@ pub trait MontConfig: 'static + Sync + Send + Sized { } } + fn from_bigint_unchecked(r: BigInt) -> Option, N>> { + Some(Fp::new_unchecked(r)) + } + fn from_bigint(r: BigInt) -> Option, N>> { let mut r = Fp::new_unchecked(r); if r.is_zero() { @@ -845,6 +849,10 @@ impl, const N: usize> FpConfig for MontBackend { T::from_bigint(r) } + fn from_bigint_unchecked(r: BigInt) -> Option> { + T::from_bigint_unchecked(r) + } + #[inline] #[allow(clippy::modulo_one)] fn into_bigint(a: Fp) -> BigInt { diff --git a/ff/src/fields/prime.rs b/ff/src/fields/prime.rs index 67a7f0db0..9ae26d83e 100644 --- a/ff/src/fields/prime.rs +++ b/ff/src/fields/prime.rs @@ -57,6 +57,9 @@ pub trait PrimeField: /// Converts an element of the prime field into an integer in the range 0..(p - 1). fn into_bigint(self) -> Self::BigInt; + /// Construct a prime field element from an integer in the range 0..(p - 1) (no reductions) + fn from_bigint_unchecked(repr: Self::BigInt) -> Option; + /// Creates a field element from a `u64`. /// Returns `None` if the `u64` is larger than or equal to the modulus. fn from_u64(val: u64) -> Option; From 0a4504d9a547861f3295cec38b43e1056eba5156 Mon Sep 17 00:00:00 2001 From: markosg04 Date: Fri, 3 Oct 2025 15:14:40 -0400 Subject: [PATCH 38/38] style: pr comments --- jolt-optimizations/src/fq12_poly.rs | 5 +--- jolt-optimizations/src/lib.rs | 6 ++--- jolt-optimizations/src/witness_gen.rs | 31 +++++++++--------------- jolt-optimizations/tests/mle_tests.rs | 5 +++- jolt-optimizations/tests/witness_test.rs | 18 ++++++-------- 5 files changed, 26 insertions(+), 39 deletions(-) diff --git a/jolt-optimizations/src/fq12_poly.rs b/jolt-optimizations/src/fq12_poly.rs index 7db7d039e..9a5683ef4 100644 --- a/jolt-optimizations/src/fq12_poly.rs +++ b/jolt-optimizations/src/fq12_poly.rs @@ -2,8 +2,6 @@ use ark_bn254::{Fq, Fq12}; use ark_ff::{Field, One, Zero}; -const NINE: u64 = 9; - /// Convert Fq12 to polynomial representation using tower basis mapping /// /// Maps Fq12 basis elements to powers of w: @@ -20,7 +18,7 @@ pub fn fq12_to_poly12_coeffs(a: &Fq12) -> [Fq; 12] { (1, 2, 5), // a.c1.c2 → w^5 ]; - let nine = Fq::from(NINE); + let nine = Fq::from(9); let mut coeffs = [Fq::zero(); 12]; for &(outer, inner, w_power) in &MAPPINGS { @@ -111,7 +109,6 @@ pub fn eval_multilinear(evals: &[Fq], point: &[Fq]) -> Fq { } /// Compute equality function weights eq(z, x) for all x ∈ {0,1}^4 -/// Helper for testing in arkworks pub fn eq_weights(z: &[Fq]) -> Vec { assert_eq!(z.len(), 4, "Point z must be 4-dimensional"); let mut w = vec![Fq::zero(); 16]; diff --git a/jolt-optimizations/src/lib.rs b/jolt-optimizations/src/lib.rs index 3cbf0f246..2f40915e6 100644 --- a/jolt-optimizations/src/lib.rs +++ b/jolt-optimizations/src/lib.rs @@ -58,8 +58,8 @@ pub use dory_g2::{ pub use batch_addition::{batch_g1_additions, batch_g1_additions_multi}; pub use fq12_poly::{ - eval_multilinear, fq12_to_multilinear_evals, fq12_to_poly12_coeffs, g_coeffs, g_eval, - to_multilinear_evals, + eq_weights, eval_multilinear, fq12_to_multilinear_evals, fq12_to_poly12_coeffs, g_coeffs, + g_eval, to_multilinear_evals, }; -pub use witness_gen::ExponentiationSteps; +pub use witness_gen::{get_g_mle, h_tilde_at_point, ExponentiationSteps}; diff --git a/jolt-optimizations/src/witness_gen.rs b/jolt-optimizations/src/witness_gen.rs index b553a2781..d5098982f 100644 --- a/jolt-optimizations/src/witness_gen.rs +++ b/jolt-optimizations/src/witness_gen.rs @@ -1,6 +1,4 @@ -use crate::fq12_poly::{ - eq_weights, eval_multilinear, fq12_to_multilinear_evals, g_coeffs, to_multilinear_evals, -}; +use crate::fq12_poly::{eq_weights, eval_multilinear, fq12_to_multilinear_evals, g_coeffs}; use ark_bn254::{Fq, Fq12, Fr}; use ark_ff::{BigInteger, Field, One, PrimeField, Zero}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; @@ -82,15 +80,13 @@ impl ExponentiationSteps { } } - /// Verify that the final result matches base^exponent, - /// Used for testing + /// Verify that the final result matches base^exponent pub fn verify_result(&self) -> bool { self.result == self.base.pow(self.exponent.into_bigint()) } /// Verify constraint at a Boolean cube point /// Checks that the constraint holds at cube vertices where it was constructed to be zero - /// Used for testing pub fn verify_constraint_at_cube_point(&self, step: usize, cube_index: usize) -> bool { if step == 0 || step > self.quotient_mles.len() || cube_index >= 16 { return false; @@ -144,23 +140,20 @@ fn compute_step_quotient_msb(rho_prev: Fq12, rho_i: Fq12, base: Fq12, bit: bool) } /// Get g as MLE evaluations over the Boolean cube {0,1}^4 -/// Used for testing pub fn get_g_mle() -> Vec { - // Use the same encoding as fq12_to_multilinear_evals - // g(X) = X^12 - 18X^6 + 82 as coefficient array + use crate::fq12_poly::eval_poly_vec; let g_vec = g_coeffs(); - let mut g_array = [Fq::zero(); 12]; - for i in 0..12 { - if i < g_vec.len() { - g_array[i] = g_vec[i]; - } - } - to_multilinear_evals(&g_array) + + (0..16) + .map(|i| { + let x = Fq::from(i as u64); + eval_poly_vec(&g_vec[..], &x) + }) + .collect() } /// Convert a cube index (0..15) to a Boolean point in {0,1}^4 -/// Used for testing -pub fn index_to_boolean_point(index: usize) -> Vec { +pub(crate) fn index_to_boolean_point(index: usize) -> Vec { vec![ Fq::from((index & 1) as u64), // bit 0 Fq::from(((index >> 1) & 1) as u64), // bit 1 @@ -171,7 +164,6 @@ pub fn index_to_boolean_point(index: usize) -> Vec { /// Evaluate an MLE at a Boolean cube point /// For Boolean points, this is equivalent to indexing but makes the evaluation explicit -/// Used for testing fn eval_mle_at_boolean_point(mle: &[Fq], point: &[Fq]) -> Fq { // For Boolean points, we could just index, but using eval_multilinear // makes it clear we're doing MLE evaluation @@ -180,7 +172,6 @@ fn eval_mle_at_boolean_point(mle: &[Fq], point: &[Fq]) -> Fq { /// H(x) = ρᵢ(x) - ρᵢ₋₁(x)² · A(x)^{bᵢ} - Qᵢ(x) · g(x) for x ∈ {0,1}^4 /// H̃(z) = Σ_{x∈{0,1}^4} eq(z,x) · H(x) -/// Used for testing pub fn h_tilde_at_point( rho_prev_mle: &[Fq], rho_curr_mle: &[Fq], diff --git a/jolt-optimizations/tests/mle_tests.rs b/jolt-optimizations/tests/mle_tests.rs index 571fccae4..c507aec64 100644 --- a/jolt-optimizations/tests/mle_tests.rs +++ b/jolt-optimizations/tests/mle_tests.rs @@ -1,7 +1,10 @@ use ark_bn254::Fq; use ark_ff::{Field, One, UniformRand, Zero}; use ark_std::test_rng; -use jolt_optimizations::fq12_poly::{eval_multilinear, eval_poly_vec, to_multilinear_evals}; +use jolt_optimizations::{ + eval_multilinear, + fq12_poly::{eval_poly_vec, to_multilinear_evals}, +}; /// Generate random polynomial coefficients for testing fn random_poly12_coeffs() -> [Fq; 12] { diff --git a/jolt-optimizations/tests/witness_test.rs b/jolt-optimizations/tests/witness_test.rs index b81679109..9a02838d7 100644 --- a/jolt-optimizations/tests/witness_test.rs +++ b/jolt-optimizations/tests/witness_test.rs @@ -1,10 +1,8 @@ use ark_bn254::{Fq, Fq12, Fr}; -use ark_ff::{Field, One, UniformRand, Zero}; +use ark_ff::{One, UniformRand, Zero}; use ark_std::test_rng; use jolt_optimizations::{ - eval_multilinear, fq12_to_multilinear_evals, g_coeffs, to_multilinear_evals, - witness_gen::{get_g_mle, h_tilde_at_point}, - ExponentiationSteps, + fq12_to_multilinear_evals, get_g_mle, h_tilde_at_point, ExponentiationSteps, }; #[test] @@ -90,7 +88,7 @@ fn test_trivial_cases() { fn test_witness_soundness() { let mut rng = test_rng(); - // Test soundness: tampering with witness should be detected + // Test soundness: tampering with witness for test_idx in 0..20 { let base = Fq12::rand(&mut rng); let exponent = if test_idx == 0 { @@ -113,7 +111,7 @@ fn test_witness_soundness() { } assert!(all_valid, "Original witness should be valid"); - // Test 1: Tampering with ρ values breaks soundness + // Test 1: Tampering with ρ values if witness.rho_mles.len() > 1 { let tamper_idx = 1 + (test_idx % (witness.rho_mles.len() - 1)); let point_idx = test_idx % 16; @@ -141,7 +139,7 @@ fn test_witness_soundness() { witness.rho_mles[tamper_idx][point_idx] = original; } - // Test 2: Tampering with quotient values breaks soundness + // Test 2: Tampering with quotient if !witness.quotient_mles.is_empty() { let q_idx = test_idx % witness.quotient_mles.len(); let point_idx = (test_idx * 7) % 16; @@ -158,7 +156,7 @@ fn test_witness_soundness() { witness.quotient_mles[q_idx][point_idx] = original; } - // Test 3: Flipping bits breaks soundness + // Test 3: Flipping bits if !witness.bits.is_empty() { let bit_idx = test_idx % witness.bits.len(); witness.bits[bit_idx] = !witness.bits[bit_idx]; @@ -175,7 +173,7 @@ fn test_witness_soundness() { witness.bits[bit_idx] = !witness.bits[bit_idx]; } - // Test 4: Tampering with final result breaks verification + // Test 4: Tampering with final result let original_result = witness.result; witness.result = witness.result + Fq12::one(); assert!( @@ -233,8 +231,6 @@ fn test_constraint_at_random_field_element() { ); } - println!("✓ Verified: H̃(z) = 0 at 20 random field elements (Sumcheck correct)"); - for step in 1..=witness.num_steps() { for cube_idx in 0..16 { assert!(