From ba1235fd3fdff9ca5a9939237f9dd565b99971c8 Mon Sep 17 00:00:00 2001 From: markosg04 Date: Mon, 4 Aug 2025 10:55:59 -0400 Subject: [PATCH] batched batch addition --- jolt-optimizations/benches/batch_addition.rs | 57 ++++++++- jolt-optimizations/src/batch_addition.rs | 128 ++++++++++++++++++- jolt-optimizations/src/lib.rs | 2 +- 3 files changed, 178 insertions(+), 9 deletions(-) diff --git a/jolt-optimizations/benches/batch_addition.rs b/jolt-optimizations/benches/batch_addition.rs index 597ac241f..92e534b83 100644 --- a/jolt-optimizations/benches/batch_addition.rs +++ b/jolt-optimizations/benches/batch_addition.rs @@ -3,7 +3,7 @@ use ark_ec::{AffineRepr, CurveGroup}; use ark_std::rand::RngCore; use ark_std::UniformRand; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; -use jolt_optimizations::batch_g1_additions; +use jolt_optimizations::{batch_g1_additions, batch_g1_additions_multi}; use rayon::prelude::*; fn naive_parallel_sum(bases: &[G1Affine], indices: &[usize]) -> G1Affine { @@ -18,7 +18,7 @@ fn bench_batch_addition(c: &mut Criterion) { let mut rng = ark_std::test_rng(); // Test different sizes - for size in [1 << 20].iter() { + for size in [1 << 15].iter() { let bases: Vec = (0..*size).map(|_| G1Affine::rand(&mut rng)).collect(); // Use half the points @@ -38,5 +38,56 @@ fn bench_batch_addition(c: &mut Criterion) { group.finish(); } -criterion_group!(benches, bench_batch_addition); +fn bench_batch_addition_multi(c: &mut Criterion) { + let mut group = c.benchmark_group("batch_g1_addition_multi"); + let mut rng = ark_std::test_rng(); + + let base_size = 1 << 19; + let bases: Vec = (0..base_size).map(|_| G1Affine::rand(&mut rng)).collect(); + + for num_batches in [10].iter() { + let batch_size = 1 << 16; + + let indices_sets: Vec> = (0..*num_batches) + .map(|_| { + (0..batch_size) + .map(|_| (rng.next_u64() as usize) % base_size) + .collect() + }) + .collect(); + + group.bench_with_input( + BenchmarkId::new("multi_batch_shared", num_batches), + num_batches, + |b, _| { + b.iter(|| black_box(batch_g1_additions_multi(&bases, &indices_sets))); + }, + ); + + group.bench_with_input( + BenchmarkId::new("parallel_naive_sum", num_batches), + num_batches, + |b, _| { + b.iter(|| { + black_box( + indices_sets + .par_iter() + .map(|indices| { + // Naive parallel sum for each batch + indices.par_iter().map(|&idx| bases[idx]).reduce( + || G1Affine::zero(), + |acc, point| (acc + point).into_affine(), + ) + }) + .collect::>(), + ) + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_batch_addition, bench_batch_addition_multi); criterion_main!(benches); diff --git a/jolt-optimizations/src/batch_addition.rs b/jolt-optimizations/src/batch_addition.rs index 7b1808440..1c8dbf561 100644 --- a/jolt-optimizations/src/batch_addition.rs +++ b/jolt-optimizations/src/batch_addition.rs @@ -2,7 +2,7 @@ //! //! Implements efficient batch addition of affine elliptic curve points //! using Montgomery's batch inversion trick to minimize field inversions. - +//! @TODO(markosg04) duplicate group elements? use ark_bn254::G1Affine; use ark_ec::AffineRepr; use rayon::prelude::*; @@ -63,7 +63,7 @@ pub fn batch_g1_additions(bases: &[G1Affine], indices: &[usize]) -> G1Affine { let lambda = (p2.y - p1.y) * inv; let x3 = lambda * lambda - p1.x - p2.x; let y3 = lambda * (p1.x - x3) - p1.y; - G1Affine::new(x3, y3) + G1Affine::new_unchecked(x3, y3) }) .collect(); @@ -78,6 +78,95 @@ pub fn batch_g1_additions(bases: &[G1Affine], indices: &[usize]) -> G1Affine { points[0] } +/// Performs multiple batch additions of G1 affine points in parallel. +/// +/// Given a slice of base points and multiple sets of indices, computes the sum +/// for each set of indices. All additions across all batches share the same +/// batch inversion. +/// +/// # Arguments +/// * `bases` - Slice of G1 affine points to select from +/// * `indices_sets` - Vector of index vectors, each specifying which points to sum +/// +/// # Returns +/// Vector of sums, one for each index set +pub fn batch_g1_additions_multi(bases: &[G1Affine], indices_sets: &[Vec]) -> Vec { + if indices_sets.is_empty() { + return vec![]; + } + + // Initialize working sets for each batch + let mut working_sets: Vec> = indices_sets + .par_iter() + .map(|indices| { + if indices.is_empty() { + vec![G1Affine::zero()] + } else if indices.len() == 1 { + vec![bases[indices[0]]] + } else { + indices.iter().map(|&i| bases[i]).collect() + } + }) + .collect(); + + // Continue until all sets have been reduced to a single point + loop { + // Count total number of pairs across all sets + let total_pairs: usize = working_sets.iter().map(|set| set.len() / 2).sum(); + + if total_pairs == 0 { + break; + } + + // Collect all denominators across all sets + let mut all_denominators = Vec::with_capacity(total_pairs); + let mut pair_info = Vec::with_capacity(total_pairs); + + for (set_idx, set) in working_sets.iter().enumerate() { + let pairs_in_set = set.len() / 2; + for pair_idx in 0..pairs_in_set { + let p1 = set[pair_idx * 2]; + let p2 = set[pair_idx * 2 + 1]; + all_denominators.push(p2.x - p1.x); + pair_info.push((set_idx, pair_idx)); + } + } + + // Batch invert all denominators at once + let mut inverses = all_denominators; + ark_ff::fields::batch_inversion(&mut inverses); + + // Apply additions using the inverted denominators + let mut new_working_sets: Vec> = working_sets + .iter() + .map(|set| Vec::with_capacity((set.len() + 1) / 2)) + .collect(); + + // Process additions and maintain order + for ((set_idx, pair_idx), inv) in pair_info.iter().zip(inverses.iter()) { + let set = &working_sets[*set_idx]; + let p1 = set[*pair_idx * 2]; + let p2 = set[*pair_idx * 2 + 1]; + let lambda = (p2.y - p1.y) * inv; + let x3 = lambda * lambda - p1.x - p2.x; + let y3 = lambda * (p1.x - x3) - p1.y; + new_working_sets[*set_idx].push(G1Affine::new_unchecked(x3, y3)); + } + + // Handle odd elements + for (set_idx, set) in working_sets.iter().enumerate() { + if set.len() % 2 == 1 { + new_working_sets[set_idx].push(set[set.len() - 1]); + } + } + + working_sets = new_working_sets; + } + + // Extract final results + working_sets.into_iter().map(|set| set[0]).collect() +} + #[cfg(test)] mod tests { use super::*; @@ -123,8 +212,8 @@ mod tests { fn test_stress_test_correctness() { let mut rng = ark_std::test_rng(); - let base_size = 10000; - let indices_size = 5000; + let base_size = 100000; + let indices_size = 50000; let bases: Vec = (0..base_size).map(|_| G1Affine::rand(&mut rng)).collect(); @@ -134,7 +223,6 @@ mod tests { let batch_result = batch_g1_additions(&bases, &indices); - // Compute expected result using naive sequential addition let mut expected = G1Affine::zero(); for &idx in &indices { expected = (expected + bases[idx]).into_affine(); @@ -145,4 +233,34 @@ mod tests { "Stress test failed: batch result doesn't match expected sum" ); } + + #[test] + fn test_batch_additions_multi_large() { + let mut rng = ark_std::test_rng(); + + let base_size = 10000; + let num_batches = 50; + + let bases: Vec = (0..base_size).map(|_| G1Affine::rand(&mut rng)).collect(); + + let indices_sets: Vec> = (0..num_batches) + .map(|_| { + let size = (rng.next_u64() as usize) % 100 + 1; + (0..size) + .map(|_| (rng.next_u64() as usize) % base_size) + .collect() + }) + .collect(); + + let batch_results = batch_g1_additions_multi(&bases, &indices_sets); + + for (i, (result, indices)) in batch_results.iter().zip(indices_sets.iter()).enumerate() { + let single_result = batch_g1_additions(&bases, indices); + assert_eq!( + *result, single_result, + "Multi vs single mismatch at batch {}", + i + ); + } + } } diff --git a/jolt-optimizations/src/lib.rs b/jolt-optimizations/src/lib.rs index 503791f03..5ba72de8f 100644 --- a/jolt-optimizations/src/lib.rs +++ b/jolt-optimizations/src/lib.rs @@ -53,4 +53,4 @@ pub use dory_g2::{ vector_add_scalar_mul_g2_windowed2_signed, vector_scalar_mul_add_gamma_g2_online, }; -pub use batch_addition::batch_g1_additions; +pub use batch_addition::{batch_g1_additions, batch_g1_additions_multi};