Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions jolt-optimizations/benches/batch_addition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use ark_ec::{AffineRepr, CurveGroup};
use ark_std::rand::RngCore;
use ark_std::UniformRand;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use jolt_optimizations::batch_g1_additions;
use jolt_optimizations::{batch_g1_additions, batch_g1_additions_multi};
use rayon::prelude::*;

fn naive_parallel_sum(bases: &[G1Affine], indices: &[usize]) -> G1Affine {
Expand All @@ -18,7 +18,7 @@ fn bench_batch_addition(c: &mut Criterion) {
let mut rng = ark_std::test_rng();

// Test different sizes
for size in [1 << 20].iter() {
for size in [1 << 15].iter() {
let bases: Vec<G1Affine> = (0..*size).map(|_| G1Affine::rand(&mut rng)).collect();

// Use half the points
Expand All @@ -38,5 +38,56 @@ fn bench_batch_addition(c: &mut Criterion) {
group.finish();
}

criterion_group!(benches, bench_batch_addition);
fn bench_batch_addition_multi(c: &mut Criterion) {
let mut group = c.benchmark_group("batch_g1_addition_multi");
let mut rng = ark_std::test_rng();

let base_size = 1 << 19;
let bases: Vec<G1Affine> = (0..base_size).map(|_| G1Affine::rand(&mut rng)).collect();

for num_batches in [10].iter() {
let batch_size = 1 << 16;

let indices_sets: Vec<Vec<usize>> = (0..*num_batches)
.map(|_| {
(0..batch_size)
.map(|_| (rng.next_u64() as usize) % base_size)
.collect()
})
.collect();

group.bench_with_input(
BenchmarkId::new("multi_batch_shared", num_batches),
num_batches,
|b, _| {
b.iter(|| black_box(batch_g1_additions_multi(&bases, &indices_sets)));
},
);

group.bench_with_input(
BenchmarkId::new("parallel_naive_sum", num_batches),
num_batches,
|b, _| {
b.iter(|| {
black_box(
indices_sets
.par_iter()
.map(|indices| {
// Naive parallel sum for each batch
indices.par_iter().map(|&idx| bases[idx]).reduce(
|| G1Affine::zero(),
|acc, point| (acc + point).into_affine(),
)
})
.collect::<Vec<_>>(),
)
});
},
);
}

group.finish();
}

criterion_group!(benches, bench_batch_addition, bench_batch_addition_multi);
criterion_main!(benches);
128 changes: 123 additions & 5 deletions jolt-optimizations/src/batch_addition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! Implements efficient batch addition of affine elliptic curve points
//! using Montgomery's batch inversion trick to minimize field inversions.

//! @TODO(markosg04) duplicate group elements?
use ark_bn254::G1Affine;
use ark_ec::AffineRepr;
use rayon::prelude::*;
Expand Down Expand Up @@ -63,7 +63,7 @@ pub fn batch_g1_additions(bases: &[G1Affine], indices: &[usize]) -> G1Affine {
let lambda = (p2.y - p1.y) * inv;
let x3 = lambda * lambda - p1.x - p2.x;
let y3 = lambda * (p1.x - x3) - p1.y;
G1Affine::new(x3, y3)
G1Affine::new_unchecked(x3, y3)
})
.collect();

Expand All @@ -78,6 +78,95 @@ pub fn batch_g1_additions(bases: &[G1Affine], indices: &[usize]) -> G1Affine {
points[0]
}

/// Performs multiple batch additions of G1 affine points in parallel.
///
/// Given a slice of base points and multiple sets of indices, computes the sum
/// for each set of indices. All additions across all batches share the same
/// batch inversion.
///
/// # Arguments
/// * `bases` - Slice of G1 affine points to select from
/// * `indices_sets` - Vector of index vectors, each specifying which points to sum
///
/// # Returns
/// Vector of sums, one for each index set
pub fn batch_g1_additions_multi(bases: &[G1Affine], indices_sets: &[Vec<usize>]) -> Vec<G1Affine> {
if indices_sets.is_empty() {
return vec![];
}

// Initialize working sets for each batch
let mut working_sets: Vec<Vec<G1Affine>> = indices_sets
.par_iter()
.map(|indices| {
if indices.is_empty() {
vec![G1Affine::zero()]
} else if indices.len() == 1 {
vec![bases[indices[0]]]
} else {
indices.iter().map(|&i| bases[i]).collect()
}
})
.collect();

// Continue until all sets have been reduced to a single point
loop {
// Count total number of pairs across all sets
let total_pairs: usize = working_sets.iter().map(|set| set.len() / 2).sum();

if total_pairs == 0 {
break;
}

// Collect all denominators across all sets
let mut all_denominators = Vec::with_capacity(total_pairs);
let mut pair_info = Vec::with_capacity(total_pairs);

for (set_idx, set) in working_sets.iter().enumerate() {
let pairs_in_set = set.len() / 2;
for pair_idx in 0..pairs_in_set {
let p1 = set[pair_idx * 2];
let p2 = set[pair_idx * 2 + 1];
all_denominators.push(p2.x - p1.x);
pair_info.push((set_idx, pair_idx));
}
}

// Batch invert all denominators at once
let mut inverses = all_denominators;
ark_ff::fields::batch_inversion(&mut inverses);

// Apply additions using the inverted denominators
let mut new_working_sets: Vec<Vec<G1Affine>> = working_sets
.iter()
.map(|set| Vec::with_capacity((set.len() + 1) / 2))
.collect();

// Process additions and maintain order
for ((set_idx, pair_idx), inv) in pair_info.iter().zip(inverses.iter()) {
let set = &working_sets[*set_idx];
let p1 = set[*pair_idx * 2];
let p2 = set[*pair_idx * 2 + 1];
let lambda = (p2.y - p1.y) * inv;
let x3 = lambda * lambda - p1.x - p2.x;
let y3 = lambda * (p1.x - x3) - p1.y;
new_working_sets[*set_idx].push(G1Affine::new_unchecked(x3, y3));
}

// Handle odd elements
for (set_idx, set) in working_sets.iter().enumerate() {
if set.len() % 2 == 1 {
new_working_sets[set_idx].push(set[set.len() - 1]);
}
}

working_sets = new_working_sets;
}

// Extract final results
working_sets.into_iter().map(|set| set[0]).collect()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -123,8 +212,8 @@ mod tests {
fn test_stress_test_correctness() {
let mut rng = ark_std::test_rng();

let base_size = 10000;
let indices_size = 5000;
let base_size = 100000;
let indices_size = 50000;

let bases: Vec<G1Affine> = (0..base_size).map(|_| G1Affine::rand(&mut rng)).collect();

Expand All @@ -134,7 +223,6 @@ mod tests {

let batch_result = batch_g1_additions(&bases, &indices);

// Compute expected result using naive sequential addition
let mut expected = G1Affine::zero();
for &idx in &indices {
expected = (expected + bases[idx]).into_affine();
Expand All @@ -145,4 +233,34 @@ mod tests {
"Stress test failed: batch result doesn't match expected sum"
);
}

#[test]
fn test_batch_additions_multi_large() {
let mut rng = ark_std::test_rng();

let base_size = 10000;
let num_batches = 50;

let bases: Vec<G1Affine> = (0..base_size).map(|_| G1Affine::rand(&mut rng)).collect();

let indices_sets: Vec<Vec<usize>> = (0..num_batches)
.map(|_| {
let size = (rng.next_u64() as usize) % 100 + 1;
(0..size)
.map(|_| (rng.next_u64() as usize) % base_size)
.collect()
})
.collect();

let batch_results = batch_g1_additions_multi(&bases, &indices_sets);

for (i, (result, indices)) in batch_results.iter().zip(indices_sets.iter()).enumerate() {
let single_result = batch_g1_additions(&bases, indices);
assert_eq!(
*result, single_result,
"Multi vs single mismatch at batch {}",
i
);
}
}
}
2 changes: 1 addition & 1 deletion jolt-optimizations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ pub use dory_g2::{
vector_add_scalar_mul_g2_windowed2_signed, vector_scalar_mul_add_gamma_g2_online,
};

pub use batch_addition::batch_g1_additions;
pub use batch_addition::{batch_g1_additions, batch_g1_additions_multi};
Loading