Skip to content

Commit 7f8db64

Browse files
authored
Merge pull request #11 from a16z/feat/batch-addition
batched batch addition
2 parents 119b733 + ba1235f commit 7f8db64

File tree

3 files changed

+178
-9
lines changed

3 files changed

+178
-9
lines changed

jolt-optimizations/benches/batch_addition.rs

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use ark_ec::{AffineRepr, CurveGroup};
33
use ark_std::rand::RngCore;
44
use ark_std::UniformRand;
55
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
6-
use jolt_optimizations::batch_g1_additions;
6+
use jolt_optimizations::{batch_g1_additions, batch_g1_additions_multi};
77
use rayon::prelude::*;
88

99
fn naive_parallel_sum(bases: &[G1Affine], indices: &[usize]) -> G1Affine {
@@ -18,7 +18,7 @@ fn bench_batch_addition(c: &mut Criterion) {
1818
let mut rng = ark_std::test_rng();
1919

2020
// Test different sizes
21-
for size in [1 << 20].iter() {
21+
for size in [1 << 15].iter() {
2222
let bases: Vec<G1Affine> = (0..*size).map(|_| G1Affine::rand(&mut rng)).collect();
2323

2424
// Use half the points
@@ -38,5 +38,56 @@ fn bench_batch_addition(c: &mut Criterion) {
3838
group.finish();
3939
}
4040

41-
criterion_group!(benches, bench_batch_addition);
41+
fn bench_batch_addition_multi(c: &mut Criterion) {
42+
let mut group = c.benchmark_group("batch_g1_addition_multi");
43+
let mut rng = ark_std::test_rng();
44+
45+
let base_size = 1 << 19;
46+
let bases: Vec<G1Affine> = (0..base_size).map(|_| G1Affine::rand(&mut rng)).collect();
47+
48+
for num_batches in [10].iter() {
49+
let batch_size = 1 << 16;
50+
51+
let indices_sets: Vec<Vec<usize>> = (0..*num_batches)
52+
.map(|_| {
53+
(0..batch_size)
54+
.map(|_| (rng.next_u64() as usize) % base_size)
55+
.collect()
56+
})
57+
.collect();
58+
59+
group.bench_with_input(
60+
BenchmarkId::new("multi_batch_shared", num_batches),
61+
num_batches,
62+
|b, _| {
63+
b.iter(|| black_box(batch_g1_additions_multi(&bases, &indices_sets)));
64+
},
65+
);
66+
67+
group.bench_with_input(
68+
BenchmarkId::new("parallel_naive_sum", num_batches),
69+
num_batches,
70+
|b, _| {
71+
b.iter(|| {
72+
black_box(
73+
indices_sets
74+
.par_iter()
75+
.map(|indices| {
76+
// Naive parallel sum for each batch
77+
indices.par_iter().map(|&idx| bases[idx]).reduce(
78+
|| G1Affine::zero(),
79+
|acc, point| (acc + point).into_affine(),
80+
)
81+
})
82+
.collect::<Vec<_>>(),
83+
)
84+
});
85+
},
86+
);
87+
}
88+
89+
group.finish();
90+
}
91+
92+
criterion_group!(benches, bench_batch_addition, bench_batch_addition_multi);
4293
criterion_main!(benches);

jolt-optimizations/src/batch_addition.rs

Lines changed: 123 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//!
33
//! Implements efficient batch addition of affine elliptic curve points
44
//! using Montgomery's batch inversion trick to minimize field inversions.
5-
5+
//! @TODO(markosg04) duplicate group elements?
66
use ark_bn254::G1Affine;
77
use ark_ec::AffineRepr;
88
use rayon::prelude::*;
@@ -63,7 +63,7 @@ pub fn batch_g1_additions(bases: &[G1Affine], indices: &[usize]) -> G1Affine {
6363
let lambda = (p2.y - p1.y) * inv;
6464
let x3 = lambda * lambda - p1.x - p2.x;
6565
let y3 = lambda * (p1.x - x3) - p1.y;
66-
G1Affine::new(x3, y3)
66+
G1Affine::new_unchecked(x3, y3)
6767
})
6868
.collect();
6969

@@ -78,6 +78,95 @@ pub fn batch_g1_additions(bases: &[G1Affine], indices: &[usize]) -> G1Affine {
7878
points[0]
7979
}
8080

81+
/// Performs multiple batch additions of G1 affine points in parallel.
82+
///
83+
/// Given a slice of base points and multiple sets of indices, computes the sum
84+
/// for each set of indices. All additions across all batches share the same
85+
/// batch inversion.
86+
///
87+
/// # Arguments
88+
/// * `bases` - Slice of G1 affine points to select from
89+
/// * `indices_sets` - Vector of index vectors, each specifying which points to sum
90+
///
91+
/// # Returns
92+
/// Vector of sums, one for each index set
93+
pub fn batch_g1_additions_multi(bases: &[G1Affine], indices_sets: &[Vec<usize>]) -> Vec<G1Affine> {
94+
if indices_sets.is_empty() {
95+
return vec![];
96+
}
97+
98+
// Initialize working sets for each batch
99+
let mut working_sets: Vec<Vec<G1Affine>> = indices_sets
100+
.par_iter()
101+
.map(|indices| {
102+
if indices.is_empty() {
103+
vec![G1Affine::zero()]
104+
} else if indices.len() == 1 {
105+
vec![bases[indices[0]]]
106+
} else {
107+
indices.iter().map(|&i| bases[i]).collect()
108+
}
109+
})
110+
.collect();
111+
112+
// Continue until all sets have been reduced to a single point
113+
loop {
114+
// Count total number of pairs across all sets
115+
let total_pairs: usize = working_sets.iter().map(|set| set.len() / 2).sum();
116+
117+
if total_pairs == 0 {
118+
break;
119+
}
120+
121+
// Collect all denominators across all sets
122+
let mut all_denominators = Vec::with_capacity(total_pairs);
123+
let mut pair_info = Vec::with_capacity(total_pairs);
124+
125+
for (set_idx, set) in working_sets.iter().enumerate() {
126+
let pairs_in_set = set.len() / 2;
127+
for pair_idx in 0..pairs_in_set {
128+
let p1 = set[pair_idx * 2];
129+
let p2 = set[pair_idx * 2 + 1];
130+
all_denominators.push(p2.x - p1.x);
131+
pair_info.push((set_idx, pair_idx));
132+
}
133+
}
134+
135+
// Batch invert all denominators at once
136+
let mut inverses = all_denominators;
137+
ark_ff::fields::batch_inversion(&mut inverses);
138+
139+
// Apply additions using the inverted denominators
140+
let mut new_working_sets: Vec<Vec<G1Affine>> = working_sets
141+
.iter()
142+
.map(|set| Vec::with_capacity((set.len() + 1) / 2))
143+
.collect();
144+
145+
// Process additions and maintain order
146+
for ((set_idx, pair_idx), inv) in pair_info.iter().zip(inverses.iter()) {
147+
let set = &working_sets[*set_idx];
148+
let p1 = set[*pair_idx * 2];
149+
let p2 = set[*pair_idx * 2 + 1];
150+
let lambda = (p2.y - p1.y) * inv;
151+
let x3 = lambda * lambda - p1.x - p2.x;
152+
let y3 = lambda * (p1.x - x3) - p1.y;
153+
new_working_sets[*set_idx].push(G1Affine::new_unchecked(x3, y3));
154+
}
155+
156+
// Handle odd elements
157+
for (set_idx, set) in working_sets.iter().enumerate() {
158+
if set.len() % 2 == 1 {
159+
new_working_sets[set_idx].push(set[set.len() - 1]);
160+
}
161+
}
162+
163+
working_sets = new_working_sets;
164+
}
165+
166+
// Extract final results
167+
working_sets.into_iter().map(|set| set[0]).collect()
168+
}
169+
81170
#[cfg(test)]
82171
mod tests {
83172
use super::*;
@@ -123,8 +212,8 @@ mod tests {
123212
fn test_stress_test_correctness() {
124213
let mut rng = ark_std::test_rng();
125214

126-
let base_size = 10000;
127-
let indices_size = 5000;
215+
let base_size = 100000;
216+
let indices_size = 50000;
128217

129218
let bases: Vec<G1Affine> = (0..base_size).map(|_| G1Affine::rand(&mut rng)).collect();
130219

@@ -134,7 +223,6 @@ mod tests {
134223

135224
let batch_result = batch_g1_additions(&bases, &indices);
136225

137-
// Compute expected result using naive sequential addition
138226
let mut expected = G1Affine::zero();
139227
for &idx in &indices {
140228
expected = (expected + bases[idx]).into_affine();
@@ -145,4 +233,34 @@ mod tests {
145233
"Stress test failed: batch result doesn't match expected sum"
146234
);
147235
}
236+
237+
#[test]
238+
fn test_batch_additions_multi_large() {
239+
let mut rng = ark_std::test_rng();
240+
241+
let base_size = 10000;
242+
let num_batches = 50;
243+
244+
let bases: Vec<G1Affine> = (0..base_size).map(|_| G1Affine::rand(&mut rng)).collect();
245+
246+
let indices_sets: Vec<Vec<usize>> = (0..num_batches)
247+
.map(|_| {
248+
let size = (rng.next_u64() as usize) % 100 + 1;
249+
(0..size)
250+
.map(|_| (rng.next_u64() as usize) % base_size)
251+
.collect()
252+
})
253+
.collect();
254+
255+
let batch_results = batch_g1_additions_multi(&bases, &indices_sets);
256+
257+
for (i, (result, indices)) in batch_results.iter().zip(indices_sets.iter()).enumerate() {
258+
let single_result = batch_g1_additions(&bases, indices);
259+
assert_eq!(
260+
*result, single_result,
261+
"Multi vs single mismatch at batch {}",
262+
i
263+
);
264+
}
265+
}
148266
}

jolt-optimizations/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ pub use dory_g2::{
5353
vector_add_scalar_mul_g2_windowed2_signed, vector_scalar_mul_add_gamma_g2_online,
5454
};
5555

56-
pub use batch_addition::batch_g1_additions;
56+
pub use batch_addition::{batch_g1_additions, batch_g1_additions_multi};

0 commit comments

Comments
 (0)