diff --git a/bench-templates/src/macros/field.rs b/bench-templates/src/macros/field.rs index 40fe597b1..db11baabf 100644 --- a/bench-templates/src/macros/field.rs +++ b/bench-templates/src/macros/field.rs @@ -405,16 +405,14 @@ macro_rules! prime_field { f[i].into_bigint() }) }); - let u64s = (0..SAMPLES) - .map(|_| rng.next_u64()) - .collect::>(); + let u64s = (0..SAMPLES).map(|_| rng.next_u64()).collect::>(); conversions.bench_function("From u64", |b| { let mut i = 0; b.iter(|| { i = (i + 1) % SAMPLES; <$F>::from_u64(u64s[i]) }) - }); + }); conversions.finish() } }; diff --git a/curves/bn254/src/lib.rs b/curves/bn254/src/lib.rs index 13bcef101..4fdd1e61c 100755 --- a/curves/bn254/src/lib.rs +++ b/curves/bn254/src/lib.rs @@ -39,6 +39,7 @@ mod fields; #[cfg(feature = "curve")] pub use curves::*; +#[allow(unused_imports)] pub use fields::*; #[cfg(feature = "r1cs")] diff --git a/ec/src/lib.rs b/ec/src/lib.rs index ba99d4c87..47b8437e5 100644 --- a/ec/src/lib.rs +++ b/ec/src/lib.rs @@ -28,11 +28,7 @@ use ark_std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, vec::*, }; -pub use scalar_mul::{ - fixed_base::FixedBase, - variable_base::VariableBaseMSM, - ScalarMul, -}; +pub use scalar_mul::{fixed_base::FixedBase, variable_base::VariableBaseMSM, ScalarMul}; use zeroize::Zeroize; pub use ark_ff::AdditiveGroup; diff --git a/ec/src/pairing.rs b/ec/src/pairing.rs index a3aa83e26..f62d1be72 100644 --- a/ec/src/pairing.rs +++ b/ec/src/pairing.rs @@ -102,8 +102,14 @@ pub trait Pairing: Sized + 'static + Copy + Debug + Sync + Send + Eq { a: impl IntoIterator>, b: impl IntoIterator>, ) -> MillerLoopOutput { - let a_cloned = a.into_iter().map(|x| x.as_ref().clone()).collect::>(); - let b_cloned = b.into_iter().map(|x| x.as_ref().clone()).collect::>(); + let a_cloned = a + .into_iter() + .map(|x| x.as_ref().clone()) + .collect::>(); + let b_cloned = b + .into_iter() + .map(|x| x.as_ref().clone()) + .collect::>(); Self::multi_miller_loop(a_cloned, b_cloned) } diff --git a/ec/src/scalar_mul/fixed_base.rs b/ec/src/scalar_mul/fixed_base.rs index c9e5270d0..ce8001ccd 100644 --- a/ec/src/scalar_mul/fixed_base.rs +++ b/ec/src/scalar_mul/fixed_base.rs @@ -95,4 +95,4 @@ impl FixedBase { .map(|e| Self::windowed_mul::(outerc, window, table, e)) .collect::>() } -} \ No newline at end of file +} diff --git a/ec/src/scalar_mul/mod.rs b/ec/src/scalar_mul/mod.rs index 81a4c6595..cb38e432d 100644 --- a/ec/src/scalar_mul/mod.rs +++ b/ec/src/scalar_mul/mod.rs @@ -1,8 +1,8 @@ pub mod glv; pub mod wnaf; -pub mod variable_base; pub mod fixed_base; +pub mod variable_base; use crate::{ short_weierstrass::{Affine, Projective, SWCurveConfig}, diff --git a/ec/src/scalar_mul/variable_base/mod.rs b/ec/src/scalar_mul/variable_base/mod.rs index 9d83dfa53..aca9e1726 100644 --- a/ec/src/scalar_mul/variable_base/mod.rs +++ b/ec/src/scalar_mul/variable_base/mod.rs @@ -1,4 +1,5 @@ use ark_ff::prelude::*; +use ark_ff::biginteger::{S128, S64}; use ark_std::{ borrow::Borrow, cfg_chunks, cfg_into_iter, cfg_iter, @@ -626,6 +627,102 @@ pub fn msm_i128( } } +pub fn msm_s64( + mut bases: &[V::MulBase], + mut scalars: &[S64], + serial: bool, +) -> V { + let (negative_bases, non_negative_bases): (Vec, Vec) = + bases.iter().enumerate().partition_map(|(i, b)| { + if !scalars[i].sign() { + Either::Left(b) + } else { + Either::Right(b) + } + }); + let (negative_scalars, non_negative_scalars): (Vec, Vec) = scalars + .iter() + .partition_map(|s| { + let mag = s.magnitude_as_u64(); + if !s.sign() { + Either::Left(mag) + } else { + Either::Right(mag) + } + }); + if serial { + return msm_serial::(&non_negative_bases, &non_negative_scalars) + - msm_serial::(&negative_bases, &negative_scalars); + } else { + let chunk_size = match preamble(&mut bases, &mut scalars, serial) { + Some(chunk_size) => chunk_size, + None => return V::zero(), + }; + + let non_negative_msm: V = cfg_chunks!(non_negative_bases, chunk_size) + .zip(cfg_chunks!(non_negative_scalars, chunk_size)) + .map(|(non_negative_bases, non_negative_scalars)| { + msm_serial::(non_negative_bases, non_negative_scalars) + }) + .sum(); + let negative_msm: V = cfg_chunks!(negative_bases, chunk_size) + .zip(cfg_chunks!(negative_scalars, chunk_size)) + .map(|(negative_bases, negative_scalars)| { + msm_serial::(negative_bases, negative_scalars) + }) + .sum(); + non_negative_msm - negative_msm + } +} + +pub fn msm_s128( + mut bases: &[V::MulBase], + mut scalars: &[S128], + serial: bool, +) -> V { + let (negative_bases, non_negative_bases): (Vec, Vec) = + bases.iter().enumerate().partition_map(|(i, b)| { + if !scalars[i].sign() { + Either::Left(b) + } else { + Either::Right(b) + } + }); + let (negative_scalars, non_negative_scalars): (Vec, Vec) = scalars + .iter() + .partition_map(|s| { + let mag = s.magnitude_as_u128(); + if !s.sign() { + Either::Left(mag) + } else { + Either::Right(mag) + } + }); + if serial { + return msm_serial::(&non_negative_bases, &non_negative_scalars) + - msm_serial::(&negative_bases, &negative_scalars); + } else { + let chunk_size = match preamble(&mut bases, &mut scalars, serial) { + Some(chunk_size) => chunk_size, + None => return V::zero(), + }; + + let non_negative_msm: V = cfg_chunks!(non_negative_bases, chunk_size) + .zip(cfg_chunks!(non_negative_scalars, chunk_size)) + .map(|(non_negative_bases, non_negative_scalars)| { + msm_serial::(non_negative_bases, non_negative_scalars) + }) + .sum(); + let negative_msm: V = cfg_chunks!(negative_bases, chunk_size) + .zip(cfg_chunks!(negative_scalars, chunk_size)) + .map(|(negative_bases, negative_scalars)| { + msm_serial::(negative_bases, negative_scalars) + }) + .sum(); + non_negative_msm - negative_msm + } +} + pub fn msm_u128( mut bases: &[V::MulBase], mut scalars: &[u128], diff --git a/ff-asm/src/lib.rs b/ff-asm/src/lib.rs index 1833af026..cd38350d6 100644 --- a/ff-asm/src/lib.rs +++ b/ff-asm/src/lib.rs @@ -59,6 +59,7 @@ pub fn x86_64_asm_mul(input: TokenStream) -> TokenStream { } else { panic!("The number of limbs must be a literal"); }; + #[allow(clippy::redundant_comparisons)] if num_limbs <= 6 && num_limbs <= 3 * MAX_REGS { let impl_block = generate_impl(num_limbs, true); @@ -110,6 +111,7 @@ pub fn x86_64_asm_square(input: TokenStream) -> TokenStream { } else { panic!("The number of limbs must be a literal"); }; + #[allow(clippy::redundant_comparisons)] if num_limbs <= 6 && num_limbs <= 3 * MAX_REGS { let impl_block = generate_impl(num_limbs, false); diff --git a/ff/Cargo.toml b/ff/Cargo.toml index dc51b5deb..3d5f0b156 100644 --- a/ff/Cargo.toml +++ b/ff/Cargo.toml @@ -29,7 +29,7 @@ zeroize = { workspace = true, features = ["zeroize_derive"] } num-bigint.workspace = true digest = { workspace = true, features = ["alloc"] } itertools.workspace = true -allocative = { version = "0.3.4", optional = true } +allocative = "0.3.4" [dev-dependencies] ark-test-curves = { workspace = true, features = [ @@ -53,4 +53,3 @@ default = [] std = ["ark-std/std", "ark-serialize/std", "itertools/use_std"] parallel = ["std", "rayon", "ark-std/parallel", "ark-serialize/parallel"] asm = [] -allocative = ["dep:allocative"] diff --git a/ff/src/biginteger/arithmetic.rs b/ff/src/biginteger/arithmetic.rs index ac15a26ae..6f6a67f26 100644 --- a/ff/src/biginteger/arithmetic.rs +++ b/ff/src/biginteger/arithmetic.rs @@ -123,6 +123,37 @@ pub fn mac_discard(a: u64, b: u64, c: u64, carry: &mut u64) { *carry = (tmp >> 64) as u64; } +/// Accumulate `limbs` into an N-limb accumulator starting at `lane_offset` (64-bit lanes), +/// returning the final carry. This is a helper for building wide accumulators. +#[inline(always)] +pub fn add_limbs_shifted_inplace( + acc: &mut [u64; N], + limbs: &[u64], + lane_offset: usize, +) -> u64 { + let mut carry = 0u64; + let mut i = 0usize; + while i < limbs.len() { + let idx = lane_offset + i; + if idx >= N { + break; + } + let tmp = (acc[idx] as u128) + (limbs[i] as u128) + (carry as u128); + acc[idx] = tmp as u64; + carry = (tmp >> 64) as u64; + i += 1; + } + // propagate carry across remaining lanes if any + let mut idx = lane_offset + i; + while carry != 0 && idx < N { + let tmp = (acc[idx] as u128) + (carry as u128); + acc[idx] = tmp as u64; + carry = (tmp >> 64) as u64; + idx += 1; + } + carry +} + macro_rules! mac_with_carry { ($a:expr, $b:expr, $c:expr, &mut $carry:expr$(,)?) => {{ let tmp = diff --git a/ff/src/biginteger/i8_or_i96.rs b/ff/src/biginteger/i8_or_i96.rs new file mode 100644 index 000000000..9d7d1e3b8 --- /dev/null +++ b/ff/src/biginteger/i8_or_i96.rs @@ -0,0 +1,662 @@ +use crate::biginteger::{S160, S224}; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; +use allocative::Allocative; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; + +/// Compact signed integer optimized for the common `i8` case, widening to a 96-bit +/// split representation when needed (low 64 bits in `large_lo`, next 32 bits in `large_hi`). +/// +/// ## Design goals: +/// - Set fields so that this fits in 16 bytes +/// - Encode the vast majority of values as `i8` for space/time locality. +/// - Keep all operations `const fn` so macros and static tables can fold at compile-time. +/// - After every operation, results are canonicalized to the smallest fitting form: +/// if a result fits in `i8`, it is stored in `small_i8`; otherwise it is stored +/// as a 96-bit split in `large_lo`/`large_hi`. +/// +/// ## Layout and Semantics +/// +/// The 96-bit value is stored in two's complement format, split across two fields: +/// - `large_hi: i32`: The upper 32 bits, which includes the sign bit of the 96-bit integer. +/// - `large_lo: u64`: The lower 64 bits, treated as an unsigned block of bits. +/// +/// The full value can be reconstructed using the formula: +/// `value = (large_hi as i128) << 64 | (large_lo as i128)` +/// This is equivalent to sign-extending `large_hi` and zero-extending `large_lo`. +/// +/// ## Notes: +/// - Arithmetic uses `i128` for intermediate computation but the representation is signed 96-bit. +/// Results are canonicalized to the smallest fitting form. If a result does not fit in 96 bits, +/// it is truncated to signed 96-bit two's complement (wrapping modulo 2^96). +/// - The `neg` implementation avoids `i8` overflow by widening `i8::MIN` to the wide form. +/// - Conversions are total: `to_i128()` always returns the exact value. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] +pub struct I8OrI96 { + /// The lower 64 bits of the constant value. + large_lo: u64, + /// The upper 32 (signed) bits above `large_lo` (bits 64..95) + large_hi: i32, + /// Small constants that fit in i8 (-128 to 127) + pub small_i8: i8, + /// Whether the constant value is small (i8) + pub is_small: bool, +} + +impl I8OrI96 { + /// Returns zero encoded as `I8(0)`. + pub const fn zero() -> Self { + I8OrI96 { + large_lo: 0, + large_hi: 0, + is_small: true, + small_i8: 0, + } + } + + /// Returns one encoded as `I8(1)`. + pub const fn one() -> Self { + I8OrI96 { + large_lo: 0, + large_hi: 0, + is_small: true, + small_i8: 1, + } + } + + /// Construct from `i8` without widening. + pub const fn from_i8(value: i8) -> Self { + I8OrI96 { + large_lo: 0, + large_hi: 0, + is_small: true, + small_i8: value, + } + } + + /// Construct from `i128`, canonicalizing to `I8` if it fits. + /// Assumes the value fits in 96 bits (i64 + i32) + pub const fn from_i128(value: i128) -> Self { + if value >= i8::MIN as i128 && value <= i8::MAX as i128 { + I8OrI96 { + large_lo: 0, + large_hi: 0, + is_small: true, + small_i8: value as i8, + } + } else { + // Store as 96-bit signed split: low 64 bits and next 32 bits + I8OrI96 { + large_lo: value as u64, + large_hi: (value >> 64) as i32, + is_small: false, + small_i8: 0, + } + } + } + + /// Mutate in-place from `i8` without widening. Only updates `small_i8` and `is_small`. + #[inline] + pub const fn set_from_i8(&mut self, value: i8) { + self.small_i8 = value; + self.is_small = true; + } + + /// Mutate in-place from `i128`, canonicalizing to `i8` when it fits. + /// Assumes the value fits in 96 bits (i64 + i32). Minimizes writes. + #[inline] + pub const fn set_from_i128(&mut self, value: i128) { + if value >= i8::MIN as i128 && value <= i8::MAX as i128 { + self.small_i8 = value as i8; + self.is_small = true; + } else { + self.large_lo = value as u64; + self.large_hi = (value >> 64) as i32; + self.is_small = false; + } + } + + /// Exact conversion to `i128`. + #[inline] + pub const fn to_i128(&self) -> i128 { + if self.is_small { + self.small_i8 as i128 + } else { + // The `large_lo` (u64) is zero-extended to i128, and `large_hi` (i32) is sign-extended. + // This correctly reconstructs the 96-bit signed value. + (self.large_lo as i128) | ((self.large_hi as i128) << 64) + } + } + + /// Absolute value as unsigned magnitude. + pub const fn unsigned_abs(&self) -> u128 { + let v = self.to_i128(); + v.unsigned_abs() + } + + /// Returns true if the value equals zero. + #[inline] + pub const fn is_zero(&self) -> bool { + if self.is_small { + self.small_i8 == 0 + } else { + self.large_lo == 0 && self.large_hi == 0 + } + } + + /// Returns true if the value is encoded as `I128`. + #[inline] + pub const fn is_large(&self) -> bool { + !self.is_small + } + + /// Add two constants, returning a canonicalized result. + /// + /// Fast-path: if both operands are `I8`, perform `i8` addition directly. + /// If the `i8` addition overflows, it falls back to the `i128` slow path. + #[inline] + pub const fn add(self, other: I8OrI96) -> I8OrI96 { + let mut out = self; + out.add_assign(&other); + out + } + + /// In-place addition assignment: `self = self + other`. + /// Preserves fast path and falls back to `i128` on `i8` overflow. + #[inline] + pub const fn add_assign(&mut self, other: &I8OrI96) { + if self.is_small && other.is_small { + let (sum, overflow) = self.small_i8.overflowing_add(other.small_i8); + if !overflow { + self.set_from_i8(sum); + return; + } + } + let sum = self.to_i128() + other.to_i128(); + self.set_from_i128(sum); + } + + /// Multiply two constants, returning a canonicalized result. + /// + /// Fast-path: if both operands are `I8`, perform `i8` multiplication directly. + /// If `i8` multiplication overflows, it falls back to the `i128` slow path. + #[inline] + pub const fn mul(self, other: I8OrI96) -> I8OrI96 { + let mut out = self; + out.mul_assign(&other); + out + } + + /// In-place multiplication assignment: `self = self * other`. + /// Preserves fast path and falls back to `i128` on `i8` overflow. + #[inline] + pub const fn mul_assign(&mut self, other: &I8OrI96) { + if self.is_small && other.is_small { + let (prod, overflow) = self.small_i8.overflowing_mul(other.small_i8); + if !overflow { + self.set_from_i8(prod); + return; + } + } + let prod = self.to_i128() * other.to_i128(); + self.set_from_i128(prod); + } + + /// Arithmetic negation with canonicalization. + /// + /// Special-cases `I8(i8::MIN)` to avoid overflow by widening to `I128`. + /// In-place arithmetic negation. Preserves `i8::MIN` widening behavior. + #[inline] + pub const fn neg(&mut self) { + if self.is_small { + let v = self.small_i8; + if v == i8::MIN { + self.set_from_i128(-(v as i128)); + } else { + self.set_from_i8(-v); + } + } else { + self.set_from_i128(-self.to_i128()); + } + } + + /// Subtraction returning a new value. Delegates to `sub_assign`. + #[inline] + pub const fn sub(self, other: I8OrI96) -> I8OrI96 { + let mut out = self; + out.sub_assign(&other); + out + } + + /// In-place subtraction assignment: `self = self - other`. + /// Fast-path: if both operands are `I8`, perform `i8` subtraction directly. + /// If `i8` subtraction overflows, it falls back to `i128` slow path. + #[inline] + pub const fn sub_assign(&mut self, other: &I8OrI96) { + if self.is_small && other.is_small { + let (diff, overflow) = self.small_i8.overflowing_sub(other.small_i8); + if !overflow { + self.set_from_i8(diff); + return; + } + } + let diff = self.to_i128() - other.to_i128(); + self.set_from_i128(diff); + } +} + +impl Add for I8OrI96 { + type Output = I8OrI96; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + I8OrI96::add(self, rhs) + } +} + +impl AddAssign for I8OrI96 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.add_assign(&rhs) + } +} + +impl Mul for I8OrI96 { + type Output = I8OrI96; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + I8OrI96::mul(self, rhs) + } +} + +impl MulAssign for I8OrI96 { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + self.mul_assign(&rhs) + } +} + +impl Sub for I8OrI96 { + type Output = I8OrI96; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + I8OrI96::sub(self, rhs) + } +} + +impl SubAssign for I8OrI96 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.sub_assign(&rhs) + } +} + +impl Ord for I8OrI96 { + #[inline] + fn cmp(&self, other: &Self) -> ark_std::cmp::Ordering { + use ark_std::cmp::Ordering; + + // Fast path for when both values are small. + if self.is_small && other.is_small { + return self.small_i8.cmp(&other.small_i8); + } + + // Deconstruct into (hi, lo) parts to perform a 96-bit two's complement comparison. + // If a value is small, we convert it to its large representation on the fly. + let (self_hi, self_lo) = if self.is_small { + let val = self.small_i8 as i128; + ((val >> 64) as i32, val as u64) + } else { + (self.large_hi, self.large_lo) + }; + + let (other_hi, other_lo) = if other.is_small { + let val = other.small_i8 as i128; + ((val >> 64) as i32, val as u64) + } else { + (other.large_hi, other.large_lo) + }; + + // Compare the high parts first. If they differ, that determines the order. + match self_hi.cmp(&other_hi) { + Ordering::Equal => { + // If high parts are the same, the order is determined by the low parts. + self_lo.cmp(&other_lo) + }, + order => order, + } + } +} + +impl PartialOrd for I8OrI96 { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: bool) -> Self { + if value { + I8OrI96::one() + } else { + I8OrI96::zero() + } + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i8) -> Self { + I8OrI96::from_i8(value) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i16) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i32) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i64) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: i128) -> Self { + I8OrI96::from_i128(value) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: isize) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: u8) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: u16) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: u32) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: u64) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +impl From for I8OrI96 { + #[inline] + fn from(value: usize) -> Self { + I8OrI96::from_i128(value as i128) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TryFromU128Error; + +impl core::fmt::Display for TryFromU128Error { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "u128 does not fit in signed 96-bit range") + } +} + +impl core::convert::TryFrom for I8OrI96 { + type Error = TryFromU128Error; + + #[inline] + fn try_from(value: u128) -> Result { + // Signed 96-bit maximum is 2^95 - 1 + const I96_POS_MAX: u128 = (1u128 << 95) - 1; + if value <= i8::MAX as u128 { + Ok(I8OrI96::from_i8(value as i8)) + } else if value <= I96_POS_MAX { + Ok(I8OrI96 { + large_lo: value as u64, + large_hi: (value >> 64) as i32, + is_small: false, + small_i8: 0, + }) + } else { + Err(TryFromU128Error) + } + } +} + +impl Mul for I8OrI96 { + type Output = S224; + + #[inline] + fn mul(self, rhs: S160) -> Self::Output { + // Determine sign of self + let self_is_positive = if self.is_small { + self.small_i8 >= 0 + } else { + self.large_hi >= 0 + }; + + // Extract rhs magnitude limbs + let rhs_lo = rhs.magnitude_lo(); + let b0 = rhs_lo[0]; + let b1 = rhs_lo[1]; + let b2 = rhs.magnitude_hi() as u64; // widen for math + let b1_is_zero = b1 == 0; + let b2_is_zero = b2 == 0; + let rhs_is_zero = (b0 | b1 | b2) == 0; + + if rhs_is_zero { + return S224::zero(); + } + + // Compute absolute magnitude of self as 96-bit split (x0: u64, x1: u32) + let (x0, x1_u32) = if self.is_small { + let v = self.small_i8; + let k = if v < 0 { + (-(v as i16)) as u64 + } else { + v as u64 + }; + (k, 0u32) + } else { + let hi = self.large_hi; + if hi >= 0 { + (self.large_lo, hi as u32) + } else { + // Two's-complement absolute: (~lo, ~hi) + 1 + let inv_lo = !self.large_lo; + let inv_hi = !(hi as u32); + let sum_lo = inv_lo.wrapping_add(1); + let carry = (sum_lo == 0) as u32; + let sum_hi = inv_hi.wrapping_add(carry); + (sum_lo, sum_hi) + } + }; + + // Compute magnitude product truncated to 224 bits: [r0,r1,r2] + hi32 + let (r0, r1, r2, hi32) = if x1_u32 == 0 { + // Fast path: scalar (<= 8-bit) * 160-bit + let k = x0; + let mut c0 = 0u64; + let r0 = mac_with_carry!(0u64, b0, k, &mut c0); + + if b1_is_zero { + if b2_is_zero { + // Only 64-bit rhs + let r1 = c0; + (r0, r1, 0u64, 0u32) + } else { + // 128-bit rhs via b2 only + let r1 = c0; + let mut hi = 0u64; + let r2 = mac_with_carry!(0u64, b2, k, &mut hi); + let hi32 = hi as u32; + (r0, r1, r2, hi32) + } + } else if b2_is_zero { + // 128-bit rhs via b1 only + let mut c1 = c0; + let r1 = mac_with_carry!(0u64, b1, k, &mut c1); + let r2 = c1; + (r0, r1, r2, 0u32) + } else { + // Full 160-bit rhs + let mut c1 = c0; + let r1 = mac_with_carry!(0u64, b1, k, &mut c1); + + let mut c2 = 0u64; + let mut r2 = mac_with_carry!(0u64, b2, k, &mut c2); + r2 = adc!(r2, c1, &mut c2); + let hi32 = c2 as u32; + (r0, r1, r2, hi32) + } + } else { + // General 96-bit (2 limbs: x0, x1_u32) times 160-bit (3 limbs: b0, b1, b2) + let x1 = x1_u32 as u64; + + let mut c0 = 0u64; + let r0 = mac_with_carry!(0u64, x0, b0, &mut c0); + + if b1_is_zero { + if b2_is_zero { + // Only 64-bit rhs + let mut c1 = c0; + let r1 = mac_with_carry!(0u64, x1, b0, &mut c1); + let r2 = c1; + (r0, r1, r2, 0u32) + } else { + // No b1, but have b2 + let mut c1 = c0; + let r1 = mac_with_carry!(0u64, x1, b0, &mut c1); + + let mut c2 = c1; + let r2 = mac_with_carry!(0u64, x0, b2, &mut c2); + + let r3_low = ((c2 as u128) + + crate::biginteger::arithmetic::widening_mul(x1, b2)) as u64; + let hi32 = (r3_low & 0xFFFF_FFFF) as u32; + (r0, r1, r2, hi32) + } + } else if b2_is_zero { + // No b2, but have b1 + let mut c1 = c0; + let mut r1 = mac_with_carry!(0u64, x0, b1, &mut c1); + r1 = mac_with_carry!(r1, x1, b0, &mut c1); + + let mut c2 = c1; + let r2 = mac_with_carry!(0u64, x1, b1, &mut c2); + let hi32 = c2 as u32; + (r0, r1, r2, hi32) + } else { + // Full 160-bit rhs + let mut c1 = c0; + let mut r1 = mac_with_carry!(0u64, x0, b1, &mut c1); + r1 = mac_with_carry!(r1, x1, b0, &mut c1); + + let mut c2 = c1; + let mut r2 = mac_with_carry!(0u64, x0, b2, &mut c2); + r2 = mac_with_carry!(r2, x1, b1, &mut c2); + + let r3_low = ((c2 as u128) + + crate::biginteger::arithmetic::widening_mul(x1, b2)) as u64; + let hi32 = (r3_low & 0xFFFF_FFFF) as u32; + (r0, r1, r2, hi32) + } + }; + + // Combine sign; canonicalize zero to positive + let mut is_positive = !(self_is_positive ^ rhs.is_positive()); + if (r0 | r1 | r2) == 0 && hi32 == 0 { + is_positive = true; + } + + S224::new([r0, r1, r2], hi32, is_positive) + } +} + +// ------------------------------------------------------------------------------------------------ +// Canonical serialization +// ------------------------------------------------------------------------------------------------ + +impl CanonicalSerialize for I8OrI96 { + #[inline] + fn serialize_with_mode( + &self, + mut w: W, + compress: Compress, + ) -> Result<(), SerializationError> { + // Print whether it is small (computed canonically from value), and the numeric value + // as 96-bit two's complement represented by (hi: i32, lo: u64), derived from to_i128(). + let v = self.to_i128(); + let is_small_value = v >= i8::MIN as i128 && v <= i8::MAX as i128; + let is_small_u8: u8 = if is_small_value { 1 } else { 0 }; + is_small_u8.serialize_with_mode(&mut w, compress)?; + + let hi: i32 = (v >> 64) as i32; + let lo: u64 = v as u64; + hi.serialize_with_mode(&mut w, compress)?; + lo.serialize_with_mode(w, compress) + } + + #[inline] + fn serialized_size(&self, compress: Compress) -> usize { + (self.is_small as u8).serialized_size(compress) + + (0i32).serialized_size(compress) + + (0u64).serialized_size(compress) + } +} + +impl CanonicalDeserialize for I8OrI96 { + #[inline] + fn deserialize_with_mode( + mut r: R, + compress: Compress, + validate: Validate, + ) -> Result { + let _is_small = u8::deserialize_with_mode(&mut r, compress, validate)?; + let hi = i32::deserialize_with_mode(&mut r, compress, validate)?; + let lo = u64::deserialize_with_mode(r, compress, validate)?; + let v: i128 = ((hi as i128) << 64) | (lo as i128); + Ok(I8OrI96::from_i128(v)) + } +} + +impl Valid for I8OrI96 { + #[inline] + fn check(&self) -> Result<(), SerializationError> { + // All bit patterns of the struct represent either a small i8 or a 96-bit two's complement value. + // No additional invariants beyond that; always valid. + Ok(()) + } +} diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index 62f7bc658..210f9c433 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -2,6 +2,7 @@ use crate::{ bits::{BitIteratorBE, BitIteratorLE}, const_for, UniformRand, }; +use allocative::Allocative; #[allow(unused)] use ark_ff_macros::unroll_for_loops; use ark_serialize::{ @@ -13,8 +14,8 @@ use ark_std::{ fmt::{Debug, Display, UpperHex}, io::{Read, Write}, ops::{ - BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, ShlAssign, Shr, - ShrAssign, + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, + ShlAssign, Shr, ShrAssign, Sub, SubAssign, }, rand::{ distributions::{Distribution, Standard}, @@ -22,6 +23,7 @@ use ark_std::{ }, str::FromStr, vec::*, + Zero, }; use num_bigint::BigUint; use zeroize::Zeroize; @@ -29,7 +31,16 @@ use zeroize::Zeroize; #[macro_use] pub mod arithmetic; -#[derive(Copy, Clone, PartialEq, Eq, Hash, Zeroize)] +pub mod signed; +pub use signed::{SignedBigInt, S128, S192, S256, S64}; + +pub mod signed_hi_32; +pub use signed_hi_32::{SignedBigIntHi32, S160, S224, S96}; + +pub mod i8_or_i96; +pub use i8_or_i96::I8OrI96; + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Zeroize, Allocative)] pub struct BigInt(pub [u64; N]); impl Default for BigInt { @@ -285,17 +296,176 @@ impl BigInt { /// leading zeros in the most significant limb. #[doc(hidden)] pub const fn num_spare_bits(self) -> u32 { - // Count the leading zeros in the most significant limb - let msb = self.0[N - 1]; - let mut count = 0; - let mut mask = 1u64 << 63; // Start with the highest bit - - while count < 64 && (msb & mask) == 0 { - count += 1; - mask >>= 1; + // Fast path: directly use the intrinsic on the most significant limb + self.0[N - 1].leading_zeros() + } + + /// Truncated-width addition: compute self + other into P limbs. + /// + /// - Semantics: returns the low P limbs of the sum; higher limbs are discarded. + /// - Precondition (debug-only): right operand width M must be <= P. + /// - Debug contract: panics in debug if an addition carry would spill beyond P limbs. + #[inline] + pub fn add_trunc(&self, other: &BigInt) -> BigInt

{ + debug_assert!(M <= P, "add_trunc: right operand wider than result width P"); + let mut acc = BigInt::

::zero(); + let copy_len = core::cmp::min(P, N); + acc.0[..copy_len].copy_from_slice(&self.0[..copy_len]); + acc.add_assign_trunc::(other); + acc + } + + /// Truncated-width subtraction: compute self - other into P limbs. + /// + /// - Semantics: returns the low P limbs of the difference; higher borrow is discarded. + /// - Precondition (debug-only): right operand width M must be <= P. + /// - Debug contract: panics in debug if a borrow would spill beyond P limbs. + #[inline] + pub fn sub_trunc(&self, other: &BigInt) -> BigInt

{ + debug_assert!(M <= P, "sub_trunc: right operand wider than result width P"); + let mut acc = BigInt::

::zero(); + let copy_len = core::cmp::min(P, N); + acc.0[..copy_len].copy_from_slice(&self.0[..copy_len]); + acc.sub_assign_trunc::(other); + acc + } + + /// Truncated-width multiplication: compute self * other and fit into P limbs; overflow is ignored. + #[inline] + pub fn mul_trunc(&self, other: &BigInt) -> BigInt

{ + let mut res = BigInt::

::zero(); + // Use core fused multiply engine specialized on M for unrolling + self.fm_limbs_into::(&other.0, &mut res, false); + res + } + + /// Truncated-width addition that mutates self: self += other, keeping N limbs (self's width). + /// + /// - Semantics: computes (self + other) mod 2^(64*N). + /// - Precondition (debug-only): right operand width M must be <= N. + /// - Debug contract: panics in debug if a carry would spill beyond N limbs. + #[inline] + #[unroll_for_loops(9)] + pub fn add_assign_trunc(&mut self, other: &BigInt) { + debug_assert!( + M <= N, + "add_assign_trunc: right operand wider than self width N" + ); + let mut carry = 0u64; + for i in 0..N { + let rhs = if i < M { other.0[i] } else { 0 }; + self.0[i] = adc!(self.0[i], rhs, &mut carry); + } + debug_assert!( + carry == 0, + "add_assign_trunc overflow: carry beyond N limbs" + ); + } + + /// Truncated-width subtraction that mutates self: self -= other, keeping N limbs (self's width). + /// + /// Semantics: computes (self - other) mod 2^(64*N). + /// Precondition (debug-only): right operand width M must be <= N. + /// Debug contract: panics in debug if a borrow would spill beyond N limbs. + #[inline] + #[unroll_for_loops(9)] + pub fn sub_assign_trunc(&mut self, other: &BigInt) { + debug_assert!( + M <= N, + "sub_assign_trunc: right operand wider than self width N" + ); + let mut borrow = 0u64; + for i in 0..N { + let rhs = if i < M { other.0[i] } else { 0 }; + self.0[i] = sbb!(self.0[i], rhs, &mut borrow); + } + debug_assert!( + borrow == 0, + "sub_assign_trunc underflow: borrow beyond N limbs" + ); + } + + /// Truncated-width multiplication that mutates self: self = (self * other) mod 2^(64*N). + /// Keeps exactly N limbs (self's width). Overflow beyond N limbs is ignored. + #[inline] + pub fn mul_assign_trunc(&mut self, other: &BigInt) { + // Fast paths + if self.is_zero() || other.is_zero() { + for i in 0..N { + self.0[i] = 0; + } + return; } + let left = *self; // snapshot original multiplicand + // zero self to use as accumulator buffer + for i in 0..N { + self.0[i] = 0; + } + // Accumulate left * other directly into self within width N; propagate carries within N + left.fm_limbs_into::(&other.0, self, true); + } - count + /// Fused multiply-add with truncation: acc += self * other, fitting into P limbs; overflow is ignored. + /// This is a generic version for arbitrary limb widths of `self` and `other`. + #[inline] + pub fn fmadd_trunc( + &self, + other: &BigInt, + acc: &mut BigInt

, + ) { + let i_limit = core::cmp::min(N, P); + for i in 0..i_limit { + let mut carry = 0u64; + let j_limit = core::cmp::min(M, P - i); + for j in 0..j_limit { + acc.0[i + j] = mac_with_carry!(acc.0[i + j], self.0[i], other.0[j], &mut carry); + } + if i + j_limit < P { + let (new_val, _of) = acc.0[i + j_limit].overflowing_add(carry); + acc.0[i + j_limit] = new_val; + } + } + } + + /// Accumulate with a compile-time-known count of multiplier limbs M to enable unrolling. + #[inline] + #[unroll_for_loops(10)] + pub(crate) fn fm_limbs_into( + &self, + other_limbs: &[u64; M], + acc: &mut BigInt

, + carry_propagate: bool, + ) { + for j in 0..M { + let mul_limb = other_limbs[j]; + if mul_limb == 0 { + // Skip zero multiplier limb + // (cannot use `continue` here due to unroll macro limitations) + } else { + let base = j; + let mut carry = 0u64; + for i in 0..N { + let idx = base + i; + if idx < P { + acc.0[idx] = mac_with_carry!(acc.0[idx], self.0[i], mul_limb, &mut carry); + } + } + let next = base + N; + if next < P { + let (v, mut of) = acc.0[next].overflowing_add(carry); + acc.0[next] = v; + if carry_propagate && of { + let mut k = next + 1; + while of && k < P { + let (nv, nof) = acc.0[k].overflowing_add(1); + acc.0[k] = nv; + of = nof; + k += 1; + } + } + } + } + } } #[inline] @@ -354,6 +524,38 @@ impl BigInt { crate::const_helpers::R2Buffer::([0u64; N], [0u64; N], 1); const_modulo!(two_pow_n_times_64_square, self) } + + /// Zero-extend a smaller BigInt into BigInt (little-endian limbs). + /// Debug-asserts that M <= N. + #[inline] + pub fn zero_extend_from(smaller: &BigInt) -> BigInt { + debug_assert!( + M <= N, + "cannot zero-extend: source has more limbs than destination" + ); + let mut limbs = [0u64; N]; + let copy_len = if M < N { M } else { N }; + limbs[..copy_len].copy_from_slice(&smaller.0[..copy_len]); + BigInt::(limbs) + } +} + +impl Zero for BigInt { + #[inline] + fn zero() -> Self { + Self::zero() + } + + #[inline] + fn is_zero(&self) -> bool { + self.0.iter().all(|&limb| limb == 0) + } +} + +impl From<[u64; N]> for BigInt { + fn from(limbs: [u64; N]) -> Self { + BigInt(limbs) + } } impl BigInteger for BigInt { @@ -443,6 +645,93 @@ impl BigInteger for BigInt { } } + #[inline] + #[unroll_for_loops(8)] + fn mul_u64_in_place(&mut self, other: u64) { + // special cases for 0 and 1 + // if other == 0 || self.is_zero() { + // *self = Self::zero(); + // return; + // } else if other == 1 { + // return; + // } + // Use the same low-level multiply-accumulate primitive that already + // benefits from x86 optimizations in this crate. + let mut carry = 0u64; + for i in 0..N { + self.0[i] = mac_with_carry!(0u64, self.0[i], other, &mut carry); + } + // Overflow is ignored by contract; assert in debug to catch misuse. + debug_assert!(carry == 0, "Overflow in BigInt::mul_u64_in_place"); + } + + #[inline] + #[unroll_for_loops(8)] + fn mul_u64_w_carry(&self, other: u64) -> BigInt { + // ensure NPLUS1 is the correct size + debug_assert!(NPLUS1 == N + 1); + // special cases for 0 and 1 + // if other == 0 || self.is_zero() { + // return BigInt::::zero(); + // } else if other == 1 { + // let mut res = BigInt::::zero(); + // for i in 0..N { + // res.0[i] = self.0[i]; + // } + // return res; + // } + // Use the same multiply-accumulate primitive and capture the final carry + let mut res = BigInt::::zero(); + let mut carry = 0u64; + for i in 0..N { + res.0[i] = mac_with_carry!(0u64, self.0[i], other, &mut carry); + } + res.0[N] = carry; + res + } + + #[inline] + #[unroll_for_loops(8)] + fn mul_u128_w_carry( + &self, + other: u128, + ) -> BigInt { + // NPLUS1 is N + 1, NPLUS2 is N + 2 + debug_assert!(NPLUS1 == N + 1); + debug_assert!(NPLUS2 == N + 2); + // special cases for 0 and 1 + if other == 0 || self.is_zero() { + return BigInt::::zero(); + } else if other == 1 { + let mut res = BigInt::::zero(); + for i in 0..N { + res.0[i] = self.0[i]; + } + return res; + } + // Split other into two u64s and accumulate directly into the result buffer. + let other_lo = other as u64; + let other_hi = (other >> 64) as u64; + + let mut res = BigInt::::zero(); + + // First pass: res[i] += self[i] * other_lo + let mut carry = 0u64; + for i in 0..N { + res.0[i] = mac_with_carry!(res.0[i], self.0[i], other_lo, &mut carry); + } + res.0[N] = carry; + + // Second pass: res[i+1] += self[i] * other_hi + let mut carry2 = 0u64; + for i in 0..N { + res.0[i + 1] = mac_with_carry!(res.0[i + 1], self.0[i], other_hi, &mut carry2); + } + res.0[N + 1] = carry2; + + res + } + #[inline] fn mul(&self, other: &Self) -> (Self, Self) { if self.is_zero() || other.is_zero() { @@ -538,11 +827,6 @@ impl BigInteger for BigInt { !self.is_odd() } - #[inline] - fn is_zero(&self) -> bool { - self.0.iter().all(|&e| e == 0) - } - #[inline] fn num_bits(&self) -> u32 { let mut ret = N as u32 * 64; @@ -938,6 +1222,71 @@ impl Not for BigInt { } } +// Arithmetic with truncating semantics for BigInt of different widths +// Note: we cannot let the output have arbitrary width due to Rust's type limitation +// So we set the output width to be the same as the width of the left operand +impl Add> for BigInt { + type Output = BigInt; + + fn add(self, rhs: BigInt) -> Self::Output { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.add_trunc::(&rhs) + } +} + +impl Add<&BigInt> for BigInt { + type Output = BigInt; + fn add(self, rhs: &BigInt) -> Self::Output { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.add_trunc::(rhs) + } +} + +impl Sub> for BigInt { + type Output = BigInt; + + fn sub(self, rhs: BigInt) -> Self::Output { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.sub_trunc::(&rhs) + } +} + +impl Sub<&BigInt> for BigInt { + type Output = BigInt; + fn sub(self, rhs: &BigInt) -> Self::Output { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.sub_trunc::(rhs) + } +} + +impl AddAssign> for BigInt { + fn add_assign(&mut self, rhs: BigInt) { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.add_assign_trunc::(&rhs); + } +} + +impl AddAssign<&BigInt> for BigInt { + fn add_assign(&mut self, rhs: &BigInt) { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.add_assign_trunc::(rhs); + } +} + +impl SubAssign> for BigInt { + fn sub_assign(&mut self, rhs: BigInt) { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.sub_assign_trunc::(&rhs); + } +} + +impl SubAssign<&BigInt> for BigInt { + fn sub_assign(&mut self, rhs: &BigInt) { + debug_assert!(N >= M, "right operand cannot be wider than left operand"); + self.sub_assign_trunc::(rhs); + } +} + /// Compute the signed modulo operation on a u64 representation, returning the result. /// If n % modulus > modulus / 2, return modulus - n /// # Example @@ -986,6 +1335,7 @@ pub trait BigInteger: + 'static + UniformRand + Zeroize + + Zero + AsMut<[u64]> + AsRef<[u64]> + From @@ -1110,6 +1460,18 @@ pub trait BigInteger: #[deprecated(since = "0.4.2", note = "please use the operator `<<` instead")] fn muln(&mut self, amt: u32); + /// NEW! Multiplies self by a u64 in place. Overflow is ignored. + fn mul_u64_in_place(&mut self, other: u64); + + /// NEW! Multiplies self by a u64, returning a bigint with one extra limb to hold overflow. + fn mul_u64_w_carry(&self, other: u64) -> BigInt; + + /// NEW! Multiplies self by a u128, returning a bigint with two extra limbs to hold overflow. + fn mul_u128_w_carry( + &self, + other: u128, + ) -> BigInt; + /// Multiplies this [`BigInteger`] by another `BigInteger`, storing the result in `self`. /// Overflow is ignored. /// @@ -1238,17 +1600,6 @@ pub trait BigInteger: /// ``` fn is_even(&self) -> bool; - /// Returns true iff this number is zero. - /// # Example - /// - /// ``` - /// use ark_ff::{biginteger::BigInteger64 as B, BigInteger as _}; - /// - /// let mut zero = B::from(0u64); - /// assert!(zero.is_zero()); - /// ``` - fn is_zero(&self) -> bool; - /// Compute the minimum number of bits needed to encode this number. /// # Example /// ``` diff --git a/ff/src/biginteger/signed.rs b/ff/src/biginteger/signed.rs new file mode 100644 index 000000000..42b20d00a --- /dev/null +++ b/ff/src/biginteger/signed.rs @@ -0,0 +1,745 @@ +use crate::biginteger::{BigInt, BigInteger}; +use allocative::Allocative; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; +use ark_std::Zero; +use core::cmp::Ordering; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +/// A signed big integer using arkworks BigInt for magnitude and a sign bit. +/// +/// Notes: +/// - Zero is not canonicalized: a zero magnitude can be paired with either sign. +/// Structural equality distinguishes `+0` and `-0` (since the sign bit differs). +/// - Ordering treats `+0` and `-0` as equal: comparisons return `Ordering::Equal` when +/// both magnitudes are zero regardless of sign. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] +pub struct SignedBigInt { + pub magnitude: BigInt, + pub is_positive: bool, +} + +impl Default for SignedBigInt { + #[inline] + fn default() -> Self { + Self::zero() + } +} + +impl Zero for SignedBigInt { + #[inline] + fn zero() -> Self { + Self::zero() + } + + #[inline] + fn is_zero(&self) -> bool { + self.magnitude.is_zero() + } +} + +pub type S64 = SignedBigInt<1>; +pub type S128 = SignedBigInt<2>; +pub type S192 = SignedBigInt<3>; +pub type S256 = SignedBigInt<4>; + +impl SignedBigInt { + #[inline] + fn cmp_magnitude_mixed(&self, rhs: &SignedBigInt) -> Ordering { + let max_limbs = if N > M { N } else { M }; + let mut i = max_limbs; + while i > 0 { + let idx = i - 1; + let a = if idx < N { self.magnitude.0[idx] } else { 0u64 }; + let b = if idx < M { rhs.magnitude.0[idx] } else { 0u64 }; + if a > b { + return Ordering::Greater; + } + if a < b { + return Ordering::Less; + } + i -= 1; + } + Ordering::Equal + } + /// Construct from limbs and sign; limbs are little-endian. + #[inline] + pub fn new(limbs: [u64; N], is_positive: bool) -> Self { + Self { + magnitude: BigInt::new(limbs), + is_positive, + } + } + + /// Construct from an existing BigInt magnitude and sign. + #[inline] + pub fn from_bigint(magnitude: BigInt, is_positive: bool) -> Self { + Self { + magnitude, + is_positive, + } + } + + /// Zero value with a positive sign (negative zero allowed elsewhere). + #[inline] + pub fn zero() -> Self { + Self { + magnitude: BigInt::from(0u64), + is_positive: true, + } + } + + /// One with a positive sign. + #[inline] + pub fn one() -> Self { + Self { + magnitude: BigInt::from(1u64), + is_positive: true, + } + } + + /// Borrow the magnitude (absolute value). + #[inline] + pub fn as_magnitude(&self) -> &BigInt { + &self.magnitude + } + + /// Return the magnitude limbs by value (copy). + #[inline] + pub fn magnitude_limbs(&self) -> [u64; N] { + self.magnitude.0 + } + + /// Borrow the magnitude limbs as a slice (avoids copying the array). + #[inline] + pub fn magnitude_slice(&self) -> &[u64] { + self.magnitude.as_ref() + } + + /// Return true iff the value is non-negative. + #[inline] + pub fn sign(&self) -> bool { + self.is_positive + } + + /// Compute self + other modulo 2^(64*N); carry beyond N limbs is dropped. + #[inline] + pub fn add(mut self, other: Self) -> Self { + self += other; + self + } + + /// Compute self - other modulo 2^(64*N); borrow beyond N limbs is dropped. + #[inline] + pub fn sub(mut self, other: Self) -> Self { + self -= other; + self + } + + /// Compute self * other and keep only the low N limbs; high limbs are discarded. + #[inline] + pub fn mul(mut self, other: Self) -> Self { + self *= other; + self + } + + /// Flip the sign; zero is not canonicalized (negative zero may occur). + #[inline] + pub fn neg(self) -> Self { + Self::from_bigint(self.magnitude, !self.is_positive) + } + + // ===== in-place helpers ===== + /// In-place addition with sign handling; drops overflow beyond N limbs. + #[inline(always)] + fn add_assign_in_place(&mut self, rhs: &Self) { + if self.is_positive == rhs.is_positive { + let _carry = self.magnitude.add_with_carry(&rhs.magnitude); + // overflow ignored by design + } else { + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let _borrow = self.magnitude.sub_with_borrow(&rhs.magnitude); + }, + Ordering::Less => { + // Minimize copies: move rhs magnitude into place and subtract old self + let old = core::mem::replace(&mut self.magnitude, rhs.magnitude); + let _borrow = self.magnitude.sub_with_borrow(&old); + self.is_positive = rhs.is_positive; + }, + } + } + } + + /// In-place subtraction with sign handling; drops borrow beyond N limbs. + #[inline(always)] + fn sub_assign_in_place(&mut self, rhs: &Self) { + // Implement directly to avoid temporary construction + if self.is_positive != rhs.is_positive { + // Signs differ -> add magnitudes; sign remains self.is_positive + let _carry = self.magnitude.add_with_carry(&rhs.magnitude); + } else { + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let _borrow = self.magnitude.sub_with_borrow(&rhs.magnitude); + // sign stays the same + }, + Ordering::Less => { + // Result takes rhs magnitude minus self magnitude, sign flips + let old = core::mem::replace(&mut self.magnitude, rhs.magnitude); + let _borrow = self.magnitude.sub_with_borrow(&old); + self.is_positive = !self.is_positive; + }, + } + } + } + + /// In-place multiply using low-limb product only; updates sign, discards high limbs. + #[inline(always)] + fn mul_assign_in_place(&mut self, rhs: &Self) { + let low = self.magnitude.mul_low(&rhs.magnitude); + self.magnitude = low; + self.is_positive = self.is_positive == rhs.is_positive; + } + + /// Zero-extend a smaller-width signed big integer into N limbs (little-endian). + /// Preserves the sign bit; only the magnitude is widened by zero-extension. + /// Debug-asserts that M <= N. + #[inline] + pub fn zero_extend_from(smaller: &SignedBigInt) -> SignedBigInt { + debug_assert!(M <= N, "cannot zero-extend: source has more limbs than destination"); + let widened_mag = BigInt::::zero_extend_from::(&smaller.magnitude); + SignedBigInt::from_bigint(widened_mag, smaller.is_positive) + } +} + +impl SignedBigInt { + // ===== truncated-width operations ===== + + /// Truncated add: compute (self + rhs) and fit into M limbs; overflow is ignored. + #[inline] + pub fn add_trunc(&self, rhs: &SignedBigInt) -> SignedBigInt { + if self.is_positive == rhs.is_positive { + let mag = self.magnitude.add_trunc::(&rhs.magnitude); + return SignedBigInt:: { magnitude: mag, is_positive: self.is_positive }; + } + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let mag = self.magnitude.sub_trunc::(&rhs.magnitude); + SignedBigInt:: { magnitude: mag, is_positive: self.is_positive } + }, + Ordering::Less => { + let mag = rhs.magnitude.sub_trunc::(&self.magnitude); + SignedBigInt:: { magnitude: mag, is_positive: rhs.is_positive } + }, + } + } + + /// Truncated sub: compute (self - rhs) and fit into M limbs; overflow is ignored. + #[inline] + pub fn sub_trunc(&self, rhs: &SignedBigInt) -> SignedBigInt { + if self.is_positive != rhs.is_positive { + let mag = self.magnitude.add_trunc::(&rhs.magnitude); + return SignedBigInt:: { magnitude: mag, is_positive: self.is_positive }; + } + match self.magnitude.cmp(&rhs.magnitude) { + Ordering::Greater | Ordering::Equal => { + let mag = self.magnitude.sub_trunc::(&rhs.magnitude); + SignedBigInt:: { magnitude: mag, is_positive: self.is_positive } + }, + Ordering::Less => { + let mag = rhs.magnitude.sub_trunc::(&self.magnitude); + SignedBigInt:: { magnitude: mag, is_positive: !self.is_positive } + }, + } + } + + /// Truncated mixed-width addition: compute (self + rhs) where rhs can have a + /// different limb count, and fit into P limbs; overflow is ignored. + #[inline] + pub fn add_trunc_mixed( + &self, + rhs: &SignedBigInt, + ) -> SignedBigInt

{ + if self.is_positive == rhs.is_positive { + let mag = self.magnitude.add_trunc::(&rhs.magnitude); + return SignedBigInt::

{ magnitude: mag, is_positive: self.is_positive }; + } + match self.cmp_magnitude_mixed(rhs) { + Ordering::Greater | Ordering::Equal => { + let mag = self.magnitude.sub_trunc::(&rhs.magnitude); + SignedBigInt::

{ magnitude: mag, is_positive: self.is_positive } + }, + Ordering::Less => { + let mag = rhs.magnitude.sub_trunc::(&self.magnitude); + SignedBigInt::

{ magnitude: mag, is_positive: rhs.is_positive } + }, + } + } + + /// Truncated mul: compute self * rhs and fit into P limbs; no assumption on P; overflow ignored. + #[inline] + pub fn mul_trunc( + &self, + rhs: &SignedBigInt, + ) -> SignedBigInt

{ + let mag = self.magnitude.mul_trunc::(&rhs.magnitude); + let sign = self.is_positive == rhs.is_positive; + SignedBigInt::

{ + magnitude: mag, + is_positive: sign, + } + } + + /// Fused multiply-add: acc += self * rhs, fitted into P limbs; overflow is ignored. + #[inline] + pub fn fmadd_trunc( + &self, + rhs: &SignedBigInt, + acc: &mut SignedBigInt

, + ) { + let prod_mag = self.magnitude.mul_trunc::(&rhs.magnitude); + let prod_sign = self.is_positive == rhs.is_positive; + if acc.is_positive == prod_sign { + let _ = acc.magnitude.add_with_carry(&prod_mag); + } else { + match acc.magnitude.cmp(&prod_mag) { + Ordering::Greater | Ordering::Equal => { + let _ = acc.magnitude.sub_with_borrow(&prod_mag); + }, + Ordering::Less => { + let old = core::mem::replace(&mut acc.magnitude, prod_mag); + let _ = acc.magnitude.sub_with_borrow(&old); + acc.is_positive = prod_sign; + }, + } + } + } +} + +impl SignedBigInt { + // ===== generic conversions ===== + + /// Construct from u64 with positive sign. + #[inline] + pub fn from_u64(value: u64) -> Self { + Self::from_bigint(BigInt::from(value), true) + } + + /// Construct from (u64, sign); sign=true is non-negative. + #[inline] + pub fn from_u64_with_sign(value: u64, is_positive: bool) -> Self { + Self::from_bigint(BigInt::from(value), is_positive) + } + + /// Construct from i64; magnitude is |value|, sign reflects value>=0. + #[inline] + pub fn from_i64(value: i64) -> Self { + if value >= 0 { + Self::from_bigint(BigInt::from(value as u64), true) + } else { + // wrapping_neg handles i64::MIN + Self::from_bigint(BigInt::from(value.wrapping_neg() as u64), false) + } + } + + /// Construct from u128 with positive sign (N must be >= 2 in debug builds). + #[inline] + pub fn from_u128(value: u128) -> Self { + debug_assert!(N >= 2, "from_u128 requires at least 2 limbs"); + Self::from_bigint(BigInt::from(value), true) + } + + /// Construct from i128; magnitude is |value|, sign reflects value>=0 (N must be >= 2 in debug builds). + #[inline] + pub fn from_i128(value: i128) -> Self { + debug_assert!(N >= 2, "from_i128 requires at least 2 limbs"); + if value >= 0 { + Self::from_bigint(BigInt::from(value as u128), true) + } else { + let mag = (value as i128).unsigned_abs(); + Self::from_bigint(BigInt::from(mag), false) + } + } + + /// Truncated mixed-width subtraction: compute (self - rhs) where rhs can have a + /// different limb count, and fit into P limbs; overflow is ignored. + #[inline] + pub fn sub_trunc_mixed( + &self, + rhs: &SignedBigInt, + ) -> SignedBigInt

{ + if self.is_positive != rhs.is_positive { + let mag = self.magnitude.add_trunc::(&rhs.magnitude); + return SignedBigInt::

{ magnitude: mag, is_positive: self.is_positive }; + } + match self.cmp_magnitude_mixed(rhs) { + Ordering::Greater | Ordering::Equal => { + let mag = self.magnitude.sub_trunc::(&rhs.magnitude); + SignedBigInt::

{ magnitude: mag, is_positive: self.is_positive } + }, + Ordering::Less => { + let mag = rhs.magnitude.sub_trunc::(&self.magnitude); + SignedBigInt::

{ magnitude: mag, is_positive: !self.is_positive } + }, + } + } +} + +impl From for SignedBigInt { + /// From: positive sign; higher limbs are zeroed. + #[inline] + fn from(value: u64) -> Self { + Self::from_u64(value) + } +} + +impl From for SignedBigInt { + /// From: sign from value; magnitude is |value|; higher limbs are zeroed. + #[inline] + fn from(value: i64) -> Self { + Self::from_i64(value) + } +} + +impl From<(u64, bool)> for SignedBigInt { + /// From<(u64,bool)>: (magnitude, is_positive); higher limbs are zeroed. + #[inline] + fn from(value_and_sign: (u64, bool)) -> Self { + Self::from_u64_with_sign(value_and_sign.0, value_and_sign.1) + } +} + +impl From for SignedBigInt { + /// From: positive sign; debug-assert N >= 2; higher limbs are zeroed. + #[inline] + fn from(value: u128) -> Self { + debug_assert!(N >= 2, "From requires at least 2 limbs"); + Self::from_u128(value) + } +} + +impl From for SignedBigInt { + /// From: sign from value; debug-assert N >= 2; magnitude is |value|. + #[inline] + fn from(value: i128) -> Self { + debug_assert!(N >= 2, "From requires at least 2 limbs"); + Self::from_i128(value) + } +} + +// Specializations for common sizes +impl S64 { + /// Convert to i128; any u64 magnitude fits for both signs. + #[inline] + pub fn to_i128(&self) -> i128 { + let magnitude = self.magnitude.0[0]; + if self.is_positive { + magnitude as i128 + } else { + -(magnitude as i128) + } + } + + /// Return the magnitude as u64 + #[inline] + pub fn magnitude_as_u64(&self) -> u64 { + self.magnitude.0[0] + } +} + +impl S128 { + /// Convert to i128 using 2^127 bounds: positive requires mag <= i128::MAX; negative allows mag == 2^127. + #[inline] + pub fn to_i128(&self) -> Option { + let hi = self.magnitude.0[1]; + let lo = self.magnitude.0[0]; + let hi_top_bit = hi >> 63; // bit 127 + if self.is_positive { + if hi_top_bit != 0 { + return None; + } + let mag = ((hi as u128) << 64) | (lo as u128); + Some(mag as i128) + } else { + if hi_top_bit == 0 { + let mag = ((hi as u128) << 64) | (lo as u128); + Some(-(mag as i128)) + } else if hi == (1u64 << 63) && lo == 0 { + Some(i128::MIN) + } else { + None + } + } + } + + /// Return the magnitude as u128 + #[inline] + pub fn magnitude_as_u128(&self) -> u128 { + (self.magnitude.0[1] as u128) << 64 | (self.magnitude.0[0] as u128) + } + + /// Construct from u128 and sign + #[inline] + pub fn from_u128_and_sign(value: u128, is_positive: bool) -> Self { + Self::new([value as u64, (value >> 64) as u64], is_positive) + } + + /// Exact product of u64 and i64 into S128 (u64 × s64 -> s128) + #[inline] + pub fn from_u64_mul_i64(u: u64, s: i64) -> Self { + let mag = (u as u128) * (s.unsigned_abs() as u128); + Self::from_u128_and_sign(mag, s >= 0) + } + + /// Exact product of i64 and u64 into S128 (s64 × u64 -> s128) + #[inline] + pub fn from_i64_mul_u64(s: i64, u: u64) -> Self { + Self::from_u64_mul_i64(u, s) + } + + /// Exact product of two u64 into S128 (u64 × u64 -> s128, non-negative) + #[inline] + pub fn from_u64_mul_u64(a: u64, b: u64) -> Self { + let mag = (a as u128) * (b as u128); + Self::from_u128_and_sign(mag, true) + } + + /// Exact product of two i64 into S128 (s64 × s64 -> s128) + #[inline] + pub fn from_i64_mul_i64(a: i64, b: i64) -> Self { + let mag = (a.unsigned_abs() as u128) * (b.unsigned_abs() as u128); + let is_positive = (a >= 0) == (b >= 0); + Self::from_u128_and_sign(mag, is_positive) + } +} + +/// Helper function for single u64 signed arithmetic +/// Adds two signed u64 values (given as magnitude+sign) modulo 2^64; returns (magnitude, sign). +#[inline] +pub fn add_with_sign_u64(a_mag: u64, a_pos: bool, b_mag: u64, b_pos: bool) -> (u64, bool) { + let a = SignedBigInt::<1>::from_u64_with_sign(a_mag, a_pos); + let b = SignedBigInt::<1>::from_u64_with_sign(b_mag, b_pos); + let result = a + b; + (result.magnitude.0[0], result.is_positive) +} + +// =============================================== +// Standard operator trait implementations +// =============================================== + +impl Add for SignedBigInt { + type Output = Self; + + #[inline] + fn add(mut self, rhs: Self) -> Self::Output { + self.add_assign_in_place(&rhs); + self + } +} + +impl Sub for SignedBigInt { + type Output = Self; + + #[inline] + fn sub(mut self, rhs: Self) -> Self::Output { + self.sub_assign_in_place(&rhs); + self + } +} + +impl Mul for SignedBigInt { + type Output = Self; + + #[inline] + fn mul(mut self, rhs: Self) -> Self::Output { + self.mul_assign_in_place(&rhs); + self + } +} + +impl Neg for SignedBigInt { + type Output = Self; + + #[inline] + fn neg(self) -> Self::Output { + SignedBigInt::neg(self) + } +} + +impl AddAssign for SignedBigInt { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.add_assign_in_place(&rhs); + } +} + +impl SubAssign for SignedBigInt { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.sub_assign_in_place(&rhs); + } +} + +impl MulAssign for SignedBigInt { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + self.mul_assign_in_place(&rhs); + } +} + +// Reference variants for efficiency +impl Add<&SignedBigInt> for SignedBigInt { + type Output = SignedBigInt; + + #[inline] + fn add(mut self, rhs: &SignedBigInt) -> Self::Output { + self.add_assign_in_place(rhs); + self + } +} + +impl Sub<&SignedBigInt> for SignedBigInt { + type Output = SignedBigInt; + + #[inline] + fn sub(mut self, rhs: &SignedBigInt) -> Self::Output { + self.sub_assign_in_place(rhs); + self + } +} + +impl Mul<&SignedBigInt> for SignedBigInt { + type Output = SignedBigInt; + + #[inline] + fn mul(mut self, rhs: &SignedBigInt) -> Self::Output { + self.mul_assign_in_place(rhs); + self + } +} + +impl AddAssign<&SignedBigInt> for SignedBigInt { + #[inline] + fn add_assign(&mut self, rhs: &SignedBigInt) { + self.add_assign_in_place(rhs); + } +} + +impl SubAssign<&SignedBigInt> for SignedBigInt { + #[inline] + fn sub_assign(&mut self, rhs: &SignedBigInt) { + self.sub_assign_in_place(rhs); + } +} + +impl MulAssign<&SignedBigInt> for SignedBigInt { + #[inline] + fn mul_assign(&mut self, rhs: &SignedBigInt) { + self.mul_assign_in_place(rhs); + } +} + +// By-ref binary operator variants to avoid copying both operands +impl core::ops::Add for &SignedBigInt { + type Output = SignedBigInt; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.add_assign_in_place(rhs); + out + } +} + +impl core::ops::Sub for &SignedBigInt { + type Output = SignedBigInt; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.sub_assign_in_place(rhs); + out + } +} + +impl core::ops::Mul for &SignedBigInt { + type Output = SignedBigInt; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.mul_assign_in_place(rhs); + out + } +} + +// =============================================== +// Ordering and canonical serialization +// =============================================== + +impl core::cmp::PartialOrd for SignedBigInt { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl core::cmp::Ord for SignedBigInt { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + // Treat +0 and -0 as equal in ordering semantics + if self.magnitude.is_zero() && other.magnitude.is_zero() { + return Ordering::Equal; + } + match (self.is_positive, other.is_positive) { + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + _ => { + let ord = self.magnitude.cmp(&other.magnitude); + if self.is_positive { ord } else { ord.reverse() } + }, + } + } +} + +impl CanonicalSerialize for SignedBigInt { + #[inline] + fn serialize_with_mode( + &self, + mut w: W, + compress: Compress, + ) -> Result<(), SerializationError> { + // encode sign as a single byte then magnitude + (self.is_positive as u8).serialize_with_mode(&mut w, compress)?; + self.magnitude.serialize_with_mode(w, compress) + } + + #[inline] + fn serialized_size(&self, compress: Compress) -> usize { + (self.is_positive as u8).serialized_size(compress) + + self.magnitude.serialized_size(compress) + } +} + +impl CanonicalDeserialize for SignedBigInt { + #[inline] + fn deserialize_with_mode( + mut r: R, + compress: Compress, + validate: Validate, + ) -> Result { + let sign_u8 = u8::deserialize_with_mode(&mut r, compress, validate)?; + let mag = BigInt::::deserialize_with_mode(r, compress, validate)?; + Ok(SignedBigInt { + magnitude: mag, + is_positive: sign_u8 != 0, + }) + } +} + +impl Valid for SignedBigInt { + #[inline] + fn check(&self) -> Result<(), SerializationError> { + self.magnitude.check() + } +} diff --git a/ff/src/biginteger/signed_hi_32.rs b/ff/src/biginteger/signed_hi_32.rs new file mode 100644 index 000000000..ae41fd37b --- /dev/null +++ b/ff/src/biginteger/signed_hi_32.rs @@ -0,0 +1,762 @@ +use allocative::Allocative; +use ark_std::cmp::Ordering; +use ark_std::vec::Vec; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use crate::biginteger::{BigInt, SignedBigInt, S64, S128}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; + +/// Compact signed big-integer parameterized by limb count `N` (total width = `N*64 + 32` bits). +/// +/// Representation (sign-magnitude): +/// - `magnitude_lo: [u64; N]` holds the low limbs in little-endian order (index 0 is least significant). +/// - `magnitude_hi: u32` holds the high 32-bit tail of the magnitude. +/// - `is_positive: bool` is the sign flag. The magnitude stores the absolute value. +/// +/// Arithmetic semantics: +/// - Addition, subtraction, and multiplication operate on magnitudes modulo `2^(64*N + 32)` +/// and then set the sign via standard sign rules. +/// - Zero is not normalized: a zero magnitude can be paired with either sign. Equality is structural, +/// so `+0 != -0`. Callers that require canonical zero should normalize externally. +/// +/// Notes: +/// - Zero is not normalized: a zero magnitude can be positive or negative. Structural equality +/// distinguishes `+0` and `-0`, but ordering treats them as equal. +/// - Specialized fast paths exist for `N ∈ {0,1,2}`; larger `N` uses a generic path. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Allocative)] +pub struct SignedBigIntHi32 { + /// Little-endian low limbs: limb 0 = low 64 bits, limb 1 = next 64 bits, and so on + magnitude_lo: [u64; N], + /// Top 32 bits + magnitude_hi: u32, + /// Whether the value is non-negative + is_positive: bool, +} + +pub type S96 = SignedBigIntHi32<1>; +pub type S160 = SignedBigIntHi32<2>; +pub type S224 = SignedBigIntHi32<3>; + +// ------------------------------------------------------------------------------------------------ +// Implementation +// ------------------------------------------------------------------------------------------------ + +impl SignedBigIntHi32 { + /// Creates a new `SignedBigIntHi32`. + /// + /// The sign is not normalized: a zero magnitude can be positive or negative. + pub const fn new(magnitude_lo: [u64; N], magnitude_hi: u32, is_positive: bool) -> Self { + Self { + magnitude_lo, + magnitude_hi, + is_positive, + } + } + + /// Returns the value `0`. + pub const fn zero() -> Self { + Self { + magnitude_lo: [0; N], + magnitude_hi: 0, + is_positive: true, + } + } + + /// Returns the value `1`. + pub fn one() -> Self { + let mut magnitude_lo = [0; N]; + let magnitude_hi; + + if N == 0 { + magnitude_hi = 1; + } else { + magnitude_lo[0] = 1; + magnitude_hi = 0; + } + + Self { + magnitude_lo, + magnitude_hi, + is_positive: true, + } + } + + // ------------------------------------------------------------------------------------------------ + // Accessors + // ------------------------------------------------------------------------------------------------ + + /// Returns the low limbs of the magnitude. + pub const fn magnitude_lo(&self) -> &[u64; N] { + &self.magnitude_lo + } + + /// Returns the high 32 bits of the magnitude. + pub const fn magnitude_hi(&self) -> u32 { + self.magnitude_hi + } + + /// Returns the sign flag (`true` for a positive sign). + /// Note: zero is not canonicalized; a zero magnitude can have either sign. + pub const fn is_positive(&self) -> bool { + self.is_positive + } + + /// Returns `true` if the number is zero. + pub const fn is_zero(&self) -> bool { + let mut lo_is_zero = true; + let mut i = 0; + while i < N { + if self.magnitude_lo[i] != 0 { + lo_is_zero = false; + break; + } + i += 1; + } + self.magnitude_hi == 0 && lo_is_zero + } + + // ------------------------------------------------------------------------------------------------ + // Private arithmetic helpers + // ------------------------------------------------------------------------------------------------ + + fn compare_magnitudes(&self, other: &Self) -> Ordering { + if self.magnitude_hi != other.magnitude_hi { + return self.magnitude_hi.cmp(&other.magnitude_hi); + } + for i in (0..N).rev() { + if self.magnitude_lo[i] != other.magnitude_lo[i] { + return self.magnitude_lo[i].cmp(&other.magnitude_lo[i]); + } + } + Ordering::Equal + } + + fn add_assign_in_place(&mut self, rhs: &Self) { + if self.is_positive == rhs.is_positive { + let (lo, hi, _carry) = self.add_magnitudes_with_carry(rhs); + self.magnitude_lo = lo; + self.magnitude_hi = hi; + } else { + match self.compare_magnitudes(rhs) { + Ordering::Greater | Ordering::Equal => { + let (lo, hi, _borrow) = self.sub_magnitudes_with_borrow(rhs); + self.magnitude_lo = lo; + self.magnitude_hi = hi; + }, + Ordering::Less => { + let (lo, hi, _borrow) = rhs.sub_magnitudes_with_borrow(self); + self.magnitude_lo = lo; + self.magnitude_hi = hi; + self.is_positive = rhs.is_positive; + }, + } + } + } + + fn sub_assign_in_place(&mut self, rhs: &Self) { + let neg_rhs = -*rhs; + self.add_assign_in_place(&neg_rhs); + } + + fn mul_magnitudes(&self, other: &Self) -> ([u64; N], u32) { + // Fast paths for small N to avoid heap allocation and loops + if N == 0 { + let a2 = self.magnitude_hi as u64; + let b2 = other.magnitude_hi as u64; + let prod = a2.wrapping_mul(b2); + let hi = (prod & 0xFFFF_FFFF) as u32; + let lo: [u64; N] = [0u64; N]; + return (lo, hi); + } + + if N == 1 { + let a0 = self.magnitude_lo[0]; + let a1 = self.magnitude_hi as u64; // 32-bit value widened + let b0 = other.magnitude_lo[0]; + let b1 = other.magnitude_hi as u64; // 32-bit value widened + + let t0 = (a0 as u128) * (b0 as u128); + let lo0 = t0 as u64; + + let cross = (t0 >> 64) + (a0 as u128) * (b1 as u128) + (a1 as u128) * (b0 as u128); + + let hi = (cross as u64 & 0xFFFF_FFFF) as u32; + let mut lo = [0u64; N]; + lo[0] = lo0; + return (lo, hi); + } + + if N == 2 { + let a0 = self.magnitude_lo[0]; + let a1 = self.magnitude_lo[1]; + let a2 = self.magnitude_hi as u64; // 32-bit value widened + let b0 = other.magnitude_lo[0]; + let b1 = other.magnitude_lo[1]; + let b2 = other.magnitude_hi as u64; // 32-bit value widened + + // word 0 + let t0 = (a0 as u128) * (b0 as u128); + let r0 = t0 as u64; + let carry0 = t0 >> 64; + + // word 1 + let sum1 = carry0 + (a0 as u128) * (b1 as u128) + (a1 as u128) * (b0 as u128); + let r1 = sum1 as u64; + let carry1 = sum1 >> 64; + + // word 2 + let sum2 = carry1 + + (a0 as u128) * (b2 as u128) + + (a1 as u128) * (b1 as u128) + + (a2 as u128) * (b0 as u128); + let r2 = sum2 as u64; + let _carry2 = (sum2 >> 64) as u64; + + // For a 160-bit result, the head (bits 128..159) is the low 32 bits of word2. + let hi = (r2 & 0xFFFF_FFFF) as u32; + let mut lo = [0u64; N]; + lo[0] = r0; + lo[1] = r1; + return (lo, hi); + } + + // General path + // Product of (N*64 + 32)-bit numbers fits in (2*N*64 + 64) bits. + // Allocate 2*N + 2 u64 limbs to safely propagate carries; we'll truncate to N u64 + 32 bits. + let mut prod = vec![0u64; 2 * N + 2]; + + let self_limbs: Vec = self + .magnitude_lo + .iter() + .cloned() + .chain(core::iter::once(self.magnitude_hi as u64)) + .collect(); + + let other_limbs: Vec = other + .magnitude_lo + .iter() + .cloned() + .chain(core::iter::once(other.magnitude_hi as u64)) + .collect(); + + for i in 0..self_limbs.len() { + let mut carry: u128 = 0; + for j in 0..other_limbs.len() { + let idx = i + j; + let p = (self_limbs[i] as u128) * (other_limbs[j] as u128) + + (prod[idx] as u128) + + carry; + prod[idx] = p as u64; + carry = p >> 64; + } + if carry > 0 { + let spill = i + other_limbs.len(); + if spill < prod.len() { + prod[spill] = prod[spill].wrapping_add(carry as u64); + } + // else: spill is beyond the truncated width; ignore (mod 2^(64*N+32)). + } + } + + // Truncate and split into lo and hi (keep only the low N u64 limbs and the low 32 bits of limb N) + let mut magnitude_lo = [0u64; N]; + if N > 0 { + magnitude_lo.copy_from_slice(&prod[0..N]); + } + let magnitude_hi = (prod[N] & 0xFFFF_FFFF) as u32; + + (magnitude_lo, magnitude_hi) + } + + // Returns final carry bit. + fn add_magnitudes_with_carry(&self, other: &Self) -> ([u64; N], u32, bool) { + let mut magnitude_lo = [0; N]; + let mut carry: u128 = 0; + + for i in 0..N { + let sum = (self.magnitude_lo[i] as u128) + (other.magnitude_lo[i] as u128) + carry; + magnitude_lo[i] = sum as u64; + carry = sum >> 64; + } + + let sum_hi = (self.magnitude_hi as u128) + (other.magnitude_hi as u128) + carry; + let magnitude_hi = sum_hi as u32; + + let final_carry = (sum_hi >> 32) != 0; + (magnitude_lo, magnitude_hi, final_carry) + } + + // Returns final borrow bit. + fn sub_magnitudes_with_borrow(&self, other: &Self) -> ([u64; N], u32, bool) { + let mut magnitude_lo = [0u64; N]; + let mut borrow = false; + + for i in 0..N { + let (d1, b1) = self.magnitude_lo[i].overflowing_sub(other.magnitude_lo[i]); + let (d2, b2) = d1.overflowing_sub(borrow as u64); + magnitude_lo[i] = d2; + borrow = b1 || b2; + } + + let (hi1, b1) = self.magnitude_hi.overflowing_sub(other.magnitude_hi); + let (hi2, b2) = hi1.overflowing_sub(borrow as u32); + let final_borrow = b1 || b2; + + (magnitude_lo, hi2, final_borrow) + } + + /// Return the unsigned magnitude as a BigInt with N+1 limbs (little-endian), + /// packing `magnitude_lo` followed by `magnitude_hi` (widened to u64). + /// This ignores the sign; pair with `is_positive()` if you need a signed value. + #[inline] + pub fn magnitude_as_bigint_nplus1(&self) -> BigInt { + debug_assert!(NPLUS1 == N + 1, "NPLUS1 must be N+1 for SignedBigIntHi32 magnitude pack"); + let mut limbs = [0u64; NPLUS1]; + if N > 0 { + limbs[..N].copy_from_slice(&self.magnitude_lo); + } + limbs[N] = self.magnitude_hi as u64; + BigInt::(limbs) + } + + /// Zero-extend a smaller-width SignedBigIntHi32 into width N (little-endian). + /// Moves the 32-bit head of the smaller value into the next low 64-bit limb on widen, + /// and clears the head in the widened representation to preserve the numeric value. + /// Debug-asserts that M <= N. + #[inline] + pub fn zero_extend_from(smaller: &SignedBigIntHi32) -> SignedBigIntHi32 { + debug_assert!(M <= N, "cannot zero-extend: source has more limbs than destination"); + if N == M { + return SignedBigIntHi32::::new( + // copy to avoid borrowing issues + { + let mut lo = [0u64; N]; + if N > 0 { + lo.copy_from_slice(smaller.magnitude_lo()); + } + lo + }, + smaller.magnitude_hi(), + smaller.is_positive(), + ); + } + // N > M + let mut lo = [0u64; N]; + if M > 0 { + lo[..M].copy_from_slice(smaller.magnitude_lo()); + } + // Place the 32-bit head into limb M + lo[M] = smaller.magnitude_hi() as u64; + SignedBigIntHi32::::new(lo, 0u32, smaller.is_positive()) + } + + /// Convert this hi-32 representation into a standard SignedBigInt with N+1 limbs. + /// Packs the low limbs verbatim and writes the 32-bit head into the highest limb. + /// Debug-asserts that NPLUS1 == N + 1. + #[inline] + pub fn to_signed_bigint_nplus1(&self) -> SignedBigInt { + debug_assert!(NPLUS1 == N + 1, "to_signed_bigint_nplus1 requires NPLUS1 = N + 1"); + let mut limbs = [0u64; NPLUS1]; + if N > 0 { + limbs[..N].copy_from_slice(self.magnitude_lo()); + } + limbs[N] = self.magnitude_hi() as u64; + let mag = BigInt::(limbs); + SignedBigInt::from_bigint(mag, self.is_positive()) + } +} + +// ------------------------------------------------------------------------------------------------ +// Operator traits +// ------------------------------------------------------------------------------------------------ + +impl Neg for SignedBigIntHi32 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self::new(self.magnitude_lo, self.magnitude_hi, !self.is_positive) + } +} + +impl Add for SignedBigIntHi32 { + type Output = Self; + + fn add(mut self, rhs: Self) -> Self::Output { + self.add_assign_in_place(&rhs); + self + } +} + +impl AddAssign for SignedBigIntHi32 { + fn add_assign(&mut self, rhs: Self) { + self.add_assign_in_place(&rhs); + } +} + +impl Sub for SignedBigIntHi32 { + type Output = Self; + + fn sub(mut self, rhs: Self) -> Self::Output { + self.sub_assign_in_place(&rhs); + self + } +} + +impl SubAssign for SignedBigIntHi32 { + fn sub_assign(&mut self, rhs: Self) { + self.sub_assign_in_place(&rhs); + } +} + +impl MulAssign for SignedBigIntHi32 { + fn mul_assign(&mut self, rhs: Self) { + *self = self.mul(&rhs); + } +} + +// Reference variants for efficiency +impl Add<&SignedBigIntHi32> for SignedBigIntHi32 { + type Output = SignedBigIntHi32; + + #[inline] + fn add(mut self, rhs: &SignedBigIntHi32) -> Self::Output { + self.add_assign_in_place(rhs); + self + } +} + +impl Sub<&SignedBigIntHi32> for SignedBigIntHi32 { + type Output = SignedBigIntHi32; + + #[inline] + fn sub(mut self, rhs: &SignedBigIntHi32) -> Self::Output { + self.sub_assign_in_place(rhs); + self + } +} + +impl Mul<&SignedBigIntHi32> for SignedBigIntHi32 { + type Output = SignedBigIntHi32; + + #[inline] + fn mul(self, rhs: &SignedBigIntHi32) -> Self::Output { + let (lo, hi) = self.mul_magnitudes(rhs); + let is_positive = !(self.is_positive ^ rhs.is_positive); + Self::new(lo, hi, is_positive) + } +} + +impl AddAssign<&SignedBigIntHi32> for SignedBigIntHi32 { + #[inline] + fn add_assign(&mut self, rhs: &SignedBigIntHi32) { + self.add_assign_in_place(rhs); + } +} + +impl SubAssign<&SignedBigIntHi32> for SignedBigIntHi32 { + #[inline] + fn sub_assign(&mut self, rhs: &SignedBigIntHi32) { + self.sub_assign_in_place(rhs); + } +} + +impl MulAssign<&SignedBigIntHi32> for SignedBigIntHi32 { + #[inline] + fn mul_assign(&mut self, rhs: &SignedBigIntHi32) { + *self = self.mul(rhs); + } +} + +// By-ref binary operator variants to avoid copying both operands +impl<'a, const N: usize> Add for &'a SignedBigIntHi32 { + type Output = SignedBigIntHi32; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.add_assign_in_place(rhs); + out + } +} + +impl<'a, const N: usize> Sub for &'a SignedBigIntHi32 { + type Output = SignedBigIntHi32; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + let mut out = *self; + out.sub_assign_in_place(rhs); + out + } +} + +impl<'a, const N: usize> Mul for &'a SignedBigIntHi32 { + type Output = SignedBigIntHi32; + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + let (lo, hi) = self.mul_magnitudes(rhs); + let is_positive = !(self.is_positive ^ rhs.is_positive); + SignedBigIntHi32::new(lo, hi, is_positive) + } +} + +// ------------------------------------------------------------------------------------------------ +// S160-specific inherent constructors (ergonomic helpers) +// ------------------------------------------------------------------------------------------------ + +impl S160 { + /// Construct from the signed difference of two u64 values: returns |a - b| with sign a>=b. + #[inline] + pub fn from_diff_u64(a: u64, b: u64) -> Self { + let mag = a.abs_diff(b); + let is_positive = a >= b; + S160::new([mag, 0], 0, is_positive) + } + + /// Construct from a u128 magnitude and an explicit sign. + #[inline] + pub fn from_magnitude_u128(mag: u128, is_positive: bool) -> Self { + let lo = mag as u64; + let hi = (mag >> 64) as u64; + S160::new([lo, hi], 0, is_positive) + } + + /// Construct from the signed difference of two u128 values: returns |u1 - u2| with sign u1>=u2. + #[inline] + pub fn from_diff_u128(u1: u128, u2: u128) -> Self { + if u1 >= u2 { + S160::from_magnitude_u128(u1 - u2, true) + } else { + S160::from_magnitude_u128(u2 - u1, false) + } + } + + /// Construct from the sum of two u128 values, preserving carry into the top 32-bit head. + #[inline] + pub fn from_sum_u128(u1: u128, u2: u128) -> Self { + let u1_lo = u1 as u64; + let u1_hi = (u1 >> 64) as u64; + let u2_lo = u2 as u64; + let u2_hi = (u2 >> 64) as u64; + let (sum_lo, carry0) = u1_lo.overflowing_add(u2_lo); + let (sum_hi1, carry1) = u1_hi.overflowing_add(u2_hi); + let (sum_hi, carry2) = sum_hi1.overflowing_add(if carry0 { 1 } else { 0 }); + let carry_out = (carry1 as u8 | carry2 as u8) != 0; + S160::new([sum_lo, sum_hi], if carry_out { 1 } else { 0 }, true) + } + + /// Construct from (u128 - i128) with full-width integer semantics. + #[inline] + pub fn from_u128_minus_i128(u: u128, i: i128) -> Self { + if i >= 0 { + S160::from_diff_u128(u, i as u128) + } else { + let abs_i: u128 = i.unsigned_abs(); + S160::from_sum_u128(u, abs_i) + } + } +} + +// ------------------------------------------------------------------------------------------------ +// Ordering and canonical serialization +// ------------------------------------------------------------------------------------------------ + +impl core::cmp::PartialOrd for SignedBigIntHi32 { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl core::cmp::Ord for SignedBigIntHi32 { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + if self.is_zero() && other.is_zero() { + return Ordering::Equal; + } + match (self.is_positive, other.is_positive) { + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + _ => { + let ord = self.compare_magnitudes(other); + if self.is_positive { ord } else { ord.reverse() } + }, + } + } +} + +impl CanonicalSerialize for SignedBigIntHi32 { + #[inline] + fn serialize_with_mode( + &self, + mut w: W, + compress: Compress, + ) -> Result<(), SerializationError> { + // Encode sign, then (hi, lo) + (self.is_positive as u8).serialize_with_mode(&mut w, compress)?; + (self.magnitude_hi as i32).serialize_with_mode(&mut w, compress)?; + for i in 0..N { + self.magnitude_lo[i].serialize_with_mode(&mut w, compress)?; + } + Ok(()) + } + + #[inline] + fn serialized_size(&self, compress: Compress) -> usize { + (self.is_positive as u8).serialized_size(compress) + + (self.magnitude_hi as i32).serialized_size(compress) + + (0u64).serialized_size(compress) * N + } +} + +impl CanonicalDeserialize for SignedBigIntHi32 { + #[inline] + fn deserialize_with_mode( + mut r: R, + compress: Compress, + validate: Validate, + ) -> Result { + let sign_u8 = u8::deserialize_with_mode(&mut r, compress, validate)?; + let hi = i32::deserialize_with_mode(&mut r, compress, validate)?; + let mut lo = [0u64; N]; + for i in 0..N { + lo[i] = u64::deserialize_with_mode(&mut r, compress, validate)?; + } + Ok(SignedBigIntHi32::new(lo, hi as u32, sign_u8 != 0)) + } +} + +impl Valid for SignedBigIntHi32 { + #[inline] + fn check(&self) -> Result<(), SerializationError> { + // No additional invariants beyond structural fields + Ok(()) + } +} + +// ------------------------------------------------------------------------------------------------ +// Symmetric mul: S160 * I8OrI96 -> S224 (for ergonomics) +// ------------------------------------------------------------------------------------------------ + +impl core::ops::Mul for S160 { + type Output = S224; + #[inline] + fn mul(self, rhs: crate::biginteger::I8OrI96) -> Self::Output { + rhs * self + } +} + +impl core::ops::Mul<&crate::biginteger::I8OrI96> for S160 { + type Output = S224; + #[inline] + fn mul(self, rhs: &crate::biginteger::I8OrI96) -> Self::Output { + (*rhs) * self + } +} + +impl core::ops::Mul for &S160 { + type Output = S224; + #[inline] + fn mul(self, rhs: crate::biginteger::I8OrI96) -> Self::Output { + rhs * *self + } +} + +impl core::ops::Mul<&crate::biginteger::I8OrI96> for &S160 { + type Output = S224; + #[inline] + fn mul(self, rhs: &crate::biginteger::I8OrI96) -> Self::Output { + (*rhs) * *self + } +} + +// ------------------------------------------------------------------------------------------------ +// From traits +// ------------------------------------------------------------------------------------------------ + +impl From for S96 { + #[inline] + fn from(val: i64) -> Self { + Self::new([val.unsigned_abs()], 0, val.is_positive()) + } +} + +impl From for S96 { + #[inline] + fn from(val: u64) -> Self { + Self::new([val], 0, true) + } +} + +impl From for S96 { + #[inline] + fn from(val: S64) -> Self { + Self::new([val.magnitude.0[0]], 0, val.is_positive) + } +} + +impl From for S160 { + #[inline] + fn from(val: i64) -> Self { + Self::new([val.unsigned_abs(), 0], 0, val.is_positive()) + } +} + +impl From for S160 { + #[inline] + fn from(val: u64) -> Self { + Self::new([val, 0], 0, true) + } +} + +impl From for S160 { + #[inline] + fn from(val: S64) -> Self { + Self::new([val.magnitude.0[0], 0], 0, val.is_positive) + } +} + +impl From for S160 { + #[inline] + fn from(val: u128) -> Self { + let lo = val as u64; + let hi = (val >> 64) as u64; + Self::new([lo, hi], 0, true) + } +} + +impl From for S160 { + #[inline] + fn from(val: i128) -> Self { + let is_positive = val.is_positive(); + let mag = val.unsigned_abs(); + let lo = mag as u64; + let hi = (mag >> 64) as u64; + Self::new([lo, hi], 0, is_positive) + } +} + +impl From for S160 { + #[inline] + fn from(val: S128) -> Self { + Self::new([val.magnitude.0[0], val.magnitude.0[1]], 0, val.is_positive) + } +} + +impl From for crate::biginteger::BigInt { + #[inline] + #[allow(unsafe_code)] + fn from(val: S224) -> Self { + if N != 4 { + panic!("FromS224 for BigInt only supports N=4, got N={N}"); + } + let lo = val.magnitude_lo(); + let hi = val.magnitude_hi() as u64; + let bigint4 = crate::biginteger::BigInt::<4>([lo[0], lo[1], lo[2], hi]); + + unsafe { + let ptr = &bigint4 as *const BigInt<4> as *const BigInt; + ptr.read() + } + } +} diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 4b3fa54b3..1aad5211d 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -1,281 +1,1130 @@ -use crate::{biginteger::BigInteger, UniformRand}; -use num_bigint::BigUint; - -// Test elementary math operations for BigInteger. -fn biginteger_arithmetic_test(a: B, b: B, zero: B, max: B) { - // zero == zero - assert_eq!(zero, zero); - - // zero.is_zero() == true - assert_eq!(zero.is_zero(), true); - - // a == a - assert_eq!(a, a); - - // a + 0 = a - let mut a0_add = a; - let carry = a0_add.add_with_carry(&zero); - assert_eq!(a0_add, a); - assert_eq!(carry, false); - - // a - 0 = a - let mut a0_sub = a; - let borrow = a0_sub.sub_with_borrow(&zero); - assert_eq!(a0_sub, a); - assert_eq!(borrow, false); - - // a - a = 0 - let mut aa_sub = a; - let borrow = aa_sub.sub_with_borrow(&a); - assert_eq!(aa_sub, zero); - assert_eq!(borrow, false); - - // a + b = b + a - let mut ab_add = a; - let ab_carry = ab_add.add_with_carry(&b); - let mut ba_add = b; - let ba_carry = ba_add.add_with_carry(&a); - assert_eq!(ab_add, ba_add); - assert_eq!(ab_carry, ba_carry); - - // a * 1 = a - let mut a_mul1 = a; - a_mul1 <<= 0; - assert_eq!(a_mul1, a); - - // a * 2 = a + a - let mut a_mul2 = a; - a_mul2.mul2(); - let mut a_plus_a = a; - let carry_a_plus_a = a_plus_a.add_with_carry(&a); // Won't assert anything about carry bit. - assert_eq!(a_mul2, a_plus_a); - - // a * 1 = a - assert_eq!(a.mul_low(&B::from(1u64)), a); - - // a * 2 = a - assert_eq!(a.mul_low(&B::from(2u64)), a_plus_a); - - // a * b = b * a - assert_eq!(a.mul_low(&b), b.mul_low(&a)); - - // a * 2 * b * 0 = 0 - assert!(a.mul_low(&zero).is_zero()); - - // a * 2 * ... * 2 = a * 2^n - let mut a_mul_n = a; - for _ in 0..20 { - a_mul_n = a_mul_n.mul_low(&B::from(2u64)); - } - assert_eq!(a_mul_n, a << 20); - - // a * 0 = (0, 0) - assert_eq!(a.mul(&zero), (zero, zero)); - - // a * 1 = (a, 0) - assert_eq!(a.mul(&B::from(1u64)), (a, zero)); - - // a * 1 = 0 (high part of the result) - assert_eq!(a.mul_high(&B::from(1u64)), (zero)); - - // a * 0 = 0 (high part of the result) - assert!(a.mul_high(&zero).is_zero()); - - // If a + a has a carry - if carry_a_plus_a { - // a + a has a carry: high part of a * 2 is not zero - assert_ne!(a.mul_high(&B::from(2u64)), zero); - } else { - // a + a has no carry: high part of a * 2 is zero - assert_eq!(a.mul_high(&B::from(2u64)), zero); - } - - // max + max = max * 2 - let mut max_plus_max = max; - max_plus_max.add_with_carry(&max); - assert_eq!(max.mul(&B::from(2u64)), (max_plus_max, B::from(1u64))); - assert_eq!(max.mul_high(&B::from(2u64)), B::from(1u64)); -} +#[cfg(test)] +pub mod tests { -fn biginteger_shr() { - let mut rng = ark_std::test_rng(); - let a = B::rand(&mut rng); - assert_eq!(a >> 0, a); - - // Binary simple test - let a = B::from(256u64); - assert_eq!(a >> 2, B::from(64u64)); - - // Test saturated underflow - let a = B::from(1u64); - assert_eq!(a >> 5, B::from(0u64)); - - // Test null bits - let a = B::rand(&mut rng); - let b = a >> 3; - assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 1), false); - assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 2), false); - assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 3), false); -} + use crate::{ + biginteger::{BigInteger, SignedBigInt}, + UniformRand, + }; + use ark_std::Zero; + use num_bigint::BigUint; -fn biginteger_shl() { - let mut rng = ark_std::test_rng(); - let a = B::rand(&mut rng); - assert_eq!(a << 0, a); - - // Binary simple test - let a = B::from(64u64); - assert_eq!(a << 2, B::from(256u64)); - - // Testing saturated overflow - let a = B::rand(&mut rng); - assert_eq!(a << ((B::NUM_LIMBS as u32) * 64), B::from(0u64)); - - // Test null bits - let a = B::rand(&mut rng); - let b = a << 3; - assert_eq!(b.get_bit(0), false); - assert_eq!(b.get_bit(1), false); - assert_eq!(b.get_bit(2), false); -} + // Test elementary math operations for BigInteger. + fn biginteger_arithmetic_test(a: B, b: B, zero: B, max: B) { + // zero == zero + assert_eq!(zero, zero); -// Test for BigInt's bitwise operations -fn biginteger_bitwise_ops_test() { - let mut rng = ark_std::test_rng(); - - // Test XOR - // a xor a = 0 - let a = B::rand(&mut rng); - assert_eq!(a ^ &a, B::from(0_u64)); - - // Testing a xor b xor b - let a = B::rand(&mut rng); - let b = B::rand(&mut rng); - let xor_ab = a ^ b; - assert_eq!(xor_ab ^ b, a); - - // Test OR - // a or a = a - let a = B::rand(&mut rng); - assert_eq!(a | &a, a); - - // Testing a or b or b - let a = B::rand(&mut rng); - let b = B::rand(&mut rng); - let or_ab = a | b; - assert_eq!(or_ab | &b, a | b); - - // Test AND - // a and a = a - let a = B::rand(&mut rng); - assert_eq!(a & (&a), a); - - // Testing a and a and b. - let a = B::rand(&mut rng); - let b = B::rand(&mut rng); - let b_clone = b.clone(); - let and_ab = a & b; - assert_eq!(and_ab & b_clone, a & b); - - // Testing De Morgan's law - let a = 0x1234567890abcdef_u64; - let b = 0xfedcba0987654321_u64; - let de_morgan_lhs = B::from(!(a | b)); - let de_morgan_rhs = B::from(!a) & B::from(!b); - assert_eq!(de_morgan_lhs, de_morgan_rhs); -} + // zero.is_zero() == true + assert_eq!(zero.is_zero(), true); -// Test correctness of BigInteger's bit values -fn biginteger_bits_test() { - let mut one = B::from(1u64); - // 0th bit of BigInteger representing 1 is 1 - assert!(one.get_bit(0)); - // 1st bit of BigInteger representing 1 is not 1 - assert!(!one.get_bit(1)); - one <<= 5; - let thirty_two = one; - // 0th bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(0)); - // 1st bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(1)); - // 2nd bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(2)); - // 3rd bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(3)); - // 4th bit of BigInteger representing 32 is not 1 - assert!(!thirty_two.get_bit(4)); - // 5th bit of BigInteger representing 32 is 1 - assert!(thirty_two.get_bit(5), "{:?}", thirty_two); - - // Generates a random BigInteger and tests bit construction methods. - let mut rng = ark_std::test_rng(); - let a: B = UniformRand::rand(&mut rng); - assert_eq!(B::from_bits_be(&a.to_bits_be()), a); - assert_eq!(B::from_bits_le(&a.to_bits_le()), a); -} + // a == a + assert_eq!(a, a); -// Test conversion from BigInteger to BigUint -fn biginteger_conversion_test() { - let mut rng = ark_std::test_rng(); + // a + 0 = a + let mut a0_add = a; + let carry = a0_add.add_with_carry(&zero); + assert_eq!(a0_add, a); + assert_eq!(carry, false); - let x: B = UniformRand::rand(&mut rng); - let x_bigint: BigUint = x.into(); - let x_recovered = B::try_from(x_bigint).ok().unwrap(); + // a - 0 = a + let mut a0_sub = a; + let borrow = a0_sub.sub_with_borrow(&zero); + assert_eq!(a0_sub, a); + assert_eq!(borrow, false); - assert_eq!(x, x_recovered); -} + // a - a = 0 + let mut aa_sub = a; + let borrow = aa_sub.sub_with_borrow(&a); + assert_eq!(aa_sub, zero); + assert_eq!(borrow, false); -// Wrapper test function for BigInteger -fn test_biginteger(max: B, zero: B) { - let mut rng = ark_std::test_rng(); - let a: B = UniformRand::rand(&mut rng); - let b: B = UniformRand::rand(&mut rng); - biginteger_arithmetic_test(a, b, zero, max); - biginteger_bits_test::(); - biginteger_conversion_test::(); - biginteger_bitwise_ops_test::(); - biginteger_shr::(); - biginteger_shl::(); -} + // a + b = b + a + let mut ab_add = a; + let ab_carry = ab_add.add_with_carry(&b); + let mut ba_add = b; + let ba_carry = ba_add.add_with_carry(&a); + assert_eq!(ab_add, ba_add); + assert_eq!(ab_carry, ba_carry); -#[test] -fn test_biginteger64() { - use crate::biginteger::BigInteger64 as B; - test_biginteger(B::new([u64::MAX; 1]), B::new([0u64; 1])); -} + // a * 1 = a + let mut a_mul1 = a; + a_mul1 <<= 0; + assert_eq!(a_mul1, a); -#[test] -fn test_biginteger128() { - use crate::biginteger::BigInteger128 as B; - test_biginteger(B::new([u64::MAX; 2]), B::new([0u64; 2])); -} + // a * 2 = a + a + let mut a_mul2 = a; + a_mul2.mul2(); + let mut a_plus_a = a; + let carry_a_plus_a = a_plus_a.add_with_carry(&a); // Won't assert anything about carry bit. + assert_eq!(a_mul2, a_plus_a); -#[test] -fn test_biginteger256() { - use crate::biginteger::BigInteger256 as B; - test_biginteger(B::new([u64::MAX; 4]), B::new([0u64; 4])); -} + // a * 1 = a + assert_eq!(a.mul_low(&B::from(1u64)), a); -#[test] -fn test_biginteger384() { - use crate::biginteger::BigInteger384 as B; - test_biginteger(B::new([u64::MAX; 6]), B::new([0u64; 6])); -} + // a * 2 = a + assert_eq!(a.mul_low(&B::from(2u64)), a_plus_a); -#[test] -fn test_biginteger448() { - use crate::biginteger::BigInteger448 as B; - test_biginteger(B::new([u64::MAX; 7]), B::new([0u64; 7])); -} + // a * b = b * a + assert_eq!(a.mul_low(&b), b.mul_low(&a)); -#[test] -fn test_biginteger768() { - use crate::biginteger::BigInteger768 as B; - test_biginteger(B::new([u64::MAX; 12]), B::new([0u64; 12])); -} + // a * 2 * b * 0 = 0 + assert!(a.mul_low(&zero).is_zero()); + + // a * 2 * ... * 2 = a * 2^n + let mut a_mul_n = a; + for _ in 0..20 { + a_mul_n = a_mul_n.mul_low(&B::from(2u64)); + } + assert_eq!(a_mul_n, a << 20); + + // a * 0 = (0, 0) + assert_eq!(a.mul(&zero), (zero, zero)); + + // a * 1 = (a, 0) + assert_eq!(a.mul(&B::from(1u64)), (a, zero)); + + // a * 1 = 0 (high part of the result) + assert_eq!(a.mul_high(&B::from(1u64)), (zero)); + + // a * 0 = 0 (high part of the result) + assert!(a.mul_high(&zero).is_zero()); + + // If a + a has a carry + if carry_a_plus_a { + // a + a has a carry: high part of a * 2 is not zero + assert_ne!(a.mul_high(&B::from(2u64)), zero); + } else { + // a + a has no carry: high part of a * 2 is zero + assert_eq!(a.mul_high(&B::from(2u64)), zero); + } + + // max + max = max * 2 + let mut max_plus_max = max; + max_plus_max.add_with_carry(&max); + assert_eq!(max.mul(&B::from(2u64)), (max_plus_max, B::from(1u64))); + assert_eq!(max.mul_high(&B::from(2u64)), B::from(1u64)); + } + + #[test] + fn test_i8_or_i96_mul_s160_edges() { + use crate::biginteger::{I8OrI96, S160, S224}; + + // Helper to convert S224 to BigUint reference value (unsigned magnitude) and sign + fn s224_to_biguint_and_sign(v: &S224) -> (num_bigint::BigUint, bool) { + let lo = v.magnitude_lo(); + let hi32 = v.magnitude_hi() as u64; + let mut limbs = [0u64; 4]; + limbs[0] = lo[0]; + limbs[1] = lo[1]; + limbs[2] = lo[2]; + limbs[3] = hi32; + (num_bigint::BigUint::from(crate::biginteger::BigInt::<4>(limbs)), v.is_positive()) + } + + // Case 1: small i8 * b1-only rhs + let k = I8OrI96::from_i8(7); // x1 = 0 + let rhs = S160::new([0, 5], 0, true); // b0=0, b1=5, b2=0 + let out = k * rhs; + let (mag_bu, sign) = s224_to_biguint_and_sign(&out); + let expected = num_bigint::BigUint::from(7u64) * (num_bigint::BigUint::from(5u64) << 64) + % (num_bigint::BigUint::from(1u8) << 224); + assert!(sign); + assert_eq!(mag_bu, expected); + + // Case 2: large x with x1!=0, b2!=0, b1==0 (hits hi32 path) + // x = 2^80 + 3 => hi32 = 2^(80-64)=2^16, lo=3 + let x = I8OrI96::from_i128(((1i128) << 80) + 3); + let rhs2 = S160::new([11, 0], 9, true); // b0=11, b1=0, b2=9 + let out2 = x * rhs2; + let (mag_bu2, sign2) = s224_to_biguint_and_sign(&out2); + let x_bi = (num_bigint::BigUint::from(1u8) << 80) + num_bigint::BigUint::from(3u8); + let rhs_bi = (num_bigint::BigUint::from(9u8) << 128) + (num_bigint::BigUint::from(11u8)); + let exp2 = x_bi * rhs_bi % (num_bigint::BigUint::from(1u8) << 224); + assert!(sign2); + assert_eq!(mag_bu2, exp2); + + // Case 3: negative small i8 * nonzero rhs, zero result canonicalizes to positive + let k3 = I8OrI96::from_i8(-1); + let rhs3 = S160::zero(); + let out3 = k3 * rhs3; + let (_, sign3) = s224_to_biguint_and_sign(&out3); + assert!(sign3); + } + + #[test] + fn test_s160_mul_s160_hi32_consistency() { + use crate::biginteger::{BigInt, S160}; + + // Spot-check a configuration that exercises the hi32 accumulation path + let a = S160::new([1u64 << 63, 0], 1, true); // a2=1, a0 has high bit + let b = S160::new([0, 1u64 << 63], 1, true); // b2=1, b1 has high bit + let got = &a * &b; // S160 result + + // Convert S160 value to BigUint: pack [lo0, lo1, hi32] into BigInt<3> + let mut pack = [0u64; 3]; + pack[0] = got.magnitude_lo()[0]; + pack[1] = got.magnitude_lo()[1]; + pack[2] = got.magnitude_hi() as u64; + let got_bu = num_bigint::BigUint::from(BigInt::<3>(pack)); + + // Reference BigUint modulo 2^160 + let a_bu = (num_bigint::BigUint::from(a.magnitude_lo()[1]) << 64) + + num_bigint::BigUint::from(a.magnitude_lo()[0]) + + (num_bigint::BigUint::from(a.magnitude_hi() as u64) << 128); + let b_bu = (num_bigint::BigUint::from(b.magnitude_lo()[1]) << 64) + + num_bigint::BigUint::from(b.magnitude_lo()[0]) + + (num_bigint::BigUint::from(b.magnitude_hi() as u64) << 128); + let prod = (a_bu * b_bu) % (num_bigint::BigUint::from(1u8) << 160); + + assert_eq!(got_bu, prod); + } + + fn biginteger_shr() { + let mut rng = ark_std::test_rng(); + let a = B::rand(&mut rng); + assert_eq!(a >> 0, a); + + // Binary simple test + let a = B::from(256u64); + assert_eq!(a >> 2, B::from(64u64)); + + // Test saturated underflow + let a = B::from(1u64); + assert_eq!(a >> 5, B::from(0u64)); + + // Test null bits + let a = B::rand(&mut rng); + let b = a >> 3; + assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 1), false); + assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 2), false); + assert_eq!(b.get_bit(B::NUM_LIMBS * 64 - 3), false); + } + + fn biginteger_shl() { + let mut rng = ark_std::test_rng(); + let a = B::rand(&mut rng); + assert_eq!(a << 0, a); + + // Binary simple test + let a = B::from(64u64); + assert_eq!(a << 2, B::from(256u64)); + + // Testing saturated overflow + let a = B::rand(&mut rng); + assert_eq!(a << ((B::NUM_LIMBS as u32) * 64), B::from(0u64)); + + // Test null bits + let a = B::rand(&mut rng); + let b = a << 3; + assert_eq!(b.get_bit(0), false); + assert_eq!(b.get_bit(1), false); + assert_eq!(b.get_bit(2), false); + } + + // Test for BigInt's bitwise operations + fn biginteger_bitwise_ops_test() { + let mut rng = ark_std::test_rng(); + + // Test XOR + // a xor a = 0 + let a = B::rand(&mut rng); + assert_eq!(a ^ &a, B::from(0_u64)); + + // Testing a xor b xor b + let a = B::rand(&mut rng); + let b = B::rand(&mut rng); + let xor_ab = a ^ b; + assert_eq!(xor_ab ^ b, a); + + // Test OR + // a or a = a + let a = B::rand(&mut rng); + assert_eq!(a | &a, a); + + // Testing a or b or b + let a = B::rand(&mut rng); + let b = B::rand(&mut rng); + let or_ab = a | b; + assert_eq!(or_ab | &b, a | b); + + // Test AND + // a and a = a + let a = B::rand(&mut rng); + assert_eq!(a & (&a), a); + + // Testing a and a and b. + let a = B::rand(&mut rng); + let b = B::rand(&mut rng); + let b_clone = b.clone(); + let and_ab = a & b; + assert_eq!(and_ab & b_clone, a & b); + + // Testing De Morgan's law + let a = 0x1234567890abcdef_u64; + let b = 0xfedcba0987654321_u64; + let de_morgan_lhs = B::from(!(a | b)); + let de_morgan_rhs = B::from(!a) & B::from(!b); + assert_eq!(de_morgan_lhs, de_morgan_rhs); + } + + // Test correctness of BigInteger's bit values + fn biginteger_bits_test() { + let mut one = B::from(1u64); + // 0th bit of BigInteger representing 1 is 1 + assert!(one.get_bit(0)); + // 1st bit of BigInteger representing 1 is not 1 + assert!(!one.get_bit(1)); + one <<= 5; + let thirty_two = one; + // 0th bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(0)); + // 1st bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(1)); + // 2nd bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(2)); + // 3rd bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(3)); + // 4th bit of BigInteger representing 32 is not 1 + assert!(!thirty_two.get_bit(4)); + // 5th bit of BigInteger representing 32 is 1 + assert!(thirty_two.get_bit(5), "{:?}", thirty_two); + + // Generates a random BigInteger and tests bit construction methods. + let mut rng = ark_std::test_rng(); + let a: B = UniformRand::rand(&mut rng); + assert_eq!(B::from_bits_be(&a.to_bits_be()), a); + assert_eq!(B::from_bits_le(&a.to_bits_le()), a); + } + + // Test conversion from BigInteger to BigUint + fn biginteger_conversion_test() { + let mut rng = ark_std::test_rng(); + + let x: B = UniformRand::rand(&mut rng); + let x_bigint: BigUint = x.into(); + let x_recovered = B::try_from(x_bigint).ok().unwrap(); + + assert_eq!(x, x_recovered); + } + + // Wrapper test function for BigInteger + fn test_biginteger(max: B, zero: B) { + let mut rng = ark_std::test_rng(); + let a: B = UniformRand::rand(&mut rng); + let b: B = UniformRand::rand(&mut rng); + biginteger_arithmetic_test(a, b, zero, max); + biginteger_bits_test::(); + biginteger_conversion_test::(); + biginteger_bitwise_ops_test::(); + biginteger_shr::(); + biginteger_shl::(); + } + + #[test] + fn test_biginteger64() { + use crate::biginteger::BigInteger64 as B; + test_biginteger(B::new([u64::MAX; 1]), B::new([0u64; 1])); + } + + #[test] + fn test_biginteger128() { + use crate::biginteger::BigInteger128 as B; + test_biginteger(B::new([u64::MAX; 2]), B::new([0u64; 2])); + } + + #[test] + fn test_biginteger256() { + use crate::biginteger::BigInteger256 as B; + test_biginteger(B::new([u64::MAX; 4]), B::new([0u64; 4])); + } + + #[test] + fn test_biginteger384() { + use crate::biginteger::BigInteger384 as B; + test_biginteger(B::new([u64::MAX; 6]), B::new([0u64; 6])); + } + + #[test] + fn test_biginteger448() { + use crate::biginteger::BigInteger448 as B; + test_biginteger(B::new([u64::MAX; 7]), B::new([0u64; 7])); + } + + #[test] + fn test_biginteger768() { + use crate::biginteger::BigInteger768 as B; + test_biginteger(B::new([u64::MAX; 12]), B::new([0u64; 12])); + } + + #[test] + fn test_biginteger832() { + use crate::biginteger::BigInteger832 as B; + test_biginteger(B::new([u64::MAX; 13]), B::new([0u64; 13])); + } + + // Tests for NEW functions + use crate::biginteger::BigInteger256; + + #[test] + fn test_mul_u64_in_place() { + let mut a = BigInteger256::from(0x123456789ABCDEFu64); + let b = 0x987654321u64; + + // Test against reference implementation + let expected = BigUint::from(0x123456789ABCDEFu64) * BigUint::from(b); + a.mul_u64_in_place(b); + assert_eq!(BigUint::from(a), expected); + + // Test zero multiplication + let mut zero = BigInteger256::zero(); + zero.mul_u64_in_place(12345); + assert!(zero.is_zero()); + + // Test multiplication by zero + let mut a = BigInteger256::from(12345u64); + a.mul_u64_in_place(0); + assert!(a.is_zero()); + + // Test multiplication by one + let orig = BigInteger256::from(0xDEADBEEFu64); + let mut a = orig; + a.mul_u64_in_place(1); + assert_eq!(a, orig); + } + + #[test] + fn test_mul_u64_w_carry() { + let a = BigInteger256::from(u64::MAX); + let b = u64::MAX; + + // Test against reference implementation + let expected = BigUint::from(u64::MAX) * BigUint::from(u64::MAX); + let result = a.mul_u64_w_carry::<5>(b); + assert_eq!(BigUint::from(result), expected); + + // Test with small numbers + let a = BigInteger256::from(12345u64); + let b = 67890u64; + let expected = BigUint::from(12345u64) * BigUint::from(67890u64); + let result = a.mul_u64_w_carry::<5>(b); + assert_eq!(BigUint::from(result), expected); + + // Test zero cases + let zero = BigInteger256::zero(); + let result = zero.mul_u64_w_carry::<5>(12345); + assert!(result.is_zero()); + + let a = BigInteger256::from(12345u64); + let result = a.mul_u64_w_carry::<5>(0); + assert!(result.is_zero()); + + // Test multiplication by one + let a = BigInteger256::from(0xDEADBEEFu64); + let result = a.mul_u64_w_carry::<5>(1); + let expected_bytes = a.to_bytes_le(); + let result_bytes = result.to_bytes_le(); + assert_eq!(&result_bytes[..expected_bytes.len()], &expected_bytes[..]); + } + + #[test] + fn test_fmu64a() { + let a = BigInteger256::from(12345u64); + let b = 67890u64; + let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); + + // Perform fused multiply-accumulate (no carry propagation in highest limb) + a.fm_limbs_into::<1, 5>(&[b], &mut acc, false); + + // Compare against separate multiply and add + let expected_mul = BigUint::from(12345u64) * BigUint::from(67890u64); + let expected_total = expected_mul + BigUint::from(11111u64); + assert_eq!(BigUint::from(acc), expected_total); + + // Test zero cases + let zero = BigInteger256::zero(); + let mut acc = BigInteger256::from(12345u64).mul_u64_w_carry::<5>(1); + let acc_copy = acc; + zero.fm_limbs_into::<1, 5>(&[67890], &mut acc, false); + assert_eq!(acc, acc_copy); // Should be unchanged + + // Test multiplication by zero + let a = BigInteger256::from(12345u64); + let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); + let acc_copy = acc; + a.fm_limbs_into::<1, 5>(&[0], &mut acc, false); + assert_eq!(acc, acc_copy); // Should be unchanged + + // Test multiplication by one (should be just addition) + let a = BigInteger256::from(12345u64); + let mut acc = BigInteger256::from(11111u64).mul_u64_w_carry::<5>(1); + a.fm_limbs_into::<1, 5>(&[1], &mut acc, false); + let expected = BigUint::from(12345u64) + BigUint::from(11111u64); + assert_eq!(BigUint::from(acc), expected); + } + + #[test] + fn test_mul_u128_w_carry() { + let a = BigInteger256::from(0x123456789ABCDEFu64); + let b = 0x987654321DEADBEEFu128; + + // Test against reference implementation + let expected = BigUint::from(0x123456789ABCDEFu64) * BigUint::from(0x987654321DEADBEEFu128); + let result = a.mul_u128_w_carry::<5, 6>(b); + assert_eq!(BigUint::from(result), expected); -#[test] -fn test_biginteger832() { - use crate::biginteger::BigInteger832 as B; - test_biginteger(B::new([u64::MAX; 13]), B::new([0u64; 13])); + // Test with u64 value (should be same as mul_u64_w_carry) + let b_u64 = 0x987654321u64; + let result_u128 = a.mul_u128_w_carry::<5, 6>(b_u64 as u128); + let result_u64 = a.mul_u64_w_carry::<5>(b_u64); + + // Compare first 5 limbs (u64 result size) + for i in 0..5 { + assert_eq!(result_u128.0[i], result_u64.0[i]); + } + assert_eq!(result_u128.0[5], 0); // Extra limb should be zero + + // Test zero cases + let zero = BigInteger256::zero(); + let result = zero.mul_u128_w_carry::<5, 6>(12345); + assert!(result.is_zero()); + + let a = BigInteger256::from(12345u64); + let result = a.mul_u128_w_carry::<5, 6>(0); + assert!(result.is_zero()); + + // Test multiplication by one + let a = BigInteger256::from(0xDEADBEEFu64); + let result = a.mul_u128_w_carry::<5, 6>(1); + let expected_bytes = a.to_bytes_le(); + let result_bytes = result.to_bytes_le(); + assert_eq!(&result_bytes[..expected_bytes.len()], &expected_bytes[..]); + } + + #[test] + fn test_fm128a_basic_and_edges() { + use crate::biginteger::BigInteger256 as B; + // Basic reference check against BigUint + let a = B::from(0x123456789ABCDEFu64); + let b = 0x987654321DEADBEEFu128; + let mut acc = B::zero().mul_u128_w_carry::<5, 6>(1); // zero-extended accumulator (6 limbs) + a.fm_limbs_into::<2, 6>(&[b as u64, (b >> 64) as u64], &mut acc, true); + let expected = num_bigint::BigUint::from(0x123456789ABCDEFu64) + * num_bigint::BigUint::from(0x987654321DEADBEEFu128); + assert_eq!(num_bigint::BigUint::from(acc), expected); + + // Zero multiplier: no change + let a = B::from(12345u64); + let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); + let acc_copy = acc; + a.fm_limbs_into::<2, 6>(&[0u64, 0u64], &mut acc, true); + assert_eq!(acc, acc_copy); + + // One multiplier: reduces to addition + let a = B::from(12345u64); + let mut acc = B::from(11111u64).mul_u128_w_carry::<5, 6>(1); + a.fm_limbs_into::<2, 6>(&[1u64, 0u64], &mut acc, true); + let expected = num_bigint::BigUint::from(12345u64) + num_bigint::BigUint::from(11111u64); + assert_eq!(num_bigint::BigUint::from(acc), expected); + + // Overflow propagation from limb N into highest limb + let a = B::new([u64::MAX; 4]); + let mut acc = B::zero().mul_u128_w_carry::<5, 6>(1); + // Pre-fill limb N to force overflow when adding the final carry from low pass + acc.0[4] = u64::MAX; // limb N + acc.0[5] = 0; // highest limb + // cause carry=1 from low pass (a * 2) + a.fm_limbs_into::<2, 6>(&[2u64, 0u64], &mut acc, true); + // Expect highest limb incremented by 1 due to overflow from limb N + assert_eq!(acc.0[5], 1); + } + + #[test] + fn test_overflow_behavior_fmu64a() { + // Test that overflow in the highest limb wraps around as documented + let a = BigInteger256::new([u64::MAX; 4]); + let mut acc = BigInteger256::new([0, 0, 0, 0]).mul_u64_w_carry::<5>(1); + acc.0[4] = u64::MAX; // Set highest limb to max + + // This should cause overflow in the highest limb + a.fm_limbs_into::<1, 5>(&[2u64], &mut acc, false); + + // The overflow should wrap around + // u64::MAX * 2 = 2^65 - 2, which when added to u64::MAX = 2^65 + u64::MAX - 2 + // This wraps to u64::MAX - 2 with a carry of 1 that itself wraps + assert_eq!(acc.0[4], u64::MAX.wrapping_add(1)); // Wrapped result + } + + #[test] + fn test_edge_cases_large_numbers() { + // Test with maximum values + let max_bi = BigInteger256::new([u64::MAX; 4]); + + // mul_u64_w_carry with max values + let result = max_bi.mul_u64_w_carry::<5>(u64::MAX); + let expected = BigUint::from(max_bi) * BigUint::from(u64::MAX); + assert_eq!(BigUint::from(result), expected); + + // mul_u128_w_carry with max values + let result = max_bi.mul_u128_w_carry::<5, 6>(u128::MAX); + let expected = BigUint::from(max_bi) * BigUint::from(u128::MAX); + assert_eq!(BigUint::from(result), expected); + } + + #[test] + fn test_fmu64a_into_nplus4_correctness_and_edges() { + use crate::biginteger::{BigInt, BigInteger256 as B}; + let a = B::from(0xDEADBEEFCAFEBABEu64); + let other = 0xFEDCBA9876543210u64; + let mut acc = BigInt::<8>::zero(); // N+4 accumulator for N=4 + + // Reference: (a * other + acc_before) mod 2^(64*(N+4)) + let before = BigUint::from(acc.clone()); + a.fm_limbs_into::<1, 8>(&[other], &mut acc, true); + let mut expected = BigUint::from(a); + expected *= BigUint::from(other); + expected += before; + let modulus = BigUint::from(1u8) << (64 * 8); + expected %= &modulus; + assert_eq!(BigUint::from(acc.clone()), expected); + + // Zero multiplier is no-op + let mut acc2 = acc.clone(); + a.fm_limbs_into::<1, 8>(&[0u64], &mut acc2, true); + assert_eq!(acc2, acc); + + // One multiplier reduces to addition + let mut acc3 = BigInt::<8>::zero(); + acc3.0[0] = 11111; + let before3 = BigUint::from(acc3.clone()); + a.fm_limbs_into::<1, 8>(&[1u64], &mut acc3, true); + let mut expected3 = BigUint::from(a); + expected3 += before3; + expected3 %= &modulus; + assert_eq!(BigUint::from(acc3), expected3); + + // Force cascading carry across N..=N+3 + let a = B::new([u64::MAX; 4]); + let mut acc4 = BigInt::<8>::zero(); + acc4.0[4] = u64::MAX; // limb N + acc4.0[5] = u64::MAX; // limb N+1 + acc4.0[6] = u64::MAX; // limb N+2 + acc4.0[7] = 0; // limb N+3 (top) + // Use multiplier 2 so the low pass produces a carry=1 + a.fm_limbs_into::<1, 8>(&[2u64], &mut acc4, true); + assert_eq!(acc4.0[7], 1); + } + + #[test] + fn test_fm2x64a_into_nplus4_correctness() { + use crate::biginteger::{BigInt, BigInteger256 as B}; + let a = B::from(0x1234567890ABCDEFu64); + let other = [0x0FEDCBA987654321u64, 0x0011223344556677u64]; + let mut acc = BigInt::<8>::zero(); + + let before = BigUint::from(acc.clone()); + a.fm_limbs_into::<2, 8>(&other, &mut acc, true); + + // Expected: a * (lo + (hi << 64)) + acc_before mod 2^(64*8) + let hi = BigUint::from(other[1]); + let lo = BigUint::from(other[0]); + let factor = (hi << 64) + lo; + let mut expected = BigUint::from(a); + expected *= factor; + expected += before; + let modulus = BigUint::from(1u8) << (64 * 8); + expected %= &modulus; + assert_eq!(BigUint::from(acc.clone()), expected); + + // Zero limbs are no-op + let mut acc2 = acc.clone(); + a.fm_limbs_into::<2, 8>(&[0u64, 0u64], &mut acc2, true); + assert_eq!(acc2, acc); + } + + #[test] + fn test_fm3x64a_into_nplus4_correctness() { + use crate::biginteger::{BigInt, BigInteger256 as B}; + let a = B::from(0x0F0E0D0C0B0A0908u64); + let other = [ + 0x89ABCDEF01234567u64, + 0x76543210FEDCBA98u64, + 0x1122334455667788u64, + ]; + let mut acc = BigInt::<8>::zero(); + + let before = BigUint::from(acc.clone()); + a.fm_limbs_into::<3, 8>(&other, &mut acc, true); + + // Expected: a * (o0 + (o1<<64) + (o2<<128)) + acc_before mod 2^(64*8) + let term0 = BigUint::from(other[0]); + let term1 = BigUint::from(other[1]) << 64; + let term2 = BigUint::from(other[2]) << 128; + let factor = term0 + term1 + term2; + let mut expected = BigUint::from(a); + expected *= factor; + expected += before; + let modulus = BigUint::from(1u8) << (64 * 8); + expected %= &modulus; + assert_eq!(BigUint::from(acc.clone()), expected); + + // Edge: ensure offset accumulation lands in correct limbs + // Fill acc with a pattern, then accumulate using only the highest limb to ensure writes start at index 2 + let a = B::from(3u64); + let mut acc2 = BigInt::<8>::zero(); + acc2.0[0] = 5; + acc2.0[1] = 7; + let other2 = [0, 0, 2]; // Only offset by 2 limbs + let before2 = BigUint::from(acc2.clone()); + a.fm_limbs_into::<3, 8>(&other2, &mut acc2, true); + let mut expected2 = BigUint::from(a); + expected2 *= BigUint::from(2u64) << 128; + expected2 += before2; + let modulus = BigUint::from(1u8) << (64 * 8); + expected2 %= &modulus; + assert_eq!(BigUint::from(acc2), expected2); + } + + // ============================== + // SignedBigInt tests + // ============================== + + #[test] + fn test_signed_construction() { + // zero and one + let z = SignedBigInt::<1>::zero(); + assert!(z.is_zero()); + assert!(z.is_positive); + let o = SignedBigInt::<1>::one(); + assert!(!o.is_zero()); + assert!(o.is_positive); + + // from_u64 + let p = SignedBigInt::<1>::from_u64(42); + assert_eq!(p.magnitude.0[0], 42); + assert!(p.is_positive); + let n = SignedBigInt::<1>::from((42u64, false)); + assert_eq!(n.magnitude.0[0], 42); + assert!(!n.is_positive); + } + + #[test] + fn test_signed_add_sub_mul_neg() { + let a = SignedBigInt::<1>::from_u64(10); + let b = SignedBigInt::<1>::from_u64(5); + assert_eq!((a + b).magnitude.0[0], 15); + assert_eq!((a - b).magnitude.0[0], 5); + assert_eq!((a * b).magnitude.0[0], 50); + let neg = -a; + assert_eq!(neg.magnitude.0[0], 10); + assert!(!neg.is_positive); + + // opposite signs + let x = SignedBigInt::<1>::from_u64(30); + let y = SignedBigInt::<1>::from((20u64, false)); + let r = x + y; // 30 - 20 + assert!(r.is_positive); + assert_eq!(r.magnitude.0[0], 10); + + let x2 = SignedBigInt::<1>::from((20u64, false)); + let y2 = SignedBigInt::<1>::from_u64(30); + let r2 = x2 + y2; // -20 + 30 + assert!(r2.is_positive); + assert_eq!(r2.magnitude.0[0], 10); + } + + #[test] + fn test_signed_to_i128_and_mag_helpers() { + let p = SignedBigInt::<1>::from_u64(100); + assert_eq!(p.to_i128(), 100); + let n = SignedBigInt::<1>::from((100u64, false)); + assert_eq!(n.to_i128(), -100); + + let d = SignedBigInt::<2>::from_u128(0x1234_5678_9abc_def0_1111_2222_3333_4444u128); + assert_eq!(d.magnitude.0[0], 0x1111_2222_3333_4444); + assert_eq!(d.magnitude.0[1], 0x1234_5678_9abc_def0); + // Positive below 2^127 should convert + let expected_i128 = 0x1234_5678_9abc_def0_1111_2222_3333_4444u128 as i128; + assert_eq!(d.to_i128(), Some(expected_i128)); + + // Positive at 2^127 should fail + let too_big_pos = SignedBigInt::<2>::from_u128(1u128 << 127); + assert_eq!(too_big_pos.to_i128(), None); + + let small = SignedBigInt::<2>::new([100, 0], true); + assert_eq!(small.to_i128(), Some(100)); + assert_eq!(small.magnitude_as_u128(), 100u128); + } + + #[test] + fn test_add_with_sign_u64_helper() { + let (mag, sign) = crate::biginteger::signed::add_with_sign_u64(10, true, 5, true); + assert_eq!(mag, 15); + assert!(sign); + let (mag2, sign2) = crate::biginteger::signed::add_with_sign_u64(10, true, 5, false); + assert_eq!(mag2, 5); + assert!(sign2); + let (mag3, sign3) = crate::biginteger::signed::add_with_sign_u64(5, true, 10, false); + assert_eq!(mag3, 5); + assert!(!sign3); + } + + #[test] + fn test_signed_truncated_add_sub() { + use crate::biginteger::SignedBigInt as S; + let a = S::<2>::from_u128(0x0000_0000_0000_0001_ffff_ffff_ffff_fffe); + let b = S::<2>::from_u128(0x0000_0000_0000_0001_0000_0000_0000_0001); + // Add and truncate to 1 limb + // Respect BigInt::add_trunc contract by truncating rhs to 1 limb + let b1 = S::<1>::from_bigint(crate::biginteger::BigInt::<1>::new([b.magnitude.0[0]]), b.is_positive); + let r1 = a.add_trunc_mixed::<1, 1>(&b1); + // expected low limb wrap of the low words, ignoring carry to limb1 + let expected_low = (0xffff_ffff_ffff_fffeu64).wrapping_add(0x0000_0000_0000_0001u64); + assert_eq!(r1.magnitude.0[0], expected_low); + assert!(r1.is_positive); + + // Different signs: subtraction path (use N=1 throughout so M<=P inside) + let a = S::<1>::from_u64(0x2); + let b = S::<1>::from(-3i64); // -3 + let r2 = a.add_trunc::<1>(&b); // 2 + (-3) = -1, truncated to 64-bit + assert_eq!(r2.magnitude.0[0], 1); + assert!(!r2.is_positive); + + // sub_trunc uses add_trunc internally + let x = S::<1>::from_u64(10); + let y = S::<1>::from_u64(7); + let r3 = x.sub_trunc::<1>(&y); + assert_eq!(r3.magnitude.0[0], 3); + assert!(r3.is_positive); + } + + #[test] + fn test_signed_truncated_mul_and_fmadd() { + use crate::biginteger::SignedBigInt as S; + // 128-bit x 64-bit -> truncated to 2 limbs (128-bit) + let a = S::<2>::from_u128(0x0000_0000_0000_0001_FFFF_FFFF_FFFF_FFFFu128); + let b = S::<1>::from_u64(0x2); + let p = a.mul_trunc::<1, 2>(&b); + // Expected low 128 bits of the product + let expected = num_bigint::BigUint::from(0x0000_0000_0000_0001_FFFF_FFFF_FFFF_FFFFu128) + * num_bigint::BigUint::from(2u64); + let got = num_bigint::BigUint::from(p.magnitude); + assert_eq!( + got, + expected & ((num_bigint::BigUint::from(1u8) << 128) - 1u8) + ); + assert!(p.is_positive); + + // fmadd into 1-limb accumulator (truncate to 64 bits) + let a = S::<1>::from_u64(0xFFFF_FFFF_FFFF_FFFF); + let b = S::<1>::from_u64(0x2); + let mut acc = S::<1>::from_u64(1); + a.fmadd_trunc::<1, 1>(&b, &mut acc); // acc = 1 + (a*b) mod 2^64 with sign + + // a*b = (2^64 - 1)*2 = 2^65 - 2 => low 64 = (2^64 - 2) + let expected_low = (u64::MAX).wrapping_sub(1); + assert_eq!(acc.magnitude.0[0], expected_low.wrapping_add(1)); + } + + #[test] + fn test_signed_truncated_add_sub_mixed() { + use crate::biginteger::SignedBigInt as S; + // Same sign, different widths, ensure carry handling and sign preservation + let a = S::<2>::from_u128(0x0000_0000_0000_0002_FFFF_FFFF_FFFF_FFFF); + let b = S::<1>::from_u64(0x0000_0000_0000_0002); + let r = a.add_trunc_mixed::<1, 2>(&b); // 128-bit result + let expected = num_bigint::BigUint::from(0x0000_0000_0000_0002_FFFF_FFFF_FFFF_FFFFu128) + + num_bigint::BigUint::from(2u64); + assert_eq!(num_bigint::BigUint::from(r.magnitude), expected); + assert!(r.is_positive); + + // Different signs, |a| > |b|: result sign should be sign(a) + let a2 = S::<2>::from_u128(5000); + let b2 = S::<1>::from((3000u64, false)); // -3000 + let r2 = a2.add_trunc_mixed::<1, 2>(&b2); + assert!(r2.is_positive); + assert_eq!(r2.magnitude.0[0], 2000); + + // Different signs, |b| > |a|: result sign should be sign(b) + let a3 = S::<2>::from_u128(1000); + let b3 = S::<1>::from((3000u64, false)); // -3000 + let r3 = a3.add_trunc_mixed::<1, 2>(&b3); + assert!(!r3.is_positive); + assert_eq!(r3.magnitude.0[0], 2000); + + // sub_trunc_mixed basic checks + let a4 = S::<2>::from_u128(10000); + let b4 = S::<1>::from_u64(9999); + let r4 = a4.sub_trunc_mixed::<1, 2>(&b4); + assert!(r4.is_positive); + assert_eq!(r4.magnitude.0[0], 1); + + let a5 = S::<2>::from_u128(1000); + let b5 = S::<1>::from_u64(5000); + let r5 = a5.sub_trunc_mixed::<1, 2>(&b5); + assert!(!r5.is_positive); + assert_eq!(r5.magnitude.0[0], 4000); + } + + #[test] + fn test_signed_fmadd_trunc_mixed_width_and_signs() { + use crate::biginteger::SignedBigInt as S; + // Case 1: same sign => pure addition of magnitudes + let a = S::<2>::from_u128(30000); + let b = S::<1>::from_u64(7); + let mut acc = S::<2>::from_u128(1000000); + a.fmadd_trunc::<1, 2>(&b, &mut acc); // acc += 210000 + assert!(acc.is_positive); + assert_eq!( + acc.magnitude.0[0] as u128 + ((acc.magnitude.0[1] as u128) << 64), + 1210000u128 + ); + + // Case 2: different sign, |prod| < |acc| => sign preserved + let a2 = S::<2>::from_u128(30000); + let b2 = S::<1>::from((7u64, false)); // -7 + let mut acc2 = S::<2>::from_u128(1000000); + a2.fmadd_trunc::<1, 2>(&b2, &mut acc2); // acc2 -= 210000 => 790000 + assert!(acc2.is_positive); + assert_eq!( + acc2.magnitude.0[0] as u128 + ((acc2.magnitude.0[1] as u128) << 64), + 790000u128 + ); + + // Case 3: different sign, |prod| > |acc| => sign flips to prod_sign + let a3 = S::<2>::from_u128(300); + let b3 = S::<1>::from((7u64, false)); // -7 => prod = -2100 + let mut acc3 = S::<2>::from_u128(1000); + a3.fmadd_trunc::<1, 2>(&b3, &mut acc3); // 1000 - 2100 = -1100 + assert!(!acc3.is_positive); + assert_eq!(acc3.magnitude.0[0], 1100); + } + + #[test] + fn test_prop_add_sub_trunc_mixed_random() { + use crate::biginteger::SignedBigInt as S; + use ark_std::rand::Rng; + let mut rng = ark_std::test_rng(); + + // Helper to validate a single pair for given consts + macro_rules! run_case { + ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ + for _ in 0..$iters { + let a_mag: crate::biginteger::BigInt<$n> = UniformRand::rand(&mut rng); + let b_mag: crate::biginteger::BigInt<$m> = UniformRand::rand(&mut rng); + let a_pos = (rng.gen::() & 1) == 1; + let b_pos = (rng.gen::() & 1) == 1; + let a = S::<$n>::from_bigint(a_mag, a_pos); + let b = S::<$m>::from_bigint(b_mag, b_pos); + + // add_trunc_mixed + let r_add = a.add_trunc_mixed::<$m, $p>(&b); + let a_bu = num_bigint::BigUint::from(a.magnitude); + let b_bu = num_bigint::BigUint::from(b.magnitude); + let (exp_add_mag, exp_add_pos) = if a_pos == b_pos { + (&a_bu + &b_bu, a_pos) + } else if a_bu >= b_bu { + (&a_bu - &b_bu, a_pos) + } else { + (&b_bu - &a_bu, b_pos) + }; + let modulus = num_bigint::BigUint::from(1u8) << (64 * $p); + let exp_add_mag_mod = exp_add_mag % &modulus; + assert_eq!(num_bigint::BigUint::from(r_add.magnitude), exp_add_mag_mod); + if exp_add_mag_mod != num_bigint::BigUint::from(0u8) { + assert_eq!(r_add.is_positive, exp_add_pos); + } + + // sub_trunc_mixed: a - b + let r_sub = a.sub_trunc_mixed::<$m, $p>(&b); + let (exp_sub_mag, exp_sub_pos) = if a_pos != b_pos { + (&a_bu + &b_bu, a_pos) + } else if a_bu >= b_bu { + (&a_bu - &b_bu, a_pos) + } else { + (&b_bu - &a_bu, !a_pos) + }; + let exp_sub_mag_mod = exp_sub_mag % &modulus; + assert_eq!(num_bigint::BigUint::from(r_sub.magnitude), exp_sub_mag_mod); + if exp_sub_mag_mod != num_bigint::BigUint::from(0u8) { + assert_eq!(r_sub.is_positive, exp_sub_pos); + } + } + }}; + } + + // Ensure P >= M to satisfy internal add_trunc constraints + run_case!(2, 3, 3, 200); + run_case!(3, 1, 2, 200); + run_case!(1, 2, 2, 200); + } + + #[test] + fn test_prop_fmadd_trunc_random() { + use crate::biginteger::SignedBigInt as S; + use ark_std::rand::Rng; + let mut rng = ark_std::test_rng(); + + macro_rules! run_case { + ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ + for _ in 0..$iters { + let a_mag: crate::biginteger::BigInt<$n> = UniformRand::rand(&mut rng); + let b_mag: crate::biginteger::BigInt<$m> = UniformRand::rand(&mut rng); + let acc_mag: crate::biginteger::BigInt<$p> = UniformRand::rand(&mut rng); + let a_pos = (rng.gen::() & 1) == 1; + let b_pos = (rng.gen::() & 1) == 1; + let acc_pos = (rng.gen::() & 1) == 1; + let a = S::<$n>::from_bigint(a_mag, a_pos); + let b = S::<$m>::from_bigint(b_mag, b_pos); + let mut acc = S::<$p>::from_bigint(acc_mag, acc_pos); + + // expected via BigUint with truncation of the product BEFORE combining signs + let a_bu = num_bigint::BigUint::from(a.magnitude); + let b_bu = num_bigint::BigUint::from(b.magnitude); + let acc_bu = num_bigint::BigUint::from(acc.magnitude); + let modulus = num_bigint::BigUint::from(1u8) << (64 * $p); + let prod_mod = (&a_bu * &b_bu) % &modulus; + let prod_pos = a_pos == b_pos; + let (exp_mag_mod, exp_pos) = if acc_pos == prod_pos { + ((acc_bu + &prod_mod) % &modulus, acc_pos) + } else if acc_bu >= prod_mod { + (acc_bu - &prod_mod, acc_pos) + } else { + (prod_mod - &acc_bu, prod_pos) + }; + + a.fmadd_trunc::<$m, $p>(&b, &mut acc); + + assert_eq!(num_bigint::BigUint::from(acc.magnitude), exp_mag_mod); + if exp_mag_mod != num_bigint::BigUint::from(0u8) { + assert_eq!(acc.is_positive, exp_pos); + } + } + }}; + } + + run_case!(2, 1, 2, 200); + run_case!(3, 2, 2, 200); + } + + // ============================== + // Tests for add_trunc and add_assign_trunc (unsigned BigInt) + // ============================== + + #[test] + fn test_add_trunc_correctness_random() { + use crate::biginteger::BigInt; + let mut rng = ark_std::test_rng(); + + macro_rules! run_case { + ($n:expr, $m:expr, $p:expr, $iters:expr) => {{ + for _ in 0..$iters { + let mut a: BigInt<$n> = UniformRand::rand(&mut rng); + let mut b: BigInt<$m> = UniformRand::rand(&mut rng); + + // Clamp low P limbs to avoid any carry across limb P-1 in add_trunc. + let mut i = 0; while i < core::cmp::min($p, $n) { a.0[i] >>= 1; i += 1; } + let mut j = 0; while j < core::cmp::min($p, $m) { b.0[j] >>= 1; j += 1; } + + // Build rhs respecting M <= P + let (res, b_p): (BigInt<$p>, BigInt<$p>) = if $m <= $p { + let mut b_p = BigInt::<$p>::zero(); + let mut k = 0; while k < $m { b_p.0[k] = b.0[k]; k += 1; } + (a.add_trunc::<$m, $p>(&b), b_p) + } else { + let mut bl = [0u64; $p]; + let mut t = 0; while t < $p { bl[t] = b.0[t]; t += 1; } + let b_trunc = BigInt::<$p>::new(bl); + (a.add_trunc::<$p, $p>(&b_trunc), b_trunc) + }; + + // Expected using low-P truncated operands (after clamping) + let mut a_p = BigInt::<$p>::zero(); + let mut u = 0; while u < core::cmp::min($p, $n) { a_p.0[u] = a.0[u]; u += 1; } + let a_bu = BigUint::from(a_p); + let b_bu = BigUint::from(b_p); + let modulus = BigUint::from(1u8) << (64 * $p); + let expected = (a_bu + b_bu) % &modulus; + assert_eq!(BigUint::from(res), expected); + } + }}; + } + + // Same-width + run_case!(4, 4, 4, 200); + // Mixed widths with P chosen to satisfy M <= P + run_case!(4, 2, 3, 200); + run_case!(2, 4, 4, 200); + } + + #[test] + fn test_add_assign_trunc_correctness_and_zeroing() { + use crate::biginteger::BigInt; + let mut rng = ark_std::test_rng(); + + // Case 1: N = 4, M = 4 (no truncation when P=N); compare against add_trunc and add_with_carry + for _ in 0..200 { + let mut a: BigInt<4> = UniformRand::rand(&mut rng); + let mut b: BigInt<4> = UniformRand::rand(&mut rng); + // Ensure no carry anywhere by masking all limbs to 62 bits + for i in 0..4 { a.0[i] &= (1u64 << 62) - 1; b.0[i] &= (1u64 << 62) - 1; } + let r_trunc = a.add_trunc::<4, 4>(&b); + let mut a2 = a; + a2.add_assign_trunc::<4>(&b); + assert_eq!(a2, r_trunc); + + // Regular add_with_carry should match lower 4 limbs modulo 2^(256) + let mut a3 = a; + a3.add_with_carry(&b); + assert_eq!(a3, r_trunc); + } + + // Case 2: N = 4, M = 4, P = 3 (truncation): low P limbs must match add_trunc + for _ in 0..200 { + let mut a: BigInt<4> = UniformRand::rand(&mut rng); + let mut b: BigInt<4> = UniformRand::rand(&mut rng); + for i in 0..4 { a.0[i] &= (1u64 << 62) - 1; b.0[i] &= (1u64 << 62) - 1; } + // Respect add_trunc contract by pre-truncating rhs to P limbs + let b3 = crate::biginteger::BigInt::<3>::new([b.0[0], b.0[1], b.0[2]]); + let r_trunc = a.add_trunc::<3, 3>(&b3); + let mut a2 = a; + a2.add_assign_trunc::<4>(&b); + // Low 3 limbs match result + for i in 0..3 { + assert_eq!(a2.0[i], r_trunc.0[i]); + } + } + + // Case 3: Mixed widths N = 4, M = 2, P = 3: low P limbs must match add_trunc + for _ in 0..200 { + let mut a: BigInt<4> = UniformRand::rand(&mut rng); + let b: BigInt<2> = UniformRand::rand(&mut rng); + a.0[3] >>= 1; + let r_trunc = a.add_trunc::<2, 3>(&b); + let mut a2 = a; + a2.add_assign_trunc::<2>(&b); + for i in 0..3 { + assert_eq!(a2.0[i], r_trunc.0[i]); + } + } + } + + #[test] + fn test_add_trunc_and_add_assign_trunc_overflow_edges() { + use crate::biginteger::BigInt; + + // Use values that don't overflow beyond N to respect debug contract + let mut a = BigInt::<4>::new([u64::MAX; 4]); + let mut b = BigInt::<4>::new([u64::MAX; 4]); + for i in 0..4 { a.0[i] >>= 1; b.0[i] >>= 1; } + // P = 4: result should match BigUint addition modulo 2^256 + // add_assign_trunc debug-overflow behavior cannot be reliably asserted in this + // environment without std; we validate the non-mutating truncated result above. + + // P = 3: validate truncated result against BigUint; pre-truncate rhs to 3 limbs + let b3 = crate::biginteger::BigInt::<3>::new([b.0[0], b.0[1], b.0[2]]); + let r3 = a.add_trunc::<3, 3>(&b3); + let a_bu = BigUint::from(a); + let b_bu = BigUint::from(b); + let modulus = BigUint::from(1u8) << (64 * 3); + let expected_r3 = (a_bu + b_bu) % &modulus; + assert_eq!(BigUint::from(r3), expected_r3); + } } diff --git a/ff/src/fields/models/fp/mod.rs b/ff/src/fields/models/fp/mod.rs index 342788f8d..993880be5 100644 --- a/ff/src/fields/models/fp/mod.rs +++ b/ff/src/fields/models/fp/mod.rs @@ -2,11 +2,11 @@ use crate::{ AdditiveGroup, BigInt, BigInteger, FftField, Field, LegendreSymbol, One, PrimeField, SqrtPrecomputation, Zero, }; -#[cfg(feature = "allocative")] use allocative::Allocative; use ark_serialize::{ - buffer_byte_size, CanonicalDeserialize, CanonicalDeserializeWithFlags, CanonicalSerialize, + CanonicalDeserialize, CanonicalDeserializeWithFlags, CanonicalSerialize, CanonicalSerializeWithFlags, Compress, EmptyFlags, Flags, SerializationError, Valid, Validate, + buffer_byte_size, }; use ark_std::{ cmp::*, @@ -99,13 +99,18 @@ pub trait FpConfig: Send + Sync + 'static + Sized { /// this range. fn from_bigint(other: BigInt) -> Option>; + /// Construct a field element from an integer in the range + /// `0..(Self::MODULUS - 1)`. Returns `None` if the integer is outside + /// this range (but do not do any Reductions) + fn from_bigint_unchecked(other: BigInt) -> Option>; + /// Convert a field element to an integer in the range `0..(Self::MODULUS - /// 1)`. fn into_bigint(other: Fp) -> BigInt; /// Creates a field element from a `u64`. /// Returns `None` if the `u64` is larger than or equal to the modulus. - fn from_u64(val: u64) -> Option>; + fn from_u64(val: u64) -> Option>; } /// Represents an element of the prime field F_p, where `p == P::MODULUS`. @@ -120,7 +125,6 @@ pub struct Fp, const N: usize>( #[doc(hidden)] pub PhantomData

, ); -#[cfg(feature = "allocative")] impl, const N: usize> Allocative for Fp { fn visit<'a, 'b: 'a>(&self, _visitor: &'a mut allocative::Visitor<'b>) {} } @@ -374,8 +378,13 @@ impl, const N: usize> PrimeField for Fp { } #[inline] - fn from_u64(r: u64) -> Option { - P::from_u64(r) + fn from_bigint_unchecked(r: BigInt) -> Option { + P::from_bigint_unchecked(r) + } + + #[inline] + fn from_u64(r: u64) -> Option { + P::from_u64::(r) } } @@ -435,11 +444,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i128) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -466,11 +471,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i64) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -487,11 +488,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i32) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -508,11 +505,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i16) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -529,11 +522,7 @@ impl, const N: usize> From for Fp { impl, const N: usize> From for Fp { fn from(other: i8) -> Self { let abs = Self::from(other.unsigned_abs()); - if other.is_positive() { - abs - } else { - -abs - } + if other.is_positive() { abs } else { -abs } } } @@ -556,7 +545,7 @@ impl, const N: usize> ark_std::rand::distributions::Distribution< u64::MAX >> shave_bits }; - if let Some(val) = tmp.0 .0.last_mut() { + if let Some(val) = tmp.0.0.last_mut() { *val &= mask } diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index eac675038..d2de50090 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -429,34 +429,43 @@ pub trait MontConfig: 'static + Sync + Send + Sized { } } - fn from_i128(r: i128) -> Option, N>> { + fn from_i128( + r: i128, + ) -> Option, N>> { // TODO: small table for signed values? - Some(Fp::new_unchecked(Self::R).mul_i128(r)) + Some(Fp::new_unchecked(Self::R).mul_i128::(r)) } - fn from_u128(r: u128) -> Option, N>> { + fn from_u128( + r: u128, + ) -> Option, N>> { if r < PRECOMP_TABLE_SIZE as u128 { Some(Self::SMALL_ELEMENT_MONTGOMERY_PRECOMP[r as usize]) } else { // Multiply R (one in Montgomery form) with the u128 - Some(Fp::new_unchecked(Self::R).mul_u128(r)) + Some(Fp::new_unchecked(Self::R).mul_u128::(r)) } } - fn from_i64(r: i64) -> Option, N>> { + fn from_i64(r: i64) -> Option, N>> { // TODO: small table for signed values? - Some(Fp::new_unchecked(Self::R).mul_i64(r)) + Some(Fp::new_unchecked(Self::R).mul_i64::(r)) } - fn from_u64(r: u64) -> Option, N>> { + fn from_u64(r: u64) -> Option, N>> { + debug_assert!(NPLUS1 == N + 1); if r < PRECOMP_TABLE_SIZE as u64 { Some(Self::SMALL_ELEMENT_MONTGOMERY_PRECOMP[r as usize]) } else { // Multiply R (one in Montgomery form) with the u64 - Some(Fp::new_unchecked(Self::R).mul_u64(r)) + Some(Fp::new_unchecked(Self::R).mul_u64::(r)) } } + fn from_bigint_unchecked(r: BigInt) -> Option, N>> { + Some(Fp::new_unchecked(r)) + } + fn from_bigint(r: BigInt) -> Option, N>> { let mut r = Fp::new_unchecked(r); if r.is_zero() { @@ -469,6 +478,55 @@ pub trait MontConfig: 'static + Sync + Send + Sized { } } + /// Construct from a smaller-width BigInt by zero-extending into N limbs. + /// Returns None if the resulting N-limb value is >= modulus. + #[inline] + fn from_bigint_mixed(r: BigInt) -> Fp, N> { + debug_assert!(M <= N, "from_bigint_mixed requires M <= N"); + let r_n = BigInt::::zero_extend_from::(&r); + Self::from_bigint(r_n).expect("from_bigint_mixed: value >= modulus") + } + + /// Construct from a signed big integer with M 64-bit limbs (sign-magnitude). + /// Returns None if |x| >= modulus. + #[inline] + fn from_signed_bigint( + x: crate::biginteger::SignedBigInt, + ) -> Fp, N> { + // if x.is_zero() { + // return Fp::zero(); + // } + let fe = Self::from_bigint_mixed::(x.magnitude); + if x.is_positive { + fe + } else { + -fe + } + } + + /// Construct from a signed big integer with high 32-bit tail and K low 64-bit limbs. + /// KPLUS1 must be K+1; the magnitude packs as [lo[0..K], hi32 as u64]. + /// Returns None if |x| >= modulus. + #[inline] + fn from_signed_bigint_hi32( + x: crate::biginteger::SignedBigIntHi32, + ) -> Fp, N> { + debug_assert!( + KPLUS1 == K + 1, + "from_signed_bigint_hi32 requires KPLUS1 = K + 1" + ); + // if x.is_zero() { + // return Fp::zero(); + // } + let mag = x.magnitude_as_bigint_nplus1::(); + let fe = Self::from_bigint_mixed::(mag); + if x.is_positive() { + fe + } else { + -fe + } + } + #[inline] #[cfg_attr(not(target_family = "wasm"), unroll_for_loops(12))] #[cfg_attr(target_family = "wasm", unroll_for_loops(6))] @@ -791,14 +849,18 @@ impl, const N: usize> FpConfig for MontBackend { T::from_bigint(r) } + fn from_bigint_unchecked(r: BigInt) -> Option> { + T::from_bigint_unchecked(r) + } + #[inline] #[allow(clippy::modulo_one)] fn into_bigint(a: Fp) -> BigInt { T::into_bigint(a) } - fn from_u64(r: u64) -> Option> { - T::from_u64(r) + fn from_u64(r: u64) -> Option> { + T::from_u64::(r) } } @@ -835,6 +897,134 @@ impl, const N: usize> Fp, N> { Self(element, PhantomData) } + /// Barrett reduce an `L`-limb BigInt to a field element (compute a mod p), generic over `L`. + /// Implementation folds from high to low using the existing N+1 Barrett kernel. + /// Precondition: L >= N. For performance, prefer small L close to N..N+3 when possible. + #[inline(always)] + pub fn from_barrett_reduce(unreduced: BigInt) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!(L >= N); + + // Start with acc = 0 (N-limb) + let mut acc = BigInt::::zero(); + // Fold each input limb from high to low: acc' = reduce( limb || acc ) via N+1 kernel + // Note: When L == 1, this reduces one N+1 formed by (low_limb, zeros) + let mut i = L; + while i > 0 { + i -= 1; + let c2 = nplus1_pair_low_to_bigint::((unreduced.0[i], acc.0)); + acc = barrett_reduce_nplus1_to_n::(c2); + } + Self::new_unchecked(acc) + } + + /// Montgomery reduction for arbitrary input width L >= 2N. + /// + /// Runs exactly N Montgomery steps (i = 0..N-1) over the L-limb buffer to compute + /// t' = (unreduced + q * MODULUS) / R, where R = b^N. The remaining (L - N) limbs + /// store t' in base-b. For L > 2N, we first fold the entire tail (indices N..L) down + /// to an N-limb accumulator using the N+1 Barrett reducer (interpreting the tail as a + /// base-b number), place that as the high N limbs to form a 2N-limb buffer, and then + /// perform the standard N-step Montgomery reduction on that 2N-limb buffer. + /// + /// Preconditions: + /// - L >= 2N (buffer must be large enough to perform N steps safely) + /// + /// Computes: unreduced * R^{-1} mod MODULUS. + #[inline(always)] + pub fn from_montgomery_reduce( + unreduced: BigInt, + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!(L >= N + N, "from_montgomery_reduce_var requires L >= 2N"); + + let mut limbs = unreduced; // reuse storage for the buffer + + // If L > 2N, first fold the extra high limbs down. + if L > 2 * N { + // Fold the tail (indices N..L) into an N-limb accumulator via Barrett. + let mut acc = BigInt::::zero(); + let mut i = L; + while i > N { + i -= 1; + let c2 = nplus1_pair_low_to_bigint::((limbs.0[i], acc.0)); + acc = barrett_reduce_nplus1_to_n::(c2); + } + + // Recompose buffer: [low_N | acc | zeros...] + limbs.0[N..(N + N)].copy_from_slice(&acc.0); + let mut j = 2 * N; + while j < L { + limbs.0[j] = 0; + j += 1; + } + } + + // Phase 2: run exactly N Montgomery steps on the 2N-limb buffer. + let carry = Self::montgomery_reduce_in_place::(&mut limbs); + + // Extract result and finalize. + let mut result_limbs = [0u64; N]; + result_limbs.copy_from_slice(&limbs.0[N..(N + N)]); + let mut result = Self::new_unchecked(BigInt::(result_limbs)); + if T::MODULUS_HAS_SPARE_BIT { + result.subtract_modulus(); + } else { + result.subtract_modulus_with_carry(carry != 0); + } + result + } + + /// Construct a new field element from a BigInt + /// which is in montgomery form and just needs to be reduced + /// via a barrett reduction. + #[inline(always)] + pub fn from_unchecked_nplus1(element: BigInt<{ NPLUS1 }>) -> Self { + debug_assert!(NPLUS1 == N + 1); + let r = barrett_reduce_nplus1_to_n::(element); + Self::new_unchecked(r) + } + + /// Construct a new field element from a BigInt + /// which is in montgomery form and just needs to be reduced + /// via a barrett reduction. + #[inline] + pub fn from_unchecked_nplus2( + element: BigInt, + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!(NPLUS2 == N + 2); + let c1 = BigInt::(element.0[1..NPLUS2].try_into().unwrap()); // c1 has N+1 limbs + let r1 = barrett_reduce_nplus1_to_n::(c1); // r1 = c1 mod p ([u64; N]) + // Round 2: Reduce c2 = c_lo[0] + r1 * r. + let c2 = nplus1_pair_low_to_bigint::((element.0[0], r1.0)); // c2 has N+1 limbs + let r2 = barrett_reduce_nplus1_to_n::(c2); // r2 = c2 mod p = c mod p ([u64; N]) + Self::new_unchecked(r2) + } + + /// Construct from a smaller-width BigInt by zero-extending into N limbs. + /// Panics if the resulting value is >= modulus. + #[inline] + pub fn from_bigint_mixed(r: BigInt) -> Self { + T::from_bigint_mixed::(r) + } + + /// Construct from a signed big integer (sign-magnitude with M limbs). + /// Panics if |x| >= modulus. + #[inline] + pub fn from_signed_bigint(x: crate::biginteger::SignedBigInt) -> Self { + T::from_signed_bigint::(x) + } + + /// Construct from a signed big integer with high 32-bit tail and K low 64-bit limbs. + /// KPLUS1 must be K+1. Panics if |x| >= modulus. + #[inline] + pub fn from_signed_bigint_hi32( + x: crate::biginteger::SignedBigIntHi32, + ) -> Self { + T::from_signed_bigint_hi32::(x) + } + const fn const_is_zero(&self) -> bool { self.0.const_is_zero() } @@ -849,9 +1039,8 @@ impl, const N: usize> Fp, N> { } /// Interpret a set of limbs (along with a sign) as a field element. - /// For *internal* use only; please use the `ark_ff::MontFp` macro instead - /// of this method - #[doc(hidden)] + /// The input limbs are interpreted little-endian. For public use; prefer + /// the `ark_ff::MontFp` macro for constant contexts. pub const fn from_sign_and_limbs(is_positive: bool, limbs: &[u64]) -> Self { let mut repr = BigInt::([0; N]); assert!(limbs.len() <= N); @@ -912,30 +1101,252 @@ impl, const N: usize> Fp, N> { } } + /// Multiply-assign by a RHS that is zero in its low N-2 limbs in Montgomery limbs, + /// and whose highest two limbs are provided by `hi` (low 64 bits map to limb N-2, + /// high 64 bits map to limb N-1). This is equivalent to K=2 non-zero high limbs. + #[inline] + pub const fn mul_assign_hi_u128(&mut self, hi: u128) { + // Construct a synthetic RHS by using the const CIOS with K=2, passing limbs directly. + // Leverage existing const CIOS specialized by K via a tiny adapter. + *self = self.const_cios_mul_rhs_hi2(hi as u64, (hi >> 64) as u64); + } + + /// Returns self * rhs_high_limbs, where RHS is zero in low N-2 limbs and has its top two + /// limbs provided by `hi` (low 64 -> limb N-2, high 64 -> limb N-1). Equivalent to K=2. + /// At the cost 2 extra words of storage uses no bit shift instructions to extract higher limbs + /// as in mul_hi_u128 + #[inline] + pub const fn mul_hi_bigint_u128(self, big_int_repre: [u64; 4]) -> Self { + self.const_cios_mul_rhs_hi2(big_int_repre[2], big_int_repre[3]) + } + /// Returns self * rhs_high_limbs, where RHS is zero in low N-2 limbs and has its top two + /// limbs provided by `hi` (low 64 -> limb N-2, high 64 -> limb N-1). Equivalent to K=2. + #[inline] + pub const fn mul_hi_u128(self, hi: u128) -> Self { + self.const_cios_mul_rhs_hi2(hi as u64, (hi >> 64) as u64) + } + + /// Const-capable CIOS fastpath specialized for exactly two high limbs (K=2), passed + /// directly as u64s instead of via an Fp operand. Assumes all lower limbs are zero. + #[inline] + #[allow(unused_assignments)] + const fn const_cios_mul_rhs_hi2(self, limb_n2: u64, limb_n1: u64) -> Self { + let mut r = [0u64; N]; + // i = N-2 + if N >= 2 { + let mut carry1 = 0u64; + r[0] = mac!(r[0], (self.0).0[0], limb_n2, &mut carry1); + let k = r[0].wrapping_mul(T::INV); + let mut carry2 = 0u64; + let _discard = mac!(r[0], k, T::MODULUS.0[0], &mut carry2); + crate::const_for!((j in 1..N) { + let new_rj = mac_with_carry!(r[j], (self.0).0[j], limb_n2, &mut carry1); + let new_rj_minus_1 = mac_with_carry!(new_rj, k, T::MODULUS.0[j], &mut carry2); + r[j] = new_rj; + r[j - 1] = new_rj_minus_1; + }); + r[N - 1] = carry1.wrapping_add(carry2); + } + // i = N-1 + { + let mut carry1 = 0u64; + r[0] = mac!(r[0], (self.0).0[0], limb_n1, &mut carry1); + let k = r[0].wrapping_mul(T::INV); + let mut carry2 = 0u64; + let _discard = mac!(r[0], k, T::MODULUS.0[0], &mut carry2); + crate::const_for!((j in 1..N) { + let new_rj = mac_with_carry!(r[j], (self.0).0[j], limb_n1, &mut carry1); + let new_rj_minus_1 = mac_with_carry!(new_rj, k, T::MODULUS.0[j], &mut carry2); + r[j] = new_rj; + r[j - 1] = new_rj_minus_1; + }); + r[N - 1] = carry1.wrapping_add(carry2); + } + let mut out = Self::new_unchecked(crate::BigInt::(r)); + out = out.const_subtract_modulus(); + out + } + + /// Multiply-assign by a BigInt that populates the highest K Montgomery limbs of the RHS, + /// with all lower limbs zero. Lower 64 bits of `rhs_hi.0[0]` map to limb N-K, etc. + #[inline] + pub const fn mul_assign_hi_bigint(&mut self, rhs_hi: &crate::BigInt) { + if T::CAN_USE_NO_CARRY_MUL_OPT { + *self = self.const_cios_mul_rhs_hi::(rhs_hi); + } else { + let (carry, res) = self.mul_without_cond_subtract_rhs_hi::(rhs_hi); + *self = res; + if T::MODULUS_HAS_SPARE_BIT { + self.const_subtract_modulus(); + } else { + self.const_subtract_modulus_with_carry(carry); + } + } + } + + /// Returns self * BigInt that populates the highest K Montgomery limbs of the RHS, + /// with all lower limbs zero. + #[inline] + pub const fn mul_hi_bigint(self, rhs_hi: &crate::BigInt) -> Self { + if T::CAN_USE_NO_CARRY_MUL_OPT { + self.const_cios_mul_rhs_hi::(rhs_hi) + } else { + let (carry, res) = self.mul_without_cond_subtract_rhs_hi::(rhs_hi); + if T::MODULUS_HAS_SPARE_BIT { + res.const_subtract_modulus() + } else { + res.const_subtract_modulus_with_carry(carry) + } + } + } + + /// Const-capable CIOS kernel for a RHS with exactly K non-zero HIGH limbs provided via BigInt. + #[inline] + #[allow(unused_assignments)] + const fn const_cios_mul_rhs_hi(self, rhs_hi: &crate::BigInt) -> Self { + let mut r = [0u64; N]; + // Iterate high columns: t indexes 0..K-1 mapping to global i = N-K+t + crate::const_for!((t in 0..K) { + let b_i = rhs_hi.0[t]; + let mut carry1 = 0u64; + r[0] = mac!(r[0], (self.0).0[0], b_i, &mut carry1); + let k = r[0].wrapping_mul(T::INV); + let mut carry2 = 0u64; + let _discard = mac!(r[0], k, T::MODULUS.0[0], &mut carry2); + crate::const_for!((j in 1..N) { + let new_rj = mac_with_carry!(r[j], (self.0).0[j], b_i, &mut carry1); + let new_rj_minus_1 = mac_with_carry!(new_rj, k, T::MODULUS.0[j], &mut carry2); + r[j] = new_rj; + r[j - 1] = new_rj_minus_1; + }); + r[N - 1] = carry1.wrapping_add(carry2); + }); + let mut out = Self::new_unchecked(crate::BigInt::(r)); + out = out.const_subtract_modulus(); + out + } + + /// Two-phase (schoolbook+REDC) multiply with a RHS whose highest K limbs are provided + /// in `rhs_hi` and lower limbs are zero. + #[inline] + const fn mul_without_cond_subtract_rhs_hi( + mut self, + rhs_hi: &crate::BigInt, + ) -> (bool, Self) { + let (mut lo, mut hi) = ([0u64; N], [0u64; N]); + // Schoolbook: only columns j in [N-K, N) + crate::const_for!((i in 0..N) { + let mut carry = 0u64; + crate::const_for!((t in 0..K) { + let j = N - K + t; + let b = rhs_hi.0[t]; + let k = i + j; + if k >= N { + hi[k - N] = mac_with_carry!(hi[k - N], (self.0).0[i], b, &mut carry); + } else { + lo[k] = mac_with_carry!(lo[k], (self.0).0[i], b, &mut carry); + } + }); + hi[i] = carry; + }); + // REDC: only i in [N-K, N) + let mut carry2 = 0u64; + crate::const_for!((i in 0..N) { + if i < N - K { /* skip */ } else { + let tmp = lo[i].wrapping_mul(T::INV); + let mut carry; + mac!(lo[i], tmp, T::MODULUS.0[0], &mut carry); + crate::const_for!((j in 1..N) { + let k = i + j; + if k >= N { + hi[k - N] = mac_with_carry!(hi[k - N], tmp, T::MODULUS.0[j], &mut carry); + } else { + lo[k] = mac_with_carry!(lo[k], tmp, T::MODULUS.0[j], &mut carry); + } + }); + hi[i] = adc!(hi[i], carry, &mut carry2); + } + }); + crate::const_for!((i in 0..N) { (self.0).0[i] = hi[i]; }); + (carry2 != 0, self) + } + + /// Montgomery reduction for 2N-limb inputs (standard Montgomery reduction) + /// Takes a 2N-limb BigInt that represents a product in "unreduced" form + /// and reduces it to N limbs in Montgomery form. + /// Keep this for now for backwards compatibility. + #[inline(always)] + pub fn montgomery_reduce_2n(input: BigInt) -> Self { + debug_assert!(TWON == 2 * N, "montgomery_reduce_2n requires TWON == 2N"); + let mut limbs = input; + let carry = Self::montgomery_reduce_in_place::(&mut limbs); + + // Extract the upper N limbs after exactly N REDC steps + let mut result_limbs = [0u64; N]; + result_limbs.copy_from_slice(&limbs.0[N..]); + + let mut result = Self::new_unchecked(BigInt::(result_limbs)); + if T::MODULUS_HAS_SPARE_BIT { + result.subtract_modulus(); + } else { + result.subtract_modulus_with_carry(carry != 0); + } + result + } + + /// Perform exactly N Montgomery reduction steps over the leading 2N limbs of `limbs`, + /// using the canonical REDC subroutine from `mul_without_cond_subtract`. + /// Treats `limbs` as `[lo[0..N), hi[0..N), extra...]` and updates only the high half. + /// Returns the final carry-out (0 or 1) from the top of the reduction. #[inline(always)] - pub fn mul_u64(self, other: u64) -> Self { - // Stage 1: Bignum Multiplication - // Compute c = self.0 * other. Result c has N+1 limbs. - let c = bigint_mul_by_u64(&self.0 .0, other); + #[unroll_for_loops(12)] + pub fn montgomery_reduce_in_place(limbs: &mut BigInt) -> u64 { + debug_assert!(L >= 2 * N, "montgomery_reduce_in_place requires L >= 2N"); + + // Work directly on the buffer to avoid copies: split into lo and hi views. + let (lo, rest) = limbs.0.split_at_mut(N); + let hi = &mut rest[..N]; + + // Montgomery reduction (canonical form) + let mut carry2 = 0u64; + for i in 0..N { + let tmp = lo[i].wrapping_mul(T::INV); + let mut carry; + mac!(lo[i], tmp, T::MODULUS.0[0], &mut carry); + for j in 1..N { + let k = i + j; + if k >= N { + let idx = k - N; + hi[idx] = mac_with_carry!(hi[idx], tmp, T::MODULUS.0[j], &mut carry); + } else { + lo[k] = mac_with_carry!(lo[k], tmp, T::MODULUS.0[j], &mut carry); + } + } + hi[i] = adc!(hi[i], carry, &mut carry2); + } - // Stage 2: Barrett Reduction - let r = barrett_reduce_nplus1_to_n::(c); + carry2 + } - // Use the final r_n_limbs which holds the correct N-limb result - Self::new_unchecked(BigInt::(r)) + #[inline(always)] + pub fn mul_u64(self, other: u64) -> Self { + debug_assert!(NPLUS1 == N + 1); + let c: BigInt = BigInt::mul_u64_w_carry(&self.0, other); // multiply + Self::from_unchecked_nplus1(c) // reduce and return the result } /// Multiply by an i64. Invokes `mul_u64` if the input is positive, /// otherwise negates the result of `mul_u64` of the absolute value. #[inline(always)] - pub fn mul_i64(self, other: i64) -> Self { + pub fn mul_i64(self, other: i64) -> Self { + debug_assert!(NPLUS1 == N + 1); if other >= 0 { // Multiply by the positive value directly - self.mul_u64(other as u64) + self.mul_u64::(other as u64) } else { // Multiply by the absolute value and then negate the result // (-other) cannot overflow since other is not i64::MIN - -(self.mul_u64((-other) as u64)) + -(self.mul_u64::((-other) as u64)) } } @@ -943,15 +1354,15 @@ impl, const N: usize> Fp, N> { /// Uses optimized mul_u64 if the absolute value of the input fits within u64, /// otherwise falls back to the two-step Barrett reduction (`mul_u128_aux`). #[inline(always)] - pub fn mul_i128(self, other: i128) -> Self { + pub fn mul_i128(self, other: i128) -> Self { if other >= 0 { let other_u128 = other as u128; if other_u128 <= u64::MAX as u128 { // Positive value fits in u64 - self.mul_u64(other_u128 as u64) + self.mul_u64::(other_u128 as u64) } else { // Positive value requires u128 path - self.mul_u128_aux(other_u128) + self.mul_u128_aux::(other_u128) } } else { // Negative value, compute absolute value as u128 @@ -959,10 +1370,10 @@ impl, const N: usize> Fp, N> { let abs_other = (-other) as u128; if abs_other <= u64::MAX as u128 { // Absolute value fits in u64 - -(self.mul_u64(abs_other as u64)) + -(self.mul_u64::(abs_other as u64)) } else { // Absolute value requires u128 path - -(self.mul_u128_aux(abs_other)) + -(self.mul_u128_aux::(abs_other)) } } } @@ -971,37 +1382,20 @@ impl, const N: usize> Fp, N> { /// Uses optimized mul_u64 if the input fits within u64, /// otherwise falls back to standard multiplication. #[inline(always)] - pub fn mul_u128(self, other: u128) -> Self { + pub fn mul_u128(self, other: u128) -> Self { if other >> 64 == 0 { - self.mul_u64(other as u64) + self.mul_u64::(other as u64) } else { - self.mul_u128_aux(other) + self.mul_u128_aux::(other) } } /// Fallback option for mul_u128: if the input does not fit within u64, /// we perform a more expensive procedure with 2 rounds of Barrett reduction. #[inline(always)] - pub fn mul_u128_aux(self, other: u128) -> Self { - // Stage 1: Bignum Multiplication - // Compute c = self.0 * other. Result c has N+2 limbs: (c_lo: [u64; 2], c_hi: [u64; N]) - let (c_lo, c_hi) = bigint_mul_by_u128(&self.0, other); - - // Stage 2: Two rounds of Barrett Reduction using the modular subroutine - - // Round 1: Reduce the top N+1 limbs c1 = floor(c / r). - // c1 has low limb c_lo[1] and high N limbs c_hi. - // Input to barrett_reduce is (u64, [u64; N]). - let c1 = (c_lo[1], c_hi); - let r1 = barrett_reduce_nplus1_to_n::(c1); // r1 = c1 mod p ([u64; N]) - - // Round 2: Reduce c2 = c_lo[0] + r1 * r. - // c2 has low limb c_lo[0] and high N limbs r1.0. - // Input to barrett_reduce is (u64, [u64; N]). - let c2 = (c_lo[0], r1); // Pass r1 directly as the high N limbs array - let r2 = barrett_reduce_nplus1_to_n::(c2); // r2 = c2 mod p = c mod p ([u64; N]) - - Self::new_unchecked(BigInt::(r2)) + pub fn mul_u128_aux(self, other: u128) -> Self { + let c = BigInt::mul_u128_w_carry::(&self.0, other); // mul + Self::from_unchecked_nplus2::(c) // Reduce and return the result } const fn const_is_valid(&self) -> bool { @@ -1034,332 +1428,140 @@ impl, const N: usize> Fp, N> { const fn sub_with_borrow(a: &BigInt, b: &BigInt) -> BigInt { a.const_sub_with_borrow(b).0 } -} - -/// Multiply a N-limb big integer with a u64, producing a N+1 limb result, -/// represented as a tuple of a u64 low limb and an array of N high limbs. -#[unroll_for_loops(8)] -#[inline(always)] -fn bigint_mul_by_u64(val: &[u64; N], other: u64) -> (u64, [u64; N]) { - let mut result_hi = [0u64; N]; - let mut carry: u64; // Start with carry = 0 - - // Calculate the full 128-bit product of the lowest limb - let prod_lo: u128 = (val[0] as u128) * (other as u128); - let result_lo = prod_lo as u64; // Lowest limb of the result - carry = (prod_lo >> 64) as u64; // Carry into the high part - - // Iterate through the remaining limbs of the input BigInt - for i in 1..N { - // Calculate the full 128-bit product of the current limb and the u64 multiplier - let prod_hi: u128 = (val[i] as u128) * (other as u128) + (carry as u128); - result_hi[i - 1] = prod_hi as u64; // Store in result_hi[0] to result_hi[N-2] - carry = (prod_hi >> 64) as u64; - } - - // After the loop, the final carry is the highest limb (N-th limb of the high part) - result_hi[N - 1] = carry; - - (result_lo, result_hi) -} - -/// Multiply a N+1 limb big integer (represented as low N limbs and high u64) with a u64, -/// producing a N+1 limb result in the same format. -/// Also returns a boolean indicating if there was a carry out (overflow). -#[unroll_for_loops(8)] -#[inline(always)] -fn bigint_plus_one_mul_by_u64( - val: ([u64; N], u64), - other: u64, -) -> (([u64; N], u64), bool) { - let (val_lo_n, val_hi) = val; - let mut result_lo_n = [0u64; N]; - let mut carry: u128 = 0; // Use u128 for intermediate carry - // Stage 1: Multiply the low N limbs - for i in 0..N { - let prod: u128 = (val_lo_n[i] as u128) * (other as u128) + carry; - result_lo_n[i] = prod as u64; - carry = prod >> 64; - } - - // Stage 2: Multiply the high limb - let prod_hi: u128 = (val_hi as u128) * (other as u128) + carry; - let result_hi = prod_hi as u64; - let final_carry = prod_hi >> 64; - - // Final carry indicates overflow - let overflow = final_carry != 0; - ((result_lo_n, result_hi), overflow) -} + /// Helper function: multiply a BigInt by u64 and accumulate into BigInt + /// This avoids creating temporary BigInt objects. + #[inline(always)] + #[unroll_for_loops(8)] + fn mul_u64_accumulate(acc: &mut BigInt, a: &BigInt, b: u64) { + debug_assert!(NPLUS1 == N + 1); + use crate::biginteger::arithmetic as fa; -/// Subtract two N+1 limb big integers (represented as low N limbs and high u64). -/// Returns the N+1 limb result and a boolean indicating if a borrow occurred. -#[unroll_for_loops(8)] -#[inline(always)] -fn sub_bigint_plus_one( - a: ([u64; N], u64), - b: ([u64; N], u64), -) -> (([u64; N], u64), bool) { - let (mut a_lo_n, mut a_hi) = a; - let (b_lo_n, b_hi) = b; - let mut borrow: u64 = 0; // sbb uses u64 for borrow + let mut carry = 0u64; + for i in 0..N { + acc.0[i] = fa::mac_with_carry(acc.0[i], a.0[i], b, &mut carry); + } - // Subtract low N limbs - for i in 0..N { - // Updates a_lo_n[i] in place and returns the new borrow - borrow = fa::sbb(&mut a_lo_n[i], b_lo_n[i], borrow); + // Add final carry to the high limb + let final_carry = fa::adc(&mut acc.0[N], carry, 0); + debug_assert!(final_carry == 0, "overflow in mul_u64_accumulate"); } - // Subtract high u64 limb - borrow = fa::sbb(&mut a_hi, b_hi, borrow); - - // Final borrow indicates if the result is negative (b > a) - let final_borrow_occurred = borrow != 0; - - ((a_lo_n, a_hi), final_borrow_occurred) -} + /// Compute a linear combination of field elements with u64 coefficients. + /// Performs unreduced accumulation in BigInt, then one final reduction. + /// This is more efficient than individual multiplications and additions. + #[inline(always)] + pub fn linear_combination_u64(pairs: &[(Self, u64)]) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!( + !pairs.is_empty(), + "linear_combination_u64 requires at least one pair" + ); -/// Subtract two N+1 limb big integers where `a` is (u64, [u64; N]) and `b` is ([u64; N], u64). -/// Returns the N+1 limb result as ([u64; N], u64) and a boolean indicating if a borrow occurred. -#[unroll_for_loops(8)] -#[inline(always)] -fn sub_bigint_plus_one_prime( - a: (u64, [u64; N]), // Format: (low_limb, high_n_limbs) - b: ([u64; N], u64), // Format: (low_n_limbs, high_limb) -) -> (([u64; N], u64), bool) { - let (a_lo, a_hi_n) = a; - let (b_lo_n, b_hi) = b; - let mut result_lo_n = [0u64; N]; - let mut borrow: u64 = 0; + // Start with first term + let mut acc = pairs[0].0 .0.mul_u64_w_carry::(pairs[0].1); - // Subtract low limb: result_lo_n[0] = a_lo - b_lo_n[0] - borrow (initial borrow = 0) - result_lo_n[0] = a_lo; // Initialize result limb with a_lo - borrow = fa::sbb(&mut result_lo_n[0], b_lo_n[0], borrow); // result_lo_n[0] -= b_lo_n[0] + borrow + // Accumulate remaining terms using multiply-accumulate to avoid temporaries + for (a, b) in &pairs[1..] { + Self::mul_u64_accumulate::(&mut acc, &a.0, *b); + } - // Subtract middle limbs (if N > 1): result_lo_n[i] = a_hi_n[i-1] - b_lo_n[i] - borrow - // This loop covers indices i = 1 to N-1. - // It uses a_hi_n limbs from index 0 to N-2. - for i in 1..N { - result_lo_n[i] = a_hi_n[i - 1]; // Initialize result limb with corresponding a limb - borrow = fa::sbb(&mut result_lo_n[i], b_lo_n[i], borrow); // result_lo_n[i] -= b_lo_n[i] + borrow + Self::from_unchecked_nplus1::(acc) } - // Subtract high limb: result_hi = a_hi_n[N-1] - b_hi - borrow - let mut result_hi = a_hi_n[N - 1]; // Initialize result limb with last a limb - borrow = fa::sbb(&mut result_hi, b_hi, borrow); // result_hi -= b_hi + borrow + /// Compute a linear combination with separate positive and negative terms. + /// Each term is multiplied by a u64 coefficient, then positive and negative + /// sums are computed separately and subtracted. One final reduction is performed. + #[inline(always)] + pub fn linear_combination_i64( + pos: &[(Self, u64)], + neg: &[(Self, u64)], + ) -> Self { + debug_assert!(NPLUS1 == N + 1); + debug_assert!( + !pos.is_empty(), + "linear_combination_i64 requires at least one positive term" + ); + debug_assert!( + !neg.is_empty(), + "linear_combination_i64 requires at least one negative term" + ); - let final_borrow_occurred = borrow != 0; + // Compute unreduced positive sum + let mut pos_lc = pos[0].0 .0.mul_u64_w_carry::(pos[0].1); + for (a, b) in &pos[1..] { + Self::mul_u64_accumulate::(&mut pos_lc, &a.0, *b); + } - ((result_lo_n, result_hi), final_borrow_occurred) -} + // Compute unreduced negative sum + let mut neg_lc = neg[0].0 .0.mul_u64_w_carry::(neg[0].1); + for (a, b) in &neg[1..] { + Self::mul_u64_accumulate::(&mut neg_lc, &a.0, *b); + } -/// Compare two N+1 limb big integers (represented as low N limbs and high u64). -#[unroll_for_loops(8)] -#[inline(always)] -fn compare_bigint_plus_one( - a: ([u64; N], u64), - b: ([u64; N], u64), -) -> core::cmp::Ordering { - // Compare high u64 limb first - if a.1 > b.1 { - return core::cmp::Ordering::Greater; - } else if a.1 < b.1 { - return core::cmp::Ordering::Less; - } - // High limbs are equal, compare the low N limbs from most significant (N-1) down to 0 - for i in (0..N).rev() { - if a.0[i] > b.0[i] { - return core::cmp::Ordering::Greater; - } else if a.0[i] < b.0[i] { - return core::cmp::Ordering::Less; + // Subtract and reduce once + match pos_lc.cmp(&neg_lc) { + core::cmp::Ordering::Greater => { + let borrow = pos_lc.sub_with_borrow(&neg_lc); + debug_assert!(!borrow, "borrow in linear_combination_i64"); + Self::from_unchecked_nplus1::(pos_lc) + }, + core::cmp::Ordering::Less => { + let borrow = neg_lc.sub_with_borrow(&pos_lc); + debug_assert!(!borrow, "borrow in linear_combination_i64"); + -Self::from_unchecked_nplus1::(neg_lc) + }, + core::cmp::Ordering::Equal => Self::zero(), } } - // All limbs are equal - return core::cmp::Ordering::Equal; } -/// Multiply a N-limb big integer with a u128, producing a N+2 limb result, -/// represented as a tuple of an array of 2 low limbs and an array of N high limbs. -#[unroll_for_loops(8)] #[inline(always)] -fn bigint_mul_by_u128(val: &BigInt, other: u128) -> ([u64; 2], [u64; N]) { - let other_lo = other as u64; - let other_hi = (other >> 64) as u64; - - // Compute partial products - // p1 = val * other_lo -> (N+1) limbs: (p1_lo: u64, p1_hi: [u64; N]) - let (p1_lo, p1_hi) = bigint_mul_by_u64(&val.0, other_lo); - // p2 = val * other_hi -> (N+1) limbs: (p2_lo: u64, p2_hi: [u64; N]) - let (p2_lo, p2_hi) = bigint_mul_by_u64(&val.0, other_hi); - - // Calculate the final result r = p1 + (p2 << 64) limb by limb. - // p1 : [p1_lo, p1_hi[0], ..., p1_hi[N-1]] - // p2 << 64 : [0, p2_lo, p2_hi[0], ..., p2_hi[N-1]] - // Sum (r) : [r_lo[0], r_lo[1], r_hi[0], ..., r_hi[N-1]] (N+2 limbs) - - let mut r_lo = [0u64; 2]; - let mut r_hi = [0u64; N]; - let mut carry: u64 = 0; - - // r_lo[0] = p1_lo + 0 + carry (carry is initially 0) - r_lo[0] = p1_lo; - // carry = 0; // Initial carry is 0 - - // Calculate r_lo[1] = p1_hi[0] + p2_lo + carry (limb 1) - r_lo[1] = p1_hi[0]; // Initialize with p1 limb - carry = fa::adc(&mut r_lo[1], p2_lo, carry); // Add p2 limb and carry - - // Calculate r_hi[0] to r_hi[N-1] (limbs 2 to N+1) - for i in 0..N { - let p1_limb = if i + 1 < N { p1_hi[i + 1] } else { 0 }; // Limb p1[i+2] - let p2_limb = p2_hi[i]; // Limb p2[i+1] - - // r_hi[i] = p1_limb + p2_limb + carry - r_hi[i] = p1_limb; // Initialize with p1 limb - carry = fa::adc(&mut r_hi[i], p2_limb, carry); // Add p2 limb and carry - } - - // The final carry MUST be zero for the result to fit in N+2 limbs. - debug_assert!(carry == 0, "Overflow in bigint_mul_by_u128"); - - (r_lo, r_hi) +fn nplus1_pair_high_to_bigint( + r_tmp: ([u64; N], u64), +) -> BigInt { + debug_assert!(NPLUS1 == N + 1); + let mut limbs = [0u64; NPLUS1]; + limbs[..N].copy_from_slice(&r_tmp.0); + limbs[N] = r_tmp.1; + BigInt::(limbs) } -/// Old conditional subtraction logic for Barrett reduction -/// Takes an N+1 limb intermediate result `r_tmp` and returns the N-limb final result. #[inline(always)] -fn _barrett_cond_subtract_old, const N: usize>( - r_tmp: ([u64; N], u64), -) -> [u64; N] { - let mut current_r = r_tmp; // Working variable in ([u64; N], u64) format - - if T::MODULUS_NUM_SPARE_BITS >= 1 { - // Case S >= 1 - if T::MODULUS_NUM_SPARE_BITS >= 2 { - // Optimization for S >= 2: r_tmp = c - m*2p < 4p - // High limb of current_r should initially be 0 - debug_assert!( - current_r.1 == 0, - "High limb of r_tmp should be zero when S >= 2" - ); - - // Conditional subtraction 1 (if r >= 2P) using N+1 compare/sub - if compare_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1) - != core::cmp::Ordering::Less - { - let (sub_res, sub_borrow) = - sub_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1); - debug_assert!(!sub_borrow, "Borrow should not occur subtracting 2P (S>=2)"); - current_r = sub_res; - } - // Conditional subtraction 2 (if r >= P) using N+1 compare/sub - if compare_bigint_plus_one(current_r, T::MODULUS_NPLUS1) != core::cmp::Ordering::Less { - let (sub_res, sub_borrow) = sub_bigint_plus_one(current_r, T::MODULUS_NPLUS1); - debug_assert!(!sub_borrow, "Borrow should not occur subtracting P (S>=2)"); - current_r = sub_res; - } - // Result must fit in N limbs now - debug_assert!( - current_r.1 == 0, - "High limb must be zero after final subtraction (S>=2)" - ); - current_r.0 // Return N low limbs - } else { - // Case S == 1: r_tmp = c - m*2p might temporarily exceed N limbs initially - - // Conditional subtraction 1: if r >= 2p - if compare_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1) - != core::cmp::Ordering::Less - { - // Subtract 2P using N+1 limb subtraction - let (sub_res, sub_borrow) = - sub_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1); - // After subtracting 2P, the result MUST fit in N limbs - debug_assert!( - sub_res.1 == 0, - "High limb must be 0 after 2P subtraction when S=1" - ); - debug_assert!( - !sub_borrow, - "Borrow should not occur when subtracting 2P for S=1" - ); - current_r = sub_res; // Update current_r (now guaranteed to have high limb 0) - } - // At this point, current_r < 2P and its high limb is 0. - debug_assert!( - current_r.1 == 0, - "High limb should be 0 before N-limb subtraction (S=1)" - ); - let mut r_n_limbs = BigInt::(current_r.0); // Extract N low limbs - - // Conditional subtraction 2 (if r >= P) using N limbs directly - if r_n_limbs >= T::MODULUS { - // Compare N limbs - r_n_limbs.sub_with_borrow(&T::MODULUS); // Subtract N limbs. Ignore borrow. - } - r_n_limbs.0 // Return N low limbs - } - } else { - // Case S == 0: Use (N+1)-limb helpers throughout - - // Conditional subtraction 1: if r >= 2p - if compare_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1) - != core::cmp::Ordering::Less - { - let (sub_res, sub_borrow) = sub_bigint_plus_one(current_r, T::MODULUS_TIMES_2_NPLUS1); - debug_assert!(!sub_borrow, "Borrow should not occur subtracting 2P (S=0)"); - current_r = sub_res; - } - // Now current_r = c mod 2p, represented as ([u64; N], u64) - - // Conditional subtraction 2: if r >= p - if compare_bigint_plus_one(current_r, T::MODULUS_NPLUS1) != core::cmp::Ordering::Less { - // if r >= p - let (sub_res, sub_borrow) = sub_bigint_plus_one(current_r, T::MODULUS_NPLUS1); - // Result MUST fit in N limbs now - debug_assert!( - sub_res.1 == 0, - "High limb must be zero after subtracting P (S=0)" - ); - debug_assert!( - !sub_borrow, - "Borrow should not occur when subtracting P for S=0" - ); - current_r = sub_res; - } - // At this point, current_r < P and its high limb must be 0. - debug_assert!( - current_r.1 == 0, - "High limb must be zero after final subtraction (S=0)" - ); - current_r.0 // Return N low limbs - } +fn nplus1_pair_low_to_bigint( + r_tmp: (u64, [u64; N]), +) -> BigInt { + debug_assert!(NPLUS1 == N + 1); + let mut limbs = [0u64; NPLUS1]; + limbs[0] = r_tmp.0; + limbs[1..NPLUS1].copy_from_slice(&r_tmp.1); + BigInt::(limbs) } /// Conditional subtraction logic for Barrett reduction, trading an extra comparison for a conditional subtraction. /// Includes optimizations based on MODULUS_NUM_SPARE_BITS. -/// Takes an N+1 limb intermediate result `r_tmp` (in `([u64; N], u64)` format) and returns the N-limb final result. +/// Takes an N+1 limb intermediate result `r_tmp` and returns the N-limb final result. #[unroll_for_loops(4)] #[inline(always)] -fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64)) -> [u64; N] { - let final_limbs: [u64; N]; - let r_n = BigInt::(r_tmp.0); // N low limbs as BigInt - let r_hi = r_tmp.1; // High limb - +fn barrett_cond_subtract, const N: usize, const NPLUS1: usize>( + r_tmp: BigInt, +) -> BigInt { + debug_assert!(NPLUS1 == N + 1); // Compare with 2p let compare_2p = if T::MODULUS_NUM_SPARE_BITS == 0 { // S = 0: Must use N+1 compare - compare_bigint_plus_one(r_tmp, T::MODULUS_TIMES_2_NPLUS1) + r_tmp.cmp(&nplus1_pair_high_to_bigint::( + T::MODULUS_TIMES_2_NPLUS1, + )) } else { // S >= 1: 2p fits N limbs (mostly). Compare N limbs. // We assume r_tmp's high limb is 0 here if S >= 1. debug_assert!( - r_hi == 0, + r_tmp.0[N] == 0, "High limb expected to be 0 if S >= 1 before 2p comparison" ); let p2_n = BigInt::(T::MODULUS_TIMES_2_NPLUS1.0); - r_n.cmp(&p2_n) + BigInt::(r_tmp.0[0..N].try_into().unwrap()).cmp(&p2_n) // Compare N limbs }; if compare_2p != core::cmp::Ordering::Less { @@ -1367,15 +1569,17 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 // Compare with 3p let compare_3p = if T::MODULUS_NUM_SPARE_BITS < 2 { // S < 2 (S=0 or S=1): Need N+1 compare - compare_bigint_plus_one(r_tmp, T::MODULUS_TIMES_3_NPLUS1) + r_tmp.cmp(&nplus1_pair_high_to_bigint::( + T::MODULUS_TIMES_3_NPLUS1, + )) } else { // S >= 2: 3p fits N limbs. Compare N limbs. debug_assert!( - r_hi == 0, + r_tmp.0[N] == 0, "High limb expected to be 0 if S >= 2 before 3p comparison" ); let p3_n = BigInt::(T::MODULUS_TIMES_3_NPLUS1.0); - r_n.cmp(&p3_n) + BigInt::(r_tmp.0[0..N].try_into().unwrap()).cmp(&p3_n) // Compare N limbs }; if compare_3p != core::cmp::Ordering::Less { @@ -1384,23 +1588,27 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 if T::MODULUS_NUM_SPARE_BITS >= 2 { // S >= 2: 3p fits N limbs. Use N-limb sub. debug_assert!( - r_hi == 0, + r_tmp.0[N] == 0, "High limb expected to be 0 if S >= 2 for 3p subtraction" ); let p3_n = BigInt::(T::MODULUS_TIMES_3_NPLUS1.0); + let r_n = BigInt::(r_tmp.0[0..N].try_into().unwrap()); + // Subtract 3p from r_n + // Use const_sub_with_borrow to avoid borrow checking issues + // This is safe because we know r_n >= 3p from the comparison above. let (res_n, borrow_n) = r_n.const_sub_with_borrow(&p3_n); debug_assert!(!borrow_n, "Borrow should not occur subtracting 3p (S>=2)"); - final_limbs = res_n.0; + return res_n; // Return the N-limb result directly } else { // S < 2: Use N+1 limb sub. - let ((res_n_limbs, res_hi_limb), borrow_n1) = - sub_bigint_plus_one(r_tmp, T::MODULUS_TIMES_3_NPLUS1); - debug_assert!(!borrow_n1, "Borrow should not occur subtracting 3p (S<2)"); + let p3_n1 = nplus1_pair_high_to_bigint::(T::MODULUS_TIMES_3_NPLUS1); + let (res_n1, borrow) = r_tmp.const_sub_with_borrow(&p3_n1); + debug_assert!(!borrow, "Borrow should not occur subtracting 3p (S<2)"); debug_assert!( - res_hi_limb == 0, + res_n1.0[N] == 0, "High limb must be zero after subtracting 3p" ); - final_limbs = res_n_limbs; + return BigInt::(res_n1.0[0..N].try_into().unwrap()); } } else { // 2p <= r_tmp < 3p @@ -1408,23 +1616,24 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 if T::MODULUS_NUM_SPARE_BITS >= 1 { // S >= 1: 2p fits N limbs (mostly). Use N-limb sub. debug_assert!( - r_hi == 0, + r_tmp.0[N] == 0, "High limb expected to be 0 if S >= 1 for 2p subtraction" ); let p2_n = BigInt::(T::MODULUS_TIMES_2_NPLUS1.0); + let r_n = BigInt::(r_tmp.0[0..N].try_into().unwrap()); let (res_n, borrow_n) = r_n.const_sub_with_borrow(&p2_n); debug_assert!(!borrow_n, "Borrow should not occur subtracting 2p (S>=1)"); - final_limbs = res_n.0; + return res_n; // Return the N-limb result directly } else { // S == 0: Use N+1 limb sub. - let ((res_n_limbs, res_hi_limb), borrow_n1) = - sub_bigint_plus_one(r_tmp, T::MODULUS_TIMES_2_NPLUS1); - debug_assert!(!borrow_n1, "Borrow should not occur subtracting 2p (S=0)"); + let p2_n1 = nplus1_pair_high_to_bigint::(T::MODULUS_TIMES_2_NPLUS1); + let (res_n1, borrow) = r_tmp.const_sub_with_borrow(&p2_n1); + debug_assert!(!borrow, "Borrow should not occur subtracting 2p (S=0)"); debug_assert!( - res_hi_limb == 0, + res_n1.0[N] == 0, "High limb must be zero after subtracting 2p" ); - final_limbs = res_n_limbs; + return BigInt::(res_n1.0[0..N].try_into().unwrap()); } } } else { @@ -1433,11 +1642,15 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 let compare_p = if T::MODULUS_NUM_SPARE_BITS >= 1 { // S >= 1: Use N-limb compare. // Assume r_tmp high limb is 0 because r_tmp < 2p and 2p fits N limbs (mostly) if S >= 1 - debug_assert!(r_hi == 0, "High limb expected to be 0 if S >= 1 and r < 2p"); - r_n.cmp(&T::MODULUS) // Compare N limbs + debug_assert!( + r_tmp.0[N] == 0, + "High limb expected to be 0 if S >= 1 before p comparison" + ); + let p_n = BigInt::(T::MODULUS.0); + BigInt::(r_tmp.0[0..N].try_into().unwrap()).cmp(&p_n) // Compare N limbs } else { // S == 0: Use N+1 limb compare. - compare_bigint_plus_one(r_tmp, T::MODULUS_NPLUS1) + r_tmp.cmp(&nplus1_pair_high_to_bigint::(T::MODULUS_NPLUS1)) }; if compare_p != core::cmp::Ordering::Less { @@ -1445,33 +1658,64 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 // Subtract p if T::MODULUS_NUM_SPARE_BITS >= 1 { // S >= 1: Use N-limb sub. - debug_assert!( - r_hi == 0, - "High limb expected to be 0 if S >= 1 for p subtraction" - ); - let (res_n, borrow_n) = r_n.const_sub_with_borrow(&T::MODULUS); + let p_n = BigInt::(T::MODULUS.0); + let r_n = BigInt::(r_tmp.0[0..N].try_into().unwrap()); + let (res_n, borrow_n) = r_n.const_sub_with_borrow(&p_n); debug_assert!(!borrow_n, "Borrow should not occur subtracting p (S>=1)"); - final_limbs = res_n.0; + return res_n; // Return the N-limb result directly } else { // S == 0: Use N+1 limb sub. - let ((res_n_limbs, res_hi_limb), borrow_n1) = - sub_bigint_plus_one(r_tmp, T::MODULUS_NPLUS1); - debug_assert!(!borrow_n1, "Borrow should not occur subtracting p (S=0)"); + let p_n1 = nplus1_pair_high_to_bigint::(T::MODULUS_NPLUS1); + let (res_n1, borrow) = r_tmp.const_sub_with_borrow(&p_n1); + debug_assert!(!borrow, "Borrow should not occur subtracting p (S=0)"); debug_assert!( - res_hi_limb == 0, + res_n1.0[N] == 0, "High limb must be zero after subtracting p" ); - final_limbs = res_n_limbs; + return BigInt::(res_n1.0[0..N].try_into().unwrap()); } } else { // r_tmp < p // Subtract 0 (No-op) // Result must already fit in N limbs. Assert high limb is 0. - debug_assert!(r_hi == 0, "High limb must be zero when r_tmp < p"); - final_limbs = r_n.0; // Use the low N limbs directly + debug_assert!(r_tmp.0[N] == 0, "High limb must be zero when r_tmp < p"); + return BigInt::(r_tmp.0[0..N].try_into().unwrap()); } } - final_limbs +} + +/// Subtract two N+1 limb big integers where `a` is (u64, [u64; N]) and `b` is ([u64; N], u64). +/// Returns the N+1 limb result as ([u64; N], u64) and a boolean indicating if a borrow occurred. +#[unroll_for_loops(8)] +#[inline(always)] +fn sub_bigint_plus_one_prime( + a: (u64, [u64; N]), // Format: (low_limb, high_n_limbs) + b: ([u64; N], u64), // Format: (low_n_limbs, high_limb) +) -> (([u64; N], u64), bool) { + let (a_lo, a_hi_n) = a; + let (b_lo_n, b_hi) = b; + let mut result_lo_n = [0u64; N]; + let mut borrow: u64 = 0; + + // Subtract low limb: result_lo_n[0] = a_lo - b_lo_n[0] - borrow (initial borrow = 0) + result_lo_n[0] = a_lo; // Initialize result limb with a_lo + borrow = fa::sbb(&mut result_lo_n[0], b_lo_n[0], borrow); // result_lo_n[0] -= b_lo_n[0] + borrow + + // Subtract middle limbs (if N > 1): result_lo_n[i] = a_hi_n[i-1] - b_lo_n[i] - borrow + // This loop covers indices i = 1 to N-1. + // It uses a_hi_n limbs from index 0 to N-2. + for i in 1..N { + result_lo_n[i] = a_hi_n[i - 1]; // Initialize result limb with corresponding a limb + borrow = fa::sbb(&mut result_lo_n[i], b_lo_n[i], borrow); // result_lo_n[i] -= b_lo_n[i] + borrow + } + + // Subtract high limb: result_hi = a_hi_n[N-1] - b_hi - borrow + let mut result_hi = a_hi_n[N - 1]; // Initialize result limb with last a limb + borrow = fa::sbb(&mut result_hi, b_hi, borrow); // result_hi -= b_hi + borrow + + let final_borrow_occurred = borrow != 0; + + ((result_lo_n, result_hi), final_borrow_occurred) } /// Helper function to perform Barrett reduction from N+1 limbs to N limbs. @@ -1480,38 +1724,50 @@ fn barrett_cond_subtract, const N: usize>(r_tmp: ([u64; N], u64 /// Output is the N-limb result `[u64; N]`. #[unroll_for_loops(4)] #[inline(always)] -fn barrett_reduce_nplus1_to_n, const N: usize>(c: (u64, [u64; N])) -> [u64; N] { - let (c_lo, c_hi) = c; // c_lo is the lowest limb, c_hi holds the top N limbs - +fn barrett_reduce_nplus1_to_n, const N: usize, const NPLUS1: usize>( + c: BigInt, +) -> BigInt { + debug_assert!(NPLUS1 == N + 1, "NPLUS1 must be N + 1 for this function"); // Compute tilde_c = floor(c / R') = floor(c / 2^MODULUS_BITS) // This involves the top two limbs of the N+1 limb number `c`. - // The highest limb is c_hi[N-1]. The second highest is c_hi[N-2]. // Assume that `N >= 1` let tilde_c: u64 = if T::MODULUS_HAS_SPARE_BIT { - let high_limb = c_hi[N - 1]; - let second_high_limb = if N > 1 { c_hi[N - 2] } else { c_lo }; // Use c_lo if N=1 + let high_limb = c.0[N]; + let second_high_limb = c.0[N - 1]; // N is at least 1, so this is safe (high_limb << T::MODULUS_NUM_SPARE_BITS) + (second_high_limb >> (64 - T::MODULUS_NUM_SPARE_BITS)) } else { - c_hi[N - 1] // If no spare bits, tilde_c is just the highest limb + c.0[N] // If no spare bits, tilde_c is just the highest limb }; // Estimate m = floor( (tilde_c * BARRETT_MU) / r ) // where r = 2^64 let m: u64 = ((tilde_c as u128 * T::BARRETT_MU as u128) >> 64) as u64; + // unroll T::MODULUS_TIMES_2_NPLUS1 from ([u64; N], u64) to BigInt + let mut m2p = nplus1_pair_high_to_bigint::(T::MODULUS_TIMES_2_NPLUS1); // Compute m * 2p (N+1 limbs) - let (m_times_2p, m2p_overflow) = bigint_plus_one_mul_by_u64::(T::MODULUS_TIMES_2_NPLUS1, m); - // If m * 2p overflows N+1 limbs, the logic might be flawed or input c was too large. - debug_assert!(!m2p_overflow, "Overflow calculating m * 2p"); + m2p.mul_u64_in_place(m); + // I really have no idea why the following sequence of operations + // is significantly faster than a simple BigInt sub operation. // Compute r_tmp = c - m * 2p (result is ([u64; N], u64)) - let (r_tmp, r_tmp_borrow) = sub_bigint_plus_one_prime(c, m_times_2p); + let m_times_2p = ( + m2p.0[0..N].try_into().unwrap(), // Convert to ([u64; N], u64) + m2p.0[N], // High limb remains as u64 + ); + let (r_tmp, r_tmp_borrow) = + sub_bigint_plus_one_prime((c.0[0], c.0[1..N + 1].try_into().unwrap()), m_times_2p); // A borrow here implies c was smaller than m*2p, which shouldn't happen with correct m. debug_assert!(!r_tmp_borrow, "Borrow occurred calculating c - m*2p"); - - // Use the optimized conditional subtraction which expects ([u64; N], u64) - barrett_cond_subtract::(r_tmp) + // Change formats again! + let r_tmp_bigint = nplus1_pair_high_to_bigint::(r_tmp); + // Alternative simple BigInt subtraction (much slower for some reason): + /*let (r_tmp_bigint, r_borrow) = c.const_sub_with_borrow(&m2p); + debug_assert!(!r_borrow, "Borrow occurred calculating c - m*2p");*/ + + // Use the optimized conditional subtraction to go from N+1 limbs to N limbs. + barrett_cond_subtract::(r_tmp_bigint) } #[cfg(test)] @@ -1521,6 +1777,10 @@ mod test { use ark_test_curves::bn254::Fr; use num_bigint::{BigInt, BigUint, Sign}; use rand::Rng; + // constants for the number of limbs in bn254 + const N: usize = 4; + const NPLUS1: usize = N + 1; + const NPLUS2: usize = N + 2; #[test] fn test_mul_u64_random() { @@ -1534,7 +1794,7 @@ mod test { let expected_c = a * Fr::from(b_bigint); // Actual result using the function under test - let result_c = a.mul_u64(b_val); + let result_c = a.mul_u64::(b_val); assert_eq!( result_c, @@ -1562,7 +1822,7 @@ mod test { }; // Actual result using the function under test - let result_c = a.mul_i64(b_val); + let result_c = a.mul_i64::(b_val); assert_eq!( result_c, expected_c, @@ -1584,7 +1844,7 @@ mod test { let expected_c = a * Fr::from(b_bigint); // Actual result using the function under test - let result_c = a.mul_u128(b_val); + let result_c = a.mul_u128::(b_val); assert_eq!( result_c, expected_c, @@ -1612,7 +1872,7 @@ mod test { }; // Actual result using the function under test - let result_c = a.mul_i128(b_val); + let result_c = a.mul_i128::(b_val); assert_eq!( result_c, expected_c, @@ -1622,13 +1882,16 @@ mod test { } } + // Removed trailing-zero API tests due to API consolidation + #[test] fn test_mont_macro_correctness() { - // This test succeeds **only** on the secp256k1 curve. + // This test succeeds only on the secp256k1 curve. let (is_positive, limbs) = str_to_limbs_u64( "111192936301596926984056301862066282284536849596023571352007112326586892541694", ); - let t = Fr::from_sign_and_limbs(is_positive, &limbs); + // Use secp256k1::Fr here (do not use the bn254 alias `Fr` above). + let t = ark_test_curves::secp256k1::Fr::from_sign_and_limbs(is_positive, &limbs); let result: BigUint = t.into(); let expected = BigUint::from_str( @@ -1657,4 +1920,27 @@ mod test { let sign_is_positive = sign != Sign::Minus; (sign_is_positive, limbs) } + + #[test] + fn test_from_montgomery_reduce_paths_l8_l9_match_field_mul() { + let mut rng = test_rng(); + for _ in 0..200 { + let a = Fr::rand(&mut rng); + let b = Fr::rand(&mut rng); + + let expected = a * b; + + // Compute 8-limb raw product of Montgomery residues + let prod8 = a.0.mul_trunc::<4, 8>(&b.0); + + // Reduce via Montgomery reduction with L = 8 + let alt8 = Fr::montgomery_reduce_2n::<8>(prod8); + assert_eq!(alt8, expected, "from_montgomery_reduce L=8 mismatch"); + + // Zero-extend to 9 limbs and reduce with L = 9 + let prod9 = ark_test_curves::ark_ff::BigInt::<9>::zero_extend_from::<8>(&prod8); + let alt9 = Fr::from_montgomery_reduce::<9, 5>(prod9); + assert_eq!(alt9, expected, "from_montgomery_reduce L=9 mismatch"); + } + } } diff --git a/ff/src/fields/prime.rs b/ff/src/fields/prime.rs index 28b896e59..9ae26d83e 100644 --- a/ff/src/fields/prime.rs +++ b/ff/src/fields/prime.rs @@ -57,9 +57,12 @@ pub trait PrimeField: /// Converts an element of the prime field into an integer in the range 0..(p - 1). fn into_bigint(self) -> Self::BigInt; - /// Creates a field element from a `u64`. + /// Construct a prime field element from an integer in the range 0..(p - 1) (no reductions) + fn from_bigint_unchecked(repr: Self::BigInt) -> Option; + + /// Creates a field element from a `u64`. /// Returns `None` if the `u64` is larger than or equal to the modulus. - fn from_u64(val: u64) -> Option; + fn from_u64(val: u64) -> Option; /// Reads bytes in big-endian, and converts them to a field element. /// If the integer represented by `bytes` is larger than the modulus `p`, this method diff --git a/ff/src/lib.rs b/ff/src/lib.rs index 6b10158f4..7464724ea 100644 --- a/ff/src/lib.rs +++ b/ff/src/lib.rs @@ -20,7 +20,7 @@ extern crate educe; pub mod biginteger; pub use biginteger::{ signed_mod_reduction, BigInt, BigInteger, BigInteger128, BigInteger256, BigInteger320, - BigInteger384, BigInteger448, BigInteger64, BigInteger768, BigInteger832, + BigInteger384, BigInteger448, BigInteger64, BigInteger768, BigInteger832, SignedBigInt, }; #[macro_use] diff --git a/jolt-optimizations/Cargo.toml b/jolt-optimizations/Cargo.toml index 49aba6276..e644c37e4 100644 --- a/jolt-optimizations/Cargo.toml +++ b/jolt-optimizations/Cargo.toml @@ -57,4 +57,4 @@ name = "batch_addition" harness = false [[example]] -name = "memory_test" \ No newline at end of file +name = "memory_test" diff --git a/jolt-optimizations/benches/dory_all.rs b/jolt-optimizations/benches/dory_all.rs index 5015658ff..8e4866b2c 100644 --- a/jolt-optimizations/benches/dory_all.rs +++ b/jolt-optimizations/benches/dory_all.rs @@ -1,7 +1,7 @@ #![allow(non_snake_case)] use ark_bn254::{Fr, G1Projective, G2Projective}; -use ark_ec::{AdditiveGroup, PrimeGroup}; +use ark_ec::PrimeGroup; use ark_ff::PrimeField; use ark_std::UniformRand; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; diff --git a/jolt-optimizations/benches/dory_utils.rs b/jolt-optimizations/benches/dory_utils.rs index 9e86aef9b..89a867471 100644 --- a/jolt-optimizations/benches/dory_utils.rs +++ b/jolt-optimizations/benches/dory_utils.rs @@ -2,15 +2,13 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri use std::time::Instant; use ark_bn254::{Fr, G2Affine, G2Projective}; -use ark_ec::PrimeGroup; -use ark_ec::{AdditiveGroup, AffineRepr}; +use ark_ec::{AffineRepr, PrimeGroup}; use ark_ff::{PrimeField, UniformRand}; use ark_std::test_rng; use jolt_optimizations::{ vector_scalar_mul_add, vector_scalar_mul_add_online, vector_scalar_mul_add_precomputed, - vector_scalar_mul_v_add_g_online, vector_scalar_mul_v_add_g_precomputed, VectorScalarMulData, - VectorScalarMulVData, + VectorScalarMulData, }; fn bench_vector_scalar_mul_add(c: &mut Criterion) { diff --git a/jolt-optimizations/benches/g1_scalar_multiplication.rs b/jolt-optimizations/benches/g1_scalar_multiplication.rs index d2b302cea..0f48a9f6e 100644 --- a/jolt-optimizations/benches/g1_scalar_multiplication.rs +++ b/jolt-optimizations/benches/g1_scalar_multiplication.rs @@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri use rayon::prelude::*; use ark_bn254::{Fr, G1Affine, G1Projective}; -use ark_ec::{AdditiveGroup, AffineRepr, PrimeGroup}; +use ark_ec::{AffineRepr, PrimeGroup}; use ark_ff::{PrimeField, UniformRand}; use ark_std::test_rng; diff --git a/jolt-optimizations/benches/scalar_multiplication.rs b/jolt-optimizations/benches/scalar_multiplication.rs index 30a4f3e3b..ab9db9f9d 100644 --- a/jolt-optimizations/benches/scalar_multiplication.rs +++ b/jolt-optimizations/benches/scalar_multiplication.rs @@ -2,14 +2,13 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri use rayon::prelude::*; use ark_bn254::{Fr, G2Affine, G2Projective}; -use ark_ec::PrimeGroup; -use ark_ec::{AdditiveGroup, AffineRepr}; +use ark_ec::{AffineRepr, PrimeGroup}; use ark_ff::{PrimeField, UniformRand}; use ark_std::test_rng; use jolt_optimizations::{ glv_four_precompute, glv_four_precompute_windowed2_signed, glv_four_scalar_mul, - glv_four_scalar_mul_online, glv_four_scalar_mul_windowed2_signed, + glv_four_scalar_mul_windowed2_signed, }; fn bench_scalar_multiplication(c: &mut Criterion) { diff --git a/jolt-optimizations/src/decomp_2d.rs b/jolt-optimizations/src/decomp_2d.rs index 5ee172057..8281e4764 100644 --- a/jolt-optimizations/src/decomp_2d.rs +++ b/jolt-optimizations/src/decomp_2d.rs @@ -11,7 +11,7 @@ use num_integer::Integer; use num_traits::{One, Signed}; /// GLV lambda for BN254 G1 -const LAMBDA: Fr = +const _LAMBDA: Fr = MontFp!("21888242871839275217838484774961031246154997185409878258781734729429964517155"); /// GLV endomorphism coefficient for BN254 G1 diff --git a/jolt-optimizations/src/fq12_poly.rs b/jolt-optimizations/src/fq12_poly.rs new file mode 100644 index 000000000..9a5683ef4 --- /dev/null +++ b/jolt-optimizations/src/fq12_poly.rs @@ -0,0 +1,133 @@ +//! Fq12 polynomial operations and conversions for BN254 +use ark_bn254::{Fq, Fq12}; +use ark_ff::{Field, One, Zero}; + +/// Convert Fq12 to polynomial representation using tower basis mapping +/// +/// Maps Fq12 basis elements to powers of w: +/// - (c0.c0, c0.c1, c0.c2, c1.c0, c1.c1, c1.c2) → (w^0, w^2, w^4, w^1, w^3, w^5) +/// - Applies the mapping: (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6} +pub fn fq12_to_poly12_coeffs(a: &Fq12) -> [Fq; 12] { + // Tower basis element mappings: (outer_idx, inner_idx, w_power) + const MAPPINGS: [(usize, usize, usize); 6] = [ + (0, 0, 0), // a.c0.c0 → w^0 + (0, 1, 2), // a.c0.c1 → w^2 + (0, 2, 4), // a.c0.c2 → w^4 + (1, 0, 1), // a.c1.c0 → w^1 + (1, 1, 3), // a.c1.c1 → w^3 + (1, 2, 5), // a.c1.c2 → w^5 + ]; + + let nine = Fq::from(9); + let mut coeffs = [Fq::zero(); 12]; + + for &(outer, inner, w_power) in &MAPPINGS { + let fp2 = match (outer, inner) { + (0, 0) => &a.c0.c0, + (0, 1) => &a.c0.c1, + (0, 2) => &a.c0.c2, + (1, 0) => &a.c1.c0, + (1, 1) => &a.c1.c1, + (1, 2) => &a.c1.c2, + _ => unreachable!(), + }; + + let (x, y) = (fp2.c0, fp2.c1); + // Apply: (x + y·u)·w^k = (x - 9y)·w^k + y·w^{k+6} + coeffs[w_power] += x - nine * y; + coeffs[w_power + 6] += y; + } + + coeffs +} + +/// Coefficients for the minimal polynomial g(X) = X^12 - 18 X^6 + 82 +const G_COEFF_0: u64 = 82; +const G_COEFF_6: u64 = 18; + +/// Evaluate g(X) = X^12 - 18 X^6 + 82 at a given point r +pub fn g_eval(r: &Fq) -> Fq { + let r6 = (r.square() * r).square(); // r^6 = (r^2 * r)^2 + let r12 = r6.square(); + r12 - Fq::from(G_COEFF_6) * r6 + Fq::from(G_COEFF_0) +} + +/// Horner evaluation for arbitrary-degree poly +pub fn eval_poly_vec(coeffs: &[Fq], r: &Fq) -> Fq { + coeffs.iter().rev().fold(Fq::zero(), |acc, c| acc * r + c) +} + +/// Build the coefficients for g(X) = X^12 - 18 X^6 + 82 +pub fn g_coeffs() -> Vec { + let mut g = vec![Fq::zero(); 13]; + g[0] = Fq::from(G_COEFF_0); + g[6] = -Fq::from(G_COEFF_6); + g[12] = Fq::one(); + g +} + +/// Compute the multilinear extension (MLE) of a univariate polynomial. +pub fn to_multilinear_evals(coeffs: &[Fq; 12]) -> Vec { + // Evaluate polynomial at points 0..16 + (0..16) + .map(|i| { + let x = Fq::from(i as u64); + eval_poly_vec(&coeffs[..], &x) + }) + .collect() +} + +/// Convert Fq12 element to multilinear extension evaluations. +/// First converts to polynomial representation, then computes MLE. +pub fn fq12_to_multilinear_evals(a: &Fq12) -> Vec { + to_multilinear_evals(&fq12_to_poly12_coeffs(a)) +} + +/// Evaluate a multilinear polynomial at a given point. +pub fn eval_multilinear(evals: &[Fq], point: &[Fq]) -> Fq { + let n = point.len(); + assert_eq!( + evals.len(), + 1 << n, + "Number of evaluations must be 2^n where n is dimension" + ); + + let mut result = Fq::zero(); + for (i, &eval) in evals.iter().enumerate() { + let mut term = eval; + for j in 0..n { + let bit = (i >> j) & 1; + term *= if bit == 1 { + point[j] + } else { + Fq::one() - point[j] + }; + } + result += term; + } + result +} + +/// Compute equality function weights eq(z, x) for all x ∈ {0,1}^4 +pub fn eq_weights(z: &[Fq]) -> Vec { + assert_eq!(z.len(), 4, "Point z must be 4-dimensional"); + let mut w = vec![Fq::zero(); 16]; + + for idx in 0..16 { + // Binary decomposition of idx + let x0 = Fq::from((idx & 1) as u64); + let x1 = Fq::from(((idx >> 1) & 1) as u64); + let x2 = Fq::from(((idx >> 2) & 1) as u64); + let x3 = Fq::from(((idx >> 3) & 1) as u64); + + // eq(z, x) = ∏ᵢ ((1-zᵢ)(1-xᵢ) + zᵢxᵢ) + let t0 = (Fq::one() - z[0]) * (Fq::one() - x0) + z[0] * x0; + let t1 = (Fq::one() - z[1]) * (Fq::one() - x1) + z[1] * x1; + let t2 = (Fq::one() - z[2]) * (Fq::one() - x2) + z[2] * x2; + let t3 = (Fq::one() - z[3]) * (Fq::one() - x3) + z[3] * x3; + + w[idx] = t0 * t1 * t2 * t3; + } + + w +} diff --git a/jolt-optimizations/src/lib.rs b/jolt-optimizations/src/lib.rs index 5ba72de8f..2f40915e6 100644 --- a/jolt-optimizations/src/lib.rs +++ b/jolt-optimizations/src/lib.rs @@ -15,8 +15,10 @@ pub mod decomp_4d; pub mod dory_g1; pub mod dory_g2; pub mod dory_utils; +pub mod fq12_poly; pub mod frobenius; pub mod glv_two; +pub mod witness_gen; mod glv_four; pub use glv_four::{ @@ -54,3 +56,10 @@ pub use dory_g2::{ }; pub use batch_addition::{batch_g1_additions, batch_g1_additions_multi}; + +pub use fq12_poly::{ + eq_weights, eval_multilinear, fq12_to_multilinear_evals, fq12_to_poly12_coeffs, g_coeffs, + g_eval, to_multilinear_evals, +}; + +pub use witness_gen::{get_g_mle, h_tilde_at_point, ExponentiationSteps}; diff --git a/jolt-optimizations/src/witness_gen.rs b/jolt-optimizations/src/witness_gen.rs new file mode 100644 index 000000000..d5098982f --- /dev/null +++ b/jolt-optimizations/src/witness_gen.rs @@ -0,0 +1,203 @@ +use crate::fq12_poly::{eq_weights, eval_multilinear, fq12_to_multilinear_evals, g_coeffs}; +use ark_bn254::{Fq, Fq12, Fr}; +use ark_ff::{BigInteger, Field, One, PrimeField, Zero}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; + +/// square-and-multiply witness generation +#[derive(Clone, Debug, Default, CanonicalSerialize, CanonicalDeserialize)] +pub struct ExponentiationSteps { + pub base: Fq12, // A (base) + pub exponent: Fr, // e (exponent) + pub result: Fq12, // Final result A^e + pub rho_mles: Vec>, // MLEs of ρ_0, ρ_1, ..., ρ_t + pub quotient_mles: Vec>, // MLEs of Q_1, Q_2, ..., Q_t + pub bits: Vec, // b_1, b_2, ..., b_t +} + +impl ExponentiationSteps { + /// Generate MLE witness for base^exponent using MSB-first square-and-multiply + pub fn new(base: Fq12, exponent: Fr) -> Self { + let bits_le = exponent.into_bigint().to_bits_le(); + + let msb_idx = match bits_le.iter().rposition(|&b| b) { + None => { + return Self { + base, + exponent, + result: Fq12::one(), + rho_mles: vec![fq12_to_multilinear_evals(&Fq12::one())], // ρ_0 = 1 + quotient_mles: vec![], + bits: vec![], + }; + }, + Some(i) => i, + }; + + // Special case: exponent == 1 ⇒ result = base + if msb_idx == 0 && bits_le[0] { + return Self { + base, + exponent, + result: base, + rho_mles: vec![ + fq12_to_multilinear_evals(&Fq12::one()), // ρ_0 + fq12_to_multilinear_evals(&base), // ρ_1 + ], + quotient_mles: vec![], + bits: vec![true], + }; + } + + let bits_msb: Vec = (0..=msb_idx).rev().map(|i| bits_le[i]).collect(); + + let mut rho = Fq12::one(); + let mut rho_mles = vec![fq12_to_multilinear_evals(&rho)]; + let mut quotient_mles = vec![]; + let mut bits = vec![]; + + for &b in &bits_msb { + bits.push(b); + + let rho_prev = rho; // ρ_{i-1} + let rho_sq = rho_prev.square(); // ρ_{i-1}² + let rho_i = if b { rho_sq * base } else { rho_sq }; // ρ_i + + // One quotient per step for: ρ_i(X) - ρ_{i-1}(X)² * A(X)^{b} = Q_i(X) g(X) + let q_i = compute_step_quotient_msb(rho_prev, rho_i, base, b); + quotient_mles.push(q_i); + + rho = rho_i; + rho_mles.push(fq12_to_multilinear_evals(&rho)); + } + + Self { + base, + exponent, + result: rho, + rho_mles, + quotient_mles, + bits, + } + } + + /// Verify that the final result matches base^exponent + pub fn verify_result(&self) -> bool { + self.result == self.base.pow(self.exponent.into_bigint()) + } + + /// Verify constraint at a Boolean cube point + /// Checks that the constraint holds at cube vertices where it was constructed to be zero + pub fn verify_constraint_at_cube_point(&self, step: usize, cube_index: usize) -> bool { + if step == 0 || step > self.quotient_mles.len() || cube_index >= 16 { + return false; + } + let point = index_to_boolean_point(cube_index); + + // Evaluate MLEs + let rho_prev = eval_mle_at_boolean_point(&self.rho_mles[step - 1], &point); + let rho_curr = eval_mle_at_boolean_point(&self.rho_mles[step], &point); + let quotient = eval_mle_at_boolean_point(&self.quotient_mles[step - 1], &point); + + let base_mle = fq12_to_multilinear_evals(&self.base); + let base_eval = eval_mle_at_boolean_point(&base_mle, &point); + let g_mle = get_g_mle(); + let g_eval = eval_mle_at_boolean_point(&g_mle, &point); + + // Compute constraint: ρ_i - ρ_{i-1}² * base^{b_i} - Q_i * g + let bit = self.bits[step - 1]; + let base_power = if bit { base_eval } else { Fq::one() }; + let constraint = rho_curr - rho_prev.square() * base_power - quotient * g_eval; + constraint.is_zero() + } + + pub fn num_steps(&self) -> usize { + self.quotient_mles.len() + } +} + +/// Compute quotient MLE +fn compute_step_quotient_msb(rho_prev: Fq12, rho_i: Fq12, base: Fq12, bit: bool) -> Vec { + let rho_prev_mle = fq12_to_multilinear_evals(&rho_prev); + let rho_i_mle = fq12_to_multilinear_evals(&rho_i); + let base_mle = fq12_to_multilinear_evals(&base); + + let g_mle = get_g_mle(); + + // Compute the quotient MLE pointwise: Q_i(x) = (ρ_i(x) - ρ_{i-1}(x)² * base(x)^{b_i}) / g(x) + let mut quotient_mle = vec![Fq::zero(); 16]; + for j in 0..16 { + let rho_prev_sq = rho_prev_mle[j].square(); + let base_power = if bit { base_mle[j] } else { Fq::one() }; + let expected = rho_prev_sq * base_power; + + // Q_i(x) = (ρ_i(x) - expected) / g(x) + if !g_mle[j].is_zero() { + quotient_mle[j] = (rho_i_mle[j] - expected) / g_mle[j]; + } + } + + quotient_mle +} + +/// Get g as MLE evaluations over the Boolean cube {0,1}^4 +pub fn get_g_mle() -> Vec { + use crate::fq12_poly::eval_poly_vec; + let g_vec = g_coeffs(); + + (0..16) + .map(|i| { + let x = Fq::from(i as u64); + eval_poly_vec(&g_vec[..], &x) + }) + .collect() +} + +/// Convert a cube index (0..15) to a Boolean point in {0,1}^4 +pub(crate) fn index_to_boolean_point(index: usize) -> Vec { + vec![ + Fq::from((index & 1) as u64), // bit 0 + Fq::from(((index >> 1) & 1) as u64), // bit 1 + Fq::from(((index >> 2) & 1) as u64), // bit 2 + Fq::from(((index >> 3) & 1) as u64), // bit 3 + ] +} + +/// Evaluate an MLE at a Boolean cube point +/// For Boolean points, this is equivalent to indexing but makes the evaluation explicit +fn eval_mle_at_boolean_point(mle: &[Fq], point: &[Fq]) -> Fq { + // For Boolean points, we could just index, but using eval_multilinear + // makes it clear we're doing MLE evaluation + eval_multilinear(mle, point) +} + +/// H(x) = ρᵢ(x) - ρᵢ₋₁(x)² · A(x)^{bᵢ} - Qᵢ(x) · g(x) for x ∈ {0,1}^4 +/// H̃(z) = Σ_{x∈{0,1}^4} eq(z,x) · H(x) +pub fn h_tilde_at_point( + rho_prev_mle: &[Fq], + rho_curr_mle: &[Fq], + base_mle: &[Fq], + q_mle: &[Fq], + g_mle: &[Fq], + bit: bool, + z: &[Fq], +) -> Fq { + assert_eq!(rho_prev_mle.len(), 16); + assert_eq!(rho_curr_mle.len(), 16); + assert_eq!(base_mle.len(), 16); + assert_eq!(q_mle.len(), 16); + assert_eq!(g_mle.len(), 16); + assert_eq!(z.len(), 4); + + let w = eq_weights(z); + let mut acc = Fq::zero(); + + for j in 0..16 { + // Compute H(x_j) where x_j is the j-th hypercube vertex + let prod = rho_prev_mle[j].square() * if bit { base_mle[j] } else { Fq::one() }; + let h_x = rho_curr_mle[j] - prod - q_mle[j] * g_mle[j]; + + acc += h_x * w[j]; + } + + acc +} diff --git a/jolt-optimizations/tests/integration_tests.rs b/jolt-optimizations/tests/integration_tests.rs index db9aac73a..216ea3313 100644 --- a/jolt-optimizations/tests/integration_tests.rs +++ b/jolt-optimizations/tests/integration_tests.rs @@ -1,6 +1,6 @@ use ark_bn254::{Fr, G2Affine, G2Projective}; use ark_ec::PrimeGroup; -use ark_ec::{AdditiveGroup, AffineRepr, CurveGroup}; +use ark_ec::{AffineRepr, CurveGroup}; use ark_ff::{PrimeField, UniformRand}; use ark_std::{test_rng, Zero}; use num_bigint::BigInt; diff --git a/jolt-optimizations/tests/mle_tests.rs b/jolt-optimizations/tests/mle_tests.rs new file mode 100644 index 000000000..c507aec64 --- /dev/null +++ b/jolt-optimizations/tests/mle_tests.rs @@ -0,0 +1,143 @@ +use ark_bn254::Fq; +use ark_ff::{Field, One, UniformRand, Zero}; +use ark_std::test_rng; +use jolt_optimizations::{ + eval_multilinear, + fq12_poly::{eval_poly_vec, to_multilinear_evals}, +}; + +/// Generate random polynomial coefficients for testing +fn random_poly12_coeffs() -> [Fq; 12] { + let mut rng = test_rng(); + let mut coeffs = [Fq::zero(); 12]; + for c in coeffs.iter_mut() { + *c = Fq::rand(&mut rng); + } + coeffs +} + +#[test] +fn test_mle_agreement_with_univariate() { + let coeffs = random_poly12_coeffs(); + let mle_evals = to_multilinear_evals(&coeffs); + + for i in 0..16 { + let x = Fq::from(i as u64); + let univariate_eval = eval_poly_vec(&coeffs[..], &x); + + assert_eq!( + univariate_eval, mle_evals[i], + "MLE evaluation doesn't match univariate at point {}", + i + ); + let binary_point = vec![ + Fq::from((i & 1) as u64), + Fq::from(((i >> 1) & 1) as u64), + Fq::from(((i >> 2) & 1) as u64), + Fq::from(((i >> 3) & 1) as u64), + ]; + let mle_eval = eval_multilinear(&mle_evals, &binary_point); + + assert_eq!( + univariate_eval, mle_eval, + "eval_multilinear doesn't agree with univariate at point {}", + i + ); + } +} + +#[test] +fn test_mle_is_multilinear() { + let mut rng = test_rng(); + let coeffs = random_poly12_coeffs(); + let mle_evals = to_multilinear_evals(&coeffs); + + for var_idx in 0..4 { + let point = vec![ + Fq::rand(&mut rng), + Fq::rand(&mut rng), + Fq::rand(&mut rng), + Fq::rand(&mut rng), + ]; + + let mut p0 = point.clone(); + p0[var_idx] = Fq::zero(); + let eval0 = eval_multilinear(&mle_evals, &p0); + + let mut p1 = point.clone(); + p1[var_idx] = Fq::one(); + let eval1 = eval_multilinear(&mle_evals, &p1); + + let t = Fq::rand(&mut rng); + let mut pt = point.clone(); + pt[var_idx] = t; + let eval_t = eval_multilinear(&mle_evals, &pt); + let expected = eval0 * (Fq::one() - t) + eval1 * t; + + assert_eq!( + eval_t, expected, + "MLE is not linear in variable {}", + var_idx + ); + } +} + +#[test] +fn test_mle_special_cases() { + let zero_coeffs = [Fq::zero(); 12]; + let mle = to_multilinear_evals(&zero_coeffs); + assert!( + mle.iter().all(|&x| x.is_zero()), + "Zero polynomial MLE should be all zeros" + ); + + let const_val = Fq::from(42u64); + let mut const_coeffs = [Fq::zero(); 12]; + const_coeffs[0] = const_val; + let mle = to_multilinear_evals(&const_coeffs); + assert!( + mle.iter().all(|&x| x == const_val), + "Constant polynomial MLE should be constant" + ); + + let mut linear_coeffs = [Fq::zero(); 12]; + linear_coeffs[1] = Fq::one(); + let mle = to_multilinear_evals(&linear_coeffs); + for i in 0..16 { + assert_eq!( + mle[i], + Fq::from(i as u64), + "Linear polynomial p(x)=x should evaluate to {} at {}", + i, + i + ); + } + let mut quad_coeffs = [Fq::zero(); 12]; + quad_coeffs[2] = Fq::one(); + let mle = to_multilinear_evals(&quad_coeffs); + for i in 0..16 { + let expected = Fq::from((i * i) as u64); + assert_eq!( + mle[i], expected, + "Quadratic polynomial p(x)=x² should evaluate correctly at {}", + i + ); + } +} + +#[test] +fn test_mle_high_degree() { + let mut coeffs = [Fq::zero(); 12]; + coeffs[11] = Fq::one(); + + let mle = to_multilinear_evals(&coeffs); + for i in 0..16 { + let x = Fq::from(i as u64); + let expected = x.pow([11u64]); + assert_eq!( + mle[i], expected, + "p(x) = x^11 should evaluate correctly at {}", + i + ); + } +} diff --git a/jolt-optimizations/tests/witness_test.rs b/jolt-optimizations/tests/witness_test.rs new file mode 100644 index 000000000..9a02838d7 --- /dev/null +++ b/jolt-optimizations/tests/witness_test.rs @@ -0,0 +1,246 @@ +use ark_bn254::{Fq, Fq12, Fr}; +use ark_ff::{One, UniformRand, Zero}; +use ark_std::test_rng; +use jolt_optimizations::{ + fq12_to_multilinear_evals, get_g_mle, h_tilde_at_point, ExponentiationSteps, +}; + +#[test] +fn test_witness_generation_and_constraints() { + let mut rng = test_rng(); + + for test_idx in 0..100 { + let base = Fq12::rand(&mut rng); + + let exponent = if test_idx == 0 { + Fr::from(5u64) + } else if test_idx == 1 { + Fr::from(100u64) + } else { + Fr::rand(&mut rng) + }; + + let witness = ExponentiationSteps::new(base, exponent); + + assert!( + witness.verify_result(), + "Final result should match base^exponent" + ); + + assert_eq!( + witness.rho_mles.len(), + witness.quotient_mles.len() + 1, + "Should have one more rho than quotients" + ); + + // Verify all MLEs have correct dimension (16 = 2^4 evaluations) + for mle in &witness.rho_mles { + assert_eq!(mle.len(), 16, "Rho MLEs should have 16 evaluations"); + } + for mle in &witness.quotient_mles { + assert_eq!(mle.len(), 16, "Quotient MLEs should have 16 evaluations"); + } + + // The constraint should be zero at all 16 cube vertices + for cube_idx in 0..16 { + for step in 1..=witness.num_steps() { + assert!( + witness.verify_constraint_at_cube_point(step, cube_idx), + "Constraint failed at step {} for cube point {}", + step, + cube_idx + ); + } + } + } +} + +#[test] +fn test_trivial_cases() { + let mut rng = test_rng(); + let base = Fq12::rand(&mut rng); + + // Test exponent = 0 + let witness_zero = ExponentiationSteps::new(base, Fr::from(0u64)); + assert_eq!(witness_zero.result, Fq12::one()); + assert!(witness_zero.verify_result()); + assert_eq!(witness_zero.bits.len(), 0); + assert_eq!(witness_zero.rho_mles.len(), 1); + assert_eq!(witness_zero.quotient_mles.len(), 0); + + // Test exponent = 1 + let witness_one = ExponentiationSteps::new(base, Fr::from(1u64)); + assert_eq!(witness_one.result, base); + assert!(witness_one.verify_result()); + assert_eq!(witness_one.bits, vec![true]); + assert_eq!(witness_one.rho_mles.len(), 2); + // Test small known values to verify bit sequence + let witness_five = ExponentiationSteps::new(base, Fr::from(5u64)); + assert_eq!(witness_five.bits, vec![true, false, true]); + assert!(witness_five.verify_result()); + + let witness_ten = ExponentiationSteps::new(base, Fr::from(10u64)); + assert_eq!(witness_ten.bits, vec![true, false, true, false]); + assert!(witness_ten.verify_result()); +} + +#[test] +fn test_witness_soundness() { + let mut rng = test_rng(); + + // Test soundness: tampering with witness + for test_idx in 0..20 { + let base = Fq12::rand(&mut rng); + let exponent = if test_idx == 0 { + Fr::from(10u64) + } else { + Fr::rand(&mut rng) + }; + + let mut witness = ExponentiationSteps::new(base, exponent); + + // Verify original witness is valid + assert!(witness.verify_result()); + let mut all_valid = true; + for cube_idx in 0..16 { + for step in 1..=witness.num_steps() { + if !witness.verify_constraint_at_cube_point(step, cube_idx) { + all_valid = false; + } + } + } + assert!(all_valid, "Original witness should be valid"); + + // Test 1: Tampering with ρ values + if witness.rho_mles.len() > 1 { + let tamper_idx = 1 + (test_idx % (witness.rho_mles.len() - 1)); + let point_idx = test_idx % 16; + let original = witness.rho_mles[tamper_idx][point_idx]; + + witness.rho_mles[tamper_idx][point_idx] += Fq::from(1u64); + + let mut soundness_broken = false; + for cube_idx in 0..16 { + if tamper_idx <= witness.num_steps() { + if !witness.verify_constraint_at_cube_point(tamper_idx, cube_idx) { + soundness_broken = true; + break; + } + } + if tamper_idx > 0 && tamper_idx + 1 <= witness.num_steps() { + if !witness.verify_constraint_at_cube_point(tamper_idx + 1, cube_idx) { + soundness_broken = true; + break; + } + } + } + + assert!(soundness_broken, "Tampering with ρ should break soundness"); + witness.rho_mles[tamper_idx][point_idx] = original; + } + + // Test 2: Tampering with quotient + if !witness.quotient_mles.is_empty() { + let q_idx = test_idx % witness.quotient_mles.len(); + let point_idx = (test_idx * 7) % 16; + let original = witness.quotient_mles[q_idx][point_idx]; + + witness.quotient_mles[q_idx][point_idx] += Fq::from(1u64); + + let soundness_broken = !witness.verify_constraint_at_cube_point(q_idx + 1, point_idx); + + assert!( + soundness_broken, + "Tampering with quotient should break soundness" + ); + witness.quotient_mles[q_idx][point_idx] = original; + } + + // Test 3: Flipping bits + if !witness.bits.is_empty() { + let bit_idx = test_idx % witness.bits.len(); + witness.bits[bit_idx] = !witness.bits[bit_idx]; + + let mut soundness_broken = false; + for cube_idx in 0..16 { + if !witness.verify_constraint_at_cube_point(bit_idx + 1, cube_idx) { + soundness_broken = true; + break; + } + } + + assert!(soundness_broken, "Flipping bits should break soundness"); + witness.bits[bit_idx] = !witness.bits[bit_idx]; + } + + // Test 4: Tampering with final result + let original_result = witness.result; + witness.result = witness.result + Fq12::one(); + assert!( + !witness.verify_result(), + "Modified result should fail verification" + ); + witness.result = original_result; + } +} + +#[test] +fn test_constraint_at_random_field_element() { + let mut rng = test_rng(); + + // Create witness for a simple exponentiation + let base = Fq12::rand(&mut rng); + let exponent = Fr::from(10000300u64); + let witness = ExponentiationSteps::new(base, exponent); + + let base_mle = fq12_to_multilinear_evals(&base); + let g_mle = get_g_mle(); + + for test_idx in 0..10000 { + let z: Vec = (0..4) + .map(|_| { + let mut val = Fq::rand(&mut rng); + // Ensure it's not 0 or 1 (not on hypercube) + while val == Fq::zero() || val == Fq::one() { + val = Fq::rand(&mut rng); + } + val + }) + .collect(); + + let step = 1 + (test_idx % witness.num_steps()); + let bit = witness.bits[step - 1]; + + let h = h_tilde_at_point( + &witness.rho_mles[step - 1], + &witness.rho_mles[step], + &base_mle, + &witness.quotient_mles[step - 1], + &g_mle, + bit, + &z, + ); + + // H̃(z) must be 0 at random z + assert!( + h.is_zero(), + "H̃(z) must be 0 at random z (test {}, step {}). Got: {:?}", + test_idx, + step, + h + ); + } + + for step in 1..=witness.num_steps() { + for cube_idx in 0..16 { + assert!( + witness.verify_constraint_at_cube_point(step, cube_idx), + "Constraint should be zero at hypercube point {} step {}", + cube_idx, + step + ); + } + } + + println!("✓ Verified: Constraints are zero on hypercube (sanity check)"); +} diff --git a/test-curves/Cargo.toml b/test-curves/Cargo.toml index 08d5380e7..5fef59054 100644 --- a/test-curves/Cargo.toml +++ b/test-curves/Cargo.toml @@ -88,3 +88,8 @@ harness = false name = "bn254" path = "benches/bn254.rs" harness = false + +[[bench]] +name = "bigint" +path = "benches/bigint.rs" +harness = false diff --git a/test-curves/benches/bigint.rs b/test-curves/benches/bigint.rs new file mode 100644 index 000000000..f0d7ab1f9 --- /dev/null +++ b/test-curves/benches/bigint.rs @@ -0,0 +1,165 @@ +// Benchmark for BigInt operations +#[cfg(feature = "bn254")] +use ark_ff::{BigInt, BigInteger}; +#[cfg(feature = "bn254")] +use ark_std::rand::{rngs::StdRng, Rng, SeedableRng}; +#[cfg(feature = "bn254")] +use criterion::{criterion_group, criterion_main, Criterion}; + +#[cfg(feature = "bn254")] +fn bigint_add_bench(c: &mut Criterion) { + const SAMPLES: usize = 1000; + // Use a fixed seed for reproducibility + let mut rng = StdRng::seed_from_u64(0u64); + + // Generate random BigInt<4> instances for benchmarking + let a_bigints = (0..SAMPLES) + .map(|_| { + BigInt::<4>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) + .collect::>(); + + let b_bigints = (0..SAMPLES) + .map(|_| { + BigInt::<4>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) + .collect::>(); + + let mut group = c.benchmark_group("BigInt<4> Addition Comparison"); + + // Benchmark add_trunc with same limb count (4 -> 4) + group.bench_function("add_trunc<4, 4>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(a_bigints[i].add_trunc::<4, 4>(&b_bigints[i])) + }) + }); + + // Benchmark add_trunc with truncation (4 -> 3 limbs) + group.bench_function("add_trunc<4, 3>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(a_bigints[i].add_trunc::<4, 3>(&b_bigints[i])) + }) + }); + + // Benchmark add_trunc with expansion (4 -> 5 limbs) + group.bench_function("add_trunc<4, 5>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(a_bigints[i].add_trunc::<4, 5>(&b_bigints[i])) + }) + }); + + // Benchmark regular addition using add_with_carry + group.bench_function("add_with_carry (regular add)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + let carry = result.add_with_carry(&b_bigints[i]); + criterion::black_box((result, carry)) + }) + }); + + // Benchmark regular addition that ignores carry (for fair comparison) + group.bench_function("add_with_carry (ignore carry)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + result.add_with_carry(&b_bigints[i]); + criterion::black_box(result) + }) + }); + + // Benchmark add_assign_trunc with same limb count (4 -> 4) + group.bench_function("add_assign_trunc<4, 4>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + result.add_assign_trunc::<4, 4>(&b_bigints[i]); + criterion::black_box(result) + }) + }); + + // Benchmark add_assign_trunc with truncation (4 -> 3 limbs) + group.bench_function("add_assign_trunc<4, 3>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + result.add_assign_trunc::<4, 3>(&b_bigints[i]); + criterion::black_box(result) + }) + }); + + // Benchmark add_assign_trunc with expansion (4 -> 5 limbs) + group.bench_function("add_assign_trunc<4, 5>", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = a_bigints[i]; + result.add_assign_trunc::<4, 5>(&b_bigints[i]); + criterion::black_box(result) + }) + }); + + // Test case: addition that would overflow to compare truncation behavior + let max_bigints = (0..SAMPLES) + .map(|_| BigInt::<4>([u64::MAX, u64::MAX, u64::MAX, u64::MAX])) + .collect::>(); + + group.bench_function("add_trunc overflow case", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + // This will overflow and be truncated + criterion::black_box(max_bigints[i].add_trunc::<4, 4>(&max_bigints[i])) + }) + }); + + group.bench_function("add_with_carry overflow case", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = max_bigints[i]; + let carry = result.add_with_carry(&max_bigints[i]); + criterion::black_box((result, carry)) + }) + }); + + group.bench_function("add_assign_trunc overflow case", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let mut result = max_bigints[i]; + result.add_assign_trunc::<4, 4>(&max_bigints[i]); + criterion::black_box(result) + }) + }); + + group.finish(); +} + +#[cfg(feature = "bn254")] +criterion_group!(benches, bigint_add_bench); +#[cfg(feature = "bn254")] +criterion_main!(benches); + +#[cfg(not(feature = "bn254"))] +fn main() {} diff --git a/test-curves/benches/small_mul.rs b/test-curves/benches/small_mul.rs index d79b9c530..0bed530ab 100644 --- a/test-curves/benches/small_mul.rs +++ b/test-curves/benches/small_mul.rs @@ -1,121 +1,364 @@ -use ark_ff::UniformRand; +// This bench prefers bn254; if not enabled, provide a no-op main +#[cfg(feature = "bn254")] +use ark_ff::{BigInteger, UniformRand}; +#[cfg(feature = "bn254")] use ark_std::rand::{rngs::StdRng, Rng, SeedableRng}; -use ark_test_curves::bn254::{Fr, FrConfig}; +#[cfg(feature = "bn254")] +use ark_test_curves::bn254::Fr; +#[cfg(feature = "bn254")] use criterion::{criterion_group, criterion_main, Criterion}; // Hack: copy over the helper functions from the Montgomery backend to be benched +#[cfg(feature = "bn254")] fn mul_small_bench(c: &mut Criterion) { const SAMPLES: usize = 1000; // Use a fixed seed for reproducibility let mut rng = StdRng::seed_from_u64(0u64); - let a_s = (0..SAMPLES) - .map(|_| Fr::rand(&mut rng)) - .collect::>(); - let a_limbs_s = a_s.iter().map(|a| a.0.0).collect::>(); + let a_s = (0..SAMPLES).map(|_| Fr::rand(&mut rng)).collect::>(); + // let a_limbs_s = a_s.iter().map(|a| a.0.0).collect::>(); - let b_u64_s = (0..SAMPLES) - .map(|_| rng.gen::()) - .collect::>(); + let b_u64_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); // Convert u64 to Fr for standard multiplication benchmark let b_fr_s = b_u64_s.iter().map(|&b| Fr::from(b)).collect::>(); let b_u64_as_u128_s = b_u64_s.iter().map(|&b| b as u128).collect::>(); - let b_i64_s = (0..SAMPLES) - .map(|_| rng.gen::()) + let b_i64_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); + + let b_u128_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); + + let b_i128_s = (0..SAMPLES).map(|_| rng.gen::()).collect::>(); + + // Generate another set of random Fr elements for addition + let c_s = (0..SAMPLES).map(|_| Fr::rand(&mut rng)).collect::>(); + + // Generate test data for reduction benchmarks + use ark_ff::BigInt; + // Extract BigInt<4> from Fr elements for mul_u64_w_carry benchmark + let a_bigints = a_s.iter().map(|a| a.0).collect::>(); + + // For Montgomery reduction: 2N-limb inputs (N=4 for bn254, so 2N=8) + let bigint_2n_s = (0..SAMPLES) + .map(|_| { + BigInt::<8>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); - let b_u128_s = (0..SAMPLES) - .map(|_| rng.gen::()) + // For Barrett reductions: N+1, N+2, N+3 limb inputs + let bigint_nplus1_s = (0..SAMPLES) + .map(|_| { + BigInt::<5>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); - let b_i128_s = (0..SAMPLES) - .map(|_| rng.gen::()) + let bigint_nplus2_s = (0..SAMPLES) + .map(|_| { + BigInt::<6>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); - // Generate another set of random Fr elements for addition - let c_s = (0..SAMPLES) - .map(|_| Fr::rand(&mut rng)) + let bigint_nplus3_s = (0..SAMPLES) + .map(|_| { + BigInt::<7>([ + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ]) + }) .collect::>(); let mut group = c.benchmark_group("Fr Arithmetic Comparison"); + // Uncommented to compare with mul_u64_w_carry + Barrett reduction group.bench_function("mul_u64", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_u64(b_u64_s[i])) + // bn254 Fr has N=4 limbs => N+1 = 5 + criterion::black_box(a_s[i].mul_u64::<5>(b_u64_s[i])) }) }); - group.bench_function("mul_i64", |bench| { + // Benchmark just the multiplication phase (without Barrett reduction) + group.bench_function("mul_u64_w_carry (multiplication only)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_i64(b_i64_s[i])) + // This is just the multiplication step, returns BigInt<5> + criterion::black_box(a_bigints[i].mul_u64_w_carry::<5>(b_u64_s[i])) }) }); - // Note: results might be worse than in real applications due to branch prediction being wrong - // 50% of the time - group.bench_function("mul_u128", |bench| { + // group.bench_function("mul_i64", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_i64::<5>(b_i64_s[i])) + // }) + // }); + + // // Note: results might be worse than in real applications due to branch prediction being wrong + // // 50% of the time + // group.bench_function("mul_u128", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // // bn254 Fr has N=4 limbs => N+1 = 5, N+2 = 6 + // criterion::black_box(a_s[i].mul_u128::<5, 6>(b_u128_s[i])) + // }) + // }); + + // group.bench_function("mul_i128", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_i128::<5, 6>(b_i128_s[i])) + // }) + // }); + + group.bench_function("standard mul (Fr * Fr)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_u128(b_u128_s[i])) + criterion::black_box(a_s[i] * b_fr_s[i]) }) }); - group.bench_function("mul_i128", |bench| { + // Bench specialized high-limb RHS fastpaths (K = 1, 2) + // Construct BigInt with random high limbs for K=1 and K=2 + let b_k1_bigint = (0..SAMPLES) + .map(|_| BigInt::<1>([rng.gen::()])) + .collect::>(); + let b_k2_bigint = (0..SAMPLES) + .map(|_| BigInt::<2>([rng.gen::(), rng.gen::()])) + .collect::>(); + + // group.bench_function("mul_assign_hi_bigint::<1>", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // let mut x = a_s[i]; + // x.mul_assign_hi_bigint::<1>(&b_k1_bigint[i]); + // criterion::black_box(x) + // }) + // }); + + // group.bench_function("mul_assign_hi_bigint::<2>", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // let mut x = a_s[i]; + // x.mul_assign_hi_bigint::<2>(&b_k2_bigint[i]); + // criterion::black_box(x) + // }) + // }); + + // group.bench_function("mul_hi_bigint::<1>", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_hi_bigint::<1>(&b_k1_bigint[i])) + // }) + // }); + + // group.bench_function("mul_hi_bigint::<2>", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_hi_bigint::<2>(&b_k2_bigint[i])) + // }) + // }); + + // group.bench_function("mul_u128 (u64 inputs)", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // // Call mul_u128 but provide a u64 input cast to u128 + // criterion::black_box(a_s[i].mul_u128::<5, 6>(b_u64_as_u128_s[i])) + // }) + // }); + + // Benchmark the auxiliary function directly (assuming it's made public) + // Note: Requires mul_u128_aux to be pub in montgomery_backend.rs + // Need to import it if not already done via wildcard/specific import + // Let's assume it's accessible via a_s[i].mul_u128_aux(...) for now + // group.bench_function("mul_u128_aux (u128 inputs)", |bench| { + // let mut i = 0; + // bench.iter(|| { + // i = (i + 1) % SAMPLES; + // criterion::black_box(a_s[i].mul_u128_aux::<5, 6>(b_u128_s[i])) + // }) + // }); + + group.bench_function("Addition (Fr + Fr)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_i128(b_i128_s[i])) + criterion::black_box(a_s[i] + c_s[i]) }) }); - group.bench_function("standard mul (Fr * Fr)", |bench| { + // Reduction benchmarks + group.bench_function("montgomery_reduce_in_place core (L=8)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i] * b_fr_s[i]) + let mut x = bigint_2n_s[i]; + criterion::black_box(Fr::montgomery_reduce_in_place::<8>(&mut x)) }) }); - // Benchmark mul_u128 specifically with inputs known to fit in u64 - group.bench_function("mul_u128 (u64 inputs)", |bench| { + group.bench_function("from_montgomery_reduce (L=2N)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - // Call mul_u128 but provide a u64 input cast to u128 - criterion::black_box(a_s[i].mul_u128(b_u64_as_u128_s[i])) + criterion::black_box(Fr::from_montgomery_reduce::<8, 5>(bigint_2n_s[i])) }) }); - // Benchmark the auxiliary function directly (assuming it's made public) - // Note: Requires mul_u128_aux to be pub in montgomery_backend.rs - // Need to import it if not already done via wildcard/specific import - // Let's assume it's accessible via a_s[i].mul_u128_aux(...) for now - group.bench_function("mul_u128_aux (u128 inputs)", |bench| { + // L=9 inputs: derive by zero-extending L=8 inputs + let bigint_9_s = bigint_2n_s + .iter() + .map(|b8| ark_ff::BigInt::<9>::zero_extend_from::<8>(b8)) + .collect::>(); + + group.bench_function("montgomery_reduce_in_place core (L=9)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i].mul_u128_aux(b_u128_s[i])) + let mut x = bigint_9_s[i]; + criterion::black_box(Fr::montgomery_reduce_in_place::<9>(&mut x)) }) }); - group.bench_function("Addition (Fr + Fr)", |bench| { + group.bench_function("from_montgomery_reduce (L=9)", |bench| { let mut i = 0; bench.iter(|| { i = (i + 1) % SAMPLES; - criterion::black_box(a_s[i] + c_s[i]) + criterion::black_box(Fr::from_montgomery_reduce::<9, 5>(bigint_9_s[i])) + }) + }); + + // Barrett reductions + group.bench_function("from_barrett_reduce (L=5)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(Fr::from_barrett_reduce::<5, 5>(bigint_nplus1_s[i])) + }) + }); + + group.bench_function("from_barrett_reduce (L=6)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(Fr::from_barrett_reduce::<6, 5>(bigint_nplus2_s[i])) + }) + }); + + group.bench_function("from_barrett_reduce (L=7)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + criterion::black_box(Fr::from_barrett_reduce::<7, 5>(bigint_nplus3_s[i])) + }) + }); + + // Linear combination benchmarks + group.bench_function("linear_combination_u64 (2 terms)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let pairs = [(a_s[i], b_u64_s[i]), (c_s[i], b_u64_s[(i + 1) % SAMPLES])]; + criterion::black_box(Fr::linear_combination_u64::<5>(&pairs)) + }) + }); + + group.bench_function("linear_combination_u64 (4 terms)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let pairs = [ + (a_s[i], b_u64_s[i]), + (c_s[i], b_u64_s[(i + 1) % SAMPLES]), + (a_s[(i + 2) % SAMPLES], b_u64_s[(i + 2) % SAMPLES]), + (c_s[(i + 3) % SAMPLES], b_u64_s[(i + 3) % SAMPLES]), + ]; + criterion::black_box(Fr::linear_combination_u64::<5>(&pairs)) + }) + }); + + group.bench_function("linear_combination_i64 (2+2 terms)", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let pos = [(a_s[i], b_u64_s[i]), (c_s[i], b_u64_s[(i + 1) % SAMPLES])]; + let neg = [ + (a_s[(i + 2) % SAMPLES], b_u64_s[(i + 2) % SAMPLES]), + (c_s[(i + 3) % SAMPLES], b_u64_s[(i + 3) % SAMPLES]), + ]; + criterion::black_box(Fr::linear_combination_i64::<5>(&pos, &neg)) + }) + }); + + // Comparison: naive approach vs linear combination (using mul_u64 for fair comparison) + group.bench_function("naive 2-term combination", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let term1 = a_s[i].mul_u64::<5>(b_u64_s[i]); + let term2 = c_s[i].mul_u64::<5>(b_u64_s[(i + 1) % SAMPLES]); + criterion::black_box(term1 + term2) + }) + }); + + group.bench_function("naive 4-term combination", |bench| { + let mut i = 0; + bench.iter(|| { + i = (i + 1) % SAMPLES; + let term1 = a_s[i].mul_u64::<5>(b_u64_s[i]); + let term2 = c_s[i].mul_u64::<5>(b_u64_s[(i + 1) % SAMPLES]); + let term3 = a_s[(i + 2) % SAMPLES].mul_u64::<5>(b_u64_s[(i + 2) % SAMPLES]); + let term4 = c_s[(i + 3) % SAMPLES].mul_u64::<5>(b_u64_s[(i + 3) % SAMPLES]); + criterion::black_box(term1 + term2 + term3 + term4) }) }); group.finish(); } +#[cfg(feature = "bn254")] criterion_group!(benches, mul_small_bench); -criterion_main!(benches); \ No newline at end of file +#[cfg(feature = "bn254")] +criterion_main!(benches); + +#[cfg(not(feature = "bn254"))] +fn main() {} diff --git a/test-curves/src/bn254/fq.rs b/test-curves/src/bn254/fq.rs index 001d94836..6bddf9bc0 100644 --- a/test-curves/src/bn254/fq.rs +++ b/test-curves/src/bn254/fq.rs @@ -10,4 +10,4 @@ pub struct FqConfig; pub type Fq = Fp256>; pub const FQ_ONE: Fq = ark_ff::MontFp!("1"); -pub const FQ_ZERO: Fq = ark_ff::MontFp!("0"); \ No newline at end of file +pub const FQ_ZERO: Fq = ark_ff::MontFp!("0"); diff --git a/test-curves/src/bn254/fr.rs b/test-curves/src/bn254/fr.rs index 4caef8e7c..4de077431 100644 --- a/test-curves/src/bn254/fr.rs +++ b/test-curves/src/bn254/fr.rs @@ -14,4 +14,4 @@ pub struct FrConfig; pub type Fr = Fp256>; pub const FR_ONE: Fr = ark_ff::MontFp!("1"); -pub const FR_ZERO: Fr = ark_ff::MontFp!("0"); \ No newline at end of file +pub const FR_ZERO: Fr = ark_ff::MontFp!("0"); diff --git a/test-curves/src/bn254/g1.rs b/test-curves/src/bn254/g1.rs index 2b3c5a0c5..608278db8 100644 --- a/test-curves/src/bn254/g1.rs +++ b/test-curves/src/bn254/g1.rs @@ -1,8 +1,6 @@ -use ark_ec::models::short_weierstrass::{ - Affine, Projective, SWCurveConfig, -}; +use ark_ec::models::short_weierstrass::{Affine, Projective, SWCurveConfig}; use ark_ec::CurveConfig; -use ark_ff::{Field, MontFp, Zero, AdditiveGroup}; +use ark_ff::{AdditiveGroup, Field, MontFp, Zero}; use crate::bn254::{Fq, Fr}; // Assuming Fq is defined in fq.rs diff --git a/test-curves/src/bn254/test.rs b/test-curves/src/bn254/test.rs index 51a9c691e..176467a7a 100644 --- a/test-curves/src/bn254/test.rs +++ b/test-curves/src/bn254/test.rs @@ -2,7 +2,9 @@ use ark_ec::{ models::short_weierstrass::SWCurveConfig, // Keep this as G1 is SW pairing::Pairing, - AffineRepr, CurveGroup, PrimeGroup, + AffineRepr, + CurveGroup, + PrimeGroup, }; use ark_ff::{Field, One, UniformRand, Zero}; use ark_std::{rand::Rng, test_rng}; diff --git a/test-templates/src/msm.rs b/test-templates/src/msm.rs index 5db14ced6..2b091ed38 100644 --- a/test-templates/src/msm.rs +++ b/test-templates/src/msm.rs @@ -81,31 +81,31 @@ pub fn test_var_base_msm_specialized() { let v = (0..SAMPLES).map(|_| bool::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u1(g.as_slice(), v.as_slice()); + let fast = G::msm_u1(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); let v = (0..SAMPLES).map(|_| u8::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u8(g.as_slice(), v.as_slice()); + let fast = G::msm_u8(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); let v = (0..SAMPLES).map(|_| u16::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u16(g.as_slice(), v.as_slice()); + let fast = G::msm_u16(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); let v = (0..SAMPLES).map(|_| u32::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u32(g.as_slice(), v.as_slice()); + let fast = G::msm_u32(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); let v = (0..SAMPLES).map(|_| u64::rand(rng)).collect::>(); let v_fe = v.iter().map(|&b| F::::from(b)).collect::>(); let naive = naive_var_base_msm::(g.as_slice(), v_fe.as_slice()); - let fast = G::msm_u64(g.as_slice(), v.as_slice()); + let fast = G::msm_u64(g.as_slice(), v.as_slice(), false); assert_eq!(naive, fast); }