Skip to content

Commit 4ac8b20

Browse files
committed
fix generic mont reduce
1 parent 001264f commit 4ac8b20

File tree

3 files changed

+115
-52
lines changed

3 files changed

+115
-52
lines changed

ff/src/biginteger/tests.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,6 @@ pub mod tests {
11141114
let mut b = BigInt::<4>::new([u64::MAX; 4]);
11151115
for i in 0..4 { a.0[i] >>= 1; b.0[i] >>= 1; }
11161116
// P = 4: result should match BigUint addition modulo 2^256
1117-
let r = a.add_trunc::<4, 4>(&b);
11181117
// add_assign_trunc debug-overflow behavior cannot be reliably asserted in this
11191118
// environment without std; we validate the non-mutating truncated result above.
11201119

ff/src/fields/models/fp/montgomery_backend.rs

Lines changed: 114 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -901,27 +901,55 @@ impl<T: MontConfig<N>, const N: usize> Fp<MontBackend<T, N>, N> {
901901
Self::new_unchecked(acc)
902902
}
903903

904-
/// Montgomery reduction of a BigInt to a field element (compute a * R^{-1} mod p).
904+
/// Montgomery reduction for arbitrary input width L >= 2N.
905905
///
906-
/// Need to specify the number of limbs `L` in the BigInt, where `L > N`.
906+
/// Runs exactly N Montgomery steps (i = 0..N-1) over the L-limb buffer to compute
907+
/// t' = (unreduced + q * MODULUS) / R, where R = b^N. The remaining (L - N) limbs
908+
/// store t' in base-b. For L > 2N, we first fold the entire tail (indices N..L) down
909+
/// to an N-limb accumulator using the N+1 Barrett reducer (interpreting the tail as a
910+
/// base-b number), place that as the high N limbs to form a 2N-limb buffer, and then
911+
/// perform the standard N-step Montgomery reduction on that 2N-limb buffer.
912+
///
913+
/// Preconditions:
914+
/// - L >= 2N (buffer must be large enough to perform N steps safely)
915+
///
916+
/// Computes: unreduced * R^{-1} mod MODULUS.
907917
#[inline(always)]
908-
pub fn from_montgomery_reduce<const L: usize>(unreduced: BigInt<L>) -> Self {
909-
debug_assert!(
910-
L > N,
911-
"from_montgomery_reduce requires L > N for a reduction to be necessary"
912-
);
913-
let mut limbs = unreduced;
914-
let steps = L - N;
918+
pub fn from_montgomery_reduce<const L: usize, const NPLUS1: usize>(
919+
unreduced: BigInt<L>,
920+
) -> Self {
921+
debug_assert!(NPLUS1 == N + 1);
922+
debug_assert!(L >= N + N, "from_montgomery_reduce_var requires L >= 2N");
923+
924+
let mut limbs = unreduced; // reuse storage for the buffer
925+
926+
// If L > 2N, first fold the extra high limbs down.
927+
if L > 2 * N {
928+
// Fold the tail (indices N..L) into an N-limb accumulator via Barrett.
929+
let mut acc = BigInt::<N>::zero();
930+
let mut i = L;
931+
while i > N {
932+
i -= 1;
933+
let c2 = nplus1_pair_low_to_bigint::<N, NPLUS1>((limbs.0[i], acc.0));
934+
acc = barrett_reduce_nplus1_to_n::<T, N, NPLUS1>(c2);
935+
}
915936

916-
let (carry, _steps_done) = Self::montgomery_steps_in_place::<L>(&mut limbs, steps);
937+
// Recompose buffer: [low_N | acc | zeros...]
938+
limbs.0[N..(N + N)].copy_from_slice(&acc.0);
939+
let mut j = 2 * N;
940+
while j < L {
941+
limbs.0[j] = 0;
942+
j += 1;
943+
}
944+
}
917945

918-
// The result is in the upper N limbs of the buffer.
919-
let mut result_limbs = [0u64; N];
920-
result_limbs.copy_from_slice(&limbs.0[steps..]);
946+
// Phase 2: run exactly N Montgomery steps on the 2N-limb buffer.
947+
let carry = Self::montgomery_reduce_in_place::<L>(&mut limbs);
921948

949+
// Extract result and finalize.
950+
let mut result_limbs = [0u64; N];
951+
result_limbs.copy_from_slice(&limbs.0[N..(N + N)]);
922952
let mut result = Self::new_unchecked(BigInt::<N>(result_limbs));
923-
924-
// Final conditional subtraction to bring the result into the canonical range.
925953
if T::MODULUS_HAS_SPARE_BIT {
926954
result.subtract_modulus();
927955
} else {
@@ -945,7 +973,7 @@ impl<T: MontConfig<N>, const N: usize> Fp<MontBackend<T, N>, N> {
945973
/// via a barrett reduction.
946974
#[inline]
947975
pub fn from_unchecked_nplus2<const NPLUS1: usize, const NPLUS2: usize>(
948-
element: BigInt<{ NPLUS2 }>,
976+
element: BigInt<NPLUS2>,
949977
) -> Self {
950978
debug_assert!(NPLUS1 == N + 1);
951979
debug_assert!(NPLUS2 == N + 2);
@@ -1224,45 +1252,58 @@ impl<T: MontConfig<N>, const N: usize> Fp<MontBackend<T, N>, N> {
12241252
/// Keep this for now for backwards compatibility.
12251253
#[inline(always)]
12261254
pub fn montgomery_reduce_2n<const TWON: usize>(input: BigInt<TWON>) -> Self {
1227-
Self::from_montgomery_reduce::<TWON>(input)
1228-
}
1255+
debug_assert!(TWON == 2 * N, "montgomery_reduce_2n requires TWON == 2N");
1256+
let mut limbs = input;
1257+
let carry = Self::montgomery_reduce_in_place::<TWON>(&mut limbs);
12291258

1230-
/// Perform one Montgomery reduction step at position `i` over a contiguous limb buffer.
1231-
/// Operates on a `BigInt<L>` that is treated as `[lo[0..N), hi[0..N), extra...]`.
1232-
/// Precondition (debug-asserted): `L >= N + i + 1` so all indices accessed are in-bounds.
1233-
/// Returns the carry-out from the top of this step.
1234-
#[inline(always)]
1235-
pub fn montgomery_step_once_at<const L: usize>(limbs: &mut BigInt<L>, i: usize) -> u64 {
1236-
debug_assert!(L >= N + i + 1, "montgomery_step_once_at: L too small for step i");
1237-
let limbs_slice = &mut limbs.0;
1238-
// Compute tmp = limbs[i] * INV (mod 2^64)
1239-
let tmp = limbs_slice[i].wrapping_mul(T::INV);
1240-
// Accumulate tmp * MODULUS into columns starting at i
1241-
let mut carry = 0u64;
1242-
fa::mac_discard(limbs_slice[i], tmp, T::MODULUS.0[0], &mut carry);
1243-
for j in 1..N {
1244-
let k = i + j;
1245-
limbs_slice[k] = mac_with_carry!(limbs_slice[k], tmp, T::MODULUS.0[j], &mut carry);
1259+
// Extract the upper N limbs after exactly N REDC steps
1260+
let mut result_limbs = [0u64; N];
1261+
result_limbs.copy_from_slice(&limbs.0[N..]);
1262+
1263+
let mut result = Self::new_unchecked(BigInt::<N>(result_limbs));
1264+
if T::MODULUS_HAS_SPARE_BIT {
1265+
result.subtract_modulus();
1266+
} else {
1267+
result.subtract_modulus_with_carry(carry != 0);
12461268
}
1247-
// Propagate the final carry into limbs[i + N]
1248-
fa::adc(&mut limbs_slice[i + N], carry, 0)
1269+
result
12491270
}
12501271

1251-
/// Perform up to `steps` Montgomery steps starting at i = 0 over an `L`-limb buffer.
1252-
/// Returns (last_carry, steps_done). In debug, asserts `L >= N + steps`; in release, saturates.
1272+
/// Perform exactly N Montgomery reduction steps over the leading 2N limbs of `limbs`,
1273+
/// using the canonical REDC subroutine from `mul_without_cond_subtract`.
1274+
/// Treats `limbs` as `[lo[0..N), hi[0..N), extra...]` and updates only the high half.
1275+
/// Returns the final carry-out (0 or 1) from the top of the reduction.
12531276
#[inline(always)]
1254-
pub fn montgomery_steps_in_place<const L: usize>(
1255-
limbs: &mut BigInt<L>,
1256-
steps: usize,
1257-
) -> (u64, usize) {
1258-
let max_steps = L.saturating_sub(N);
1259-
debug_assert!(steps <= max_steps, "steps exceed capacity: L < N + steps");
1260-
let steps_done = core::cmp::min(steps, max_steps);
1261-
let mut last_carry = 0u64;
1262-
for i in 0..steps_done {
1263-
last_carry = Self::montgomery_step_once_at::<L>(limbs, i);
1264-
}
1265-
(last_carry, steps_done)
1277+
pub fn montgomery_reduce_in_place<const L: usize>(limbs: &mut BigInt<L>) -> u64 {
1278+
debug_assert!(L >= 2 * N, "montgomery_reduce_in_place requires L >= 2N");
1279+
1280+
// Copy the leading 2N limbs into local halves to mirror the canonical subroutine.
1281+
let mut lo = [0u64; N];
1282+
let mut hi = [0u64; N];
1283+
lo.copy_from_slice(&limbs.0[0..N]);
1284+
hi.copy_from_slice(&limbs.0[N..(N + N)]);
1285+
1286+
// Montgomery reduction (canonical form)
1287+
let mut carry2 = 0u64;
1288+
crate::const_for!((i in 0..N) {
1289+
let tmp = lo[i].wrapping_mul(T::INV);
1290+
let mut carry;
1291+
mac!(lo[i], tmp, T::MODULUS.0[0], &mut carry);
1292+
crate::const_for!((j in 1..N) {
1293+
let k = i + j;
1294+
if k >= N {
1295+
hi[k - N] = mac_with_carry!(hi[k - N], tmp, T::MODULUS.0[j], &mut carry);
1296+
} else {
1297+
lo[k] = mac_with_carry!(lo[k], tmp, T::MODULUS.0[j], &mut carry);
1298+
}
1299+
});
1300+
hi[i] = adc!(hi[i], carry, &mut carry2);
1301+
});
1302+
1303+
// Write the reduced high half back into the buffer; low half is discarded by callers.
1304+
limbs.0[N..(N + N)].copy_from_slice(&hi);
1305+
1306+
carry2
12661307
}
12671308

12681309
#[inline(always)]
@@ -1857,4 +1898,27 @@ mod test {
18571898
let sign_is_positive = sign != Sign::Minus;
18581899
(sign_is_positive, limbs)
18591900
}
1901+
1902+
#[test]
1903+
fn test_from_montgomery_reduce_paths_l8_l9_match_field_mul() {
1904+
let mut rng = test_rng();
1905+
for _ in 0..200 {
1906+
let a = Fr::rand(&mut rng);
1907+
let b = Fr::rand(&mut rng);
1908+
1909+
let expected = a * b;
1910+
1911+
// Compute 8-limb raw product of Montgomery residues
1912+
let prod8 = a.0.mul_trunc::<4, 8>(&b.0);
1913+
1914+
// Reduce via Montgomery reduction with L = 8
1915+
let alt8 = Fr::montgomery_reduce_2n::<8>(prod8);
1916+
assert_eq!(alt8, expected, "from_montgomery_reduce L=8 mismatch");
1917+
1918+
// Zero-extend to 9 limbs and reduce with L = 9
1919+
let prod9 = ark_test_curves::ark_ff::BigInt::<9>::zero_extend_from::<8>(&prod8);
1920+
let alt9 = Fr::from_montgomery_reduce::<9, 5>(prod9);
1921+
assert_eq!(alt9, expected, "from_montgomery_reduce L=9 mismatch");
1922+
}
1923+
}
18601924
}

test-curves/benches/small_mul.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ fn mul_small_bench(c: &mut Criterion) {
231231
let mut i = 0;
232232
bench.iter(|| {
233233
i = (i + 1) % SAMPLES;
234-
criterion::black_box(Fr::from_montgomery_reduce::<8>(bigint_2n_s[i]))
234+
criterion::black_box(Fr::from_montgomery_reduce::<8, 5>(bigint_2n_s[i]))
235235
})
236236
});
237237

0 commit comments

Comments
 (0)